diff --git a/.bazelrc b/.bazelrc index e765c302c28..396b84f70b3 100644 --- a/.bazelrc +++ b/.bazelrc @@ -5,6 +5,7 @@ # Android options: # android: # android_arm: +# android_arm64: # android_x86: # android_x86_64: # @@ -46,10 +47,6 @@ # using_cuda: CUDA is available to build system. # cuda: Build with full cuda support. # rocm: Build with AMD GPU support (rocm). -# sycl: Build with SYCL support. -# sycl_nodouble: -# sycl_asan: -# sycl_trisycl: # mkl: Enable full mkl support. # tensorrt: Enable Tensorrt support. # ngraph: Enable ngraph support. @@ -89,6 +86,7 @@ # release_cpu_linux: Toolchain and CUDA options for Linux CPU builds. # release_cpu_macos: Toolchain and CUDA options for MacOS CPU builds. # release_gpu_linux: Toolchain and CUDA options for Linux GPU builds. +# release_gpu_linux_cuda_10_1: Toolchain and CUDA options for CUDA 10.1 Linux GPU builds. # release_cpu_windows: Toolchain and CUDA options for Windows CPU builds. # release_gpu_windows: Toolchain and CUDA options for Windows GPU builds. @@ -161,13 +159,11 @@ build --host_java_toolchain=//third_party/toolchains/java:tf_java_toolchain # environment variable "TF_MKL_ROOT" every time before build. build:mkl --define=build_with_mkl=true --define=enable_mkl=true build:mkl --define=tensorflow_mkldnn_contraction_kernel=0 -build:mkl --define=build_with_mkl_dnn_v1_only=true build:mkl -c opt # config to build OneDNN backend with a user specified threadpool. build:mkl_threadpool --define=build_with_mkl=true --define=enable_mkl=true build:mkl_threadpool --define=tensorflow_mkldnn_contraction_kernel=0 -build:mkl_threadpool --define=build_with_mkl_dnn_v1_only=true build:mkl_threadpool --define=build_with_mkl_opensource=true build:mkl_threadpool --define=build_with_mkldnn_threadpool=true build:mkl_threadpool -c opt @@ -175,10 +171,15 @@ build:mkl_threadpool -c opt # Config setting to build with oneDNN and without the binary blob build:mkl_opensource_only --define=build_with_mkl=true --define=enable_mkl=true build:mkl_opensource_only --define=tensorflow_mkldnn_contraction_kernel=0 -build:mkl_opensource_only --define=build_with_mkl_dnn_v1_only=true build:mkl_opensource_only --define=build_with_mkl_opensource=true build:mkl_opensource_only -c opt +# Config setting to build with oneDNN for Arm. +build:mkl_aarch64 --define=build_with_mkl_aarch64=true --define=enable_mkl=true +build:mkl_aarch64 --define=tensorflow_mkldnn_contraction_kernel=0 +build:mkl_aarch64 --define=build_with_mkl_opensource=true +build:mkl_aarch64 -c opt + # This config refers to building with CUDA available. It does not necessarily # mean that we build CUDA op kernels. build:using_cuda --define=using_cuda=true @@ -216,19 +217,6 @@ build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true build:rocm --action_env TF_NEED_ROCM=1 -build:sycl --crosstool_top=@local_config_sycl//crosstool:toolchain -build:sycl --define=using_sycl=true -build:sycl --action_env TF_NEED_OPENCL_SYCL=1 - -build:sycl_nodouble --config=sycl -build:sycl_nodouble --cxxopt -DTENSORFLOW_SYCL_NO_DOUBLE - -build:sycl_nodouble --config=sycl -build:sycl_asan --copt -fno-omit-frame-pointer --copt -fsanitize-coverage=3 --copt -DGPR_NO_DIRECT_SYSCALLS --linkopt -fPIC --linkopt -fsanitize=address - -build:sycl_nodouble --config=sycl -build:sycl_trisycl --define=using_trisycl=true - # Options extracted from configure script build:ngraph --define=with_ngraph_support=true build:numa --define=with_numa_support=true @@ -293,6 +281,7 @@ build:ios --noenable_platform_specific_config build:android --copt=-w build:ios --copt=-w build:linux --copt=-w +build:linux --host_copt=-w build:macos --copt=-w build:windows --copt=/w @@ -334,6 +323,11 @@ build:windows --host_copt=-DWIN32_LEAN_AND_MEAN build:windows --copt=-DNOGDI build:windows --host_copt=-DNOGDI +# MSVC (Windows): Standards-conformant preprocessor mode +# See https://docs.microsoft.com/en-us/cpp/preprocessor/preprocessor-experimental-overview +build:windows --copt=/experimental:preprocessor +build:windows --host_copt=/experimental:preprocessor + # Misc build options we need for windows. build:windows --linkopt=/DEBUG build:windows --host_linkopt=/DEBUG @@ -358,6 +352,7 @@ build --config=short_logs # TODO(gunan): Create a feature in toolchains for avx/avx2 to # avoid having to define linux/win separately. build:avx_linux --copt=-mavx +build:avx_linux --host_copt=-mavx build:avx2_linux --copt=-mavx2 build:native_arch_linux --copt=-march=native build:avx_win --copt=/arch=AVX @@ -411,9 +406,12 @@ build:rbe_linux --config=avx_linux build:rbe_linux --config=short_logs # TODO(gunan): Check why we need this specified in rbe, but not in other builds. build:rbe_linux --linkopt=-lrt +build:rbe_linux --host_linkopt=-lrt build:rbe_linux --linkopt=-lm +build:rbe_linux --host_linkopt=-lm build:rbe_cpu_linux --config=rbe_linux +build:rbe_cpu_linux --host_crosstool_top="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010:toolchain" build:rbe_cpu_linux --crosstool_top="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010:toolchain" build:rbe_cpu_linux --extra_toolchains="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010:cc-toolchain-k8" build:rbe_cpu_linux --extra_execution_platforms="@ubuntu16.04-manylinux2010-py3_config_platform//:platform" @@ -431,6 +429,7 @@ test:rbe_linux_cuda_base --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/ build:rbe_linux_cuda10.1_nvcc_base --config=rbe_linux_cuda_base build:rbe_linux_cuda10.1_nvcc_base --define=using_cuda_nvcc=true +build:rbe_linux_cuda10.1_nvcc_base --host_crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain" build:rbe_linux_cuda10.1_nvcc_base --crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain" build:rbe_linux_cuda10.1_nvcc_base --extra_toolchains="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64" build:rbe_linux_cuda10.1_nvcc_base --extra_execution_platforms="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform" @@ -447,6 +446,7 @@ build:rbe_linux_cuda10.1_nvcc_py3.8 --config=rbe_linux_cuda10.1_nvcc_base --repo build:rbe_linux_cuda11.0_nvcc_base --config=rbe_linux_cuda_base build:rbe_linux_cuda11.0_nvcc_base --define=using_cuda_nvcc=true +build:rbe_linux_cuda11.0_nvcc_base --host_crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_cuda//crosstool:toolchain" build:rbe_linux_cuda11.0_nvcc_base --crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_cuda//crosstool:toolchain" build:rbe_linux_cuda11.0_nvcc_base --extra_toolchains="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_cuda//crosstool:toolchain-linux-x86_64" build:rbe_linux_cuda11.0_nvcc_base --extra_execution_platforms="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_platform//:platform" @@ -587,7 +587,7 @@ build:release_gpu_common --action_env CUDA_TOOLKIT_PATH="/usr/local/cuda-11.0" build:release_gpu_common --action_env=TF_CUDA_VERSION="11" build:release_gpu_common --action_env=TF_CUDNN_VERSION="8" build:release_gpu_common --action_env=TF_NEED_TENSORRT="1" -build:release_gpu_common --action_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_35,sm_37,sm_52,sm_60,sm_61,compute_70" +build:release_gpu_common --action_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_35,sm_50,sm_60,sm_70,sm_75,compute_80" build:release_gpu_common --action_env=TENSORRT_INSTALL_PATH="/usr/local/tensorrt" build:release_gpu_common --action_env=LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/tensorrt/lib" build:release_gpu_common --action_env=GCC_HOST_COMPILER_PATH="/usr/bin/gcc-5" @@ -603,3 +603,8 @@ build:release_windows_common --announce_rc build:release_cpu_windows --config=release_windows_common build:release_gpu_windows --config=release_windows_common + +build:release_gpu_linux_cuda_10_1 --config=release_gpu_linux +build:release_gpu_linux_cuda_10_1 --action_env CUDA_TOOLKIT_PATH="/usr/local/cuda-10.1" +build:release_gpu_linux_cuda_10_1 --action_env=TF_CUDA_VERSION="10" +build:release_gpu_linux_cuda_10_1 --action_env=TF_CUDNN_VERSION="7" diff --git a/.github/bot_config.yml b/.github/bot_config.yml index d0e7256aec0..c6dc0ec9c85 100644 --- a/.github/bot_config.yml +++ b/.github/bot_config.yml @@ -12,12 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -# -# THIS IS A GENERATED DOCKERFILE. -# -# This file was assembled from multiple pieces, whose use is documented -# throughout. Please refer to the TensorFlow dockerfiles documentation -# for more information. # A list of assignees assignees: @@ -40,6 +34,22 @@ segfault_memory: # assignees filesystem_security_assignee: - mihaimaruseac + +tflite_micro_path: + - tensorflow/lite/micro + +tflite_micro_comment: > + Thanks for contributing to TensorFlow Lite Micro. + + + To keep this process moving along, we'd like to make sure that you have completed the items on this list: + * Read the [contributing guidelines for TensorFlow Lite Micro](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/micro/CONTRIBUTING.md) + * Created a [TF Lite Micro Github issue](https://github.com/tensorflow/tensorflow/issues/new?labels=comp%3Amicro&template=70-tflite-micro-issue.md) + * Linked to the issue from the PR description + + + We would like to have a discussion on the Github issue first to determine the best path forward, and then proceed to the PR review. + # Cuda Comment cuda_comment: > From the template it looks like you are installing **TensorFlow** (TF) prebuilt binaries: diff --git a/tensorflow/tools/ci_build/release/windows/cpu_py35_full/release_pip_rename.sh b/.github/workflows/update-nightly.yml similarity index 64% rename from tensorflow/tools/ci_build/release/windows/cpu_py35_full/release_pip_rename.sh rename to .github/workflows/update-nightly.yml index 43982623109..01b5147d053 100644 --- a/tensorflow/tools/ci_build/release/windows/cpu_py35_full/release_pip_rename.sh +++ b/.github/workflows/update-nightly.yml @@ -1,4 +1,3 @@ -#!/bin/bash # Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -12,14 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ============================================================================== -set -e -set -x +# ============================================================================ -source tensorflow/tools/ci_build/release/common.sh - -# Rename to tensorflow_cpu -for f in $(ls py_test_dir/tensorflow-*cp3*-cp3*m-win_amd64.whl); do - copy_to_new_project_name "${f}" tensorflow_cpu - rm "${f}" -done +on: + workflow_dispatch: # Allow manual triggers + schedule: + - cron: 0 4 * * * # 4am UTC is 9pm PDT and 8pm PST +name: Set nightly branch to master HEAD +jobs: + master-to-nightly: + runs-on: ubuntu-latest + steps: + - uses: zofrex/mirror-branch@v1 + name: Set nightly branch to master HEAD + with: + target-branch: 'nightly' diff --git a/ADOPTERS.md b/ADOPTERS.md deleted file mode 100644 index c0be567dc14..00000000000 --- a/ADOPTERS.md +++ /dev/null @@ -1,10 +0,0 @@ -# TensorFlow Adopters - -This page contains a list of people and organizations who are using TensorFlow. If you'd like to be included -here, please send a pull request which modifies this file. - -We intend to use this list to contact you for surveys, and to find good candidates for invite-only events. -We will also point to this list if we are asked who uses TensorFlow. - -We will not use any of the information here for promotions or to send other regular communications. You -should subscribe to discuss@tensorflow.org for such announcements. diff --git a/CODEOWNERS b/CODEOWNERS index 3ef02ffd68c..9de1922a262 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1,16 +1,15 @@ # Where component owners are known, add them here. -/tensorflow/c/eager @jaingaurav @alextp -/tensorflow/core/common_runtime/eager @jaingaurav @alextp +/tensorflow/c/eager @qqfish @kkimdev +/tensorflow/core/common_runtime/eager @qqfish @kkimdev /tenosrflow/core/debug @caisq /tensorflow/core/nccl/ @azaks2 @chsigg -/tensorflow/core/platform/windows/ @mrry +/tensorflow/core/platform/windows/ @mihaimaruseac /tensorflow/lite/experimental/micro @petewarden @advaitjain /tensorflow/python/autograph/ @mdanatg @kkimdev /tensorflow/python/debug @caisq -/tensorflow/python/eager @jaingaurav @alextp +/tensorflow/python/eager @rohan100jain @kkimdev /tensorflow/python/tools/api/generator/ @annarev -/tensorflow/tensorboard/ @jart /tensorflow/tools/docs/ @markdaoust /third_party/systemlibs/ @perfinion diff --git a/README.md b/README.md index 6398e8e27a1..63d85ce2df4 100644 --- a/README.md +++ b/README.md @@ -103,23 +103,22 @@ open-source software development: ### Official Builds -Build Type | Status | Artifacts ------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------- -**Linux CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/) -**Linux GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/) -**Linux XLA** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.html) | TBA -**macOS** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/) -**Windows CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.html) | [PyPI](https://pypi.org/project/tf-nightly/) -**Windows GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/) -**Android** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.html) | [![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg)](https://bintray.com/google/tensorflow/tensorflow/_latestVersion) -**Raspberry Pi 0 and 1** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv6l.whl) -**Raspberry Pi 2 and 3** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv7l.whl) -**Libtensorflow MacOS CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-mac-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-mac-cpu.html) | [GCS](https://storage.googleapis.com/libtensorflow-nightly) -**Libtensorflow Linux CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-cpu.html) | [GCS](https://storage.googleapis.com/libtensorflow-nightly) -**Libtensorflow Linux GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-gpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-gpu.html) | [GCS](https://storage.googleapis.com/libtensorflow-nightly) -**Libtensorflow Windows CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-cpu.html) | [GCS](https://storage.googleapis.com/libtensorflow-nightly) -**Libtensorflow Windows GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-gpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-gpu.html) | [GCS](https://storage.googleapis.com/libtensorflow-nightly) - +Build Type | Status | Artifacts +----------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------- +**Linux CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/) +**Linux GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/) +**Linux XLA** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.html) | TBA +**macOS** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/) +**Windows CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.html) | [PyPI](https://pypi.org/project/tf-nightly/) +**Windows GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/) +**Android** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.html) | [![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg)](https://bintray.com/google/tensorflow/tensorflow/_latestVersion) +**Raspberry Pi 0 and 1** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv6l.whl) +**Raspberry Pi 2 and 3** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv7l.whl) +**Libtensorflow MacOS CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-mac-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-mac-cpu.html) | [Nightly GCS](https://storage.googleapis.com/libtensorflow-nightly) [Official GCS](https://storage.googleapis.com/tensorflow/) +**Libtensorflow Linux CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-cpu.html) | [Nightly GCS](https://storage.googleapis.com/libtensorflow-nightly) [Official GCS](https://storage.googleapis.com/tensorflow/) +**Libtensorflow Linux GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-gpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-gpu.html) | [Nightly GCS](https://storage.googleapis.com/libtensorflow-nightly) [Official GCS](https://storage.googleapis.com/tensorflow/) +**Libtensorflow Windows CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-cpu.html) | [Nightly GCS](https://storage.googleapis.com/libtensorflow-nightly) [Official GCS](https://storage.googleapis.com/tensorflow/) +**Libtensorflow Windows GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-gpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-gpu.html) | [Nightly GCS](https://storage.googleapis.com/libtensorflow-nightly) [Official GCS](https://storage.googleapis.com/tensorflow/) ### Community Supported Builds @@ -145,19 +144,20 @@ Build Type * [TensorFlow Tutorials](https://www.tensorflow.org/tutorials/) * [TensorFlow Official Models](https://github.com/tensorflow/models/tree/master/official) * [TensorFlow Examples](https://github.com/tensorflow/examples) -* [TensorFlow in Practice from Coursera](https://www.coursera.org/specializations/tensorflow-in-practice) +* [DeepLearning.AI TensorFlow Developer Professional Certificate](https://www.coursera.org/specializations/tensorflow-in-practice) * [TensorFlow: Data and Deployment from Coursera](https://www.coursera.org/specializations/tensorflow-data-and-deployment) * [Getting Started with TensorFlow 2 from Coursera](https://www.coursera.org/learn/getting-started-with-tensor-flow2) * [Intro to TensorFlow for Deep Learning from Udacity](https://www.udacity.com/course/intro-to-tensorflow-for-deep-learning--ud187) * [Introduction to TensorFlow Lite from Udacity](https://www.udacity.com/course/intro-to-tensorflow-lite--ud190) * [Machine Learning with TensorFlow on GCP](https://www.coursera.org/specializations/machine-learning-tensorflow-gcp) +* [TensorFlow Codelabs](https://codelabs.developers.google.com/?cat=TensorFlow) * [TensorFlow Chat Room on StackOverflow (not actively monitored by the TensorFlow team)](https://chat.stackoverflow.com/rooms/216694/tensorflow) * [TensorFlow Blog](https://blog.tensorflow.org) * [Learn ML with TensorFlow](https://www.tensorflow.org/resources/learn-ml) * [TensorFlow Twitter](https://twitter.com/tensorflow) * [TensorFlow YouTube](https://www.youtube.com/channel/UC0rqucBdTuFTjJiefW5t-IQ) -* [TensorFlow Roadmap](https://www.tensorflow.org/community/roadmap) +* [TensorFlow Roadmap](https://www.tensorflow.org/model_optimization/guide/roadmap) * [TensorFlow White Papers](https://www.tensorflow.org/about/bib) * [TensorBoard Visualization Toolkit](https://github.com/tensorflow/tensorboard) diff --git a/RELEASE.md b/RELEASE.md index 7057657c340..5aac986a135 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -34,9 +34,33 @@ shape assumptions (note that you can pass shapes with `None` entries for axes that are meant to be dynamic). You can also disable the input checking entirely by setting `model.input_spec = None`. +* TF pip packages now use CUDA11 and cuDNN 8.0.2. * XLA:CPU and XLA:GPU devices are no longer registered by default. Use `TF_XLA_FLAGS=--tf_xla_enable_xla_devices` if you really need them (to be removed). +* `tf.raw_ops.Max` and `tf.raw_ops.Min` no longer accept inputs of type + `tf.complex64` or `tf.complex128`, because the behavior of these ops is not + well defined for complex types. +* `tf.data.experimental.service.DispatchServer` now takes a config tuple + instead of individual arguments. Usages should be updated to + `tf.data.experimental.service.DispatchServer(dispatcher_config)`. +* `tf.data.experimental.service.WorkerServer` now takes a config tuple + instead of individual arguments. Usages should be updated to + `tf.data.experimental.service.WorkerServer(worker_config)`. +* `tf.quantization.quantize_and_dequantize_v2` has been introduced, which + updates the gradient definition for quantization which is outside the range + to be 0. To simulate the V1 the behavior of + tf.quantization.quantize_and_dequantize(...) use + tf.grad_pass_through(tf.quantization.quantize_and_dequantize_v2)(...). +* `tf.distribute.Strategy.experimental_make_numpy_dataset` is removed. Please + use `tf.data.Dataset.from_tensor_slices` instead. +* `experimental_hints` in `tf.distribute.StrategyExtended.reduce_to`, + `tf.distribute.StrategyExtended.batch_reduce_to`, + `tf.distribute.ReplicaContext.all_reduce` are renamed to `options`. + `tf.distribute.experimental.CollectiveHints` is renamed + `tf.distribute.experimental.CommunicationOptions`. + `tf.distribute.experimental.CollectiveCommunication` is renamed + `tf.distribute.experimental.CommunicationImplementation`. ## Known Caveats @@ -46,89 +70,180 @@ * * -* A new module named `tf.experimental.numpy` is added, which is a NumPy-compatible API for writing TF programs. This module provides class `ndarray`, which mimics the `ndarray` class in NumPy, and wraps an immutable `tf.Tensor` under the hood. A subset of NumPy functions (e.g. `numpy.add`) are provided. Their inter-operation with TF facilities is seamless in most cases. See tensorflow/python/ops/numpy_ops/README.md for details of what are supported and what are the differences with NumPy. +* A new module named `tf.experimental.numpy` is added, which is a NumPy-compatible API for writing TF programs. This module provides class `ndarray`, which mimics the `ndarray` class in NumPy, and wraps an immutable `tf.Tensor` under the hood. A subset of NumPy functions (e.g. `numpy.add`) are provided. Their inter-operation with TF facilities is seamless in most cases. See [tensorflow/python/ops/numpy_ops/README.md](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/numpy_ops/README.md) for details of what operations are supported and what are the differences from NumPy. * A major refactoring of the internals of the Keras Functional API has been completed, that should improve the reliability, stability, and performance of constructing Functional models. +* `tf.distribute`: + * Deprecated `experimental_distribute_datasets_from_function` method and renamed it to `distribute_datasets_from_function` as it is no longer experimental. + ## Bug Fixes and Other Changes -* -* -* -* TF Core: - * `tf.types.experimental.TensorLike` is a new `Union` type that can be used as - type annotation for variables representing a Tensor or a value that can be - converted to Tensor by `tf.convert_to_tensor`. - * Calling ops with a python constants or numpy values is now consistent with - tf.convert_to_tensor behavior. This avoids operations like tf.reshape - truncating inputs such as from int64 to int32. - * Added `tf.sparse.map_values` to apply a function to the `.value`s of `SparseTensror` arguments. - * The Python bitwise operators for `Tensor` (`__and__`, `__or__`, `__xor__` - and `__invert__` now support non-`bool` arguments and apply the - corresponding bitwise ops. `bool` arguments continue to be supported and - dispatch to logical ops. This brings them more in line with Python and NumPy - benavior. - * Added `tf.SparseTensor.with_values`. This returns a new SparseTensor with - the same sparsity pattern, but with new provided values. It is similar to - the `with_values` function of `RaggedTensor`. - * Added `StatelessCase` op, and uses it if none of case branches has stateful ops. -* `tf.data`: - * Added new `tf.data.experimental.service.register_dataset` and - `tf.data.experimental.service.from_dataset_id` APIs to enable one process - to register a dataset with the tf.data service, and another process to - consume data from the dataset. - * Added support for tf.data service dispatcher fault tolerance. To enable - fault tolerance, configure a `work_dir` when running your dispatcher - server and set `dispatcher_fault_tolerance=True`. The dispatcher will - store its state to `work_dir`, so that on restart it can continue from its - previous state after restart. - * Added tf.data service support for sharing dataset graphs via shared - filesystem instead of over RPC. This reduces load on the dispatcher, - improving performance of distributing datasets. For this to work, the - dispatcher's `work_dir` must be accessible from workers. If the worker - fails to read from the `work_dir`, it falls back to using RPC for dataset - graph transfer. - * Added optional `exclude_cols` parameter to CsvDataset. This parameter is - the complement of `select_cols`; at most one of these should be specified. - * We have implemented an optimization which reorders data-discarding - transformations such as `take` and `shard` to happen earlier in the - dataset when it is safe to do so. The optimization can be disabled via - the `experimental_optimization.reorder_data_discarding_ops` dataset - option. - * `tf.data.Options` were previously immutable and can now be overriden. -* `tf.image`: - * Added deterministic `tf.image.stateless_random_*` functions for each - `tf.image.random_*` function. Added a new op - `stateless_sample_distorted_bounding_box` which is a determinstic - version of `sample_distorted_bounding_box` op. Given the same seed, these - stateless functions/ops produce the same results independent of how many - times the function is called, and independent of global seed settings. +* +* +* +* Security: + * Fixes an undefined behavior causing a segfault in `tf.raw_ops.Switch` + ([CVE-2020-15190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15190)) + * Fixes three vulnerabilities in conversion to DLPack format + ([CVE-2020-15191](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15191), + [CVE-2020-15192](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15192), + [CVE-2020-15193](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15193)) + * Fixes two vulnerabilities in `SparseFillEmptyRowsGrad` + ([CVE-2020-15194](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15194), + [CVE-2020-15195](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15195)) + * Fixes several vulnerabilities in `RaggedCountSparseOutput` and + `SparseCountSparseOutput` operations + ([CVE-2020-15196](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15196), + [CVE-2020-15197](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15197), + [CVE-2020-15198](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15198), + [CVE-2020-15199](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15199), + [CVE-2020-15200](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15200), + [CVE-2020-15201](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15201)) + * Fixes an integer truncation vulnerability in code using the work sharder + API + ([CVE-2020-15202](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15202)) + * Fixes a format string vulnerability in `tf.strings.as_string` + ([CVE-2020-15203](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15203)) + * Fixes segfault raised by calling session-only ops in eager mode + ([CVE-2020-15204](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15204)) + * Fixes data leak and potential ASLR violation from + `tf.raw_ops.StringNGrams` + ([CVE-2020-15205](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15205)) + * Fixes segfaults caused by incomplete `SavedModel` validation + ([CVE-2020-15206](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15206)) + * Fixes a data corruption due to a bug in negative indexing support in + TFLite + ([CVE-2020-15207](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15207)) + * Fixes a data corruption due to dimension mismatch in TFLite + ([CVE-2020-15208](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15208)) + * Fixes several vulnerabilities in TFLite saved model format + ([CVE-2020-15209](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15209), + [CVE-2020-15210](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15210), + [CVE-2020-15211](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15211)) + * Fixes several vulnerabilities in TFLite implementation of segment sum + ([CVE-2020-15212](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15212), + [CVE-2020-15213](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15213), + [CVE-2020-15214](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15214)) +* TF Core: + * `tf.types.experimental.TensorLike` is a new `Union` type that can be + used as type annotation for variables representing a Tensor or a value + that can be converted to Tensor by `tf.convert_to_tensor`. + * Calling ops with a python constants or numpy values is now consistent + with tf.convert_to_tensor behavior. This avoids operations like + tf.reshape truncating inputs such as from int64 to int32. + * Added `tf.sparse.map_values` to apply a function to the `.value`s of + `SparseTensor` arguments. + * The Python bitwise operators for `Tensor` (`__and__`, `__or__`, + `__xor__` and `__invert__` now support non-`bool` arguments and apply + the corresponding bitwise ops. `bool` arguments continue to be supported + and dispatch to logical ops. This brings them more in line with Python + and NumPy behavior. + * Added `tf.SparseTensor.with_values`. This returns a new SparseTensor + with the same sparsity pattern, but with new provided values. It is + similar to the `with_values` function of `RaggedTensor`. + * Added `StatelessCase` op, and uses it if none of case branches has + stateful ops. + * Added `tf.config.experimental.get_memory_usage` to return total memory + usage of the device. + * Added gradients for `RaggedTensorToVariant` and `RaggedTensorFromVariant`. +* `tf.data`: + * tf.data service: + * Added new `tf.data.experimental.service.register_dataset` and + `tf.data.experimental.service.from_dataset_id` APIs to enable one + process to register a dataset with the tf.data service, and another + process to consume data from the dataset. + * Added support for dispatcher fault tolerance. To enable fault tolerance, + configure a `work_dir` when running your dispatcher server and set + `dispatcher_fault_tolerance=True`. The dispatcher will store its state + to `work_dir`, so that on restart it can continue from its previous + state after restart. + * Added support for sharing dataset graphs via shared filesystem instead + of over RPC. This reduces load on the dispatcher, improving performance + of distributing datasets. For this to work, the dispatcher's `work_dir` + must be accessible from workers. If the worker fails to read from the + `work_dir`, it falls back to using RPC for dataset graph transfer. + * Added support for a new "distributed_epoch" processing mode. This + processing mode distributes a dataset across all tf.data workers, + instead of having each worker process the full dataset. See + [the tf.data service docs](https://www.tensorflow.org/api_docs/python/tf/data/experimental/service#understand_processing_mode) + to learn more. + * Added optional `exclude_cols` parameter to CsvDataset. This parameter is + the complement of `select_cols`; at most one of these should be + specified. + * We have implemented an optimization which reorders data-discarding + transformations such as `take` and `shard` to happen earlier in the + dataset when it is safe to do so. The optimization can be disabled via + the `experimental_optimization.reorder_data_discarding_ops` dataset + option. + * `tf.data.Options` were previously immutable and can now be overridden. + * `tf.data.Dataset.from_generator` now supports Ragged and Sparse tensors + with a new `output_signature` argument, which allows `from_generator` to + produce any type describable by a `tf.TypeSpec`. + * `tf.data.experimental.AUTOTUNE` is now available in the core API as + `tf.data.AUTOTUNE`. +* `tf.image`: + * Added deterministic `tf.image.stateless_random_*` functions for each + `tf.image.random_*` function. Added a new op + `stateless_sample_distorted_bounding_box` which is a deterministic + version of `sample_distorted_bounding_box` op. Given the same seed, + these stateless functions/ops produce the same results independent of + how many times the function is called, and independent of global seed + settings. * `tf.distribute`: - * -* `tf.keras`: - * Improvements from the functional API refactoring: - * Functional model construction does not need to maintain a global workspace graph, removing memory leaks especially when building many models or very large models. - * Functional model construction should be ~8-10% faster on average. - * Functional models can now contain non-symbolic values in their call inputs inside of the first positional argument. - * Several classes of TF ops that were not reliably converted to Keras layers during functional API construction should now work, e.g. `tf.image.ssim_multiscale` - * Error messages when Functional API construction goes wrong (and when ops cannot be converted to Keras layers automatically) should be clearer and easier to understand. - * `Optimizer.minimize` can now accept a loss `Tensor` and a `GradientTape` - as an alternative to accepting a `callable` loss. - * Added `beta` hyperparameter to FTRL optimizer classes (Keras and others) - to match FTRL paper (https://research.google.com/pubs/archive/41159.pdf). - * Added `mobilenet_v3` to keras application model. - * `Optimizer.__init__` now accepts a `gradient_aggregator` to allow for - customization of how gradients are aggregated across devices, as well as - `gradients_transformers` to allow for custom gradient transformations - (such as gradient clipping). -* `tf.function` / AutoGraph: - * Added `experimental_follow_type_hints` argument for `tf.function`. When - True, the function may use type annotations to optimize the tracing - performance. - * Added support for `iter(DistributedDataset)` in AutoGraph `for` loops. - * AutoGraph now allows creating new symbols inside a TensorFLow loop, if - the values of these symbols at an iteration does not depend on the previous - iteration. These types of loops must run at least one iteration, and will - raise a runtime error otherwise. + * +* `tf.keras`: + * Improvements from the functional API refactoring: + * Functional model construction does not need to maintain a global + workspace graph, removing memory leaks especially when building many + models or very large models. + * Functional model construction should be ~8-10% faster on average. + * Functional models can now contain non-symbolic values in their call + inputs inside of the first positional argument. + * Several classes of TF ops that were not reliably converted to Keras + layers during functional API construction should now work, e.g. + `tf.image.ssim_multiscale` + * Error messages when Functional API construction goes wrong (and when + ops cannot be converted to Keras layers automatically) should be + clearer and easier to understand. + * `Optimizer.minimize` can now accept a loss `Tensor` and a `GradientTape` + as an alternative to accepting a `callable` loss. + * Added `beta` hyperparameter to FTRL optimizer classes (Keras and others) + to match FTRL paper + (https://research.google.com/pubs/archive/41159.pdf). + * Added `mobilenet_v3` to keras application model. + * `Optimizer.__init__` now accepts a `gradient_aggregator` to allow for + customization of how gradients are aggregated across devices, as well as + `gradients_transformers` to allow for custom gradient transformations + (such as gradient clipping). + * The `steps_per_execution` argument in `compile()` is no longer + experimental; if you were passing `experimental_steps_per_execution`, + rename it to `steps_per_execution` in your code. This argument controls + the number of batches to run during each `tf.function` call when calling + `fit()`. Running multiple batches inside a single `tf.function` call can + greatly improve performance on TPUs or small models with a large Python + overhead. + * Improvements to Keras preprocessing layers: + * TextVectorization can now accept a vocabulary list or file as an + init arg. + * Normalization can now accept mean and variance values as init args. + * In `Attention` and `AdditiveAttention` layers, the `call()` method now + accepts a `return_attention_scores` argument. When set to + True, the layer returns the attention scores as an additional output + argument. + * Added `tf.metrics.log_cosh` and `tf.metrics.logcosh` API entrypoints + with the same implementation as their `tf.losses` equivalent. + * For Keras model, the individual call of `Model.evaluate` uses no cached + data for evaluation, while `Model.fit` uses cached data when + `validation_data` arg is provided for better performance. +* `tf.function` / AutoGraph: + * Added `experimental_follow_type_hints` argument for `tf.function`. When + True, the function may use type annotations to optimize the tracing + performance. + * Added support for `iter(DistributedDataset)` in AutoGraph `for` loops. + * AutoGraph now allows creating new symbols inside a TensorFLow loop, if + the values of these symbols at an iteration does not depend on the + previous iteration. These types of loops must run at least one + iteration, and will raise a runtime error otherwise. Example: @@ -137,45 +252,103 @@ outputs = train_step(batch) tf.print('final outputs', outputs) ``` + See tensorflow/python/autograph/g3doc/reference/limitations.md for more info. + * `tf.lite`: - * `DynamicBuffer::AddJoinedString()` will now add a separator if the first - string to be joined is empty. - * `TFLiteConverter`: - * Support optional flags `inference_input_type` and `inference_output_type` for full integer quantized models. This allows users to modify the model input and output type to integer types (`tf.int8`, `tf.uint8`) instead of defaulting to float type (`tf.float32`). - * Deprecate `Interpreter::UseNNAPI(bool)` C++ API - * Prefer using `NnApiDelegate()` and related delegate configuration methods directly. - * Add NNAPI Delegation support for requantization use cases by converting the operation into a dequantize-quantize pair. - * + + * `TFLiteConverter`: + * Support optional flags `inference_input_type` and + `inference_output_type` for full integer quantized models. This + allows users to modify the model input and output type to integer + types (`tf.int8`, `tf.uint8`) instead of defaulting to float type + (`tf.float32`). + * TFLite Profiler for Android is available. See the detailed + [guide](https://www.tensorflow.org/lite/performance/measurement#trace_tensorflow_lite_internals_in_android). + * NNAPI + * Added NNAPI Delegation support for requantization use cases by + converting the operation into a dequantize-quantize pair. + * Removed deprecated `Interpreter.setUseNNAPI(boolean)` Java API. + * Use `Interpreter.Options.setUseNNAPI` instead. + * Deprecate `Interpreter::UseNNAPI(bool)` C++ API. + * Use `NnApiDelegate()` and related delegate configuration methods + directly. + * Deprecate `Interpreter::SetAllowFp16PrecisionForFp32(bool)` C++ API + * Prefer controlling this via delegate options, e.g. + `tflite::StatefulNnApiDelegate::Options::allow_fp16' or + `TfLiteGpuDelegateOptionsV2::is_precision_loss_allowed`. + * `DynamicBuffer::AddJoinedString()` will now add a separator if the first + string to be joined is empty. + * + * `tf.random`: - * + + * + * Math and Linear Algebra: - * + + * Add `tf.math.erfcinv`, the inverse to `tf.math.erfc`. + * TPU Enhancements: - * Added support for the `beta` parameter of the FTRL optimizer for TPU - embeddings. Users of other TensorFlow platforms can implement equivalent - behavior by adjusting the `l2` parameter. - * + + * Added support for the `beta` parameter of the FTRL optimizer for TPU + embeddings. Users of other TensorFlow platforms can implement equivalent + behavior by adjusting the `l2` parameter. + * + * XLA Support: - * xla.experimental.compile is deprecated, use - `tf.function(experimental_compile=True)` instead - * + + * xla.experimental.compile is deprecated, use + `tf.function(experimental_compile=True)` instead + * Added `tf.function.experimental_get_compiler_ir` which returns compiler + IR (currently 'hlo' and 'optimized_hlo') for given input for given + function. + * + * Tracing and Debugging: - * + + * + * `tf.train.Checkpoint`: - * Now accepts a `root` argument in the initialization, which generates a - checkpoint with a root object. This allows users to create a `Checkpoint` - object that is compatible with Keras `model.save_weights()` and - `model.load_weights`. The checkpoint is also compatible with the - checkpoint saved in the `variables/` folder in the SavedModel. - * When restoring, `save_path` can be a path to a SavedModel. The function - will automatically find the checkpoint in the SavedModel. + + * Now accepts a `root` argument in the initialization, which generates a + checkpoint with a root object. This allows users to create a + `Checkpoint` object that is compatible with Keras `model.save_weights()` + and `model.load_weights`. The checkpoint is also compatible with the + checkpoint saved in the `variables/` folder in the SavedModel. + * When restoring, `save_path` can be a path to a SavedModel. The function + will automatically find the checkpoint in the SavedModel. + +* `tf.nn`: + + * `tf.nn.max_pool2d` now supports explicit padding. + +* `tf.debugging`: + + * `tf.debugging.assert_shapes()` now works on `SparseTensor`s (#36268). + +* `tf.print`: + + * Bug fix in `tf.print()` with `OrderedDict` where if an `OrderedDict` + didn't have the keys sorted, the keys and values were not being printed + in accordance with their correct mapping. + +* `TensorRT` + + * We now issue a warning when the `session_config` parameter for the TF1 + converter is used or the `rewrite_config_template` field in the TF2 + converter parameter object is used. + * Other: - * We have replaced uses of "whitelist" and "blacklist" with "allowlist" - and "denylist" where possible. Please see - https://developers.google.com/style/word-list#blacklist for more context. - * + + * We have replaced uses of "whitelist" and "blacklist" with "allowlist" + and "denylist" where possible. Please see + https://developers.google.com/style/word-list#blacklist for more + context. + * Add `tf.config.experimental.mlir_bridge_rollout` which will help us + rollout the new MLIR TPU bridge. + * ## Thanks to our Contributors @@ -183,45 +356,327 @@ This release contains contributions from many people at Google, as well as: stjohnso98, , , , , + +# Release 2.3.1 + +## Bug Fixes and Other Changes +* Fixes an undefined behavior causing a segfault in `tf.raw_ops.Switch` + ([CVE-2020-15190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15190)) +* Fixes three vulnerabilities in conversion to DLPack format + ([CVE-2020-15191](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15191), + [CVE-2020-15192](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15192), + [CVE-2020-15193](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15193)) +* Fixes two vulnerabilities in `SparseFillEmptyRowsGrad` + ([CVE-2020-15194](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15194), + [CVE-2020-15195](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15195)) +* Fixes several vulnerabilities in `RaggedCountSparseOutput` and + `SparseCountSparseOutput` operations + ([CVE-2020-15196](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15196), + [CVE-2020-15197](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15197), + [CVE-2020-15198](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15198), + [CVE-2020-15199](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15199), + [CVE-2020-15200](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15200), + [CVE-2020-15201](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15201)) +* Fixes an integer truncation vulnerability in code using the work sharder API + ([CVE-2020-15202](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15202)) +* Fixes a format string vulnerability in `tf.strings.as_string` + ([CVE-2020-15203](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15203)) +* Fixes segfault raised by calling session-only ops in eager mode + ([CVE-2020-15204](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15204)) +* Fixes data leak and potential ASLR violation from `tf.raw_ops.StringNGrams` + ([CVE-2020-15205](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15205)) +* Fixes segfaults caused by incomplete `SavedModel` validation + ([CVE-2020-15206](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15206)) +* Fixes a data corruption due to a bug in negative indexing support in TFLite + ([CVE-2020-15207](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15207)) +* Fixes a data corruption due to dimension mismatch in TFLite + ([CVE-2020-15208](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15208)) +* Fixes several vulnerabilities in TFLite saved model format + ([CVE-2020-15209](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15209), + [CVE-2020-15210](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15210), + [CVE-2020-15211](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15211)) +* Fixes several vulnerabilities in TFLite implementation of segment sum + ([CVE-2020-15212](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15212), + [CVE-2020-15213](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15213), + [CVE-2020-15214](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15214)) +* Updates `sqlite3` to `3.33.00` to handle + [CVE-2020-15358](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15358). +* Fixes deprecated usage of `collections` API +* Removes `scipy` dependency from `setup.py` since TensorFlow does not need it + to install the pip package + + +# Release 2.2.1 + +## Bug Fixes and Other Changes +* Fixes an undefined behavior causing a segfault in `tf.raw_ops.Switch` + ([CVE-2020-15190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15190)) +* Fixes three vulnerabilities in conversion to DLPack format + ([CVE-2020-15191](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15191), + [CVE-2020-15192](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15192), + [CVE-2020-15193](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15193)) +* Fixes two vulnerabilities in `SparseFillEmptyRowsGrad` + ([CVE-2020-15194](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15194), + [CVE-2020-15195](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15195)) +* Fixes an integer truncation vulnerability in code using the work sharder API + ([CVE-2020-15202](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15202)) +* Fixes a format string vulnerability in `tf.strings.as_string` + ([CVE-2020-15203](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15203)) +* Fixes segfault raised by calling session-only ops in eager mode + ([CVE-2020-15204](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15204)) +* Fixes data leak and potential ASLR violation from `tf.raw_ops.StringNGrams` + ([CVE-2020-15205](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15205)) +* Fixes segfaults caused by incomplete `SavedModel` validation + ([CVE-2020-15206](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15206)) +* Fixes a data corruption due to a bug in negative indexing support in TFLite + ([CVE-2020-15207](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15207)) +* Fixes a data corruption due to dimension mismatch in TFLite + ([CVE-2020-15208](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15208)) +* Fixes several vulnerabilities in TFLite saved model format + ([CVE-2020-15209](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15209), + [CVE-2020-15210](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15210), + [CVE-2020-15211](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15211)) +* Fixes several vulnerabilities in TFLite implementation of segment sum + ([CVE-2020-15212](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15212), + [CVE-2020-15213](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15213), + [CVE-2020-15214](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15214)) +* Updates `sqlite3` to `3.33.00` to handle + [CVE-2020-9327](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-9327), + [CVE-2020-11655](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-11655), + [CVE-2020-11656](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-11656), + [CVE-2020-13434](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13434), + [CVE-2020-13435](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13435), + [CVE-2020-13630](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13630), + [CVE-2020-13631](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13631), + [CVE-2020-13871](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13871), + and + [CVE-2020-15358](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15358). +* Fixes deprecated usage of `collections` API +* Removes `scipy` dependency from `setup.py` since TensorFlow does not need it + to install the pip package + + +# Release 2.1.2 + +## Bug Fixes and Other Changes +* Fixes an undefined behavior causing a segfault in `tf.raw_ops.Switch` + ([CVE-2020-15190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15190)) +* Fixes three vulnerabilities in conversion to DLPack format + ([CVE-2020-15191](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15191), + [CVE-2020-15192](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15192), + [CVE-2020-15193](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15193)) +* Fixes two vulnerabilities in `SparseFillEmptyRowsGrad` + ([CVE-2020-15194](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15194), + [CVE-2020-15195](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15195)) +* Fixes an integer truncation vulnerability in code using the work sharder API + ([CVE-2020-15202](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15202)) +* Fixes a format string vulnerability in `tf.strings.as_string` + ([CVE-2020-15203](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15203)) +* Fixes segfault raised by calling session-only ops in eager mode + ([CVE-2020-15204](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15204)) +* Fixes data leak and potential ASLR violation from `tf.raw_ops.StringNGrams` + ([CVE-2020-15205](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15205)) +* Fixes segfaults caused by incomplete `SavedModel` validation + ([CVE-2020-15206](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15206)) +* Fixes a data corruption due to a bug in negative indexing support in TFLite + ([CVE-2020-15207](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15207)) +* Fixes a data corruption due to dimension mismatch in TFLite + ([CVE-2020-15208](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15208)) +* Fixes several vulnerabilities in TFLite saved model format + ([CVE-2020-15209](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15209), + [CVE-2020-15210](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15210), + [CVE-2020-15211](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15211)) +* Updates `sqlite3` to `3.33.00` to handle + [CVE-2020-9327](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-9327), + [CVE-2020-11655](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-11655), + [CVE-2020-11656](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-11656), + [CVE-2020-13434](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13434), + [CVE-2020-13435](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13435), + [CVE-2020-13630](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13630), + [CVE-2020-13631](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13631), + [CVE-2020-13871](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13871), + and + [CVE-2020-15358](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15358). +* Removes `scipy` dependency from `setup.py` since TensorFlow does not need it + to install the pip package +* Switches ROCM builds to use ROCM 3.7 + + +# Release 2.0.3 + +## Bug Fixes and Other Changes +* Fixes an undefined behavior causing a segfault in `tf.raw_ops.Switch` + ([CVE-2020-15190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15190)) +* Fixes three vulnerabilities in conversion to DLPack format + ([CVE-2020-15191](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15191), + [CVE-2020-15192](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15192), + [CVE-2020-15193](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15193)) +* Fixes two vulnerabilities in `SparseFillEmptyRowsGrad` + ([CVE-2020-15194](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15194), + [CVE-2020-15195](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15195)) +* Fixes an integer truncation vulnerability in code using the work sharder API + ([CVE-2020-15202](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15202)) +* Fixes a format string vulnerability in `tf.strings.as_string` + ([CVE-2020-15203](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15203)) +* Fixes segfault raised by calling session-only ops in eager mode + ([CVE-2020-15204](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15204)) +* Fixes data leak and potential ASLR violation from `tf.raw_ops.StringNGrams` + ([CVE-2020-15205](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15205)) +* Fixes segfaults caused by incomplete `SavedModel` validation + ([CVE-2020-15206](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15206)) +* Fixes a data corruption due to a bug in negative indexing support in TFLite + ([CVE-2020-15207](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15207)) +* Fixes a data corruption due to dimension mismatch in TFLite + ([CVE-2020-15208](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15208)) +* Fixes several vulnerabilities in TFLite saved model format + ([CVE-2020-15209](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15209), + [CVE-2020-15210](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15210), + [CVE-2020-15211](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15211)) +* Updates `sqlite3` to `3.33.00` to handle + [CVE-2020-9327](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-9327), + [CVE-2020-11655](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-11655), + [CVE-2020-11656](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-11656), + [CVE-2020-13434](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13434), + [CVE-2020-13435](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13435), + [CVE-2020-13630](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13630), + [CVE-2020-13631](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13631), + [CVE-2020-13871](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13871), + and + [CVE-2020-15358](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15358). +* Pins `numpy` to 1.18.5 to prevent ABI breakage when compiling code that uses + both NumPy and TensorFlow headers. + + +# Release 1.15.4 + +## Bug Fixes and Other Changes +* Fixes an undefined behavior causing a segfault in `tf.raw_ops.Switch` + ([CVE-2020-15190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15190)) +* Fixes three vulnerabilities in conversion to DLPack format + ([CVE-2020-15191](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15191), + [CVE-2020-15192](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15192), + [CVE-2020-15193](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15193)) +* Fixes two vulnerabilities in `SparseFillEmptyRowsGrad` + ([CVE-2020-15194](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15194), + [CVE-2020-15195](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15195)) +* Fixes an integer truncation vulnerability in code using the work sharder API + ([CVE-2020-15202](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15202)) +* Fixes a format string vulnerability in `tf.strings.as_string` + ([CVE-2020-15203](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15203)) +* Fixes segfault raised by calling session-only ops in eager mode + ([CVE-2020-15204](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15204)) +* Fixes data leak and potential ASLR violation from `tf.raw_ops.StringNGrams` + ([CVE-2020-15205](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15205)) +* Fixes segfaults caused by incomplete `SavedModel` validation + ([CVE-2020-15206](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15206)) +* Fixes a data corruption due to a bug in negative indexing support in TFLite + ([CVE-2020-15207](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15207)) +* Fixes a data corruption due to dimension mismatch in TFLite + ([CVE-2020-15208](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15208)) +* Fixes several vulnerabilities in TFLite saved model format + ([CVE-2020-15209](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15209), + [CVE-2020-15210](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15210), + [CVE-2020-15211](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15211)) +* Updates `sqlite3` to `3.33.00` to handle + [CVE-2020-9327](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-9327), + [CVE-2020-11655](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-11655), + [CVE-2020-11656](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-11656), + [CVE-2020-13434](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13434), + [CVE-2020-13435](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13435), + [CVE-2020-13630](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13630), + [CVE-2020-13631](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13631), + [CVE-2020-13871](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13871), + and + [CVE-2020-15358](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15358). +* Fixes #41630 by including `max_seq_length` in CuDNN descriptor cache key +* Pins `numpy` to 1.18.5 to prevent ABI breakage when compiling code that uses + both NumPy and TensorFlow headers. + + # Release 2.3.0 ## Major Features and Improvements - * `tf.data` adds two new mechanisms to solve input pipeline bottlenecks and save resources: - * [snapshot](https://www.tensorflow.org/api_docs/python/tf/data/experimental/snapshot) - * [tf.data service](https://www.tensorflow.org/api_docs/python/tf/data/experimental/service). - In addition checkout the detailed [guide](https://www.tensorflow.org/guide/data_performance_analysis) for analyzing input pipeline performance with TF Profiler. +* `tf.data` adds two new mechanisms to solve input pipeline bottlenecks and + save resources: - * [`tf.distribute.TPUStrategy`](https://www.tensorflow.org/api_docs/python/tf/distribute/TPUStrategy) is now a stable API and no longer considered experimental for TensorFlow. (earlier `tf.distribute.experimental.TPUStrategy`). + * [snapshot](https://www.tensorflow.org/api_docs/python/tf/data/experimental/snapshot) + * [tf.data service](https://www.tensorflow.org/api_docs/python/tf/data/experimental/service). - * [TF Profiler](https://www.tensorflow.org/guide/profiler) introduces two new tools: a memory profiler to visualize your model’s memory usage over time and a [python tracer](https://www.tensorflow.org/guide/profiler#events) which allows you to trace python function calls in your model. Usability improvements include better diagnostic messages and [profile options](https://tensorflow.org/guide/profiler#collect_performance_data) to customize the host and device trace verbosity level. + In addition checkout the detailed + [guide](https://www.tensorflow.org/guide/data_performance_analysis) for + analyzing input pipeline performance with TF Profiler. - * Introduces experimental support for Keras Preprocessing Layers API ([`tf.keras.layers.experimental.preprocessing.*`](https://www.tensorflow.org/api_docs/python/tf/keras/layers/experimental/preprocessing?version=nightly)) to handle data preprocessing operations, with support for composite tensor inputs. Please see below for additional details on these layers. +* [`tf.distribute.TPUStrategy`](https://www.tensorflow.org/api_docs/python/tf/distribute/TPUStrategy) + is now a stable API and no longer considered experimental for TensorFlow. + (earlier `tf.distribute.experimental.TPUStrategy`). - * TFLite now properly supports dynamic shapes during conversion and inference. We’ve also added opt-in support on Android and iOS for [XNNPACK](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/delegates/xnnpack), a highly optimized set of CPU kernels, as well as opt-in support for [executing quantized models on the GPU](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/performance/gpu_advanced.md#running-quantized-models-experimental). +* [TF Profiler](https://www.tensorflow.org/guide/profiler) introduces two new + tools: a memory profiler to visualize your model’s memory usage over time + and a [python tracer](https://www.tensorflow.org/guide/profiler#events) + which allows you to trace python function calls in your model. Usability + improvements include better diagnostic messages and + [profile options](https://tensorflow.org/guide/profiler#collect_performance_data) + to customize the host and device trace verbosity level. - * Libtensorflow packages are available in GCS starting this release. We have also started to [release a nightly version of these packages](https://github.com/tensorflow/tensorflow#official-builds). +* Introduces experimental support for Keras Preprocessing Layers API + ([`tf.keras.layers.experimental.preprocessing.*`](https://www.tensorflow.org/api_docs/python/tf/keras/layers/experimental/preprocessing?version=nightly)) + to handle data preprocessing operations, with support for composite tensor + inputs. Please see below for additional details on these layers. - * The experimental Python API [`tf.debugging.experimental.enable_dump_debug_info()`](https://www.tensorflow.org/api_docs/python/tf/debugging/experimental/enable_dump_debug_info) now allows you to instrument a TensorFlow program and dump debugging information to a directory on the file system. The directory can be read and visualized by a new interactive dashboard in TensorBoard 2.3 called [Debugger V2](https://www.tensorflow.org/tensorboard/debugger_v2), which reveals the details of the TensorFlow program including graph structures, history of op executions at the Python (eager) and intra-graph levels, the runtime dtype, shape, and numerical composistion of tensors, as well as their code locations. +* TFLite now properly supports dynamic shapes during conversion and inference. + We’ve also added opt-in support on Android and iOS for + [XNNPACK](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/delegates/xnnpack), + a highly optimized set of CPU kernels, as well as opt-in support for + [executing quantized models on the GPU](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/performance/gpu_advanced.md#running-quantized-models-experimental). + +* Libtensorflow packages are available in GCS starting this release. We have + also started to + [release a nightly version of these packages](https://github.com/tensorflow/tensorflow#official-builds). + +* The experimental Python API + [`tf.debugging.experimental.enable_dump_debug_info()`](https://www.tensorflow.org/api_docs/python/tf/debugging/experimental/enable_dump_debug_info) + now allows you to instrument a TensorFlow program and dump debugging + information to a directory on the file system. The directory can be read and + visualized by a new interactive dashboard in TensorBoard 2.3 called + [Debugger V2](https://www.tensorflow.org/tensorboard/debugger_v2), which + reveals the details of the TensorFlow program including graph structures, + history of op executions at the Python (eager) and intra-graph levels, the + runtime dtype, shape, and numerical composition of tensors, as well as their + code locations. ## Breaking Changes -* Increases the **minimum bazel version** required to build TF to **3.1.0**. -* `tf.data` - * Makes the following (breaking) changes to the `tf.data`. - * C++ API: - `IteratorBase::RestoreInternal`, `IteratorBase::SaveInternal`, and `DatasetBase::CheckExternalState` become pure-virtual and subclasses are now expected to provide an implementation. - * The deprecated `DatasetBase::IsStateful` method is removed in favor of `DatasetBase::CheckExternalState`. - * Deprecated overrides of `DatasetBase::MakeIterator` and `MakeIteratorFromInputElement` are removed. - * The signature of `tensorflow::data::IteratorBase::SaveInternal` and `tensorflow::data::IteratorBase::SaveInput` has been extended with `SerializationContext` argument to enable overriding the default policy for the handling external state during iterator checkpointing. This is not a backwards compatible change and all subclasses of `IteratorBase` *need to be updated* accordingly. -* `tf.keras` - * Add a new `BackupAndRestore` callback for handling distributed training failures & restarts. Please take a look at this [tutorial](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras) for details on how to use the callback. -* `tf.image.extract_glimpse` has been updated to correctly process the case - where `centered=False` and `normalized=False`. This is a breaking change as - the output is different from (incorrect) previous versions. Note this - breaking change only impacts `tf.image.extract_glimpse` and - `tf.compat.v2.image.extract_glimpse` API endpoints. The behavior of - `tf.compat.v1.image.extract_glimpse` does not change. The behavior of - exsiting C++ kernel `ExtractGlimpse` does not change either, so saved - models using `tf.raw_ops.ExtractGlimpse` will not be impacted. + +* Increases the **minimum bazel version** required to build TF to **3.1.0**. +* `tf.data` + * Makes the following (breaking) changes to the `tf.data`. + * C++ API: - `IteratorBase::RestoreInternal`, + `IteratorBase::SaveInternal`, and `DatasetBase::CheckExternalState` + become pure-virtual and subclasses are now expected to provide an + implementation. + * The deprecated `DatasetBase::IsStateful` method is removed in favor of + `DatasetBase::CheckExternalState`. + * Deprecated overrides of `DatasetBase::MakeIterator` and + `MakeIteratorFromInputElement` are removed. + * The signature of `tensorflow::data::IteratorBase::SaveInternal` and + `tensorflow::data::IteratorBase::SaveInput` has been extended with + `SerializationContext` argument to enable overriding the default policy + for the handling external state during iterator checkpointing. This is + not a backwards compatible change and all subclasses of `IteratorBase` + *need to be updated* accordingly. +* `tf.keras` + * Add a new `BackupAndRestore` callback for handling distributed training + failures & restarts. Please take a look at this + [tutorial](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras) + for details on how to use the callback. +* `tf.image.extract_glimpse` has been updated to correctly process the case + where `centered=False` and `normalized=False`. This is a breaking change as + the output is different from (incorrect) previous versions. Note this + breaking change only impacts `tf.image.extract_glimpse` and + `tf.compat.v2.image.extract_glimpse` API endpoints. The behavior of + `tf.compat.v1.image.extract_glimpse` does not change. The behavior of + existing C++ kernel `ExtractGlimpse` does not change either, so saved models + using `tf.raw_ops.ExtractGlimpse` will not be impacted. ## Known Caveats * `tf.lite` @@ -791,7 +1246,7 @@ This release contains contributions from many people at Google, as well as: 8bitmp3, Aaron Ma, AbdüLhamit Yilmaz, Abhai Kollara, aflc, Ag Ramesh, Albert Z. Guo, Alex Torres, amoitra, Andrii Prymostka, angeliand, Anshuman Tripathy, Anthony Barbier, Anton Kachatkou, Anubh-V, Anuja Jakhade, Artem Ryabov, autoih, Bairen Yi, Bas Aarts, Basit Ayantunde, Ben Barsdell, Bhavani Subramanian, Brett Koonce, candy.dc, Captain-Pool, caster, cathy, Chong Yan, Choong Yin Thong, Clayne Robison, Colle, Dan Ganea, David Norman, David Refaeli, dengziming, Diego Caballero, Divyanshu, djshen, Douman, Duncan Riach, EFanZh, Elena Zhelezina, Eric Schweitz, Evgenii Zheltonozhskii, Fei Hu, fo40225, Fred Reiss, Frederic Bastien, Fredrik Knutsson, fsx950223, fwcore, George Grzegorz Pawelczak, George Sterpu, Gian Marco Iodice, Giorgio Arena, giuros01, Gomathi Ramamurthy, Guozhong Zhuang, Haifeng Jin, Haoyu Wu, HarikrishnanBalagopal, HJYOO, Huang Chen-Yi, Ilham Firdausi Putra, Imran Salam, Jared Nielsen, Jason Zaman, Jasper Vicenti, Jeff Daily, Jeff Poznanovic, Jens Elofsson, Jerry Shih, jerryyin, Jesper Dramsch, jim.meyer, Jongwon Lee, Jun Wan, Junyuan Xie, Kaixi Hou, kamalkraj, Kan Chen, Karthik Muthuraman, Keiji Ariyama, Kevin Rose, Kevin Wang, Koan-Sin Tan, kstuedem, Kwabena W. Agyeman, Lakshay Tokas, latyas, Leslie-Fang-Intel, Li, Guizi, Luciano Resende, Lukas Folle, Lukas Geiger, Mahmoud Abuzaina, Manuel Freiberger, Mark Ryan, Martin Mlostek, Masaki Kozuki, Matthew Bentham, Matthew Denton, mbhuiyan, mdfaijul, Muhwan Kim, Nagy Mostafa, nammbash, Nathan Luehr, Nathan Wells, Niranjan Hasabnis, Oleksii Volkovskyi, Olivier Moindrot, olramde, Ouyang Jin, OverLordGoldDragon, Pallavi G, Paul Andrey, Paul Wais, pkanwar23, Pooya Davoodi, Prabindh Sundareson, Rajeshwar Reddy T, Ralovich, Kristof, Refraction-Ray, Richard Barnes, richardbrks, Robert Herbig, Romeo Kienzler, Ryan Mccormick, saishruthi, Saket Khandelwal, Sami Kama, Sana Damani, Satoshi Tanaka, Sergey Mironov, Sergii Khomenko, Shahid, Shawn Presser, ShengYang1, Siddhartha Bagaria, Simon Plovyt, skeydan, srinivasan.narayanamoorthy, Stephen Mugisha, sunway513, Takeshi Watanabe, Taylor Jakobson, TengLu, TheMindVirus, ThisIsIsaac, Tim Gates, Timothy Liu, Tomer Gafner, Trent Lo, Trevor Hickey, Trevor Morris, vcarpani, Wei Wang, Wen-Heng (Jack) Chung, wenshuai, Wenshuai-Xiaomi, wenxizhu, william, William D. Irons, Xinan Jiang, Yannic, Yasir Modak, Yasuhiro Matsumoto, Yong Tang, Yongfeng Gu, Youwei Song, Zaccharie Ramzi, Zhang, Zhenyu Guo, 王振华 (Zhenhua Wang), 韩董, 이중건 Isaac Lee # Release 1.15.0 -This is the last 1.x release for TensorFlow. We do not expect to update the 1.x branch with features, although we will issue patch releases to fix vulnerabilities for at least one year. +This is the last 1.x release for TensorFlow. We do not expect to update the 1.x branch with features, although we will issue patch releases to fix vulnerabilities for at least one year. ## Major Features and Improvements * As [announced](https://groups.google.com/a/tensorflow.org/forum/#!topic/developers/iRCt5m4qUz0), `tensorflow` pip package will by default include GPU support (same as `tensorflow-gpu` now) for the platforms we currently have GPU support (Linux and Windows). It will work on machines with and without Nvidia GPUs. `tensorflow-gpu` will still be available, and CPU-only packages can be downloaded at `tensorflow-cpu` for users who are concerned about package size. @@ -801,7 +1256,7 @@ This enables writing forward compatible code: by explicitly importing either `te * Add toggles `tf.enable_control_flow_v2()` and `tf.disable_control_flow_v2()` for enabling/disabling v2 control flow. * Enable v2 control flow as part of `tf.enable_v2_behavior()` and `TF2_BEHAVIOR=1`. * AutoGraph translates Python control flow into TensorFlow expressions, allowing users to write regular Python inside `tf.function`-decorated functions. AutoGraph is also applied in functions used with `tf.data`, `tf.distribute` and `tf.keras` APIS. -* Adds `enable_tensor_equality()`, which switches the behavior such that: +* Adds `enable_tensor_equality()`, which switches the behavior such that: * Tensors are no longer hashable. * Tensors can be compared with `==` and `!=`, yielding a Boolean Tensor with element-wise comparison results. This will be the default behavior in 2.0. @@ -957,12 +1412,12 @@ For information on upgrading your existing TensorFlow 1.x models, please refer t * TensorFlow 2.0.0 is built using devtoolset7 (GCC7) on Ubuntu 16. This may lead to ABI incompatibilities with extensions built against earlier versions of TensorFlow. * Tensorflow code now produces 2 different pip packages: tensorflow_core containing all the code (in the future it will contain only the private implementation) and tensorflow which is a virtual pip package doing forwarding to tensorflow_core (and in the future will contain only the public API of tensorflow). We don't expect this to be breaking, unless you were importing directly from the implementation. Removed the `freeze_graph` command line tool; `SavedModel` should be used in place of frozen graphs. - + * `tf.contrib`: * `tf.contrib` has been deprecated, and functionality has been either migrated to the core TensorFlow API, to an ecosystem project such as [tensorflow/addons](https://www.github.com/tensorflow/addons) or [tensorflow/io](https://www.github.com/tensorflow/io), or removed entirely. * Remove `tf.contrib.timeseries` dependency on TF distributions. * Replace contrib references with `tf.estimator.experimental.*` for apis in `early_stopping.py`. - + * `tf.estimator`: * Premade estimators in the tf.estimator.DNN/Linear/DNNLinearCombined family have been updated to use `tf.keras.optimizers` instead of the `tf.compat.v1.train.Optimizer`s. If you do not pass in an `optimizer=` arg or if you use a string, the premade estimator will use the Keras optimizer. This is checkpoint breaking, as the optimizers have separate variables. A checkpoint converter tool for converting optimizers is included with the release, but if you want to avoid any change, switch to the v1 version of the estimator: `tf.compat.v1.estimator.DNN/Linear/DNNLinearCombined*`. * Default aggregation for canned Estimators is now `SUM_OVER_BATCH_SIZE`. To maintain previous default behavior, please pass `SUM` as the loss aggregation method. @@ -970,13 +1425,13 @@ For information on upgrading your existing TensorFlow 1.x models, please refer t * `Estimator.export_savedmodel` has been renamed to `export_saved_model`. * When saving to SavedModel, Estimators will strip default op attributes. This is almost always the correct behavior, as it is more forwards compatible, but if you require that default attributes to be saved with the model, please use `tf.compat.v1.Estimator`. * Feature Columns have been upgraded to be more Eager-friendly and to work with Keras. As a result, `tf.feature_column.input_layer` has been deprecated in favor of `tf.keras.layers.DenseFeatures`. v1 feature columns have direct analogues in v2 except for `shared_embedding_columns`, which are not cross-compatible with v1 and v2. Use `tf.feature_column.shared_embeddings` instead. - + * `tf.keras`: * `OMP_NUM_THREADS` is no longer used by the default Keras config. To configure the number of threads, use `tf.config.threading` APIs. * `tf.keras.model.save_model` and `model.save` now defaults to saving a TensorFlow SavedModel. HDF5 files are still supported. * Deprecated `tf.keras.experimental.export_saved_model` and `tf.keras.experimental.function`. Please use `tf.keras.models.save_model(..., save_format='tf')` and `tf.keras.models.load_model` instead. * Layers now default to float32, and automatically cast their inputs to the layer's dtype. If you had a model that used float64, it will probably silently use float32 in TensorFlow 2, and a warning will be issued that starts with `Layer ` is casting an input tensor from dtype float64 to the layer's dtype of float32. To fix, either set the default dtype to float64 with `tf.keras.backend.set_floatx('float64')`, or pass `dtype='float64'` to each of the Layer constructors. See `tf.keras.layers.Layer` for more information. - + * `tf.lite`: * Removed `lite.OpHint`, `lite.experimental`, and `lite.constant` from 2.0 API. * Tensors are no longer hashable, but instead compare element-wise with `==` and `!=`. Use `tf.compat.v1.disable_tensor_equality()` to return to the previous behavior. @@ -1211,8 +1666,8 @@ If you experience any snags when using TF 2.0, please let us know at the [TF 2.0 conversion. TensorRT initialization arguments are now passed wrapped in a named-tuple, `TrtConversionParams`, rather than as separate arguments as in `TrtGraphConverter`. - * Changed API to optimize TensorRT enginges during graph optimization. - This is now done by calling `converter.build()` where previously + * Changed API to optimize TensorRT engines during graph optimization. This + is now done by calling `converter.build()` where previously `is_dynamic_op=False` would be set. * `converter.convert()` no longer returns a `tf.function`. Now the function must be accessed from the saved model. @@ -2222,7 +2677,7 @@ Ag Ramesh, Alex Wiltschko, Alexander Pantyukhin, Amogh Mannekote, An Jiaoyang, A * [`tf.contrib.estimator.RNNEstimator`](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/contrib/estimator/RNNClassifier) * The [distributions.Bijector](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/contrib/distributions/bijectors/Bijector) API supports broadcasting for Bijectors with new API changes. - + ## Breaking Changes * If you're opening empty variable scopes; replace `variable_scope('', ...)` by `variable_scope(tf.get_variable_scope(), ...)`. @@ -2701,7 +3156,7 @@ Samuel He, Sandeep Dcunha, sandipmgiri, Sang Han, scott, Scott Mudge, Se-Won Kim Simone Cirillo, Steffen Schmitz, Suvojit Manna, Sylvus, Taehoon Lee, Ted Chang, Thomas Deegan, Till Hoffmann, Tim, Toni Kunic, Toon Verstraelen, Tristan Rice, Urs KöSter, Utkarsh Upadhyay, Vish (Ishaya) Abrams, Winnie Tsang, Yan Chen, Yan Facai (颜发才), Yi Yang, Yong Tang, -Youssef Hesham, Yuan (Terry) Tang, Zhengsheng Wei, zxcqwe4906, 张志豪, 田传武 +Youssef Hesham, Yuan (Terry) Tang, Zhengsheng Wei, zxcqwe4906, 张志豪, 田传武 We are also grateful to all who filed issues or helped resolve them, asked and answered questions, and were part of inspiring discussions. diff --git a/configure.py b/configure.py index 9524eada3cd..e381c8c20db 100644 --- a/configure.py +++ b/configure.py @@ -38,9 +38,6 @@ _DEFAULT_CUDNN_VERSION = '7' _DEFAULT_TENSORRT_VERSION = '6' _DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,7.0' -_TF_OPENCL_VERSION = '1.2' -_DEFAULT_COMPUTECPP_TOOLKIT_PATH = '/usr/local/computecpp' -_DEFAULT_TRISYCL_INCLUDE_DIR = '/usr/local/triSYCL/include' _SUPPORTED_ANDROID_NDK_VERSIONS = [10, 11, 12, 13, 14, 15, 16, 17, 18] _DEFAULT_PROMPT_ASK_ATTEMPTS = 10 @@ -1114,62 +1111,6 @@ def set_host_c_compiler(environ_cp): write_action_env_to_bazelrc('HOST_C_COMPILER', host_c_compiler) -def set_computecpp_toolkit_path(environ_cp): - """Set COMPUTECPP_TOOLKIT_PATH.""" - - def toolkit_exists(toolkit_path): - """Check if a computecpp toolkit path is valid.""" - if is_linux(): - sycl_rt_lib_path = 'lib/libComputeCpp.so' - else: - sycl_rt_lib_path = '' - - sycl_rt_lib_path_full = os.path.join(toolkit_path, sycl_rt_lib_path) - exists = os.path.exists(sycl_rt_lib_path_full) - if not exists: - print('Invalid SYCL %s library path. %s cannot be found' % - (_TF_OPENCL_VERSION, sycl_rt_lib_path_full)) - return exists - - computecpp_toolkit_path = prompt_loop_or_load_from_env( - environ_cp, - var_name='COMPUTECPP_TOOLKIT_PATH', - var_default=_DEFAULT_COMPUTECPP_TOOLKIT_PATH, - ask_for_var=( - 'Please specify the location where ComputeCpp for SYCL %s is ' - 'installed.' % _TF_OPENCL_VERSION), - check_success=toolkit_exists, - error_msg='Invalid SYCL compiler path. %s cannot be found.', - suppress_default_error=True) - - write_action_env_to_bazelrc('COMPUTECPP_TOOLKIT_PATH', - computecpp_toolkit_path) - - -def set_trisycl_include_dir(environ_cp): - """Set TRISYCL_INCLUDE_DIR.""" - - ask_trisycl_include_dir = ('Please specify the location of the triSYCL ' - 'include directory. (Use --config=sycl_trisycl ' - 'when building with Bazel) ' - '[Default is %s]: ') % ( - _DEFAULT_TRISYCL_INCLUDE_DIR) - - while True: - trisycl_include_dir = get_from_env_or_user_or_default( - environ_cp, 'TRISYCL_INCLUDE_DIR', ask_trisycl_include_dir, - _DEFAULT_TRISYCL_INCLUDE_DIR) - if os.path.exists(trisycl_include_dir): - break - - print('Invalid triSYCL include directory, %s cannot be found' % - (trisycl_include_dir)) - - # Set TRISYCL_INCLUDE_DIR - environ_cp['TRISYCL_INCLUDE_DIR'] = trisycl_include_dir - write_action_env_to_bazelrc('TRISYCL_INCLUDE_DIR', trisycl_include_dir) - - def system_specific_test_config(environ_cp): """Add default build and test flags required for TF tests to bazelrc.""" write_to_bazelrc('test --flaky_test_attempts=3') @@ -1397,8 +1338,6 @@ def main(): setup_python(environ_cp) if is_windows(): - environ_cp['TF_NEED_OPENCL_SYCL'] = '0' - environ_cp['TF_NEED_COMPUTECPP'] = '0' environ_cp['TF_NEED_OPENCL'] = '0' environ_cp['TF_CUDA_CLANG'] = '0' environ_cp['TF_NEED_TENSORRT'] = '0' @@ -1415,21 +1354,6 @@ def main(): if environ_cp.get('TF_ENABLE_XLA', '1') == '1': write_to_bazelrc('build --config=xla') - set_action_env_var( - environ_cp, - 'TF_NEED_OPENCL_SYCL', - 'OpenCL SYCL', - False, - bazel_config_name='sycl') - if environ_cp.get('TF_NEED_OPENCL_SYCL') == '1': - set_host_cxx_compiler(environ_cp) - set_host_c_compiler(environ_cp) - set_action_env_var(environ_cp, 'TF_NEED_COMPUTECPP', 'ComputeCPP', True) - if environ_cp.get('TF_NEED_COMPUTECPP') == '1': - set_computecpp_toolkit_path(environ_cp) - else: - set_trisycl_include_dir(environ_cp) - set_action_env_var( environ_cp, 'TF_NEED_ROCM', 'ROCm', False, bazel_config_name='rocm') if (environ_cp.get('TF_NEED_ROCM') == '1' and @@ -1442,6 +1366,11 @@ def main(): write_action_env_to_bazelrc('ROCM_PATH', environ_cp.get('ROCM_PATH')) write_action_env_to_bazelrc('ROCM_ROOT', environ_cp.get('ROCM_PATH')) + if ((environ_cp.get('TF_NEED_ROCM') == '1') and + (environ_cp.get('TF_ENABLE_MLIR_GENERATED_GPU_KERNELS') == '1')): + write_to_bazelrc( + 'build:rocm --define tensorflow_enable_mlir_generated_gpu_kernels=1') + environ_cp['TF_NEED_CUDA'] = str( int(get_var(environ_cp, 'TF_NEED_CUDA', 'CUDA', False))) if (environ_cp.get('TF_NEED_CUDA') == '1' and @@ -1523,17 +1452,15 @@ def main(): # use it for the CPU build. set_tf_download_clang(environ_cp) - # SYCL / ROCm / CUDA are mutually exclusive. + # ROCm / CUDA are mutually exclusive. # At most 1 GPU platform can be configured. gpu_platform_count = 0 - if environ_cp.get('TF_NEED_OPENCL_SYCL') == '1': - gpu_platform_count += 1 if environ_cp.get('TF_NEED_ROCM') == '1': gpu_platform_count += 1 if environ_cp.get('TF_NEED_CUDA') == '1': gpu_platform_count += 1 if gpu_platform_count >= 2: - raise UserInputError('SYCL / CUDA / ROCm are mututally exclusive. ' + raise UserInputError('CUDA / ROCm are mututally exclusive. ' 'At most 1 GPU platform can be configured.') set_cc_opt_flags(environ_cp) @@ -1558,6 +1485,7 @@ def main(): 'adding "--config=<>" to your build command. See .bazelrc for more ' 'details.') config_info_line('mkl', 'Build with MKL support.') + config_info_line('mkl_aarch64', 'Build with oneDNN support for Aarch64.') config_info_line('monolithic', 'Config for mostly static monolithic build.') config_info_line('ngraph', 'Build with Intel nGraph support.') config_info_line('numa', 'Build with NUMA support.') diff --git a/tensorflow/BUILD b/tensorflow/BUILD index d1c1d7dcdef..8946b45cacb 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -497,13 +497,20 @@ config_setting( visibility = ["//visibility:public"], ) -# This flag enables experimental MLIR bridge support. +# This flag forcibly enables experimental MLIR bridge support. config_setting( name = "enable_mlir_bridge", values = {"define": "enable_mlir_bridge=true"}, visibility = ["//visibility:public"], ) +# This flag forcibly disables experimental MLIR bridge support. +config_setting( + name = "disable_mlir_bridge", + values = {"define": "enable_mlir_bridge=false"}, + visibility = ["//visibility:public"], +) + # This flag enables experimental TPU support config_setting( name = "with_tpu_support", @@ -562,33 +569,17 @@ selects.config_setting_group( package_group( name = "internal", packages = [ - "//learning/brain/swift/x10/...", - "//perftools/accelerators/xprof/api/...", + "//learning/lib/ami/simple_ml/...", "//tensorflow/...", - "//tensorflow_estimator/python/estimator/...", - "//tensorflow_models/official/...", - "//third_party/py/autograph/...", - "//third_party/swift/tensorflow/x10/...", - "//third_party/swift/tensorflow_apis/...", ], ) -package_group( - name = "ndarray_tensor_allow_list", - packages = ["//learning/pathways/..."], -) - -# Packages that use composite tensors or dispatch. -# TODO(b/154762408) Remove this package group once it's no longer needed. -# If this is modified, then copy.bara.sky must also be modified. -package_group(name = "composite_tensor_whitelist") +package_group(name = "ndarray_tensor_allow_list") # Packages that use private types symbols, until they are exported. # TODO(b/154650521) Remove. -package_group( - name = "types_whitelist", - packages = ["//learning/deepmind/tensorflow/replicator/..."], -) +# If this is modified, then copy.bara.sky must also be modified. +package_group(name = "types_whitelist") # Packages that use StructuredTensors. # TODO(b/159007891) Remove this package once StructuredTensor is exported. @@ -714,8 +705,12 @@ tf_cc_shared_object( soversion = VERSION, visibility = ["//visibility:public"], deps = [ + "//tensorflow/c/experimental/filesystem:filesystem_interface", + "//tensorflow/c/experimental/stream_executor:stream_executor_hdrs", + "//tensorflow/c:kernels_hdrs", + "//tensorflow/c:ops_hdrs", "//tensorflow/cc/saved_model:loader_lite_impl", - "//tensorflow/core:core_cpu_impl", + "//tensorflow/core/common_runtime:core_cpu_impl", "//tensorflow/core:framework_internal_impl", "//tensorflow/core/common_runtime/gpu:gpu_runtime_impl", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry_impl", diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py index 5932dda514d..99a278a14a4 100644 --- a/tensorflow/api_template.__init__.py +++ b/tensorflow/api_template.__init__.py @@ -138,12 +138,12 @@ if _running_from_pip_package(): for _s in _site_packages_dirs: # Load first party dynamic kernels. _main_dir = _os.path.join(_s, 'tensorflow/core/kernels') - if _fi.file_exists(_main_dir): + if _os.path.exists(_main_dir): _ll.load_library(_main_dir) # Load third party dynamic kernels. _plugin_dir = _os.path.join(_s, 'tensorflow-plugins') - if _fi.file_exists(_plugin_dir): + if _os.path.exists(_plugin_dir): _ll.load_library(_plugin_dir) # Add module aliases diff --git a/tensorflow/api_template_v1.__init__.py b/tensorflow/api_template_v1.__init__.py index 0d1d2e56fae..ae82f7b4792 100644 --- a/tensorflow/api_template_v1.__init__.py +++ b/tensorflow/api_template_v1.__init__.py @@ -148,12 +148,12 @@ if _running_from_pip_package(): for _s in _site_packages_dirs: # Load first party dynamic kernels. _main_dir = _os.path.join(_s, 'tensorflow/core/kernels') - if _fi.file_exists(_main_dir): + if _os.path.exists(_main_dir): _ll.load_library(_main_dir) # Load third party dynamic kernels. _plugin_dir = _os.path.join(_s, 'tensorflow-plugins') - if _fi.file_exists(_plugin_dir): + if _os.path.exists(_plugin_dir): _ll.load_library(_plugin_dir) # Delete modules that should be hidden from dir(). diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index 9d8032aca52..1628bf05fd6 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -1,6 +1,7 @@ # Description: # C API for TensorFlow, for use by client language bindings. +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load( "//tensorflow:tensorflow.bzl", "tf_cc_test", @@ -9,6 +10,11 @@ load( "tf_custom_op_library", "tf_kernel_library", ) + +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "filegroup") + +# buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") package( @@ -211,6 +217,8 @@ tf_cuda_library( "//tensorflow/core:lib_internal", "//tensorflow/core/distributed_runtime:server_lib", "//tensorflow/core/kernels:logging_ops", + "//tensorflow/compiler/mlir/tfr:node_expansion_pass", + "//tensorflow/compiler/mlir/tfr:graph_decompose_pass", ], }), alwayslink = 1, @@ -248,6 +256,30 @@ tf_cuda_library( }), ) +cc_library( + name = "tf_shape", + srcs = ["tf_shape.cc"], + hdrs = ["tf_shape.h"], + copts = tf_copts(), + visibility = ["//visibility:public"], + deps = [ + ":c_api_macros", + ":tf_shape_internal", + "//tensorflow/core:framework", + ], +) + +cc_library( + name = "tf_shape_internal", + hdrs = ["tf_shape_internal.h"], + copts = tf_copts(), + visibility = ["//tensorflow:internal"], + deps = [ + ":conversion_macros", + "//tensorflow/core:framework", + ], +) + cc_library( name = "tf_status", srcs = ["tf_status.cc"], @@ -377,6 +409,7 @@ tf_cuda_library( "//tensorflow/c/eager:tfe_op_internal", "//tensorflow/c/eager:tfe_tensorhandle_internal", "//tensorflow/compiler/jit:flags", + "//tensorflow/compiler/jit:get_compiler_ir", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -387,6 +420,7 @@ tf_cuda_library( "//tensorflow/core/common_runtime/eager:eager_operation", "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", "//tensorflow/core/platform", + "//tensorflow/core/platform:blocking_counter", "@com_google_absl//absl/strings", ], alwayslink = 1, @@ -477,6 +511,18 @@ tf_cuda_library( ], ) +cc_library( + name = "kernels_hdrs", + hdrs = ["kernels.h"], + visibility = ["//tensorflow:internal"], + deps = [ + ":c_api_internal", + ":tf_datatype", + ":tf_status", + ":tf_tensor", + ], +) + tf_cuda_library( name = "kernels", srcs = [ @@ -530,6 +576,16 @@ tf_cuda_library( alwayslink = 1, ) +cc_library( + name = "ops_hdrs", + hdrs = ["ops.h"], + visibility = ["//tensorflow:internal"], + deps = [ + ":tf_datatype", + ":tf_status", + ], +) + # ----------------------------------------------------------------------------- # Tests diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 2e1759ecea0..a03e9227a75 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -2488,6 +2488,48 @@ TF_Buffer* TF_GetRegisteredKernelsForOp(const char* name, TF_Status* status) { return ret; } +void TF_UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst, + TF_Status* status) { + using tensorflow::RecordMutation; + mutex_lock l(graph->mu); + tensorflow::shape_inference::InferenceContext* ic = + graph->refiner.GetContext(&new_src.oper->node); + + if (ic->num_outputs() <= new_src.index) { + status->status = tensorflow::errors::OutOfRange( + "Cannot update edge. Output index [", new_src.index, + "] is greater than the number of total outputs [", ic->num_outputs(), + "]."); + return; + } + tensorflow::shape_inference::ShapeHandle shape = ic->output(new_src.index); + + tensorflow::shape_inference::InferenceContext* ic_dst = + graph->refiner.GetContext(&dst.oper->node); + if (ic_dst->num_inputs() <= dst.index) { + status->status = tensorflow::errors::OutOfRange( + "Cannot update edge. Input index [", dst.index, + "] is greater than the number of total inputs [", ic_dst->num_inputs(), + "]."); + return; + } + if (!ic_dst->MergeInput(dst.index, shape)) { + status->status = tensorflow::errors::InvalidArgument( + "Cannot update edge, incompatible shapes: ", ic_dst->DebugString(shape), + " and ", ic_dst->DebugString(ic_dst->input(dst.index)), "."); + return; + } + status->status = graph->graph.UpdateEdge(&new_src.oper->node, new_src.index, + &dst.oper->node, dst.index); + + if (TF_GetCode(status) == TF_OK) { + // This modification only updates the destination node for + // the purposes of running this graph in a session. Thus, we don't + // record the source node as being modified. + RecordMutation(graph, *dst.oper, "updating input tensor"); + } +} + // TF_Server functions ---------------------------------------------- #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index 0b4d9993e4d..db5f8fd68f8 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -1524,6 +1524,10 @@ TF_CAPI_EXPORT extern TF_Buffer* TF_GetAllRegisteredKernels(TF_Status* status); TF_CAPI_EXPORT extern TF_Buffer* TF_GetRegisteredKernelsForOp( const char* name, TF_Status* status); +// Update edge, switch input/ output in a node +TF_CAPI_EXPORT extern void TF_UpdateEdge(TF_Graph* graph, TF_Output new_src, + TF_Input dst, TF_Status* status); + // -------------------------------------------------------------------------- // In-process TensorFlow server functionality, for use in distributed training. // A Server instance encapsulates a set of devices and a Session target that diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index b4297033b6d..81fb9d1a2b8 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/platform/blocking_counter.h" #include "tensorflow/core/platform/casts.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/net.h" @@ -560,6 +561,21 @@ TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx, collective_executor_handle->get()->StartAbort(status->status); } +TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(TFE_Context* ctx, + const char* task, + TF_Status* status) { + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); + auto collective_executor_handle = context->GetCollectiveExecutorHandle(); + tensorflow::Notification done; + collective_executor_handle->get()->remote_access()->CheckPeerHealth( + task, [&done, status](const Status& s) { + status->status = s; + done.Notify(); + }); + done.WaitForNotification(); +} + TF_ShapeAndTypeList* TF_NewShapeAndTypeList(int num_items) { TF_ShapeAndTypeList* result = new TF_ShapeAndTypeList; result->num_items = num_items; diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h index ebd14b4b571..c9c74f4e874 100644 --- a/tensorflow/c/c_api_experimental.h +++ b/tensorflow/c/c_api_experimental.h @@ -231,13 +231,20 @@ TF_CAPI_EXPORT extern void TFE_EnableCollectiveOps(TFE_Context* ctx, TF_Status* status); // Aborts all ongoing collectives with the specified status. After abortion, -// subsequent collectives will error with this status immediately. +// subsequent collectives will error with this status immediately. To reset the +// collectives, create a new EagerContext. // -// This is intended to be used when a peer failure is detected. There's yet no -// way to reset the collectives other than restarting the program. +// This is intended to be used when a peer failure is detected. TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx, TF_Status* status); +// Checks the health of collective ops peers. Explicit health check is needed in +// multi worker collective ops to detect failures in the cluster. If a peer is +// down, collective ops may hang. +TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(TFE_Context* ctx, + const char* task, + TF_Status* status); + // Information about the shape of a Tensor and its type. struct TF_ShapeAndType { // Number of dimensions. -1 indicates unknown rank. diff --git a/tensorflow/c/c_api_function_test.cc b/tensorflow/c/c_api_function_test.cc index 3fff9bcd371..ec8cfe4a31a 100644 --- a/tensorflow/c/c_api_function_test.cc +++ b/tensorflow/c/c_api_function_test.cc @@ -1704,66 +1704,5 @@ TEST_F(CApiFunctionTest, GetFunctionsFromGraph) { TF_DeleteFunction(func1); } -// This test only works when the TF build includes XLA compiler. One way to set -// this up is via bazel build option "--define with_xla_support=true". -// -// FIXME: generalize the macro name TENSORFLOW_EAGER_USE_XLA to -// something like TENSORFLOW_CAPI_USE_XLA. -#ifdef TENSORFLOW_EAGER_USE_XLA -TEST_F(CApiFunctionTest, StatelessIf_XLA) { - TF_Function* func; - const std::string funcName = "BranchFunc"; - DefineFunction(funcName.c_str(), &func); - TF_GraphCopyFunction(host_graph_, func, nullptr, s_); - ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); - - TF_Operation* feed = Placeholder(host_graph_, s_); - ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); - - TF_Operation* true_cond = ScalarConst(true, host_graph_, s_); - ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); - - TF_OperationDescription* desc = - TF_NewOperation(host_graph_, "StatelessIf", "IfNode"); - TF_AddInput(desc, {true_cond, 0}); - TF_Output inputs[] = {{feed, 0}}; - TF_AddInputList(desc, inputs, TF_ARRAYSIZE(inputs)); - ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); - TF_SetAttrType(desc, "Tcond", TF_BOOL); - TF_DataType inputType = TF_INT32; - TF_SetAttrTypeList(desc, "Tin", &inputType, 1); - TF_SetAttrTypeList(desc, "Tout", &inputType, 1); - TF_SetAttrFuncName(desc, "then_branch", funcName.data(), funcName.size()); - TF_SetAttrFuncName(desc, "else_branch", funcName.data(), funcName.size()); - TF_SetDevice(desc, "/device:XLA_CPU:0"); - auto op = TF_FinishOperation(desc, s_); - ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); - ASSERT_NE(op, nullptr); - - // Create a session for this graph. - CSession csession(host_graph_, s_, /*use_XLA*/ true); - ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); - - // Run the graph. - csession.SetInputs({{feed, Int32Tensor(17)}}); - csession.SetOutputs({op}); - csession.Run(s_); - ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); - TF_Tensor* out = csession.output_tensor(0); - ASSERT_TRUE(out != nullptr); - EXPECT_EQ(TF_INT32, TF_TensorType(out)); - EXPECT_EQ(0, TF_NumDims(out)); // scalar - ASSERT_EQ(sizeof(int32), TF_TensorByteSize(out)); - int32* output_contents = static_cast(TF_TensorData(out)); - EXPECT_EQ(-17, *output_contents); - - // Clean up - csession.CloseAndDelete(s_); - ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); - - TF_DeleteFunction(func); -} -#endif // TENSORFLOW_EAGER_USE_XLA - } // namespace } // namespace tensorflow diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index bbbbb8f7d56..fc1fdccee16 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -634,6 +634,40 @@ TEST(CAPI, Graph) { TF_DeleteStatus(s); } +TEST(CAPI, UpdateEdge) { + TF_Status* s = TF_NewStatus(); + TF_Graph* graph = TF_NewGraph(); + + // Make two scalar constants. + TF_Operation* one = ScalarConst(1, graph, s, "one"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + TF_Operation* two = ScalarConst(2, graph, s, "two"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + // Add oper. + TF_Operation* add = Add(one, two, graph, s, "add"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + // Add another oper to the graph. + TF_Operation* neg = Neg(add, graph, s, "neg"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + NodeDef node_def_neg; + ASSERT_TRUE(GetNodeDef(neg, &node_def_neg)); + EXPECT_EQ(string("add"), node_def_neg.input(0)); + + // update edge of neg + TF_UpdateEdge(graph, TF_Output{one, 0}, TF_Input{neg, 0}, s); + + ASSERT_TRUE(GetNodeDef(neg, &node_def_neg)); + EXPECT_EQ(string("one:0"), node_def_neg.input(0)); + + // Clean up + TF_DeleteGraph(graph); + TF_DeleteStatus(s); +} + /* TODO(skyewm): this test currently DCHECKs, change to bad status diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index cc02d83fe01..08b3c73ed02 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -1,13 +1,23 @@ # Experimental extensions to the C API for eager execution of kernels. +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load( "//tensorflow:tensorflow.bzl", + "if_libtpu", "tf_cc_test", "tf_copts", "tf_cuda_cc_test", "tf_cuda_library", - "tfe_xla_copts", ) + +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "cc_header_only_library") + +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "filegroup") + +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "internal_tfrt_deps") load( "//tensorflow/core/platform:build_config.bzl", "tf_kernel_tests_linkstatic", @@ -31,7 +41,7 @@ tf_cuda_library( "c_api_unified_experimental.h", ], hdrs = ["c_api.h"], - copts = tf_copts() + tfe_xla_copts(), + copts = tf_copts(), visibility = ["//visibility:public"], deps = select({ "//tensorflow:android": [ @@ -72,13 +82,6 @@ tf_cuda_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/profiler/lib:traceme", ], - }) + select({ - "//tensorflow:with_xla_support": [ - "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/jit", - "//tensorflow/compiler/jit:xla_device", - ], - "//conditions:default": [], }) + [ "@com_google_absl//absl/memory", "//tensorflow/core/common_runtime/eager:eager_operation", @@ -95,7 +98,7 @@ tf_cuda_library( "//tensorflow/core/distributed_runtime:server_lib", "//tensorflow/core/distributed_runtime:worker_env", "//tensorflow/core:gpu_runtime", - ], + ] + internal_tfrt_deps(), alwayslink = 1, ) @@ -109,11 +112,16 @@ filegroup( "c_api_experimental.h", "c_api_internal.h", "c_api_unified_experimental.h", + "c_api_unified_experimental_internal.h", "dlpack.h", + "gradients.h", + "gradients_internal.h", "immediate_execution_context.h", "immediate_execution_operation.h", "immediate_execution_tensor_handle.h", + "tape.h", "tfe_cancellation_manager_internal.h", + "tfe_context_internal.h", "tfe_executor_internal.h", "tfe_monitoring_internal.h", "tfe_op_attrs_internal.h", @@ -172,27 +180,20 @@ cc_library( ) cc_library( - name = "gradients", - srcs = [ - "gradients.cc", - "gradients_internal.h", - ], + name = "tracing_utils", + srcs = ["tracing_utils.cc"], hdrs = [ - "gradients.h", + "tracing_utils.h", ], visibility = [ "//tensorflow:internal", ], deps = [ - ":abstract_context", ":abstract_operation", - ":abstract_tensor_handle", ":c_api_unified_internal", - ":tape", - "//tensorflow/core/common_runtime/eager:attr_builder", + "//tensorflow/c/experimental/gradients/tape:tape_operation", "//tensorflow/core/lib/llvm_rtti", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/strings", + "//tensorflow/core/platform:errors", ], ) @@ -228,10 +229,10 @@ tf_cuda_cc_test( "gradients_test.cc", ], args = ["--heap_check=local"], - extra_copts = tfe_xla_copts(), linkstatic = tf_kernel_tests_linkstatic(), tags = tf_cuda_tests_tags() + ["nomac"], deps = [ + ":abstract_context", ":abstract_tensor_handle", ":c_api_experimental", ":c_api_test_util", @@ -242,7 +243,8 @@ tf_cuda_cc_test( "//tensorflow/c:tf_status_helper", "//tensorflow/c/experimental/gradients:array_grad", "//tensorflow/c/experimental/gradients:math_grad", - "//tensorflow/c/experimental/ops:array_ops", + "//tensorflow/c/experimental/gradients/tape:tape_context", + "//tensorflow/c/experimental/ops", "//tensorflow/cc/profiler", "//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration", "//tensorflow/core:lib", @@ -256,6 +258,46 @@ tf_cuda_cc_test( ], ) +cc_library( + name = "gradients_util", + srcs = [ + "gradients_util.cc", + ], + hdrs = [ + "gradients_util.h", + ], + visibility = [ + "//tensorflow:internal", + ], + deps = [ + ":abstract_context", + ":abstract_operation", + ":abstract_tensor_handle", + ":c_api", + ":c_api_experimental", + ":c_api_unified_internal", + ":gradients_internal", + ":tape", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "//tensorflow/c:c_api", + "//tensorflow/c:tf_status_helper", + "//tensorflow/c/experimental/ops:array_ops", + "//tensorflow/c/experimental/ops:math_ops", + "//tensorflow/c/experimental/ops:nn_ops", + "//tensorflow/cc/profiler", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/lib/llvm_rtti", + ] + if_libtpu( + if_false = ["//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration"], + if_true = [], + ), +) + cc_library( name = "mnist_gradients_testutil", srcs = [ @@ -272,17 +314,93 @@ cc_library( ":c_api_experimental", ":c_api_unified_internal", ":gradients_internal", - "//tensorflow/c:tf_status_helper", - "//tensorflow/c:tf_tensor", + ":gradients_util", + ":tape", + "//tensorflow/c/experimental/gradients/tape:tape_context", "//tensorflow/c/experimental/ops:array_ops", "//tensorflow/c/experimental/ops:math_ops", "//tensorflow/c/experimental/ops:nn_ops", "//tensorflow/core/lib/llvm_rtti", + "//tensorflow/core/platform:status", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/types:span", ], ) +cc_library( + name = "gradient_checker", + srcs = [ + "gradient_checker.cc", + ], + hdrs = [ + "gradient_checker.h", + ], + visibility = [ + "//tensorflow:internal", + ], + deps = [ + ":abstract_tensor_handle", + ":c_api_experimental", + ":c_api_unified_internal", + ":gradients_internal", + ":gradients_util", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "//tensorflow/c:c_api", + "//tensorflow/c:tf_status_helper", + "//tensorflow/c/experimental/gradients:math_grad", + "//tensorflow/c/experimental/gradients:nn_grad", + "//tensorflow/c/experimental/ops:array_ops", + "//tensorflow/c/experimental/ops:math_ops", + "//tensorflow/c/experimental/ops:nn_ops", + "//tensorflow/cc/profiler", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/lib/llvm_rtti", + ] + if_libtpu( + if_false = ["//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration"], + if_true = [], + ), +) + +tf_cuda_cc_test( + name = "gradient_checker_test", + size = "small", + srcs = [ + "gradient_checker_test.cc", + ], + args = ["--heap_check=local"], + linkstatic = tf_kernel_tests_linkstatic(), + tags = tf_cuda_tests_tags() + ["nomac"], + deps = [ + ":abstract_tensor_handle", + ":c_api_experimental", + ":c_api_test_util", + ":c_api_unified_internal", + ":gradient_checker", + ":gradients_internal", + ":gradients_util", + ":mnist_gradients_testutil", + "//tensorflow/c:c_api", + "//tensorflow/c:c_test_util", + "//tensorflow/c:tf_status_helper", + "//tensorflow/c/experimental/gradients:math_grad", + "//tensorflow/c/experimental/gradients:nn_grad", + "//tensorflow/c/experimental/ops:array_ops", + "//tensorflow/c/experimental/ops:math_ops", + "//tensorflow/c/experimental/ops:nn_ops", + "//tensorflow/cc/profiler", + "//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/lib/llvm_rtti", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + tf_cuda_cc_test( name = "mnist_gradients_test", size = "small", @@ -290,19 +408,16 @@ tf_cuda_cc_test( "mnist_gradients_test.cc", ], args = ["--heap_check=local"], - extra_copts = tfe_xla_copts(), linkstatic = tf_kernel_tests_linkstatic(), tags = tf_cuda_tests_tags() + [ "nomac", - "notap", # TODO(b/166150182): Enable - "no_oss", # TODO(b/166150182): Enable ], deps = [ ":abstract_tensor_handle", ":c_api_experimental", - ":c_api_test_util", ":c_api_unified_internal", ":gradients_internal", + ":gradients_util", ":mnist_gradients_testutil", "//tensorflow/c:c_api", "//tensorflow/c:c_test_util", @@ -526,6 +641,19 @@ cc_library( ], ) +cc_header_only_library( + name = "tfe_tensorhandle_internal_hdrs_only", + extra_deps = [ + "@com_google_absl//absl/strings", + ], + visibility = [ + "//tensorflow:internal", + ], + deps = [ + ":tfe_tensorhandle_internal", + ], +) + tf_cuda_library( name = "c_api_test_util", testonly = 1, @@ -539,6 +667,8 @@ tf_cuda_library( ":c_api", ":c_api_experimental", "//tensorflow/c:c_test_util", + "//tensorflow/c:tf_datatype", + "//tensorflow/c:tf_tensor", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -553,7 +683,6 @@ tf_cuda_cc_test( "c_api_debug_test.cc", "c_api_test.cc", ], - extra_copts = tfe_xla_copts(), tags = [ "noguitar", # TODO(b/155445984): flaky #"guitar", @@ -608,7 +737,6 @@ tf_cuda_cc_test( ], # TODO(b/136478427): Figure out how to correctly shut the server down args = ["--heap_check=local"], - extra_copts = tfe_xla_copts(), tags = [ "no_windows", ], @@ -641,7 +769,6 @@ tf_cuda_cc_test( ], # TODO(b/136478427): Figure out how to correctly shut the server down args = ["--heap_check=local"], - extra_copts = tfe_xla_copts(), tags = [ "no_windows", ], @@ -660,7 +787,6 @@ tf_cuda_cc_test( ], # TODO(b/136478427): Figure out how to correctly shut the server down args = ["--heap_check=local"], - extra_copts = tfe_xla_copts(), tags = [ "no_windows", "noasan", # leaks gRPC server instances @@ -694,7 +820,6 @@ tf_cuda_cc_test( ], # TODO(b/136478427): Figure out how to correctly shut the server down args = ["--heap_check=local"], - extra_copts = tfe_xla_copts(), tags = [ "no_windows", ], @@ -729,7 +854,7 @@ tf_cuda_library( "c_api_experimental.h", "c_api_unified_experimental.h", ], - copts = tf_copts() + tfe_xla_copts(), + copts = tf_copts(), visibility = ["//visibility:public"], deps = select({ "//tensorflow:android": [ @@ -801,7 +926,6 @@ tf_cuda_cc_test( "c_api_experimental_test.cc", ], args = ["--heap_check=local"], - extra_copts = tfe_xla_copts(), linkstatic = tf_kernel_tests_linkstatic(), tags = tf_cuda_tests_tags() + ["nomac"], deps = [ @@ -814,6 +938,7 @@ tf_cuda_cc_test( "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core/platform:status", "@com_google_absl//absl/strings", ], ) @@ -825,7 +950,6 @@ tf_cuda_cc_test( "c_api_unified_experimental_test.cc", ], args = ["--heap_check=local"], - extra_copts = tfe_xla_copts(), linkstatic = tf_kernel_tests_linkstatic(), tags = tf_cuda_tests_tags() + ["nomac"], deps = [ @@ -834,6 +958,7 @@ tf_cuda_cc_test( ":c_api_test_util", "//tensorflow/c:c_api", "//tensorflow/c:c_test_util", + "//tensorflow/c:tf_status_helper", "//tensorflow/cc/profiler", "//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration", "//tensorflow/core:lib", @@ -943,7 +1068,13 @@ filegroup( "c_api_unified_experimental_eager.cc", "c_api_unified_experimental_graph.cc", "c_api_unified_experimental_internal.h", + "gradient_checker.cc", + "gradient_checker.h", "gradients.cc", # Uses RTTI. + "gradients_util.cc", + "gradients_util.h", + "tracing_utils.h", + "tracing_utils.cc", "*test*", "*dlpack*", ], diff --git a/tensorflow/c/eager/abstract_context.h b/tensorflow/c/eager/abstract_context.h index b488255d150..d31b1e13611 100644 --- a/tensorflow/c/eager/abstract_context.h +++ b/tensorflow/c/eager/abstract_context.h @@ -32,7 +32,7 @@ namespace tensorflow { // environment, a traced representation etc. class AbstractContext { protected: - enum AbstractContextKind { kGraph, kMlir, kEager, kTfrt }; + enum AbstractContextKind { kGraph, kMlir, kEager, kTfrt, kTape }; explicit AbstractContext(AbstractContextKind kind) : kind_(kind) {} virtual ~AbstractContext() {} diff --git a/tensorflow/c/eager/abstract_operation.h b/tensorflow/c/eager/abstract_operation.h index b332679cc7c..4c630528f5d 100644 --- a/tensorflow/c/eager/abstract_operation.h +++ b/tensorflow/c/eager/abstract_operation.h @@ -30,7 +30,7 @@ namespace tensorflow { // tracing or immediate execution mode. class AbstractOperation { protected: - enum AbstractOperationKind { kGraph, kMlir, kEager, kTfrt }; + enum AbstractOperationKind { kGraph, kMlir, kEager, kTfrt, kTape }; explicit AbstractOperation(AbstractOperationKind kind) : kind_(kind) {} virtual ~AbstractOperation() {} diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index fefa753c608..5f388bfe0cd 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -39,7 +39,7 @@ limitations under the License. #include "tensorflow/c/eager/tfe_op_internal.h" #include "tensorflow/c/eager/tfe_tensorhandle_internal.h" #include "tensorflow/c/tf_tensor_internal.h" -#ifdef PLATFORM_GOOGLE +#if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE) #include "tensorflow/core/tfrt/eager/c_api_tfrt.h" #endif #include "tensorflow/core/common_runtime/device.h" @@ -51,9 +51,6 @@ limitations under the License. #include "tensorflow/core/protobuf/device_filters.pb.h" #include "tensorflow/core/protobuf/error_codes.pb.h" #include "tensorflow/core/util/device_name_utils.h" -#ifdef TENSORFLOW_EAGER_USE_XLA -#include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#endif // TENSORFLOW_EAGER_USE_XLA #include "tensorflow/core/common_runtime/copy_tensor.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_mgr.h" @@ -629,21 +626,30 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( "targets will fail."; } } else { - // The master's context_view_id will be incremented by one - // the UpdateRemoteMaster call later. We want all new workers and - // existing workers to also have the updated context_view_id, so - // we must set their context_view_id to the existing master's - // context_view_id + 1. - sg.Update(CreateRemoteContexts( - ctx, added_workers, context_id, context_view_id + 1, keep_alive_secs, - server_def, remote_eager_workers.get(), context->Executor().Async(), - context->LazyCopyFunctionRemoteInputs(), base_request)); + if (sg.ok()) { + // Create remote contexts on the newly added workers only if the master + // has collected all device information from them (i.e., the + // GetAllRemoteDevices call returns succussfully). Note that in rare cases + // GetAllRemoteDevices can still fail even with RPCs configured to wait + // until the remote workers to become alive. If the master creates remote + // contexts on the workers whose devices are still not collected, those + // workers will be treated as existing workers subsequently, so the master + // will never get devices from them even with retrying UpdateServerDef. + sg.Update(CreateRemoteContexts( + ctx, added_workers, context_id, context_view_id + 1, keep_alive_secs, + server_def, remote_eager_workers.get(), context->Executor().Async(), + context->LazyCopyFunctionRemoteInputs(), base_request)); + } if (!existing_workers.empty()) { if (VLOG_IS_ON(1)) { for (const string& w : existing_workers) { VLOG(1) << "Updating cluster with existing worker " << w; } } + // The master's context_view_id will be incremented by one in the + // UpdateRemoteMaster call later. We want existing workers to also have + // the updated context_view_id, so we must set their context_view_id to + // the master's current context_view_id + 1. sg.Update(UpdateRemoteContexts(ctx, existing_workers, added_workers, removed_workers, context_id, context_view_id + 1, server_def, @@ -723,7 +729,7 @@ void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; } TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { if (opts->use_tfrt) { -#ifdef PLATFORM_GOOGLE +#if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE) return tensorflow::wrap(new tfrt::tf::ContextInterface(opts->async)); #else status->status = tensorflow::errors::Unimplemented("TFRT is not supported"); @@ -745,10 +751,8 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { opts->session_options.options, static_cast( opts->device_placement_policy), - static_cast(opts->mirroring_policy), opts->async, opts->lazy_remote_inputs_copy, device_mgr.release(), - /*device_mgr_owned*/ true, r, - tensorflow::GetDefaultCustomKernelCreator())); + /*device_mgr_owned*/ true, r)); } void TFE_DeleteContext(TFE_Context* ctx) { @@ -851,20 +855,9 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx, #else // !defined(IS_MOBILE_PLATFORM) tensorflow::EagerContext* context = tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - tensorflow::GrpcServer* grpc_server = - static_cast(context->GetServer()); - - std::unique_ptr remote_eager_workers; - status->status = grpc_server->master_env()->worker_cache->GetEagerClientCache( - &remote_eager_workers); - if (!status->status.ok()) { - LOG(ERROR) << "Failed to get client cache for remote workers."; - return false; - } - // TODO(yuefengz): support partially specified `worker_name`. tensorflow::core::RefCountPtr eager_client; - status->status = remote_eager_workers->GetClient(worker_name, &eager_client); + status->status = context->GetClient(worker_name, &eager_client); if (!status->status.ok()) { return false; } @@ -911,9 +904,7 @@ TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx, void TFE_ContextSetThreadLocalDevicePlacementPolicy( TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) { - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - context->SetThreadLocalDevicePlacementPolicy( + tensorflow::unwrap(ctx)->SetThreadLocalDevicePlacementPolicy( static_cast(policy)); } @@ -922,10 +913,8 @@ void TFE_ContextSetThreadLocalDevicePlacementPolicy( // safe to call this function from the async EagerExecutor threads. extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy( TFE_Context* ctx) { - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); return static_cast( - context->GetDevicePlacementPolicy()); + tensorflow::unwrap(ctx)->GetDevicePlacementPolicy()); } TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t, TF_Status* status) { @@ -1149,26 +1138,23 @@ void TFE_DeleteOp(TFE_Op* op) { tensorflow::unwrap(op)->Release(); } +const char* TFE_OpGetName(const TFE_Op* op, TF_Status* status) { + return tensorflow::unwrap(op)->Name().c_str(); +} + +TFE_Context* TFE_OpGetContext(const TFE_Op* op, TF_Status* status) { + return tensorflow::wrap( + &(OperationFromInterface(tensorflow::unwrap(op))->EagerContext())); +} + void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) { status->status = tensorflow::unwrap(op)->SetDeviceName(device_name); } -const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) { +const char* TFE_OpGetDevice(const TFE_Op* op, TF_Status* status) { return tensorflow::unwrap(op)->DeviceName().c_str(); } -void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) { -#ifdef TENSORFLOW_EAGER_USE_XLA - tensorflow::Status s = tensorflow::unwrap(op)->SetUseXla(enable); - if (!s.ok()) { - LOG(ERROR) << "Could not enable XLA compilation for op: " << s; - } -#else - LOG(WARNING) << "This call is a no-op, as the TensorFlow library is not " - "built with XLA support."; -#endif // TENSORFLOW_EAGER_USE_XLA -} - void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) { status->status = tensorflow::unwrap(op)->AddInput(tensorflow::unwrap(input)); } @@ -1181,6 +1167,15 @@ void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs, static_cast(num_inputs)}); } +extern int TFE_OpGetFlatInputCount(const TFE_Op* op, TF_Status* status) { + return tensorflow::unwrap(op)->GetInputs().size(); +} + +extern TFE_TensorHandle* TFE_OpGetFlatInput(const TFE_Op* op, int index, + TF_Status* status) { + return tensorflow::wrap(tensorflow::unwrap(op)->GetInputs()[index]); +} + TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name, unsigned char* is_list, TF_Status* status) { TF_AttrType ret = TF_ATTR_INT; @@ -1430,21 +1425,15 @@ void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name, } unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) { - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - return context->FindFunctionDef(name) != nullptr; + return tensorflow::unwrap(ctx)->FindFunctionDef(name) != nullptr; } void TFE_ContextEnableRunMetadata(TFE_Context* ctx) { - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - context->SetShouldStoreGraphs(true); + tensorflow::unwrap(ctx)->SetShouldStoreGraphs(true); } void TFE_ContextDisableRunMetadata(TFE_Context* ctx) { - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - context->SetShouldStoreGraphs(false); + tensorflow::unwrap(ctx)->SetShouldStoreGraphs(false); } } // extern "C" @@ -1486,7 +1475,7 @@ void TFE_ContextEndStep(TFE_Context* ctx) { tensorflow::unwrap(ctx)->EndStep(); } -const TFE_OpAttrs* TFE_OpGetAttrs(TFE_Op* op) { +const TFE_OpAttrs* TFE_OpGetAttrs(const TFE_Op* op) { return tensorflow::wrap( &OperationFromInterface(tensorflow::unwrap(op))->Attrs()); } @@ -1551,8 +1540,67 @@ void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op, TFE_OpSetAttrFunction(op, attr_name, func_op); TFE_DeleteOp(func_op); } break; - case tensorflow::AttrValue::kList: - TF_FALLTHROUGH_INTENDED; + case tensorflow::AttrValue::kList: { + // String + if (const int s_size = default_value.list().s_size()) { + absl::InlinedVector values_vector; + absl::InlinedVector lengths_vector; + for (int i = 0; i < s_size; ++i) { + const string& v = default_value.list().s(i); + values_vector.push_back(v.data()); + lengths_vector.push_back(v.size()); + } + TFE_OpSetAttrStringList(op, attr_name, values_vector.data(), + lengths_vector.data(), s_size); + } + + // Int + if (const int i_size = default_value.list().i_size()) { + absl::InlinedVector i_vector; + for (int i = 0; i < i_size; ++i) { + i_vector.push_back(default_value.list().i(i)); + } + TFE_OpSetAttrIntList(op, attr_name, i_vector.data(), i_size); + } + // Float + if (const int f_size = default_value.list().f_size()) { + absl::InlinedVector f_vector; + for (int i = 0; i < f_size; ++i) { + f_vector.push_back(default_value.list().f(i)); + } + TFE_OpSetAttrFloatList(op, attr_name, f_vector.data(), f_size); + } + // Bool + if (const int b_size = default_value.list().b_size()) { + absl::InlinedVector b_vector; + for (int i = 0; i < b_size; i++) { + b_vector.push_back(default_value.list().b(i)); + } + TFE_OpSetAttrBoolList(op, attr_name, b_vector.data(), b_size); + } + // Type + if (const int type_size = default_value.list().type_size()) { + absl::InlinedVector type_vector; + for (int i = 0; i < type_size; ++i) { + type_vector.push_back(default_value.list().type(i)); + } + TFE_OpSetAttrTypeList( + op, attr_name, + reinterpret_cast(type_vector.data()), + type_size); + } + + // Rest are not supported. + if (default_value.list().shape_size() > 0 || + default_value.list().func_size() > 0 || + default_value.list().tensor_size() > 0) { + TF_SetStatus( + status, TF_UNIMPLEMENTED, + tensorflow::strings::StrCat("Unable to get setfor default value: ", + default_value.DebugString()) + .data()); + } + } break; case tensorflow::AttrValue::kTensor: TF_FALLTHROUGH_INTENDED; case tensorflow::AttrValue::kPlaceholder: @@ -1612,19 +1660,12 @@ class CustomDeviceAPI : public tensorflow::CustomDevice { return status.status; } - tensorflow::Status Execute(tensorflow::EagerOperation* op, + tensorflow::Status Execute(const tensorflow::EagerOperation* op, tensorflow::TensorHandle** retvals, int* num_retvals) override { - std::vector inputs; - inputs.reserve(op->Inputs().size()); - for (int i = 0; i < op->Inputs().size(); ++i) { - op->Inputs()[i]->Ref(); - inputs.push_back(tensorflow::wrap(op->Inputs()[i])); - } std::vector outputs(*num_retvals); TF_Status status; - device_.execute(context_, inputs.size(), inputs.data(), op->Name().c_str(), - wrap(&op->Attrs()), num_retvals, outputs.data(), &status, + device_.execute(tensorflow::wrap(op), num_retvals, outputs.data(), &status, info_); if (status.status.ok()) { for (int i = 0; i < *num_retvals; ++i) { @@ -1634,10 +1675,6 @@ class CustomDeviceAPI : public tensorflow::CustomDevice { TFE_DeleteTensorHandle(outputs[i]); } } - - for (auto inp : inputs) { - TFE_DeleteTensorHandle(inp); - } return status.status; } diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index 5afe3047dd7..0afb69bb82c 100644 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -74,7 +74,7 @@ typedef enum TFE_ContextDevicePlacementPolicy { // Placement policy which silently copies int32 tensors but not other dtypes. TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3, } TFE_ContextDevicePlacementPolicy; -// LINT.ThenChange(//tensorflow/core/common_runtime/eager/context.h) +// LINT.ThenChange(//tensorflow/c/eager/immediate_execution_context.h) // Sets the default execution mode (sync/async). Note that this can be // overridden per thread using TFE_ContextSetExecutorForThread. @@ -248,22 +248,22 @@ typedef struct TFE_Op TFE_Op; TF_CAPI_EXPORT extern TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name, TF_Status* status); - TF_CAPI_EXPORT extern void TFE_DeleteOp(TFE_Op* op); +// Returns the op or function name `op` will execute. +// +// The returned string remains valid throughout the lifetime of 'op'. +TF_CAPI_EXPORT extern const char* TFE_OpGetName(const TFE_Op* op, + TF_Status* status); +TF_CAPI_EXPORT extern TFE_Context* TFE_OpGetContext(const TFE_Op* op, + TF_Status* status); + TF_CAPI_EXPORT extern void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status); // The returned string remains valid throughout the lifetime of 'op'. -TF_CAPI_EXPORT extern const char* TFE_OpGetDevice(TFE_Op* op, +TF_CAPI_EXPORT extern const char* TFE_OpGetDevice(const TFE_Op* op, TF_Status* status); -// When 'enable' is set to 1, and if TensorFlow library is built with XLA -// support, a subsequent TFE_Execute() call on `op` will run the op via XLA. -// -// If the library is not built with XLA support, this call would be a no-op. -TF_CAPI_EXPORT extern void TFE_OpSetXLACompilation(TFE_Op* op, - unsigned char enable); - TF_CAPI_EXPORT extern void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status); @@ -272,6 +272,23 @@ TF_CAPI_EXPORT extern void TFE_OpAddInputList(TFE_Op* op, int num_inputs, TF_Status* status); +// Fetches the current number of inputs attached to `op`. +// +// Does not use the operation's definition to determine how many inputs should +// be attached. It is intended for use with TFE_OpGetFlatInput to inspect an +// already-finalized operation. +// +// Note that TFE_OpGetFlatInputCount and TFE_OpGetFlatInput operate on a flat +// sequence of inputs, unlike TFE_OpGetInputLength (for getting the length of a +// particular named input list, which may only be part of the op's inputs). +TF_CAPI_EXPORT extern int TFE_OpGetFlatInputCount(const TFE_Op* op, + TF_Status* status); +// Returns a borrowed reference to one of `op`'s inputs. Use +// `TFE_TensorHandleCopySharingTensor` to make a new reference. +TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_OpGetFlatInput(const TFE_Op* op, + int index, + TF_Status* status); + TF_CAPI_EXPORT extern TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name, unsigned char* is_list, diff --git a/tensorflow/c/eager/c_api_debug.cc b/tensorflow/c/eager/c_api_debug.cc index dd55f05283b..b5721cdab0a 100644 --- a/tensorflow/c/eager/c_api_debug.cc +++ b/tensorflow/c/eager/c_api_debug.cc @@ -22,9 +22,6 @@ limitations under the License. #include "tensorflow/c/tf_status_internal.h" #include "tensorflow/core/common_runtime/eager/tensor_handle.h" #include "tensorflow/core/platform/status.h" -#ifdef TENSORFLOW_EAGER_USE_XLA -#include "tensorflow/compiler/jit/xla_device.h" -#endif // TENSORFLOW_EAGER_USE_XLA using tensorflow::string; @@ -64,87 +61,6 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo( return nullptr; } -#ifdef TENSORFLOW_EAGER_USE_XLA - auto* device = absl::get(handle->device()); - - // If tensor resides on an XLA device, use XLA device's PaddedShapeFn. - auto* xla_device = dynamic_cast(device); - if (xla_device != nullptr) { - tensorflow::XlaDevice::PaddedShapeFn shape_fn = - xla_device->metadata().padded_shape_fn(); - xla::Shape padded_shape; - status->status = shape_fn(*tensor, &padded_shape); - if (!status->status.ok()) { - return nullptr; - } - if (VLOG_IS_ON(3)) { - std::vector shape_to_log = - TensorShapeAsVector(*handle, &status->status); - if (!status->status.ok()) { - // Ignore the status here as we are simply logging. - status->status = tensorflow::Status::OK(); - } else { - VLOG(3) << "Fully padded shape of [" - << absl::StrJoin(shape_to_log, ", ") << "] is " - << padded_shape.DebugString(); - } - } - - if (padded_shape.IsTuple()) { - if (xla::ShapeUtil::TupleElementCount(padded_shape) != 2) { - // Currently, the only case of XlaTensor containing a tuple shape is to - // represent 64 bit ints, doubles, and complex numbers (we don't support - // 64bit complex numbers). - status->status = tensorflow::errors::InvalidArgument( - "XlaTensors should only contain tuples of size 2. Shape: ", - padded_shape.DebugString()); - return nullptr; - } - - // shape0 is not a const& because we will assign it to padded_shape below. - // It is illegal to assign a part of a message to itself. - xla::Shape shape0 = xla::ShapeUtil::GetTupleElementShape(padded_shape, 0); - const xla::Shape& shape1 = - xla::ShapeUtil::GetTupleElementShape(padded_shape, 1); - if (shape0.IsTuple() || shape1.IsTuple()) { - status->status = tensorflow::errors::InvalidArgument( - "XlaTensors should not contain nested tuples. Shape: ", - padded_shape.DebugString()); - return nullptr; - } - if (!xla::ShapeUtil::Equal(shape0, shape1)) { - status->status = tensorflow::errors::InvalidArgument( - "Subshapes of XlaTensors should be the same. Shape: ", - padded_shape.DebugString()); - return nullptr; - } - - // Since the only case we handle here are two equal subshapes, we - // simply return one of them. The caller will interpret it as this - // shape directly storing the 64bit types. This approximation is good - // enough for this API's debugging use case. - padded_shape = shape0; - } - - int rank = padded_shape.dimensions_size(); - std::vector dev_dims; - dev_dims.reserve(rank); - if (rank == 1) { - // Rank 1 tensors might not have padded_shape.layout.minor_to_major set, - dev_dims.push_back(padded_shape.dimensions(0)); - } else { - for (int i = rank - 1; i >= 0; --i) { - tensorflow::int64 dim_index = padded_shape.layout().minor_to_major(i); - dev_dims.push_back(padded_shape.dimensions(dim_index)); - } - } - status->status = tensorflow::Status::OK(); - return new TFE_TensorDebugInfo(dev_dims); - } -#endif // TENSORFLOW_EAGER_USE_XLA - - // If the tensor is not an XLA tensor, the device shape is - // the same as regular tensor shape. std::vector dev_dims = TensorShapeAsVector(*handle, &status->status); if (!status->status.ok()) { diff --git a/tensorflow/c/eager/c_api_distributed_test.cc b/tensorflow/c/eager/c_api_distributed_test.cc index cf35c2d634d..d21cadfd0cb 100644 --- a/tensorflow/c/eager/c_api_distributed_test.cc +++ b/tensorflow/c/eager/c_api_distributed_test.cc @@ -121,25 +121,6 @@ string AddVariablesFunction() { return def.SerializeAsString(); } -void VarIsInitialized(TFE_Context* ctx, TFE_TensorHandle* var_handle) { - TF_Status* status = TF_NewStatus(); - TFE_Op* op = TFE_NewOp(ctx, "VarIsInitializedOp", status); - EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); - TFE_OpAddInput(op, var_handle, status); - TFE_TensorHandle* is_initialized[1] = {nullptr}; - int num_retvals = 1; - TFE_Execute(op, &is_initialized[0], &num_retvals, status); - CHECK_EQ(1, num_retvals); - TF_Tensor* t = TFE_TensorHandleResolve(is_initialized[0], status); - bool initialized = false; - memcpy(&initialized, TF_TensorData(t), TF_TensorByteSize(t)); - EXPECT_EQ(initialized, true); - TF_DeleteTensor(t); - TFE_DeleteTensorHandle(is_initialized[0]); - TFE_DeleteOp(op); - delete status; -} - void TestFunctionWithPackedInput(const bool remote) { tensorflow::ServerDef server_def = GetServerDef(3); @@ -182,9 +163,8 @@ void TestFunctionWithPackedInput(const bool remote) { // Add a sync point in order to make sure that variables have been initialized // before the function execution starts. - // TODO(b/155789951): Remove once b/155789951 is fixed. - VarIsInitialized(ctx, h1); - VarIsInitialized(ctx, h2); + TFE_ContextAsyncWait(ctx, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); // Pack 3 variable handles into one TFE_TensorHandle. // When remote is false, function device is placed on task0. Handle types are @@ -396,6 +376,8 @@ TEST(CAPI, DistributedFunctionGraphPassOnlyOnce) { TFE_TensorHandle* var_handle = TestVariable(ctx, 2.0, dev2_name); EXPECT_NE(var_handle, nullptr); + TFE_ContextAsyncWait(ctx, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); const string function_def = VariableAddFunction(); TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(), @@ -517,6 +499,8 @@ void TestDistributedFunctionCancellation(bool inject_error) { TFE_TensorHandle* var_handle = TestVariable(ctx, 2.0, dev2_name); EXPECT_NE(var_handle, nullptr); + TFE_ContextAsyncWait(ctx, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); const string function_def = inject_error ? VariableAddFunctionWithGraphError() : VariableAddFunction(); @@ -561,7 +545,9 @@ TEST(CAPI, DistributedFunctionNoError) { TestDistributedFunctionCancellation(false); } -TEST(CAPI, DistributedFunctionCancelledOnError) { +// TODO(b/170399182): Update test once an alternative to using the function +// optimization hook is in place. +TEST(CAPI, DISABLED_DistributedFunctionCancelledOnError) { TestDistributedFunctionCancellation(true); } diff --git a/tensorflow/c/eager/c_api_experimental.cc b/tensorflow/c/eager/c_api_experimental.cc index 7390cf243be..1ef536a66f6 100644 --- a/tensorflow/c/eager/c_api_experimental.cc +++ b/tensorflow/c/eager/c_api_experimental.cc @@ -49,15 +49,11 @@ void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name, } void TFE_ContextEnableGraphCollection(TFE_Context* ctx) { - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - context->SetShouldStoreGraphs(true); + tensorflow::unwrap(ctx)->SetShouldStoreGraphs(true); } void TFE_ContextDisableGraphCollection(TFE_Context* ctx) { - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - context->SetShouldStoreGraphs(false); + tensorflow::unwrap(ctx)->SetShouldStoreGraphs(false); } uint64_t TFE_GetContextId(TFE_Context* ctx) { @@ -486,29 +482,6 @@ TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler2( static_cast(sampler->sampler->GetCell(label1, label2))); } -void TFE_ContextOptionsSetMirroringPolicy(TFE_ContextOptions* options, - TFE_ContextMirroringPolicy policy) { - options->mirroring_policy = policy; -} - -void TFE_ContextSetThreadLocalMirroringPolicy( - TFE_Context* ctx, TFE_ContextMirroringPolicy policy) { - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - context->SetThreadLocalMirroringPolicy( - static_cast(policy)); -} - -// Note: this function looks up a thread local policy. So it should be called in -// the appropriate client thread. In particular, in async mode, it may not be -// safe to call this function from the async EagerExecutor threads. -extern TFE_ContextMirroringPolicy TFE_ContextGetMirroringPolicy( - TFE_Context* ctx) { - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - return static_cast(context->GetMirroringPolicy()); -} - void TFE_ContextOptionsSetLazyRemoteInputsCopy(TFE_ContextOptions* options, bool lazy_copy) { options->lazy_remote_inputs_copy = lazy_copy; @@ -567,22 +540,16 @@ void TFE_ExecutorClearError(TFE_Executor* executor) { } void TFE_ContextSetExecutorForThread(TFE_Context* ctx, TFE_Executor* executor) { - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - context->SetExecutorForThread(executor->executor()); + tensorflow::unwrap(ctx)->SetExecutorForThread(executor->executor()); } TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) { - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - return new TFE_Executor(&context->Executor()); + return new TFE_Executor(&tensorflow::unwrap(ctx)->Executor()); } void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) { - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); auto address_space = tensorflow::DeviceNameUtils::AddressSpace( - context->HostCPU()->parsed_name()); + tensorflow::unwrap(ctx)->HostCPUParsedName()); auto str = tensorflow::DeviceNameUtils::ParsedNameToString(address_space); void* data = tensorflow::port::Malloc(str.length()); str.copy(static_cast(data), str.length(), 0); @@ -595,9 +562,7 @@ void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) { void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name, TF_Buffer* buf, TF_Status* status) { - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - auto* function_def = context->FindFunctionDef(function_name); + auto* function_def = tensorflow::unwrap(ctx)->FindFunctionDef(function_name); if (function_def == nullptr) { status->status = tensorflow::errors::NotFound( "Unable to find FunctionDef with name: ", function_name); @@ -666,14 +631,26 @@ TFE_TensorHandle* TFE_CreatePackedTensorHandle(TFE_Context* ctx, void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx, unsigned char enable, TF_Status* status) { - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - context->SetAllowSoftPlacement(enable); + tensorflow::unwrap(ctx)->SetAllowSoftPlacement(enable); } void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx, unsigned char enable, TF_Status* status) { - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - context->SetLogDevicePlacement(enable); + tensorflow::unwrap(ctx)->SetLogDevicePlacement(enable); +} + +const char* TFE_TensorHandleDeviceType(TFE_TensorHandle* h, TF_Status* status) { + if (h == nullptr) { + status->status = tensorflow::errors::InvalidArgument("Invalid handle"); + return nullptr; + } + return tensorflow::unwrap(h)->DeviceType(&status->status); +} + +int TFE_TensorHandleDeviceID(TFE_TensorHandle* h, TF_Status* status) { + if (h == nullptr) { + status->status = tensorflow::errors::InvalidArgument("Invalid handle"); + return -1; + } + return tensorflow::unwrap(h)->DeviceId(&status->status); } diff --git a/tensorflow/c/eager/c_api_experimental.h b/tensorflow/c/eager/c_api_experimental.h index 1af76c01154..d0739a5437d 100644 --- a/tensorflow/c/eager/c_api_experimental.h +++ b/tensorflow/c/eager/c_api_experimental.h @@ -265,33 +265,6 @@ TF_CAPI_EXPORT extern void TFE_MonitoringDeleteSampler2( TF_CAPI_EXPORT extern TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler2( TFE_MonitoringSampler2* sampler, const char* label1, const char* label2); -// LINT.IfChange -// Note: Keep in sync with internal copy of enum in eager/context.h. -typedef enum TFE_ContextMirroringPolicy { - // Do not maintain mirrors in a TensorHandle, instead make new TensorHandle - // copies with their own lifetime. - TFE_MIRRORING_NONE = 0, - // Mirroring any remote tensor handles, associating them with the lifetime of - // the local TensorHandle. - TFE_MIRRORING_ALL = 1, -} TFE_ContextMirroringPolicy; -// LINT.ThenChange(//tensorflow/core/common_runtime/eager/context.h) - -TF_CAPI_EXPORT extern void TFE_ContextOptionsSetMirroringPolicy( - TFE_ContextOptions*, TFE_ContextMirroringPolicy); - -// Sets a thread-local mirroring policy. After this call, other calls to -// TFE_Execute in the same thread will use the mirroring policy specified here -// instead of the mirroring policy used to construct the context. This has no -// effect on the mirroring policy used by other program threads. -TF_CAPI_EXPORT extern void TFE_ContextSetThreadLocalMirroringPolicy( - TFE_Context*, TFE_ContextMirroringPolicy); - -// Returns the mirroring policy to be used by this context in the current -// thread. -TF_CAPI_EXPORT extern TFE_ContextMirroringPolicy TFE_ContextGetMirroringPolicy( - TFE_Context*); - // Sets whether to copy the remote inputs of a function lazily. TF_CAPI_EXPORT extern void TFE_ContextOptionsSetLazyRemoteInputsCopy( TFE_ContextOptions*, bool lazy_copy); @@ -441,7 +414,7 @@ typedef struct TFE_OpAttrs TFE_OpAttrs; // Fetch a reference to `op`'s attributes. The returned reference is only valid // while `op` is alive. -const TFE_OpAttrs* TFE_OpGetAttrs(TFE_Op* op); +TF_CAPI_EXPORT extern const TFE_OpAttrs* TFE_OpGetAttrs(const TFE_Op* op); // Add attributes in `attrs` to `op`. // // Does not overwrite or update existing attributes, but adds new ones. @@ -462,7 +435,11 @@ TF_CAPI_EXPORT extern void TFE_OpSetAttrValueProto(const TFE_Op* op, size_t proto_len, TF_Status* status); -#define TFE_CUSTOM_DEVICE_VERSION 2 +// TODO(b/166642410): It would be nice, for custom devices and for other users, +// to have a non-string representation of devices (TF_Device) extracted from +// tensors/ops/etc. and usable in APIs like OpSetDevice/ResetOp/etc. + +#define TFE_CUSTOM_DEVICE_VERSION 3 // Struct to be filled in typedef struct TFE_CustomDevice { @@ -481,9 +458,16 @@ typedef struct TFE_CustomDevice { void* device_info); // Method to execute an operation. - void (*execute)(TFE_Context* context, int num_inputs, - TFE_TensorHandle** inputs, const char* operation_name, - const TFE_OpAttrs* attributes, int* num_outputs, + // + // Arguments provide enough information to reconstruct the original `TFE_Op`, + // or construct a transformed version, by inspecting the passed `op`. + // + // TFE_OpGetDevice(op) records the original placement of the operation. It may + // be an empty string if no device was explicitly requested, but will + // otherwise be the name of this custom device. Ops are placed onto a custom + // device if any of their inputs are on that custom device, but custom devices + // are free to set a bad status in order to require explicit placement. + void (*execute)(const TFE_Op* op, int* num_outputs, TFE_TensorHandle** outputs, TF_Status* s, void* device_info); // Method to delete a device. @@ -569,6 +553,14 @@ TF_CAPI_EXPORT void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx, unsigned char enable, TF_Status* status); +// Returns the device type of the operation that produced `h`. +TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceType( + TFE_TensorHandle* h, TF_Status* status); + +// Returns the device ID of the operation that produced `h`. +TF_CAPI_EXPORT extern int TFE_TensorHandleDeviceID(TFE_TensorHandle* h, + TF_Status* status); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/eager/c_api_experimental_test.cc b/tensorflow/c/eager/c_api_experimental_test.cc index a4d31417073..4fe83b5116d 100644 --- a/tensorflow/c/eager/c_api_experimental_test.cc +++ b/tensorflow/c/eager/c_api_experimental_test.cc @@ -316,86 +316,6 @@ TEST(CAPI, Function_ident_CPU) { TF_DeleteStatus(status); } -#ifdef TENSORFLOW_EAGER_USE_XLA -TEST(CAPI, Function_ident_XLA_CPU) { - // First create a simple identity function. - TF_Graph* function_graph = TF_NewGraph(); - TF_OperationDescription* arg_descr = - TF_NewOperation(function_graph, "Placeholder", "arg"); - TF_SetAttrType(arg_descr, "dtype", TF_INT32); - TF_Status* status = TF_NewStatus(); - TF_Operation* arg = TF_FinishOperation(arg_descr, status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - TF_OperationDescription* id_descr = - TF_NewOperation(function_graph, "Identity", "id"); - TF_SetAttrType(id_descr, "T", TF_INT32); - TF_AddInput(id_descr, {arg, 0}); - TF_Operation* id = TF_FinishOperation(id_descr, status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - TF_Output input{arg, 0}; - TF_Output output{id, 0}; - TF_Function* fn = - TF_GraphToFunction(function_graph, "ident", 0, 1, &id, 1, &input, 1, - &output, nullptr, nullptr, "test", status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - TF_DeleteGraph(function_graph); - TFE_ContextOptions* opts = TFE_NewContextOptions(); - TFE_Context* ctx = TFE_NewContext(opts, status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - TFE_DeleteContextOptions(opts); - TFE_ContextAddFunction(ctx, fn, status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - TF_DeleteFunction(fn); - - for (bool async : {false, true, false}) { - TFE_Executor* old_executor = TFE_ContextGetExecutorForThread(ctx); - TFE_Executor* executor = TFE_NewExecutor(async); - TFE_ContextSetExecutorForThread(ctx, executor); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK); - TF_Tensor* t = - TF_AllocateTensor(TF_INT32, nullptr, 0, 1 * sizeof(tensorflow::int32)); - *reinterpret_cast(TF_TensorData(t)) = 42; - TFE_TensorHandle* h = TFE_NewTensorHandle(t, status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - TF_DeleteTensor(t); - - TFE_Op* op = TFE_NewOp(ctx, "ident", status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - TFE_OpAddInput(op, h, status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - - // Now run it via XLA. - TFE_OpSetXLACompilation(op, true); - - std::vector result; - result.push_back(nullptr); - int num_retvals = 1; - TFE_Execute(op, result.data(), &num_retvals, status); - TFE_DeleteOp(op); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - ASSERT_EQ(num_retvals, 1); - - TF_Tensor* r = TFE_TensorHandleResolve(result[0], status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - EXPECT_EQ(*reinterpret_cast(TF_TensorData(r)), 42); - TFE_ContextSetExecutorForThread(ctx, old_executor); - TFE_ExecutorWaitForAllPendingNodes(executor, status); - ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_DeleteExecutor(executor); - TFE_DeleteExecutor(old_executor); - TFE_DeleteTensorHandle(h); - TF_DeleteTensor(r); - TFE_DeleteTensorHandle(result[0]); - } - TFE_ContextRemoveFunction(ctx, "ident", status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - TFE_DeleteContext(ctx); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - TF_DeleteStatus(status); -} -#endif // TENSORFLOW_EAGER_USE_XLA - void Executor_MatMul_CPU(bool async) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); @@ -491,5 +411,109 @@ TEST(CAPI, TensorHandleOnDeviceMemory) { TF_DeleteStatus(status); } +TEST(CAPI, TensorHandleNullptr) { + TFE_TensorHandle* h = nullptr; + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + + const char* device_type = TFE_TensorHandleDeviceType(h, status.get()); + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get())); + ASSERT_EQ(device_type, nullptr); + ASSERT_EQ("Invalid handle", string(TF_Message(status.get()))); + + TF_SetStatus(status.get(), TF_OK, ""); + + int device_id = TFE_TensorHandleDeviceID(h, status.get()); + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get())); + ASSERT_EQ(device_id, -1); + ASSERT_EQ("Invalid handle", string(TF_Message(status.get()))); +} + +TEST(CAPI, TensorHandleDevices) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status.get()); + TFE_DeleteContextOptions(opts); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + TFE_TensorHandle* hcpu = TestMatrixTensorHandle(ctx); + const char* device_type = TFE_TensorHandleDeviceType(hcpu, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + ASSERT_TRUE(absl::StrContains(device_type, "CPU")) << device_type; + int device_id = TFE_TensorHandleDeviceID(hcpu, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + ASSERT_EQ(0, device_id) << device_id; + + // Disable the test if no GPU is present. + string gpu_device_name; + if (GetDeviceName(ctx, &gpu_device_name, "GPU")) { + TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice( + hcpu, ctx, gpu_device_name.c_str(), status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + TFE_Op* shape_op = ShapeOp(ctx, hgpu); + TFE_OpSetDevice(shape_op, gpu_device_name.c_str(), status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + TFE_TensorHandle* retvals[1]; + int num_retvals = 1; + TFE_Execute(shape_op, &retvals[0], &num_retvals, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + device_type = TFE_TensorHandleDeviceType(retvals[0], status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + ASSERT_TRUE(absl::StrContains(device_type, "GPU")) << device_type; + + device_id = TFE_TensorHandleDeviceID(retvals[0], status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + ASSERT_EQ(0, device_id) << device_id; + + TFE_DeleteOp(shape_op); + TFE_DeleteTensorHandle(retvals[0]); + TFE_DeleteTensorHandle(hgpu); + } + + TFE_DeleteTensorHandle(hcpu); + TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx); + TFE_ExecutorWaitForAllPendingNodes(executor, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TFE_DeleteExecutor(executor); + TFE_DeleteContext(ctx); +} + +TEST(CAPI, TensorHandleDefaults) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status.get()); + TFE_DeleteContextOptions(opts); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + TFE_TensorHandle* h_default = TestMatrixTensorHandle(ctx); + const char* device_type = TFE_TensorHandleDeviceType(h_default, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + ASSERT_TRUE(absl::StrContains(device_type, "CPU")) << device_type; + int device_id = TFE_TensorHandleDeviceID(h_default, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + ASSERT_EQ(0, device_id) << device_id; + + TFE_TensorHandle* h_cpu = TFE_TensorHandleCopyToDevice( + h_default, ctx, "/device:CPU:0", status.get()); + const char* device_type_cpu = TFE_TensorHandleDeviceType(h_cpu, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + ASSERT_TRUE(absl::StrContains(device_type_cpu, "CPU")) << device_type_cpu; + int device_id_cpu = TFE_TensorHandleDeviceID(h_cpu, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + ASSERT_EQ(0, device_id_cpu) << device_id_cpu; + + TFE_DeleteTensorHandle(h_default); + TFE_DeleteTensorHandle(h_cpu); + TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx); + TFE_ExecutorWaitForAllPendingNodes(executor, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TFE_DeleteExecutor(executor); + TFE_DeleteContext(ctx); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index 4d9be0c2501..356476c2186 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -32,7 +32,6 @@ struct TFE_ContextOptions { bool async = false; TFE_ContextDevicePlacementPolicy device_placement_policy{ TFE_DEVICE_PLACEMENT_SILENT}; - TFE_ContextMirroringPolicy mirroring_policy{TFE_MIRRORING_NONE}; // If true, lazily copy the remote inputs of a function to the target devices. bool lazy_remote_inputs_copy = true; // If true, use TFRT backend diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 724176505ba..fd208c6770d 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include // clang-format off +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/platform/platform.h" // clang-format on @@ -876,89 +877,6 @@ TEST(CAPI, Execute_Min_CPU) { TF_DeleteStatus(status); } -#ifdef TENSORFLOW_EAGER_USE_XLA -void Execute_MatMul_XLA_CPU(bool async) { - TF_Status* status = TF_NewStatus(); - TFE_ContextOptions* opts = TFE_NewContextOptions(); - TFE_ContextOptionsSetAsync(opts, static_cast(async)); - TFE_Context* ctx = TFE_NewContext(opts, status); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_DeleteContextOptions(opts); - - TFE_TensorHandle* m = TestMatrixTensorHandle(ctx); - TFE_Op* matmul = MatMulOp(ctx, m, m); - - TFE_OpSetXLACompilation(matmul, true); - - TFE_TensorHandle* retvals[1] = {nullptr}; - int num_retvals = 1; - TFE_Execute(matmul, &retvals[0], &num_retvals, status); - // Running a primitive TF operator via XLA is not yet supported. - ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - - TFE_DeleteOp(matmul); - TFE_DeleteTensorHandle(m); - ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - - EXPECT_EQ(1, num_retvals); - - TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); - TFE_DeleteTensorHandle(retvals[0]); - ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - float product[4] = {0}; - EXPECT_EQ(sizeof(product), TF_TensorByteSize(t)); - memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t)); - TF_DeleteTensor(t); - EXPECT_EQ(7, product[0]); - EXPECT_EQ(10, product[1]); - EXPECT_EQ(15, product[2]); - EXPECT_EQ(22, product[3]); - TFE_DeleteContext(ctx); - TF_DeleteStatus(status); -} -TEST(CAPI, Execute_MatMul_XLA_CPU) { Execute_MatMul_XLA_CPU(false); } -TEST(CAPI, Execute_MatMul_XLA_CPUAsync) { Execute_MatMul_XLA_CPU(true); } - -void Execute_Min_XLA_CPU(bool async) { - TF_Status* status = TF_NewStatus(); - TFE_ContextOptions* opts = TFE_NewContextOptions(); - TFE_ContextOptionsSetAsync(opts, static_cast(async)); - TFE_Context* ctx = TFE_NewContext(opts, status); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_DeleteContextOptions(opts); - - TFE_TensorHandle* input = TestMatrixTensorHandle(ctx); - TFE_TensorHandle* axis = TestAxisTensorHandle(ctx); - TFE_Op* minOp = MinOp(ctx, input, axis); - - TFE_OpSetXLACompilation(minOp, true); - - TFE_TensorHandle* retvals[1] = {nullptr}; - int num_retvals = 1; - TFE_Execute(minOp, &retvals[0], &num_retvals, status); - EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_DeleteOp(minOp); - TFE_DeleteTensorHandle(input); - TFE_DeleteTensorHandle(axis); - ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - ASSERT_EQ(1, num_retvals); - - TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); - TFE_DeleteTensorHandle(retvals[0]); - ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - float output[2] = {0}; - EXPECT_EQ(sizeof(output), TF_TensorByteSize(t)); - memcpy(&output[0], TF_TensorData(t), TF_TensorByteSize(t)); - TF_DeleteTensor(t); - EXPECT_EQ(1, output[0]); - EXPECT_EQ(3, output[1]); - TFE_DeleteContext(ctx); - TF_DeleteStatus(status); -} -TEST(CAPI, Execute_Min_XLA_CPU) { Execute_Min_XLA_CPU(false); } -TEST(CAPI, Execute_Min_XLA_CPUAsync) { Execute_Min_XLA_CPU(true); } -#endif // TENSORFLOW_EAGER_USE_XLA - void ExecuteWithTracing(bool async) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); @@ -1274,6 +1192,68 @@ TEST(CAPI, StringAttributes) { TF_DeleteStatus(status); } +// Same test as above, expect use SetOpAttrValueScalar to set attrs. +TEST(CAPI, TestTFE_SetOpAttrs) { + // Test that TFE_OpSetAttrString doesn't hold on to the value after it + // returns. + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + std::vector dims(4, 1); + TFE_Op* op = TFE_NewOp(ctx, "AvgPool", status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TF_Tensor* tensor = + TF_AllocateTensor(TF_FLOAT, dims.data(), dims.size(), sizeof(float)); + float tensor_data[] = {1}; + memcpy(TF_TensorData(tensor), tensor_data, TF_TensorByteSize(tensor)); + TFE_TensorHandle* tensor_handle = TFE_NewTensorHandle(tensor, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(op, tensor_handle, status); + TF_DeleteTensor(tensor); + TFE_DeleteTensorHandle(tensor_handle); + + tensorflow::AttrValue i_list_values; + for (int i = 0; i < 4; ++i) { + i_list_values.mutable_list()->add_i(1); + } + SetOpAttrValueScalar(ctx, op, i_list_values, "ksize", status); + SetOpAttrValueScalar(ctx, op, i_list_values, "strides", status); + + tensorflow::AttrValue padding_value; + *padding_value.mutable_s() = "VALID"; + tensorflow::SetOpAttrValueScalar(ctx, op, padding_value, "padding", status); + + tensorflow::AttrValue data_format_value; + *data_format_value.mutable_s() = "NHWC"; + tensorflow::SetOpAttrValueScalar(ctx, op, data_format_value, "data_format", + status); + + TFE_OpSetAttrType(op, "T", TF_FLOAT); + + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_TensorHandle* retvals[1]; + int num_retvals = 1; + TFE_Execute(op, &retvals[0], &num_retvals, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + ASSERT_EQ(1, num_retvals); + + tensor = TFE_TensorHandleResolve(retvals[0], status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + EXPECT_EQ(4, TF_TensorByteSize(tensor)); + TF_DeleteTensor(tensor); + TFE_DeleteTensorHandle(retvals[0]); + + TFE_DeleteOp(op); + + TFE_DeleteContext(ctx); + TF_DeleteStatus(status); +} + TEST(CAPI, TestTFE_TensorHandleCopySharingUnderlyingTensorHandle) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); @@ -1620,4 +1600,91 @@ TEST(CAPI, TestTFE_OpAttrsSerialize) { TFE_DeleteContext(ctx); } +// Needs to work with a const TFE_Op since custom devices should not modify the +// op they are called with. +TFE_Op* CloneOp(const TFE_Op* other) { + TF_Status* status = TF_NewStatus(); + TFE_Context* context = TFE_OpGetContext(other, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + const char* op_name = TFE_OpGetName(other, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_Op* ret = TFE_NewOp(context, op_name, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + const char* device = TFE_OpGetDevice(other, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpSetDevice(ret, device, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddAttrs(ret, TFE_OpGetAttrs(other)); + int num_inputs = TFE_OpGetFlatInputCount(other, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + for (int input_index = 0; input_index < num_inputs; ++input_index) { + TFE_TensorHandle* input = TFE_OpGetFlatInput(other, input_index, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(ret, input, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + } + TF_DeleteStatus(status); + return ret; +} + +TEST(CAPI, TestTFE_OpRecreation) { + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + // Clone an op with attributes and a device set. + TFE_Op* original_var_op = TFE_NewOp(ctx, "VarHandleOp", status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpSetAttrType(original_var_op, "dtype", TF_INT64); + TFE_OpSetAttrShape(original_var_op, "shape", {}, 0, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + EXPECT_EQ("", std::string(TFE_OpGetDevice(original_var_op, status))); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpSetDevice(original_var_op, + "/job:localhost/replica:0/task:0/device:CPU:0", status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_Op* cloned = CloneOp(original_var_op); + + EXPECT_EQ("/job:localhost/replica:0/task:0/device:CPU:0", + std::string(TFE_OpGetDevice(cloned, status))); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + EXPECT_EQ("VarHandleOp", std::string(TFE_OpGetName(cloned, status))); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + int num_retvals = 1; + TFE_TensorHandle* ret; + TFE_Execute(cloned, &ret, &num_retvals, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteTensorHandle(ret); + + // Clone an op with inputs and no device set. + TFE_TensorHandle* input1 = TestMatrixTensorHandle(ctx); + TFE_TensorHandle* input2 = TestMatrixTensorHandle(ctx); + TFE_Op* original_identity = TFE_NewOp(ctx, "IdentityN", status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_TensorHandle* inputs[] = {input1, input2}; + TFE_OpAddInputList(original_identity, inputs, 2, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_Op* cloned_identity = CloneOp(original_identity); + EXPECT_EQ("", std::string(TFE_OpGetDevice(cloned_identity, status))); + TFE_TensorHandle* identity_ret[] = {nullptr, nullptr}; + num_retvals = 2; + TFE_Execute(cloned_identity, identity_ret, &num_retvals, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_DeleteTensorHandle(input1); + TFE_DeleteTensorHandle(input2); + TFE_DeleteTensorHandle(identity_ret[0]); + TFE_DeleteTensorHandle(identity_ret[1]); + + TFE_DeleteOp(cloned_identity); + TFE_DeleteOp(original_identity); + TFE_DeleteOp(original_var_op); + TFE_DeleteOp(cloned); + TF_DeleteStatus(status); + TFE_DeleteContext(ctx); +} + } // namespace diff --git a/tensorflow/c/eager/c_api_test_util.cc b/tensorflow/c/eager/c_api_test_util.cc index fd68866f502..6eb5b521c50 100644 --- a/tensorflow/c/eager/c_api_test_util.cc +++ b/tensorflow/c/eager/c_api_test_util.cc @@ -17,12 +17,16 @@ limitations under the License. #include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/c/tf_datatype.h" +#include "tensorflow/c/tf_tensor.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/strcat.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/tstring.h" #include "tensorflow/core/protobuf/cluster.pb.h" using tensorflow::string; +using tensorflow::tstring; TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, float value) { float data[] = {value}; @@ -36,6 +40,19 @@ TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, float value) { return th; } +TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, + const tensorflow::tstring& value) { + TF_Status* status = TF_NewStatus(); + TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_STRING, nullptr, 0, status); + tstring* data = static_cast(TF_TensorData(t)); + *data = value; + TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteTensor(t); + TF_DeleteStatus(status); + return th; +} + TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, int value) { int data[] = {value}; TF_Status* status = TF_NewStatus(); diff --git a/tensorflow/c/eager/c_api_test_util.h b/tensorflow/c/eager/c_api_test_util.h index 2f77ae5cf44..ad0c7c6340f 100644 --- a/tensorflow/c/eager/c_api_test_util.h +++ b/tensorflow/c/eager/c_api_test_util.h @@ -16,6 +16,7 @@ limitations under the License. #define TENSORFLOW_C_EAGER_C_API_TEST_UTIL_H_ #include "tensorflow/c/eager/c_api.h" +#include "tensorflow/core/platform/tstring.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/protobuf/tensorflow_server.pb.h" @@ -28,6 +29,10 @@ TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, int value); // Return a tensor handle containing a bool scalar TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, bool value); +// Return a tensor handle containing a tstring scalar +TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, + const tensorflow::tstring& value); + // Return a tensor handle containing a 2x2 matrix of doubles TFE_TensorHandle* DoubleTestMatrixTensorHandle(TFE_Context* ctx); diff --git a/tensorflow/c/eager/c_api_unified_experimental.cc b/tensorflow/c/eager/c_api_unified_experimental.cc index 8408f7ef60f..2d290df19ce 100644 --- a/tensorflow/c/eager/c_api_unified_experimental.cc +++ b/tensorflow/c/eager/c_api_unified_experimental.cc @@ -39,7 +39,7 @@ static FactoriesMap& GetFactories() { return *factories; } -static const char* default_factory = ""; +static tracing::FactoryFunction default_factory; void RegisterTracingEngineFactory(const string& name, FactoryFunction factory) { assert((!GetFactories().count(name)) || @@ -48,15 +48,15 @@ void RegisterTracingEngineFactory(const string& name, FactoryFunction factory) { GetFactories()[name] = factory; } -void SetDefaultTracingEngine(const char* name) { default_factory = name; } - -static TracingContext* CreateTracingExecutionContext(const char* fn_name, - TF_Status* s) { - auto entry = GetFactories().find(default_factory); - if (entry != GetFactories().end()) return entry->second(fn_name, s); +Status SetDefaultTracingEngine(const char* name) { + auto entry = GetFactories().find(name); + if (entry != GetFactories().end()) { + default_factory = GetFactories().find(name)->second; + return Status::OK(); + } string msg = absl::StrCat( - "No tracing engine factory has been registered with the key '", - default_factory, "' (available: "); + "No tracing engine factory has been registered with the key '", name, + "' (available: "); // Ensure deterministic (sorted) order in the error message std::set factories_sorted; for (const auto& factory : GetFactories()) @@ -68,7 +68,16 @@ static TracingContext* CreateTracingExecutionContext(const char* fn_name, } msg += ")"; - TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str()); + return errors::InvalidArgument(msg.c_str()); +} + +static TracingContext* CreateTracingExecutionContext(const char* fn_name, + TF_Status* s) { + if (default_factory) { + return default_factory(fn_name, s); + } + Set_TF_Status_from_Status( + s, errors::FailedPrecondition("default_factory is nullptr")); return nullptr; } @@ -99,8 +108,8 @@ using tensorflow::tracing::TracingContext; using tensorflow::tracing::TracingOperation; using tensorflow::tracing::TracingTensorHandle; -void TF_SetTracingImplementation(const char* name) { - SetDefaultTracingEngine(name); +void TF_SetTracingImplementation(const char* name, TF_Status* s) { + Set_TF_Status_from_Status(s, SetDefaultTracingEngine(name)); } // Creates a new TensorFlow function, it is an execution context attached to a diff --git a/tensorflow/c/eager/c_api_unified_experimental.h b/tensorflow/c/eager/c_api_unified_experimental.h index b66869b4290..d216b4e694b 100644 --- a/tensorflow/c/eager/c_api_unified_experimental.h +++ b/tensorflow/c/eager/c_api_unified_experimental.h @@ -52,7 +52,7 @@ typedef struct TF_AbstractFunction TF_AbstractFunction; // This allows the client to swap the implementation of the tracing engine. // Any future call to TF_CreateFunction will use the implementation defined // here. -void TF_SetTracingImplementation(const char* name); +void TF_SetTracingImplementation(const char* name, TF_Status*); // Creates a new TensorFlow function. A Function is an execution context, and as // such it can trace operations through TF_ExecuteOperation. After completing diff --git a/tensorflow/c/eager/c_api_unified_experimental_graph.cc b/tensorflow/c/eager/c_api_unified_experimental_graph.cc index 9d064039141..0e9d6c18157 100644 --- a/tensorflow/c/eager/c_api_unified_experimental_graph.cc +++ b/tensorflow/c/eager/c_api_unified_experimental_graph.cc @@ -365,9 +365,10 @@ class GraphContext : public TracingContext { } auto s = TF_NewStatus(); - func->func = TF_GraphToFunction( - graph_.get(), name_, 0, -1, nullptr, inputs_.size(), inputs_.data(), - graph_outputs.size(), graph_outputs.data(), nullptr, nullptr, name_, s); + func->func = TF_GraphToFunction(graph_.get(), name_.data(), 0, -1, nullptr, + inputs_.size(), inputs_.data(), + graph_outputs.size(), graph_outputs.data(), + nullptr, nullptr, name_.data(), s); TF_RETURN_IF_ERROR(StatusFromTF_Status(s)); TF_DeleteStatus(s); *f = func.release(); @@ -391,7 +392,7 @@ class GraphContext : public TracingContext { private: std::unique_ptr graph_; std::vector inputs_; - const char* name_; + string name_; }; static TracingContext* GraphTracingFactory(const char* name, TF_Status* s) { @@ -401,7 +402,7 @@ static TracingContext* GraphTracingFactory(const char* name, TF_Status* s) { // Register the tracing implemented in this file as the default tracing engine. static bool register_tracing = [] { RegisterTracingEngineFactory("graphdef", GraphTracingFactory); - SetDefaultTracingEngine("graphdef"); + SetDefaultTracingEngine("graphdef").IgnoreError(); return true; }(); diff --git a/tensorflow/c/eager/c_api_unified_experimental_internal.h b/tensorflow/c/eager/c_api_unified_experimental_internal.h index c00e04d98af..9433fe8f120 100644 --- a/tensorflow/c/eager/c_api_unified_experimental_internal.h +++ b/tensorflow/c/eager/c_api_unified_experimental_internal.h @@ -120,7 +120,7 @@ class TracingContext : public AbstractContext { }; typedef TracingContext* (*FactoryFunction)(const char* fn_name, TF_Status*); -void SetDefaultTracingEngine(const char* name); +Status SetDefaultTracingEngine(const char* name); void RegisterTracingEngineFactory(const ::tensorflow::string& name, FactoryFunction factory); } // namespace tracing diff --git a/tensorflow/c/eager/c_api_unified_experimental_test.cc b/tensorflow/c/eager/c_api_unified_experimental_test.cc index 7b3a497a0c5..432ddb4b2d4 100644 --- a/tensorflow/c/eager/c_api_unified_experimental_test.cc +++ b/tensorflow/c/eager/c_api_unified_experimental_test.cc @@ -22,10 +22,15 @@ limitations under the License. #include "tensorflow/c/eager/c_api_test_util.h" #include "tensorflow/c/tf_datatype.h" #include "tensorflow/c/tf_status.h" +#include "tensorflow/c/tf_status_helper.h" #include "tensorflow/c/tf_tensor.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/test.h" +using tensorflow::Status; using tensorflow::string; +using tensorflow::TF_StatusPtr; namespace tensorflow { namespace { @@ -37,7 +42,10 @@ class UnifiedCAPI : public ::testing::TestWithParam> { protected: void SetUp() override { - TF_SetTracingImplementation(std::get<0>(GetParam())); + TF_StatusPtr status(TF_NewStatus()); + TF_SetTracingImplementation(std::get<0>(GetParam()), status.get()); + Status s = StatusFromTF_Status(status.get()); + CHECK_EQ(errors::OK, s.code()) << s.error_message(); } }; diff --git a/tensorflow/c/eager/custom_device_test.cc b/tensorflow/c/eager/custom_device_test.cc index 1c078d4f42c..b058c79a17b 100644 --- a/tensorflow/c/eager/custom_device_test.cc +++ b/tensorflow/c/eager/custom_device_test.cc @@ -36,7 +36,8 @@ TEST(CUSTOM_DEVICE, RegisterSimpleDevice) { bool arrived = false; bool executed = false; const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0"; - RegisterLoggingDevice(context, name, &arrived, &executed, status.get()); + RegisterLoggingDevice(context, name, /*strict_scope_placement=*/true, + &arrived, &executed, status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); TFE_TensorHandle* hcpu = TestMatrixTensorHandle(context); ASSERT_FALSE(arrived); @@ -73,7 +74,8 @@ TEST(CUSTOM_DEVICE, ResetOperation) { bool executed = false; const char* custom_device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0"; - RegisterLoggingDevice(context.get(), custom_device_name, &arrived, &executed, + RegisterLoggingDevice(context.get(), custom_device_name, + /*strict_scope_placement=*/true, &arrived, &executed, status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); @@ -103,7 +105,8 @@ TEST(CUSTOM_DEVICE, MakeVariable) { bool arrived = false; bool executed = false; const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0"; - RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get()); + RegisterLoggingDevice(context.get(), name, /*strict_scope_placement=*/true, + &arrived, &executed, status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); // Create a variable handle placed on the custom device. @@ -187,7 +190,8 @@ TEST(CUSTOM_DEVICE, AccessVariableOnCustomDevice) { bool arrived = false; bool executed = false; const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0"; - RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get()); + RegisterLoggingDevice(context.get(), name, /*strict_scope_placement=*/false, + &arrived, &executed, status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); // Create a variable handle placed on the custom device. @@ -264,10 +268,12 @@ TEST(CUSTOM_DEVICE, InputBasedPlacement) { const char* custom1 = "/job:localhost/replica:0/task:0/device:CUSTOM:1"; bool arrived = false; bool executed = false; - RegisterLoggingDevice(context.get(), custom0, &arrived, &executed, + RegisterLoggingDevice(context.get(), custom0, + /*strict_scope_placement=*/false, &arrived, &executed, status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); - RegisterLoggingDevice(context.get(), custom1, &arrived, &executed, + RegisterLoggingDevice(context.get(), custom1, + /*strict_scope_placement=*/true, &arrived, &executed, status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); @@ -314,14 +320,34 @@ TEST(CUSTOM_DEVICE, InputBasedPlacement) { ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom0)); ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom1)); - // Custom device: mix of custom/physical fails. + // Custom device: mix of custom/physical places the op on the custom device. matmul.reset(MatMulOp(context.get(), hcustom0.get(), hcpu.get())); num_retvals = 1; + executed = false; TFE_Execute(matmul.get(), &retval, &num_retvals, status.get()); - ASSERT_NE(TF_OK, TF_GetCode(status.get())); - ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom0)); - ASSERT_TRUE( - absl::StrContains(TF_Message(status.get()), "[]")); // kVariantDeviceNull + EXPECT_TRUE(executed); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + TFE_DeleteTensorHandle(retval); + + // Explicit placement still forces the op onto the requested device + matmul.reset(MatMulOp(context.get(), hcustom0.get(), hcpu.get())); + TFE_OpSetDevice(matmul.get(), "/job:localhost/replica:0/task:0/device:CPU:0", + status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + num_retvals = 1; + executed = false; + TFE_Execute(matmul.get(), &retval, &num_retvals, status.get()); + EXPECT_FALSE(executed); + ASSERT_FALSE(TF_GetCode(status.get()) == TF_OK); + + // Custom devices can refuse to do type-based dispatch (as hcustom1 is + // configured to do) + matmul.reset(MatMulOp(context.get(), hcustom1.get(), hcpu.get())); + num_retvals = 1; + executed = false; + TFE_Execute(matmul.get(), &retval, &num_retvals, status.get()); + EXPECT_FALSE(executed); + ASSERT_FALSE(TF_GetCode(status.get()) == TF_OK); } TEST(CUSTOM_DEVICE, InvalidRegistrationError) { @@ -334,21 +360,24 @@ TEST(CUSTOM_DEVICE, InvalidRegistrationError) { ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); bool arrived = false; bool executed = false; - RegisterLoggingDevice(context.get(), "/device:CUSTOM:0", &arrived, &executed, + RegisterLoggingDevice(context.get(), "/device:CUSTOM:0", + /*strict_scope_placement=*/true, &arrived, &executed, status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT) << TF_Message(status.get()); const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0"; - RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get()); + RegisterLoggingDevice(context.get(), name, /*strict_scope_placement=*/true, + &arrived, &executed, status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); - RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get()); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS) - << TF_Message(status.get()); - - RegisterLoggingDevice(context.get(), - "/job:localhost/replica:0/task:0/device:CPU:0", + RegisterLoggingDevice(context.get(), name, /*strict_scope_placement=*/true, &arrived, &executed, status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS) << TF_Message(status.get()); + + RegisterLoggingDevice( + context.get(), "/job:localhost/replica:0/task:0/device:CPU:0", + /*strict_scope_placement=*/true, &arrived, &executed, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS) + << TF_Message(status.get()); } diff --git a/tensorflow/c/eager/custom_device_testutil.cc b/tensorflow/c/eager/custom_device_testutil.cc index 28de3665653..014abe38368 100644 --- a/tensorflow/c/eager/custom_device_testutil.cc +++ b/tensorflow/c/eager/custom_device_testutil.cc @@ -33,6 +33,9 @@ struct LoggingDevice { bool* arrived_flag; // Set to true whenever an operation is executed bool* executed_flag; + // If true, only explicit op placements are accepted. If false, uses + // type-based dispatch. + bool strict_scope_placement; }; struct LoggedTensor { @@ -84,18 +87,35 @@ TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_Context* context, return nullptr; } -void LoggingDeviceExecute(TFE_Context* context, int num_inputs, - TFE_TensorHandle** inputs, const char* operation_name, - const TFE_OpAttrs* attributes, int* num_outputs, +void LoggingDeviceExecute(const TFE_Op* original_op, int* num_outputs, TFE_TensorHandle** outputs, TF_Status* s, void* device_info) { + const char* requested_placement = TFE_OpGetDevice(original_op, s); + if (TF_GetCode(s) != TF_OK) return; + LoggingDevice* dev = reinterpret_cast(device_info); + if (dev->strict_scope_placement && *requested_placement == '\0') { + TF_SetStatus(s, TF_INTERNAL, + "Ops must be placed on the device explicitly, or their inputs " + "first copied to other devices."); + return; + } + TFE_Context* context = TFE_OpGetContext(original_op, s); + if (TF_GetCode(s) != TF_OK) return; + const char* operation_name = TFE_OpGetName(original_op, s); + if (TF_GetCode(s) != TF_OK) return; + const TFE_OpAttrs* attributes = TFE_OpGetAttrs(original_op); + TFE_Op* op(TFE_NewOp(context, operation_name, s)); if (TF_GetCode(s) != TF_OK) return; TFE_OpAddAttrs(op, attributes); TFE_OpSetDevice(op, dev->underlying_device.c_str(), s); + if (TF_GetCode(s) != TF_OK) return; + int num_inputs = TFE_OpGetFlatInputCount(original_op, s); + if (TF_GetCode(s) != TF_OK) return; for (int j = 0; j < num_inputs; ++j) { - TFE_TensorHandle* input = inputs[j]; + TFE_TensorHandle* input = TFE_OpGetFlatInput(original_op, j, s); + if (TF_GetCode(s) != TF_OK) return; const char* input_device = TFE_TensorHandleDeviceName(input, s); if (TF_GetCode(s) != TF_OK) return; if (dev->device_name == input_device) { @@ -131,8 +151,8 @@ void DeleteLoggingDevice(void* device_info) { } // namespace void RegisterLoggingDevice(TFE_Context* context, const char* name, - bool* arrived_flag, bool* executed_flag, - TF_Status* status) { + bool strict_scope_placement, bool* arrived_flag, + bool* executed_flag, TF_Status* status) { TFE_CustomDevice custom_device; custom_device.copy_tensor_to_device = &CopyToLoggingDevice; custom_device.copy_tensor_from_device = &CopyTensorFromLoggingDevice; @@ -143,6 +163,7 @@ void RegisterLoggingDevice(TFE_Context* context, const char* name, device->executed_flag = executed_flag; device->device_name = name; device->underlying_device = "/job:localhost/replica:0/task:0/device:CPU:0"; + device->strict_scope_placement = strict_scope_placement; TFE_RegisterCustomDevice(context, custom_device, name, device, status); } @@ -168,5 +189,6 @@ void AllocateLoggingDevice(const char* name, bool* arrived_flag, logging_device->device_name = name; logging_device->underlying_device = "/job:localhost/replica:0/task:0/device:CPU:0"; + logging_device->strict_scope_placement = true; *device_info = reinterpret_cast(logging_device); } diff --git a/tensorflow/c/eager/custom_device_testutil.h b/tensorflow/c/eager/custom_device_testutil.h index 509df7d3e3e..a7c60080adf 100644 --- a/tensorflow/c/eager/custom_device_testutil.h +++ b/tensorflow/c/eager/custom_device_testutil.h @@ -25,8 +25,8 @@ limitations under the License. #include "tensorflow/c/tf_status.h" void RegisterLoggingDevice(TFE_Context* context, const char* name, - bool* arrived_flag, bool* executed_flag, - TF_Status* status); + bool strict_scope_placement, bool* arrived_flag, + bool* executed_flag, TF_Status* status); void AllocateLoggingDevice(const char* name, bool* arrived_flag, bool* executed_flag, TFE_CustomDevice** device, void** device_info); diff --git a/tensorflow/c/eager/dlpack.cc b/tensorflow/c/eager/dlpack.cc index 45048bd6efb..df8e9ace997 100644 --- a/tensorflow/c/eager/dlpack.cc +++ b/tensorflow/c/eager/dlpack.cc @@ -109,7 +109,8 @@ DLDataType GetDlDataType(TF_DataType data_type, TF_Status* status) { // Gets DLPack's DLContext from eager tensor handle. DLContext GetDlContext(TFE_TensorHandle* h, TF_Status* status) { DLContext ctx; - const char* device_name = tensorflow::unwrap(h)->DeviceName(&status->status); + const char* device_name = + tensorflow::unwrap(h)->BackingDeviceName(&status->status); DeviceNameUtils::ParsedName parsed_name; tensorflow::DeviceNameUtils::ParseFullName(device_name, &parsed_name); std::string device_type = parsed_name.type; @@ -248,21 +249,36 @@ void TFE_CallDLManagedTensorDeleter(void* dlm_ptr) { } void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) { + auto tf_dlm_context = GetDlContext(h, status); + if (!status->status.ok()) { + return nullptr; + } + + auto* tf_dlm_data = TFE_TensorHandleDevicePointer(h, status); + if (!status->status.ok()) { + return nullptr; + } + const Tensor* tensor = GetTensorFromHandle(h, status); TF_DataType data_type = static_cast(tensor->dtype()); - TensorReference tensor_ref(*tensor); // This will call buf_->Ref() + auto tf_dlm_type = GetDlDataType(data_type, status); + if (!status->status.ok()) { + return nullptr; + } + + TensorReference tensor_ref(*tensor); // This will call buf_->Ref() auto* tf_dlm_tensor_ctx = new TfDlManagedTensorCtx(tensor_ref); tf_dlm_tensor_ctx->reference = tensor_ref; DLManagedTensor* dlm_tensor = &tf_dlm_tensor_ctx->tensor; dlm_tensor->manager_ctx = tf_dlm_tensor_ctx; dlm_tensor->deleter = &DLManagedTensorDeleter; - dlm_tensor->dl_tensor.ctx = GetDlContext(h, status); + dlm_tensor->dl_tensor.ctx = tf_dlm_context; int ndim = tensor->dims(); dlm_tensor->dl_tensor.ndim = ndim; - dlm_tensor->dl_tensor.data = TFE_TensorHandleDevicePointer(h, status); - dlm_tensor->dl_tensor.dtype = GetDlDataType(data_type, status); + dlm_tensor->dl_tensor.data = tf_dlm_data; + dlm_tensor->dl_tensor.dtype = tf_dlm_type; std::vector* shape_arr = &tf_dlm_tensor_ctx->shape; std::vector* stride_arr = &tf_dlm_tensor_ctx->strides; @@ -275,13 +291,14 @@ void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) { (*stride_arr)[i] = (*shape_arr)[i + 1] * (*stride_arr)[i + 1]; } - dlm_tensor->dl_tensor.shape = &(*shape_arr)[0]; + dlm_tensor->dl_tensor.shape = shape_arr->data(); // There are two ways to represent compact row-major data // 1) nullptr indicates tensor is compact and row-majored. // 2) fill in the strides array as the real case for compact row-major data. // Here we choose option 2, since some frameworks didn't handle the strides // argument properly. - dlm_tensor->dl_tensor.strides = &(*stride_arr)[0]; + dlm_tensor->dl_tensor.strides = stride_arr->data(); + dlm_tensor->dl_tensor.byte_offset = 0; // TF doesn't handle the strides and byte_offsets here return static_cast(dlm_tensor); diff --git a/tensorflow/c/eager/gradient_checker.cc b/tensorflow/c/eager/gradient_checker.cc new file mode 100644 index 00000000000..640edc7228a --- /dev/null +++ b/tensorflow/c/eager/gradient_checker.cc @@ -0,0 +1,201 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/c/eager/gradient_checker.h" + +#include + +#include "absl/types/span.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/c/eager/c_api_unified_experimental.h" +#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" +#include "tensorflow/c/eager/gradients.h" +#include "tensorflow/c/eager/gradients_internal.h" +#include "tensorflow/c/experimental/gradients/math_grad.h" +#include "tensorflow/c/experimental/gradients/nn_grad.h" +#include "tensorflow/c/experimental/ops/array_ops.h" +#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/c/tf_tensor.h" +#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" +#include "tensorflow/core/platform/errors.h" + +namespace tensorflow { +namespace gradients { + +using namespace std; + +// ================== Helper functions ================= + +// Fills data with values [start,end) with given step size. +void Range(vector* data, int start, int end, int step = 1) { + for (int i = start; i < end; i += step) { + (*data)[i] = i; + } +} + +// Returns AbstractTensorHandlePtr containing [0, ..., n-1]. +AbstractTensorHandlePtr GetRangeTensorHandleUtil(AbstractContext* ctx, int n) { + vector vals(n); + int64_t vals_shape[] = {n}; + Range(&vals, 0, n); + AbstractTensorHandlePtr r = + GetTensorHandleUtilInt(ctx, vals.data(), vals_shape, 1); + return r; +} + +// Fills out_dims with the dimensions of the given tensor. +void GetDims(const TF_Tensor* t, int64_t* out_dims) { + int num_dims = TF_NumDims(t); + for (int i = 0; i < num_dims; i++) { + out_dims[i] = TF_Dim(t, i); + } +} + +// Runs model as is if output is a scalar, +// else sums the output tensor before returning. +Status RunAndMaybeSum(AbstractContext* ctx, Model forward, + absl::Span inputs, + absl::Span outputs, + bool use_function) { + GradientRegistry registry; + std::vector model_outputs(1); + + // Run the model. + TF_RETURN_IF_ERROR(RunModel(forward, ctx, inputs, + absl::MakeSpan(model_outputs), use_function, + registry)); + AbstractTensorHandle* model_out = model_outputs[0]; + + TF_Tensor* model_out_tensor; + TF_RETURN_IF_ERROR(GetValue(model_out, &model_out_tensor)); + int num_dims_out = TF_NumDims(model_out_tensor); + + // If the output is a scalar, then return the scalar output + if (num_dims_out == 0) { + outputs[0] = model_out; + return Status::OK(); + } + + // Else, reduce sum the output to get a scalar + + // Will sum all dimensions, so get a Tensor containing [0,...,num_dims_out-1]. + AbstractTensorHandlePtr sum_dims = + GetRangeTensorHandleUtil(ctx, num_dims_out); + + // Reduce sum the output on all dimensions. + std::vector sum_inputs(2); + sum_inputs[0] = model_out; + sum_inputs[1] = sum_dims.get(); + + TF_RETURN_IF_ERROR(ops::Sum(ctx, absl::MakeSpan(sum_inputs), + absl::MakeSpan(model_outputs), "sum_output")); + outputs[0] = model_outputs[0]; + return Status::OK(); +} +// ========================= End Helper Functions============================== + +Status CalcNumericalGrad(AbstractContext* ctx, Model forward, + absl::Span inputs, + int input_index, bool use_function, + AbstractTensorHandle** numerical_grad) { + AbstractTensorHandle* theta = + inputs[input_index]; // parameter we are grad checking + + // Convert from AbstractTensor to TF_Tensor. + TF_Tensor* theta_tensor; + TF_RETURN_IF_ERROR(GetValue(theta, &theta_tensor)); + + // Get number of elements and fill data. + int num_elems = TF_TensorElementCount(theta_tensor); + vector theta_data(num_elems); + memcpy(theta_data.data(), TF_TensorData(theta_tensor), + TF_TensorByteSize(theta_tensor)); + + // Initialize space for the numerical gradient. + vector dtheta_approx(num_elems); + + // Get theta shape and store in theta_dims. + int num_dims = TF_NumDims(theta_tensor); + vector theta_dims(num_dims); + GetDims(theta_tensor, theta_dims.data()); + + // Initialize auxilary data structures. + vector thetaPlus_data(num_elems); + vector thetaMinus_data(num_elems); + std::vector f_outputs(1); + + // Numerical Grad Check + for (int i = 0; i < num_elems; i++) { + // Get relative epsilon value + float epsilon = + std::abs(theta_data[i] * 1e-4 + 1e-4); // add 1e-4 to prevent div by 0 + AbstractTensorHandlePtr two_eps = + GetScalarTensorHandleUtil(ctx, 2 * epsilon); + + // Initialize theta[i] + epsilon. + memcpy(thetaPlus_data.data(), TF_TensorData(theta_tensor), + TF_TensorByteSize(theta_tensor)); + thetaPlus_data[i] += epsilon; + AbstractTensorHandlePtr thetaPlus = GetTensorHandleUtilFloat( + ctx, thetaPlus_data.data(), theta_dims.data(), num_dims); + + // Initialize theta[i] - epsilon. + memcpy(&thetaMinus_data[0], TF_TensorData(theta_tensor), + TF_TensorByteSize(theta_tensor)); + thetaMinus_data[i] -= epsilon; + AbstractTensorHandlePtr thetaMinus = GetTensorHandleUtilFloat( + ctx, thetaMinus_data.data(), theta_dims.data(), num_dims); + + // Get f(theta + eps): + inputs[input_index] = thetaPlus.get(); + TF_RETURN_IF_ERROR(RunAndMaybeSum(ctx, forward, inputs, + absl::MakeSpan(f_outputs), use_function)); + AbstractTensorHandle* fPlus = f_outputs[0]; + + // Get f(theta - eps): + inputs[input_index] = thetaMinus.get(); + TF_RETURN_IF_ERROR(RunAndMaybeSum(ctx, forward, inputs, + absl::MakeSpan(f_outputs), use_function)); + AbstractTensorHandle* fMinus = f_outputs[0]; + + // Take Difference of both estimates: (f(theta + eps) - f(theta - eps)). + TF_RETURN_IF_ERROR( + ops::Sub(ctx, {fPlus, fMinus}, absl::MakeSpan(f_outputs), "sub_top")); + AbstractTensorHandle* fDiff = f_outputs[0]; + + // Calculate using the difference quotient definition: + // (f(theta + eps) - f(theta - eps)) / (2 * eps). + TF_RETURN_IF_ERROR(ops::DivNoNan(ctx, {fDiff, two_eps.get()}, + absl::MakeSpan(f_outputs), + "diff_quotient")); + AbstractTensorHandle* diff_quotient = f_outputs[0]; + + TF_Tensor* grad_tensor; + TF_RETURN_IF_ERROR(GetValue(diff_quotient, &grad_tensor)); + float grad_data[1]; + memcpy(&grad_data[0], TF_TensorData(grad_tensor), + TF_TensorByteSize(grad_tensor)); + + dtheta_approx[i] = grad_data[0]; + } + + // Populate *numerical_grad with the data from dtheta_approx. + TF_RETURN_IF_ERROR(TensorHandleWithDimsFloat( + ctx, dtheta_approx.data(), theta_dims.data(), num_dims, numerical_grad)); + return Status::OK(); +} + +} // namespace gradients +} // namespace tensorflow diff --git a/tensorflow/c/eager/gradient_checker.h b/tensorflow/c/eager/gradient_checker.h new file mode 100644 index 00000000000..8497f5af48e --- /dev/null +++ b/tensorflow/c/eager/gradient_checker.h @@ -0,0 +1,53 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "absl/types/span.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/c/eager/c_api_unified_experimental.h" +#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" +#include "tensorflow/c/eager/gradients.h" +#include "tensorflow/c/eager/gradients_internal.h" +#include "tensorflow/c/eager/gradients_util.h" +#include "tensorflow/c/experimental/gradients/math_grad.h" +#include "tensorflow/c/experimental/gradients/nn_grad.h" +#include "tensorflow/c/experimental/ops/array_ops.h" +#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/c/tf_tensor.h" +#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" +#include "tensorflow/core/platform/errors.h" + +namespace tensorflow { +namespace gradients { + +/* Returns numerical grad inside `dtheta_approx` given `forward` model and + * parameter specified by `input_index`. + * + * I.e. if y = and w = inputs[input_index], + * this will calculate dy/dw numerically. + * + * `use_function` indicates whether to use graph mode(true) or eager(false). + * + * `numerical_grad` is the pointer to the AbstractTensorHandle* which will + * hold the numerical gradient data at the end of the function. + */ +Status CalcNumericalGrad(AbstractContext* ctx, Model forward, + absl::Span inputs, + int input_index, bool use_function, + AbstractTensorHandle** numerical_grad); + +} // namespace gradients +} // namespace tensorflow diff --git a/tensorflow/c/eager/gradient_checker_test.cc b/tensorflow/c/eager/gradient_checker_test.cc new file mode 100644 index 00000000000..7a438085fb5 --- /dev/null +++ b/tensorflow/c/eager/gradient_checker_test.cc @@ -0,0 +1,265 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/c/eager/gradient_checker.h" + +#include + +#include "absl/types/span.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/c/eager/c_api_unified_experimental.h" +#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" +#include "tensorflow/c/eager/gradients.h" +#include "tensorflow/c/eager/gradients_internal.h" +#include "tensorflow/c/eager/gradients_util.h" +#include "tensorflow/c/eager/mnist_gradients_testutil.h" +#include "tensorflow/c/experimental/gradients/math_grad.h" +#include "tensorflow/c/experimental/gradients/nn_grad.h" +#include "tensorflow/c/experimental/ops/array_ops.h" +#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/c/tf_tensor.h" +#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace gradients { +namespace internal { +namespace { + +class GradientCheckerTest + : public ::testing::TestWithParam> { + protected: + void SetUp() override { + TF_StatusPtr status(TF_NewStatus()); + TF_SetTracingImplementation(std::get<0>(GetParam()), status.get()); + Status s = StatusFromTF_Status(status.get()); + CHECK_EQ(errors::OK, s.code()) << s.error_message(); + } +}; + +Status RegisterGradients(GradientRegistry* registry) { + TF_RETURN_IF_ERROR(registry->Register("MatMul", MatMulRegisterer)); + TF_RETURN_IF_ERROR( + registry->Register("SparseSoftmaxCrossEntropyWithLogits", + SparseSoftmaxCrossEntropyWithLogitsRegisterer)); + return Status::OK(); +} + +TEST_P(GradientCheckerTest, TestGradCheckMatMul) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + AbstractContextPtr ctx; + { + AbstractContext* ctx_raw = nullptr; + Status s = + BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ctx.reset(ctx_raw); + } + + float A_vals[] = {1.0f, 2.0f, 3.0f, 4.0f}; + int64_t A_dims[] = {2, 2}; + float B_vals[] = {.5f, -1.0f, 1.0f, 1.0f}; + int64_t B_dims[] = {2, 2}; + int num_dims = 2; + + AbstractTensorHandlePtr A = + GetTensorHandleUtilFloat(ctx.get(), A_vals, A_dims, num_dims); + AbstractTensorHandlePtr B = + GetTensorHandleUtilFloat(ctx.get(), B_vals, B_dims, num_dims); + + std::vector inputs; + inputs.push_back(A.get()); + inputs.push_back(B.get()); + + AbstractTensorHandle* grad_approx; + Status s = CalcNumericalGrad( + ctx.get(), MatMulModel, absl::MakeSpan(inputs), /*input_index=*/0, + /*use_function=*/!std::get<2>(GetParam()), &grad_approx); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + TF_Tensor* gt; + s = GetValue(grad_approx, >); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + float result_data[4] = {0}; + memcpy(&result_data[0], TF_TensorData(gt), TF_TensorByteSize(gt)); + + float expected_dA[4] = {-.5f, 2.0f, -.5f, 2.0f}; + float tolerance = 1e-2; + for (int j = 0; j < 4; j++) { + ASSERT_NEAR(expected_dA[j], result_data[j], tolerance); + } + TF_DeleteTensor(gt); +} + +TEST_P(GradientCheckerTest, TestGradCheckMul) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + + AbstractContextPtr ctx; + { + AbstractContext* ctx_raw = nullptr; + Status s = + BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ctx.reset(ctx_raw); + } + + AbstractTensorHandlePtr x; + { + AbstractTensorHandle* x_raw = nullptr; + Status s = ScalarTensorHandle(ctx.get(), 2.0f, &x_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + x.reset(x_raw); + } + + AbstractTensorHandlePtr y; + { + AbstractTensorHandle* y_raw = nullptr; + Status s = ScalarTensorHandle(ctx.get(), 7.0f, &y_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + y.reset(y_raw); + } + + // Will perform z = x*y. + // dz/dx = y + + std::vector inputs; + inputs.push_back(x.get()); + inputs.push_back(y.get()); + AbstractTensorHandle* g; + + Status s = CalcNumericalGrad(ctx.get(), MulModel, absl::MakeSpan(inputs), + /*input_index=*/0, + /*use_function=*/!std::get<2>(GetParam()), &g); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + TF_Tensor* gt; + s = GetValue(g, >); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + float result_data[1] = {0}; + memcpy(&result_data[0], TF_TensorData(gt), TF_TensorByteSize(gt)); + + ASSERT_NEAR(result_data[0], 7.0f, /*abs_error=*/1e-2); + TF_DeleteTensor(gt); +} + +TEST_P(GradientCheckerTest, TestGradCheckSoftmax) { + bool use_function = !std::get<2>(GetParam()); + if (use_function) { + // TODO(b/168850692): Enable this. + GTEST_SKIP() << "Can't take gradient of " + "SparseSoftmaxCrossEntropyWithLogits in tracing mode."; + } + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + + /** Test to show how to use this API with analytical gradients: + * + * We have `SoftmaxLossGradModel`, which is a wrapper for the + * Softmax analytical gradient found in c/experimental/nn_grads. + * + * We will use the GradientChecker by applying finite differences + * to the forward pass wrapped in `SoftmaxModel` and verify that + * both the analytical and numerical gradients are relatively + * close. + * + */ + + AbstractContextPtr ctx; + { + AbstractContext* ctx_raw = nullptr; + Status s = + BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ctx.reset(ctx_raw); + } + + // X = scores + float X_vals[] = {1.0f, 2.0f, 3.0f, -5.0f, -4.0f, -3.0f, 2.0f, 0.0f, 1.0f}; + int64_t X_dims[] = {3, 3}; + int num_dims = 2; + AbstractTensorHandlePtr X = + GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims); + + // y = labels + int y_vals[] = {1, 0, 1}; + int64_t y_dims[] = {3}; + num_dims = sizeof(y_dims) / sizeof(y_dims[0]); + AbstractTensorHandlePtr y = + GetTensorHandleUtilInt(ctx.get(), y_vals, y_dims, num_dims); + + GradientRegistry registry; + Status s = RegisterGradients(®istry); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + std::vector inputs; + inputs.push_back(X.get()); + inputs.push_back(y.get()); + + // Run analytical gradient and get its data. + std::vector outputs(2); + s = RunModel(SoftmaxLossGradModel, ctx.get(), absl::MakeSpan(inputs), + absl::MakeSpan(outputs), + /*use_function=*/!std::get<2>(GetParam()), registry); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + TF_Tensor* dX_tensor; + s = GetValue(outputs[0], &dX_tensor); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + float danalytical[9] = {0}; // Contains data from analytical gradient. + memcpy(&danalytical[0], TF_TensorData(dX_tensor), + TF_TensorByteSize(dX_tensor)); + + // Run numerical gradient approximation using the GradientChecker API. + AbstractTensorHandle* g; // Will contain numerical approximation data. + s = CalcNumericalGrad(ctx.get(), SoftmaxModel, absl::MakeSpan(inputs), + /*input_index=*/0, + /*use_function=*/!std::get<2>(GetParam()), &g); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + TF_Tensor* gt; + s = GetValue(g, >); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + float dnumerical[9] = {0}; + memcpy(&dnumerical[0], TF_TensorData(gt), TF_TensorByteSize(gt)); + + // Now compare the two implementations: + for (int j = 0; j < 9; j++) { + ASSERT_NEAR(dnumerical[j], danalytical[j], /*abs_error=*/1e-2); + } + + // Only Unref() first output as 2nd is nullptr grad for labels + outputs[0]->Unref(); + TF_DeleteTensor(dX_tensor); + TF_DeleteTensor(gt); +} + +#ifdef PLATFORM_GOOGLE +INSTANTIATE_TEST_SUITE_P( + UnifiedCAPI, GradientCheckerTest, + ::testing::Combine(::testing::Values("graphdef"), + /*tfrt*/ ::testing::Values(false), + /*executing_eagerly*/ ::testing::Values(true, false))); +#else +INSTANTIATE_TEST_SUITE_P( + UnifiedCAPI, GradientCheckerTest, + ::testing::Combine(::testing::Values("graphdef"), + /*tfrt*/ ::testing::Values(false), + /*executing_eagerly*/ ::testing::Values(true, false))); +#endif +} // namespace +} // namespace internal +} // namespace gradients +} // namespace tensorflow diff --git a/tensorflow/c/eager/gradients.cc b/tensorflow/c/eager/gradients.cc index 9bcd0d0fea0..58ffcf247cf 100644 --- a/tensorflow/c/eager/gradients.cc +++ b/tensorflow/c/eager/gradients.cc @@ -122,14 +122,12 @@ int64 ToId(AbstractTensorHandle* t) { return static_cast(reinterpret_cast(t)); } -TapeTensor::TapeTensor(AbstractTensorHandle* handle, AbstractContext* ctx) - : handle_(handle), ctx_(ctx) { +TapeTensor::TapeTensor(AbstractTensorHandle* handle) : handle_(handle) { handle_->Ref(); } TapeTensor::TapeTensor(const TapeTensor& other) { handle_ = other.handle_; handle_->Ref(); - ctx_ = other.ctx_; } TapeTensor::~TapeTensor() { handle_->Unref(); } @@ -138,33 +136,7 @@ tensorflow::int64 TapeTensor::GetID() const { return ToId(handle_); } tensorflow::DataType TapeTensor::GetDType() const { return handle_->DataType(); } - -AbstractTensorHandle* TapeTensor::OnesLike() const { - AbstractOperationPtr op(ctx_->CreateOperation()); - Status s = op->Reset("OnesLike", /*raw_device_name=*/nullptr); - if (!s.ok()) { - return nullptr; - } - if (isa(op.get())) { - s = dyn_cast(op.get())->SetOpName( - absl::StrCat("OnesLike", ToId(handle_)).c_str()); - if (!s.ok()) { - return nullptr; - } - } - s = op->AddInput(handle_); - if (!s.ok()) { - return nullptr; - } - int num_outputs = 1; - // TODO(srbs): Figure out who is in charge of releasing this. - std::vector outputs(num_outputs); - s = op->Execute(absl::Span(outputs), &num_outputs); - if (!s.ok()) { - return nullptr; - } - return outputs[0]; -} +AbstractTensorHandle* TapeTensor::GetHandle() const { return handle_; } AbstractTensorHandle* TapeTensor::ZerosLike() const { return nullptr; } @@ -219,6 +191,23 @@ Status TapeVSpace::CallBackwardFunction( &ctx, incoming_gradients, result); } +Status TapeVSpace::BuildOnesLike(const TapeTensor& t, + AbstractTensorHandle** result) const { + AbstractOperationPtr op(ctx_->CreateOperation()); + TF_RETURN_IF_ERROR(op->Reset("OnesLike", /*raw_device_name=*/nullptr)); + if (isa(op.get())) { + TF_RETURN_IF_ERROR(dyn_cast(op.get())->SetOpName( + absl::StrCat("OnesLike", ToId(t.GetHandle())).c_str())); + } + TF_RETURN_IF_ERROR(op->AddInput(t.GetHandle())); + int num_outputs = 1; + std::vector outputs(num_outputs); + TF_RETURN_IF_ERROR( + op->Execute(absl::Span(outputs), &num_outputs)); + *result = outputs[0]; + return Status::OK(); +} + // Looks up the ID of a Gradient. int64 TapeVSpace::TensorId(AbstractTensorHandle* tensor) const { return ToId(tensor); @@ -226,7 +215,7 @@ int64 TapeVSpace::TensorId(AbstractTensorHandle* tensor) const { // Converts a Gradient to a TapeTensor. TapeTensor TapeVSpace::TapeTensorFromGradient(AbstractTensorHandle* g) const { - return TapeTensor(g, ctx_); + return TapeTensor(g); } void TapeVSpace::MarkAsResult(AbstractTensorHandle* gradient) const {} @@ -242,6 +231,7 @@ namespace internal { Status Reset(AbstractOperation* op_, const char* op, const char* raw_device_name, ForwardOperation* forward_op_) { forward_op_->op_name = op; + forward_op_->attrs.Reset(op); return op_->Reset(op, raw_device_name); } Status AddInput(AbstractOperation* op_, AbstractTensorHandle* input, @@ -418,9 +408,14 @@ Status Execute(AbstractOperation* op_, AbstractContext* ctx, // TODO(srbs): Manage refcount of ForwardOperation's inputs/outputs. forward_op_->outputs.push_back(retvals[i]); } + // TODO(b/166669239): This is needed to support AttrBuilder::Get for string + // attributes. Number type attrs and DataType attrs work fine without this. + // Consider getting rid of this and making the behavior between number types + // and string consistent. + forward_op_->attrs.BuildNodeDef(); std::vector tape_tensors; for (auto t : retvals) { - tape_tensors.push_back(TapeTensor(t, ctx)); + tape_tensors.push_back(TapeTensor(t)); } tape->RecordOperation( op_->Name(), tape_tensors, input_ids, input_dtypes, diff --git a/tensorflow/c/eager/gradients.h b/tensorflow/c/eager/gradients.h index 04e11291404..f7d80cbeb34 100644 --- a/tensorflow/c/eager/gradients.h +++ b/tensorflow/c/eager/gradients.h @@ -80,7 +80,6 @@ struct ForwardOperation { std::vector inputs; std::vector outputs; AttrBuilder attrs; - AbstractContext* ctx; }; // Interface for building default zeros gradients for op outputs which are @@ -181,10 +180,6 @@ int64 ToId(AbstractTensorHandle* t); // allow us to trace the data dependencies between operations and hence compute // gradients. // -// This also implements `OnesLike` to create the default -// incoming gradients for tensors which do not already have an incoming -// gradient. -// // `ZerosLike` is not expected to be called and returns a nullptr. The creation // of default zeros grads is handled by the `DefaultGradientFunction` registered // for each op. @@ -193,20 +188,19 @@ int64 ToId(AbstractTensorHandle* t); // TODO(srbs): Should ZerosLike check-fail instead of returning nullptr? class TapeTensor { public: - TapeTensor(AbstractTensorHandle* handle, AbstractContext* ctx); + explicit TapeTensor(AbstractTensorHandle* handle); TapeTensor(const TapeTensor& other); ~TapeTensor(); tensorflow::int64 GetID() const; tensorflow::DataType GetDType() const; - AbstractTensorHandle* OnesLike() const; AbstractTensorHandle* ZerosLike() const; + AbstractTensorHandle* GetHandle() const; + private: AbstractTensorHandle* handle_; - // The context where OnesLike ops are to be created. - AbstractContext* ctx_; }; // Vector space for actually computing gradients. Implements methods for calling @@ -234,6 +228,10 @@ class TapeVSpace gtl::ArraySlice output_gradients, std::vector* result) const override; + // Builds a tensor filled with ones with the same shape and dtype as `t`. + Status BuildOnesLike(const TapeTensor& t, + AbstractTensorHandle** result) const override; + // Looks up the ID of a Gradient. int64 TensorId(AbstractTensorHandle* tensor) const override; diff --git a/tensorflow/c/eager/gradients_test.cc b/tensorflow/c/eager/gradients_test.cc index 80b1f157074..7fafd6eaa07 100644 --- a/tensorflow/c/eager/gradients_test.cc +++ b/tensorflow/c/eager/gradients_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/types/span.h" +#include "tensorflow/c/eager/abstract_context.h" #include "tensorflow/c/eager/abstract_tensor_handle.h" #include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/c/eager/c_api_test_util.h" @@ -26,7 +27,9 @@ limitations under the License. #include "tensorflow/c/eager/gradients_internal.h" #include "tensorflow/c/experimental/gradients/array_grad.h" #include "tensorflow/c/experimental/gradients/math_grad.h" +#include "tensorflow/c/experimental/gradients/tape/tape_context.h" #include "tensorflow/c/experimental/ops/array_ops.h" +#include "tensorflow/c/experimental/ops/math_ops.h" #include "tensorflow/c/tf_status_helper.h" #include "tensorflow/c/tf_tensor.h" #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" @@ -38,84 +41,32 @@ namespace gradients { namespace internal { namespace { using std::vector; +using tensorflow::TF_StatusPtr; using tracing::TracingOperation; class CppGradients : public ::testing::TestWithParam> { protected: void SetUp() override { - TF_SetTracingImplementation(std::get<0>(GetParam())); + TF_StatusPtr status(TF_NewStatus()); + TF_SetTracingImplementation(std::get<0>(GetParam()), status.get()); + Status s = StatusFromTF_Status(status.get()); + CHECK_EQ(errors::OK, s.code()) << s.error_message(); } }; Status RegisterGradients(GradientRegistry* registry) { - TF_RETURN_IF_ERROR(registry->Register("Add", AddRegisterer)); + // TODO(srbs): Rename ops::Add to ops::AddV2 and AddRegister to + // AddV2Registerer. + TF_RETURN_IF_ERROR(registry->Register("AddV2", AddRegisterer)); TF_RETURN_IF_ERROR(registry->Register("Exp", ExpRegisterer)); TF_RETURN_IF_ERROR(registry->Register("IdentityN", IdentityNRegisterer)); + TF_RETURN_IF_ERROR(registry->Register("Sqrt", SqrtRegisterer)); + TF_RETURN_IF_ERROR(registry->Register("Neg", NegRegisterer)); + TF_RETURN_IF_ERROR(registry->Register("Sub", SubRegisterer)); return Status::OK(); } -// Computes `inputs[0] + inputs[1]` and records it on the tape. -Status Add(AbstractContext* ctx, Tape* tape, - absl::Span inputs, - absl::Span outputs, - const GradientRegistry& registry) { - AbstractOperationPtr add_op(ctx->CreateOperation()); - ForwardOperation forward_op; - forward_op.ctx = ctx; - TF_RETURN_IF_ERROR( - Reset(add_op.get(), "Add", /*raw_device_name=*/nullptr, &forward_op)); - if (isa(add_op.get())) { - TF_RETURN_IF_ERROR( - dyn_cast(add_op.get())->SetOpName("my_add")); - } - TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[0], &forward_op)); - TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[1], &forward_op)); - int num_retvals = 1; - return Execute(add_op.get(), ctx, outputs, &num_retvals, &forward_op, tape, - registry); -} - -// Computes `exp(inputs[0])` and records it on the tape. -Status Exp(AbstractContext* ctx, Tape* tape, - absl::Span inputs, - absl::Span outputs, - const GradientRegistry& registry) { - AbstractOperationPtr exp_op(ctx->CreateOperation()); - ForwardOperation forward_op; - forward_op.ctx = ctx; - TF_RETURN_IF_ERROR( - Reset(exp_op.get(), "Exp", /*raw_device_name=*/nullptr, &forward_op)); - if (isa(exp_op.get())) { - TF_RETURN_IF_ERROR( - dyn_cast(exp_op.get())->SetOpName("my_exp")); - } - TF_RETURN_IF_ERROR(AddInput(exp_op.get(), inputs[0], &forward_op)); - int num_retvals = 1; - return Execute(exp_op.get(), ctx, outputs, &num_retvals, &forward_op, tape, - registry); -} - -// Computes `IdentityN(inputs)` and records it on the tape. -Status IdentityN(AbstractContext* ctx, Tape* tape, - absl::Span inputs, - absl::Span outputs, - const GradientRegistry& registry) { - AbstractOperationPtr identity_n_op(ctx->CreateOperation()); - ForwardOperation forward_op; - forward_op.ctx = ctx; - TF_RETURN_IF_ERROR(Reset(identity_n_op.get(), "IdentityN", - /*raw_device_name=*/nullptr, &forward_op)); - if (isa(identity_n_op.get())) { - TF_RETURN_IF_ERROR(dyn_cast(identity_n_op.get()) - ->SetOpName("my_identity_n")); - } - TF_RETURN_IF_ERROR(AddInputList(identity_n_op.get(), inputs, &forward_op)); - int num_retvals = outputs.size(); - return Execute(identity_n_op.get(), ctx, outputs, &num_retvals, &forward_op, - tape, registry); -} - // Computes // y = inputs[0] + inputs[1] // return grad(y, {inputs[0], inputs[1]}) @@ -128,8 +79,10 @@ Status AddGradModel(AbstractContext* ctx, tape->Watch(ToId(inputs[0])); // Watch x. tape->Watch(ToId(inputs[1])); // Watch y. std::vector add_outputs(1); - TF_RETURN_IF_ERROR(Add(ctx, tape, inputs, absl::MakeSpan(add_outputs), - registry)); // Compute x+y. + AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); + TF_RETURN_IF_ERROR(ops::Add(tape_ctx.get(), inputs, + absl::MakeSpan(add_outputs), + "Add")); // Compute x+y. std::unordered_map source_tensors_that_are_targets; @@ -160,8 +113,9 @@ Status ExpGradModel(AbstractContext* ctx, auto tape = new Tape(/*persistent=*/false); tape->Watch(ToId(inputs[0])); // Watch x. std::vector exp_outputs(1); - TF_RETURN_IF_ERROR(Exp(ctx, tape, inputs, absl::MakeSpan(exp_outputs), - registry)); // Compute x+y. + AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); + TF_RETURN_IF_ERROR( + ops::Exp(tape_ctx.get(), inputs, absl::MakeSpan(exp_outputs), "Exp")); std::unordered_map source_tensors_that_are_targets; @@ -179,6 +133,37 @@ Status ExpGradModel(AbstractContext* ctx, return Status::OK(); } +// Computes +// y = sqrt(inputs[0]) +// return grad(y, {inputs[0]}) +Status SqrtGradModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry) { + TapeVSpace vspace(ctx); + auto tape = new Tape(/*persistent=*/false); + tape->Watch(ToId(inputs[0])); // Watch x. + std::vector sqrt_outputs(1); + AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); + TF_RETURN_IF_ERROR( + ops::Sqrt(tape_ctx.get(), inputs, absl::MakeSpan(sqrt_outputs), "Sqrt")); + std::unordered_map + source_tensors_that_are_targets; + + std::vector out_grads; + TF_RETURN_IF_ERROR(tape->ComputeGradient( + vspace, /*target_tensor_ids=*/{ToId(sqrt_outputs[0])}, + /*source_tensor_ids=*/{ToId(inputs[0])}, source_tensors_that_are_targets, + /*output_gradients=*/{}, &out_grads, + /*build_default_zeros_grads=*/false)); + for (auto sqrt_output : sqrt_outputs) { + sqrt_output->Unref(); + } + outputs[0] = out_grads[0]; + delete tape; + return Status::OK(); +} + // Computes // ignored, y = IdentityN(inputs[0], inputs[1]) // return grad(y, {inputs[0], inputs[1]}) @@ -193,8 +178,9 @@ Status IdentityNGradModel(AbstractContext* ctx, tape->Watch(ToId(inputs[1])); vector identity_n_outputs(2); - TF_RETURN_IF_ERROR(IdentityN(ctx, tape, inputs, - absl::MakeSpan(identity_n_outputs), registry)); + AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); + TF_RETURN_IF_ERROR(ops::IdentityN( + tape_ctx.get(), inputs, absl::MakeSpan(identity_n_outputs), "IdentityN")); std::unordered_map source_tensors_that_are_targets; @@ -214,6 +200,73 @@ Status IdentityNGradModel(AbstractContext* ctx, return Status::OK(); } +// Computes +// y = - inputs[0] +// return grad(y, {inputs[0]}) +Status NegGradModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry) { + TapeVSpace vspace(ctx); + auto tape = new Tape(/*persistent=*/false); + tape->Watch(ToId(inputs[0])); + + std::vector neg_outputs(1); + AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); + TF_RETURN_IF_ERROR( + ops::Neg(tape_ctx.get(), inputs, absl::MakeSpan(neg_outputs), "Neg")); + + std::unordered_map + source_tensors_that_are_targets; + std::vector out_grads; + TF_RETURN_IF_ERROR(tape->ComputeGradient( + vspace, /*target_tensor_ids=*/{ToId(neg_outputs[0])}, + /*source_tensor_ids=*/{ToId(inputs[0])}, source_tensors_that_are_targets, + /*output_gradients=*/{}, &out_grads, + /*build_default_zeros_grads=*/false)); + for (auto neg_output : neg_outputs) { + neg_output->Unref(); + } + outputs[0] = out_grads[0]; + delete tape; + return Status::OK(); +} + +// Computes +// y = inputs[0] - inputs[1] +// return grad(y, {inputs[0], inputs[1]}) +Status SubGradModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry) { + TapeVSpace vspace(ctx); + auto tape = new Tape(/*persistent=*/false); + tape->Watch(ToId(inputs[0])); // Watch x. + tape->Watch(ToId(inputs[1])); // Watch y. + std::vector sub_outputs(1); + AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); + TF_RETURN_IF_ERROR(ops::Sub(tape_ctx.get(), inputs, + absl::MakeSpan(sub_outputs), + "Sub")); // Compute x-y. + std::unordered_map + source_tensors_that_are_targets; + + std::vector out_grads; + TF_RETURN_IF_ERROR(tape->ComputeGradient( + vspace, /*target_tensor_ids=*/{ToId(sub_outputs[0])}, + /*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])}, + source_tensors_that_are_targets, + /*output_gradients=*/{}, &out_grads, + /*build_default_zeros_grads=*/false)); + for (auto sub_output : sub_outputs) { + sub_output->Unref(); + } + outputs[0] = out_grads[0]; + outputs[1] = out_grads[1]; + delete tape; + return Status::OK(); +} + AbstractContext* BuildFunction(const char* fn_name) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); @@ -448,6 +501,50 @@ TEST_P(CppGradients, TestExpGrad) { result_tensor = nullptr; } +TEST_P(CppGradients, TestSqrtGrad) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + AbstractContextPtr ctx; + { + AbstractContext* ctx_raw = nullptr; + Status s = + BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ctx.reset(ctx_raw); + } + + AbstractTensorHandlePtr x; + { + AbstractTensorHandle* x_raw = nullptr; + Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + x.reset(x_raw); + } + + GradientRegistry registry; + Status s = RegisterGradients(®istry); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + // Pseudo-code: + // + // tape.watch(x) + // y = sqrt(x) + // outputs = tape.gradient(y, x) + std::vector outputs(1); + s = RunModel(SqrtGradModel, ctx.get(), {x.get()}, absl::MakeSpan(outputs), + /*use_function=*/!std::get<2>(GetParam()), registry); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + TF_Tensor* result_tensor; + s = getValue(outputs[0], &result_tensor); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + auto result_value = static_cast(TF_TensorData(result_tensor)); + EXPECT_NEAR(*result_value, 0.5, 0.001); + outputs[0]->Unref(); + TF_DeleteTensor(result_tensor); + result_tensor = nullptr; +} + TEST_P(CppGradients, TestIdentityNGrad) { // Pseudo-code: // @@ -507,6 +604,161 @@ TEST_P(CppGradients, TestIdentityNGrad) { result_tensor = nullptr; } +TEST_P(CppGradients, TestNegGrad) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + AbstractContextPtr ctx; + { + AbstractContext* ctx_raw = nullptr; + Status s = + BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ctx.reset(ctx_raw); + } + + AbstractTensorHandlePtr x; + { + AbstractTensorHandle* x_raw = nullptr; + Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &x_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + x.reset(x_raw); + } + + GradientRegistry registry; + Status s = RegisterGradients(®istry); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + // Pseudo-code: + // + // tape.watch(x) + // y = - x + // outputs = tape.gradient(y, x) + std::vector outputs(1); + s = RunModel(NegGradModel, ctx.get(), {x.get()}, absl::MakeSpan(outputs), + /*use_function=*/!std::get<2>(GetParam()), registry); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + TF_Tensor* result_tensor; + s = getValue(outputs[0], &result_tensor); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + auto result_value = static_cast(TF_TensorData(result_tensor)); + EXPECT_EQ(*result_value, -1.0); + outputs[0]->Unref(); + TF_DeleteTensor(result_tensor); + result_tensor = nullptr; +} + +TEST_P(CppGradients, TestSubGrad) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + AbstractContextPtr ctx; + { + AbstractContext* ctx_raw = nullptr; + Status s = + BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ctx.reset(ctx_raw); + } + + AbstractTensorHandlePtr x; + { + AbstractTensorHandle* x_raw = nullptr; + Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &x_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + x.reset(x_raw); + } + + AbstractTensorHandlePtr y; + { + AbstractTensorHandle* y_raw = nullptr; + Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &y_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + y.reset(y_raw); + } + + GradientRegistry registry; + Status s = RegisterGradients(®istry); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + // Pseudo-code: + // + // tape.watch(x) + // tape.watch(y) + // y = x - y + // outputs = tape.gradient(y, [x, y]) + std::vector outputs(2); + s = RunModel(SubGradModel, ctx.get(), {x.get(), y.get()}, + absl::MakeSpan(outputs), + /*use_function=*/!std::get<2>(GetParam()), registry); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + TF_Tensor* result_tensor; + s = getValue(outputs[0], &result_tensor); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + auto result_value = static_cast(TF_TensorData(result_tensor)); + EXPECT_EQ(*result_value, 1.0); + outputs[0]->Unref(); + TF_DeleteTensor(result_tensor); + result_tensor = nullptr; + + s = getValue(outputs[1], &result_tensor); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + result_value = static_cast(TF_TensorData(result_tensor)); + EXPECT_EQ(*result_value, -1.0); + outputs[1]->Unref(); + TF_DeleteTensor(result_tensor); +} + +TEST_P(CppGradients, TestSetAttrString) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + AbstractContextPtr ctx; + { + AbstractContext* ctx_raw = nullptr; + Status s = + BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ctx.reset(ctx_raw); + } + + AbstractTensorHandlePtr t; + { + AbstractTensorHandle* x_raw = nullptr; + Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + t.reset(x_raw); + } + + AbstractOperationPtr check_numerics_op(ctx->CreateOperation()); + ForwardOperation forward_op; + Status s = Reset(check_numerics_op.get(), "CheckNumerics", + /*raw_device_name=*/nullptr, &forward_op); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + if (isa(check_numerics_op.get())) { + s = dyn_cast(check_numerics_op.get()) + ->SetOpName("check_numerics"); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + } + s = AddInput(check_numerics_op.get(), t.get(), &forward_op); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + string message = "This is the way!"; + s = SetAttrString(check_numerics_op.get(), "message", message.data(), + message.length(), &forward_op); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + int num_retvals = 1; + std::vector outputs(1); + GradientRegistry registry; + std::unique_ptr tape(new Tape(/*persistent=*/false)); + s = Execute(check_numerics_op.get(), ctx.get(), absl::MakeSpan(outputs), + &num_retvals, &forward_op, tape.get(), registry); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + string read_message; + s = forward_op.attrs.Get("message", &read_message); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ASSERT_EQ(read_message, message); +} + // TODO(b/164171226): Enable this test with tfrt after AddInputList is // supported. It is needed for IdentityN. #ifdef PLATFORM_GOOGLE diff --git a/tensorflow/c/eager/gradients_util.cc b/tensorflow/c/eager/gradients_util.cc new file mode 100644 index 00000000000..e53faf4a3f3 --- /dev/null +++ b/tensorflow/c/eager/gradients_util.cc @@ -0,0 +1,317 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/c/eager/gradients_util.h" + +#include + +#include "absl/types/span.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/c/eager/c_api_unified_experimental.h" +#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" +#include "tensorflow/c/eager/gradients.h" +#include "tensorflow/c/eager/gradients_internal.h" +#include "tensorflow/c/experimental/ops/array_ops.h" +#include "tensorflow/c/experimental/ops/math_ops.h" +#include "tensorflow/c/experimental/ops/nn_ops.h" +#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/c/tf_tensor.h" +#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" +#include "tensorflow/core/platform/errors.h" + +namespace tensorflow { +namespace gradients { + +using namespace std; + +Status ScalarTensorHandleHelper(TFE_Context* ctx, float value, + TFE_TensorHandle** result) { + float data[] = {value}; + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TF_Tensor* t = + TFE_AllocateHostTensor(ctx, TF_FLOAT, nullptr, 0, status.get()); + memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); + TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status.get()); + *result = th; + TF_DeleteTensor(t); + return StatusFromTF_Status(status.get()); +} + +Status TensorHandleWithDimsFloatHelper(TFE_Context* ctx, float data[], + int64_t dims[], int num_dims, + TFE_TensorHandle** result) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TF_Tensor* t = + TFE_AllocateHostTensor(ctx, TF_FLOAT, &dims[0], num_dims, status.get()); + memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); + TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status.get()); + *result = th; + TF_DeleteTensor(t); + return StatusFromTF_Status(status.get()); +} + +Status TensorHandleWithDimsIntHelper(TFE_Context* ctx, int data[], + int64_t dims[], int num_dims, + TFE_TensorHandle** result) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TF_Tensor* t = + TFE_AllocateHostTensor(ctx, TF_INT32, &dims[0], num_dims, status.get()); + memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); + TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status.get()); + *result = th; + TF_DeleteTensor(t); + return StatusFromTF_Status(status.get()); +} + +// Get a scalar TensorHandle with given value +Status ScalarTensorHandle(AbstractContext* ctx, float value, + AbstractTensorHandle** tensor) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TFE_Context* eager_ctx = + TF_ExecutionContextGetTFEContext(wrap(ctx), status.get()); + TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get())); + TFE_TensorHandle* input_eager; + TF_RETURN_IF_ERROR(ScalarTensorHandleHelper(eager_ctx, value, &input_eager)); + *tensor = + unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get())); + return StatusFromTF_Status(status.get()); +} + +// Get a TensorHandle with given float values and dimensions +Status TensorHandleWithDimsFloat(AbstractContext* ctx, float data[], + int64_t dims[], int num_dims, + AbstractTensorHandle** tensor) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TFE_Context* eager_ctx = + TF_ExecutionContextGetTFEContext(wrap(ctx), status.get()); + TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get())); + TFE_TensorHandle* input_eager; + TF_RETURN_IF_ERROR(TensorHandleWithDimsFloatHelper(eager_ctx, data, dims, + num_dims, &input_eager)); + *tensor = + unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get())); + return StatusFromTF_Status(status.get()); +} + +// Get a TensorHandle with given int values and dimensions +Status TensorHandleWithDimsInt(AbstractContext* ctx, int data[], int64_t dims[], + int num_dims, AbstractTensorHandle** tensor) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TFE_Context* eager_ctx = + TF_ExecutionContextGetTFEContext(wrap(ctx), status.get()); + TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get())); + TFE_TensorHandle* input_eager; + TF_RETURN_IF_ERROR(TensorHandleWithDimsIntHelper(eager_ctx, data, dims, + num_dims, &input_eager)); + *tensor = + unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get())); + return StatusFromTF_Status(status.get()); +} + +Status GetValue(AbstractTensorHandle* t, TF_Tensor** result_tensor) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TFE_TensorHandle* result_t = + TF_AbstractTensorGetEagerTensor(wrap(t), status.get()); + TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get())); + *result_tensor = TFE_TensorHandleResolve(result_t, status.get()); + return StatusFromTF_Status(status.get()); +} + +AbstractTensorHandlePtr GetTensorHandleUtilFloat(AbstractContext* ctx, + float vals[], int64_t dims[], + int num_dims) { + AbstractTensorHandlePtr A; + AbstractTensorHandle* a_raw = nullptr; + Status s = TensorHandleWithDimsFloat(ctx, vals, dims, num_dims, &a_raw); + if (s.ok()) { + A.reset(a_raw); + } + return A; +} + +AbstractTensorHandlePtr GetTensorHandleUtilInt(AbstractContext* ctx, int vals[], + int64_t dims[], int num_dims) { + AbstractTensorHandlePtr A; + AbstractTensorHandle* a_raw = nullptr; + Status s = TensorHandleWithDimsInt(ctx, vals, dims, num_dims, &a_raw); + if (s.ok()) { + A.reset(a_raw); + } + return A; +} + +AbstractTensorHandlePtr GetScalarTensorHandleUtil(AbstractContext* ctx, + float val) { + AbstractTensorHandlePtr y; + AbstractTensorHandle* y_raw = nullptr; + Status s = ScalarTensorHandle(ctx, val, &y_raw); + if (s.ok()) { + y.reset(y_raw); + } + return y; +} + +Status UpdateWeights(AbstractContext* ctx, vector& grads, + vector& weights, + AbstractTensorHandle* learning_rate) { + /* Update weights one by one using gradient update rule: + * + * w -= lr*grad[w] + * + * NOTE: assuming learning rate is positive + */ + + int num_grads = grads.size(); + vector temp_outputs(1); + std::string update_str; + + // Negate learning rate for gradient descent + TF_RETURN_IF_ERROR(ops::Neg(ctx, {learning_rate}, + absl::MakeSpan(temp_outputs), + "neg_lr")); // Compute -lr + learning_rate = temp_outputs[0]; + + for (int i = 0; i < num_grads; i++) { + // Compute dW = -lr * grad(w[i]) + update_str = "update_mul_" + std::to_string(i); + TF_RETURN_IF_ERROR(ops::Mul(ctx, {learning_rate, grads[i]}, + absl::MakeSpan(temp_outputs), + update_str.c_str())); + + AbstractTensorHandle* dW = temp_outputs[0]; + + // Compute temp = weights[i] + dW + update_str = "update_add_" + std::to_string(i); + TF_RETURN_IF_ERROR(ops::Add(ctx, {weights[i], dW}, + absl::MakeSpan(temp_outputs), + update_str.c_str())); + + // Update the weights + weights[i] = temp_outputs[0]; + } + + return Status::OK(); +} + +AbstractContext* BuildFunction(const char* fn_name) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name, status.get()); + return unwrap(graph_ctx); +} + +Status CreateParamsForInputs(AbstractContext* ctx, + absl::Span inputs, + vector* params) { + tracing::TracingTensorHandle* handle = nullptr; + for (auto input : inputs) { + TF_RETURN_IF_ERROR(dyn_cast(ctx)->AddParameter( + input->DataType(), &handle)); + params->emplace_back(handle); + } + return Status::OK(); +} + +Status RunModel(Model model, AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, bool use_function, + const GradientRegistry& registry) { + if (use_function) { + const char* fn_name = "test_fn"; + std::unique_ptr scoped_func; + // Returning null tensors from a tf.function is not supported, so we keep + // track of indices in the model's outputs are nullptr in this set. + // The FunctionDef only outputs the non-null tensors. We later pad the + // function op outputs to have nullptrs at the `null_indices`. + absl::flat_hash_set null_indices; + { + AbstractContextPtr func_ctx(BuildFunction(fn_name)); + vector func_inputs; + func_inputs.reserve(inputs.size()); + TF_RETURN_IF_ERROR( + CreateParamsForInputs(func_ctx.get(), inputs, &func_inputs)); + vector model_outputs; + model_outputs.resize(outputs.size()); + TF_RETURN_IF_ERROR(model(func_ctx.get(), absl::MakeSpan(func_inputs), + absl::MakeSpan(model_outputs), registry)); + for (auto func_input : func_inputs) { + func_input->Unref(); + } + AbstractFunction* func = nullptr; + OutputList output_list; + output_list.expected_num_outputs = 0; + output_list.outputs.reserve(outputs.size()); + for (int i = 0; i < model_outputs.size(); i++) { + if (model_outputs[i]) { + output_list.outputs.emplace_back(model_outputs[i]); + output_list.expected_num_outputs += 1; + } else { + null_indices.insert(i); + } + } + TF_RETURN_IF_ERROR(dyn_cast(func_ctx.get()) + ->Finalize(&output_list, &func)); + scoped_func.reset(func); + for (auto output : output_list.outputs) { + output->Unref(); + } + TF_RETURN_IF_ERROR(ctx->RegisterFunction(func)); + } + + AbstractOperationPtr fn_op(ctx->CreateOperation()); + TF_RETURN_IF_ERROR(fn_op->Reset(fn_name, /*raw_device_name=*/nullptr)); + for (auto input : inputs) { + TF_RETURN_IF_ERROR(fn_op->AddInput(input)); + } + int retvals = outputs.size() - null_indices.size(); + vector fn_outputs(retvals); + TF_RETURN_IF_ERROR(fn_op->Execute( + absl::Span(fn_outputs.data(), fn_outputs.size()), + &retvals)); + int skipped_indices = 0; + for (int i = 0; i < outputs.size(); i++) { + if (!null_indices.contains(i)) { + outputs[i] = fn_outputs[i - skipped_indices]; + } else { + skipped_indices += 1; + } + } + TF_RETURN_IF_ERROR(ctx->RemoveFunction(fn_name)); + return Status::OK(); + } else { + return model(ctx, inputs, outputs, registry); + } +} + +Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetTfrt(opts, use_tfrt); + *ctx = unwrap(TF_NewEagerExecutionContext(opts, status.get())); + TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get())); + TFE_DeleteContextOptions(opts); + return Status::OK(); +} + +} // namespace gradients +} // namespace tensorflow \ No newline at end of file diff --git a/tensorflow/c/eager/gradients_util.h b/tensorflow/c/eager/gradients_util.h new file mode 100644 index 00000000000..cd0bbc0720d --- /dev/null +++ b/tensorflow/c/eager/gradients_util.h @@ -0,0 +1,88 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/types/span.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/c/eager/c_api_unified_experimental.h" +#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" +#include "tensorflow/c/eager/gradients.h" +#include "tensorflow/c/eager/gradients_internal.h" +#include "tensorflow/c/experimental/ops/array_ops.h" +#include "tensorflow/c/experimental/ops/math_ops.h" +#include "tensorflow/c/experimental/ops/nn_ops.h" +#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/c/tf_tensor.h" +#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace gradients { + +// Get a scalar TensorHandle with given value +Status ScalarTensorHandle(AbstractContext* ctx, float value, + AbstractTensorHandle** tensor); + +// Get a TensorHandle with given float values and dimensions +Status TensorHandleWithDimsFloat(AbstractContext* ctx, float data[], + int64_t dims[], int num_dims, + AbstractTensorHandle** tensor); + +// Get a TensorHandle with given int values and dimensions +Status TensorHandleWithDimsInt(AbstractContext* ctx, int data[], int64_t dims[], + int num_dims, AbstractTensorHandle** tensor); + +// Places data from `t` into *result_tensor. +Status GetValue(AbstractTensorHandle* t, TF_Tensor** result_tensor); + +// Util function that wraps an AbstractTensorHandle* with given data and dims. +AbstractTensorHandlePtr GetTensorHandleUtilFloat(AbstractContext* ctx, + float vals[], int64_t dims[], + int num_dims); + +// Util function that wraps an AbstractTensorHandle* with given data and dims. +AbstractTensorHandlePtr GetTensorHandleUtilInt(AbstractContext* ctx, int vals[], + int64_t dims[], int num_dims); + +// Util function that wraps an AbstractTensorHandle* with given data. +AbstractTensorHandlePtr GetScalarTensorHandleUtil(AbstractContext* ctx, + float val); + +// Performs gradient update for each weight using given learning rate. +Status UpdateWeights(AbstractContext* ctx, + std::vector& grads, + std::vector& weights, + AbstractTensorHandle* learning_rate); + +using Model = std::function, + absl::Span, const GradientRegistry&)>; + +// Runs given model in either graph or eager mode depending on value of +// use_function. +Status RunModel(Model model, AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, bool use_function, + const GradientRegistry& registry); + +// Builds context and returns inside *ctx. +Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx); + +} // namespace gradients +} // namespace tensorflow diff --git a/tensorflow/c/eager/immediate_execution_context.h b/tensorflow/c/eager/immediate_execution_context.h index 02a3320ef65..a3e3857b34b 100644 --- a/tensorflow/c/eager/immediate_execution_context.h +++ b/tensorflow/c/eager/immediate_execution_context.h @@ -29,8 +29,25 @@ limitations under the License. #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/tstring.h" +#include "tensorflow/core/util/device_name_utils.h" namespace tensorflow { +class EagerExecutor; + +// LINT.IfChange +// Note: Keep in sync with exported copy of enum in eager/c_api.h. +enum ContextDevicePlacementPolicy { + // Running operations with input tensors on the wrong device will fail. + DEVICE_PLACEMENT_EXPLICIT = 0, + // Copy the tensor to the right device but log a warning. + DEVICE_PLACEMENT_WARN = 1, + // Silently copy the tensor, which has a performance cost since the operation + // will be blocked till the copy completes. This is the default policy. + DEVICE_PLACEMENT_SILENT = 2, + // Placement policy which silently copies int32 tensors but not other dtypes. + DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3, +}; +// LINT.ThenChange(//tensorflow/c/eager/c_api.h) // Abstract interface to a context. // @@ -81,14 +98,6 @@ class ImmediateExecutionContext : public AbstractContext { // List attributes of available devices virtual void ListDevices(std::vector* devices) = 0; - virtual void ClearCachesAndThreadExecutors() = 0; - - // Initialize the step resource container for a training step. This is used - // in current TF runtime. For tfrt, it is used by fallback op handler. - virtual void StartStep() = 0; - // Destroy the step resource container for a training step. - virtual void EndStep() = 0; - // Block until all pending nodes are finished. virtual Status AsyncWait() = 0; @@ -97,11 +106,52 @@ class ImmediateExecutionContext : public AbstractContext { // already exists. virtual Status AddFunctionDef(const FunctionDef& fdef) = 0; + // Find and return a added function by its name. + virtual const FunctionDef* FindFunctionDef(const string& name) const = 0; + + // Return the ParsedName of Host CPU device. + virtual const DeviceNameUtils::ParsedName& HostCPUParsedName() const = 0; + + // Configure soft device placement policy. + virtual void SetAllowSoftPlacement(bool enable) = 0; + + // Configure device placement policy logging. + virtual void SetLogDevicePlacement(bool enable) = 0; + + // Sets the device placement policy for the current thread. + virtual void SetThreadLocalDevicePlacementPolicy( + ContextDevicePlacementPolicy policy) = 0; + // Returns the device placement policy for the current thread. + virtual ContextDevicePlacementPolicy GetDevicePlacementPolicy() const = 0; + // For LLVM style RTTI. static bool classof(const AbstractContext* ptr) { return ptr->getKind() == kEager || ptr->getKind() == kTfrt; } + //===--------------------------------------------------------------------===// + // Following are legacy features in TF Eager Runtime. + // TODO(tf-runtime): Figure out a way to deprecate following features after + // migrated to TFRT. + //===--------------------------------------------------------------------===// + // Clear pending nodes in thread executors and kernel caches. + virtual void ClearCachesAndThreadExecutors() = 0; + + // Initialize the step resource container for a training step. This is used + // in current TF runtime. For tfrt, it is used by fallback op handler. + virtual void StartStep() = 0; + // Destroy the step resource container for a training step. + virtual void EndStep() = 0; + + // Return the Eager Executor for current thread. Please note that Eager + // Executor is only used in current TF but not in TFRT. + virtual EagerExecutor& Executor() = 0; + // Update the Eager Executor for current thread. + virtual void SetExecutorForThread(EagerExecutor* executor) = 0; + + // Configure graph collection in RunMetadata. + virtual void SetShouldStoreGraphs(bool value) = 0; + protected: explicit ImmediateExecutionContext(AbstractContextKind kind) : AbstractContext(kind) {} diff --git a/tensorflow/c/eager/immediate_execution_operation.h b/tensorflow/c/eager/immediate_execution_operation.h index ee212b21a96..7b68ec2c9f4 100644 --- a/tensorflow/c/eager/immediate_execution_operation.h +++ b/tensorflow/c/eager/immediate_execution_operation.h @@ -47,9 +47,6 @@ class ImmediateExecutionOperation : public AbstractOperation { virtual Status InputLength(const char* input_name, int* length) = 0; virtual Status OutputLength(const char* output_name, int* length) = 0; - // Experimental - virtual Status SetUseXla(bool enable) = 0; - // Set stack trace to be used for potential async error reporting. virtual void SetStackTrace(AbstractStackTrace stack_trace) = 0; diff --git a/tensorflow/c/eager/immediate_execution_tensor_handle.h b/tensorflow/c/eager/immediate_execution_tensor_handle.h index 6d32d482747..bb6d471f12f 100644 --- a/tensorflow/c/eager/immediate_execution_tensor_handle.h +++ b/tensorflow/c/eager/immediate_execution_tensor_handle.h @@ -44,6 +44,10 @@ class ImmediateExecutionTensorHandle : public AbstractTensorHandle { virtual const char* DeviceName(Status* status) const = 0; // Returns the device where the tensor was placed. virtual const char* BackingDeviceName(Status* status) const = 0; + // Returns the device type which created the handle. + virtual const char* DeviceType(Status* status) const = 0; + // Returns the device ID which created the handle. + virtual int DeviceId(Status* status) const = 0; // Returns a tensor for the handle. If tensor is remote, it will be copied. virtual AbstractTensorInterface* Resolve(Status* status) = 0; diff --git a/tensorflow/c/eager/mnist_gradients_test.cc b/tensorflow/c/eager/mnist_gradients_test.cc index d6dd94806a7..4114f50a798 100644 --- a/tensorflow/c/eager/mnist_gradients_test.cc +++ b/tensorflow/c/eager/mnist_gradients_test.cc @@ -14,11 +14,11 @@ limitations under the License. #include "absl/types/span.h" #include "tensorflow/c/eager/abstract_tensor_handle.h" #include "tensorflow/c/eager/c_api_experimental.h" -#include "tensorflow/c/eager/c_api_test_util.h" #include "tensorflow/c/eager/c_api_unified_experimental.h" #include "tensorflow/c/eager/c_api_unified_experimental_internal.h" #include "tensorflow/c/eager/gradients.h" #include "tensorflow/c/eager/gradients_internal.h" +#include "tensorflow/c/eager/gradients_util.h" #include "tensorflow/c/eager/mnist_gradients_testutil.h" #include "tensorflow/c/experimental/gradients/math_grad.h" #include "tensorflow/c/experimental/gradients/nn_grad.h" @@ -33,12 +33,16 @@ namespace tensorflow { namespace gradients { namespace internal { namespace { +using tensorflow::TF_StatusPtr; class CppGradients : public ::testing::TestWithParam> { protected: void SetUp() override { - TF_SetTracingImplementation(std::get<0>(GetParam())); + TF_StatusPtr status(TF_NewStatus()); + TF_SetTracingImplementation(std::get<0>(GetParam()), status.get()); + Status s = StatusFromTF_Status(status.get()); + CHECK_EQ(errors::OK, s.code()) << s.error_message(); } }; @@ -49,89 +53,10 @@ Status RegisterGradients(GradientRegistry* registry) { TF_RETURN_IF_ERROR(registry->Register("Relu", ReluRegisterer)); TF_RETURN_IF_ERROR( registry->Register("SparseSoftmaxCrossEntropyWithLogits", - SparseSoftmaxCrossEntropyLossRegisterer)); + SparseSoftmaxCrossEntropyWithLogitsRegisterer)); return Status::OK(); } -// ========================= Test Util Functions ============================== - -// Get a scalar TensorHandle with given value -Status TestScalarTensorHandle(AbstractContext* ctx, float value, - AbstractTensorHandle** tensor) { - std::unique_ptr status( - TF_NewStatus(), TF_DeleteStatus); - TFE_Context* eager_ctx = - TF_ExecutionContextGetTFEContext(wrap(ctx), status.get()); - TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get())); - TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, value); - *tensor = - unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get())); - return Status::OK(); -} - -// Get a Matrix TensorHandle with given float values and dimensions -Status TestTensorHandleWithDimsFloat(AbstractContext* ctx, float data[], - int64_t dims[], int num_dims, - AbstractTensorHandle** tensor) { - std::unique_ptr status( - TF_NewStatus(), TF_DeleteStatus); - TFE_Context* eager_ctx = - TF_ExecutionContextGetTFEContext(wrap(ctx), status.get()); - TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get())); - TFE_TensorHandle* input_eager = - TestTensorHandleWithDimsFloat(eager_ctx, data, dims, num_dims); - *tensor = - unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get())); - return Status::OK(); -} - -// Get a Matrix TensorHandle with given int values and dimensions -Status TestTensorHandleWithDimsInt(AbstractContext* ctx, int data[], - int64_t dims[], int num_dims, - AbstractTensorHandle** tensor) { - std::unique_ptr status( - TF_NewStatus(), TF_DeleteStatus); - TFE_Context* eager_ctx = - TF_ExecutionContextGetTFEContext(wrap(ctx), status.get()); - TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get())); - TFE_TensorHandle* input_eager = - TestTensorHandleWithDimsInt(eager_ctx, data, dims, num_dims); - *tensor = - unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get())); - return Status::OK(); -} - -Status GetValue(AbstractTensorHandle* t, TF_Tensor** result_tensor) { - std::unique_ptr status( - TF_NewStatus(), TF_DeleteStatus); - TFE_TensorHandle* result_t = - TF_AbstractTensorGetEagerTensor(wrap(t), status.get()); - TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get())); - *result_tensor = TFE_TensorHandleResolve(result_t, status.get()); - return Status::OK(); -} - -AbstractTensorHandlePtr GetTensorHandleUtilFloat(AbstractContext* ctx, - float vals[], int64_t dims[], - int num_dims) { - AbstractTensorHandlePtr A; - AbstractTensorHandle* a_raw = nullptr; - Status s = TestTensorHandleWithDimsFloat(ctx, vals, dims, num_dims, &a_raw); - A.reset(a_raw); - return A; -} - -AbstractTensorHandlePtr GetTensorHandleUtilInt(AbstractContext* ctx, int vals[], - int64_t dims[], int num_dims) { - AbstractTensorHandlePtr A; - AbstractTensorHandle* a_raw = nullptr; - Status s = TestTensorHandleWithDimsInt(ctx, vals, dims, num_dims, &a_raw); - A.reset(a_raw); - return A; -} - -// =========================== Start Tests ================================ - TEST_P(CppGradients, TestMatMulGrad) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); @@ -465,6 +390,12 @@ TEST_P(CppGradients, TestReluGrad) { } TEST_P(CppGradients, TestSoftmaxLossGrad) { + bool use_function = !std::get<2>(GetParam()); + if (use_function) { + // TODO(b/168850692): Enable this. + GTEST_SKIP() << "Can't take gradient of " + "SparseSoftmaxCrossEntropyWithLogits in tracing mode."; + } std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); @@ -533,6 +464,12 @@ TEST_P(CppGradients, TestSoftmaxLossGrad) { } TEST_P(CppGradients, TestMNISTGrad) { + bool use_function = !std::get<2>(GetParam()); + if (use_function) { + // TODO(b/168850692): Enable this. + GTEST_SKIP() << "Can't take gradient of " + "SparseSoftmaxCrossEntropyWithLogits in tracing mode."; + } std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); AbstractContextPtr ctx; @@ -603,7 +540,6 @@ TEST_P(CppGradients, TestMNISTGrad) { TF_TensorByteSize(dW1_tensor)); float expected_dW1[4] = {0.0f, 3.2f, 0.0f, 4.8f}; - ; // dLoss for (int j = 0; j < 4; j++) { ASSERT_NEAR(result_data[j], expected_dW1[j], tolerance); } @@ -643,7 +579,7 @@ TEST_P(CppGradients, TestScalarMul) { AbstractTensorHandlePtr eta; { AbstractTensorHandle* x_raw = nullptr; - Status s = TestScalarTensorHandle(ctx.get(), 1.5f, &x_raw); + Status s = ScalarTensorHandle(ctx.get(), 1.5f, &x_raw); ASSERT_EQ(errors::OK, s.code()) << s.error_message(); eta.reset(x_raw); } @@ -681,6 +617,12 @@ TEST_P(CppGradients, TestScalarMul) { } TEST_P(CppGradients, TestMNIST_Training) { + bool use_function = !std::get<2>(GetParam()); + if (use_function) { + // TODO(b/168850692): Enable this. + GTEST_SKIP() << "Can't take gradient of " + "SparseSoftmaxCrossEntropyWithLogits in tracing mode."; + } std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); @@ -733,7 +675,7 @@ TEST_P(CppGradients, TestMNIST_Training) { // Set learning rate to be 1e-1 AbstractTensorHandle* learning_rate = nullptr; - s = TestScalarTensorHandle(ctx.get(), 1e-1, &learning_rate); + s = ScalarTensorHandle(ctx.get(), 1e-1, &learning_rate); ASSERT_EQ(errors::OK, s.code()) << s.error_message(); // Train @@ -765,13 +707,13 @@ TEST_P(CppGradients, TestMNIST_Training) { #ifdef PLATFORM_GOOGLE INSTANTIATE_TEST_SUITE_P( UnifiedCAPI, CppGradients, - ::testing::Combine(::testing::Values("graphdef"), + ::testing::Combine(::testing::Values("graphdef", "mlir"), /*tfrt*/ ::testing::Values(false), /*executing_eagerly*/ ::testing::Values(true, false))); #else INSTANTIATE_TEST_SUITE_P( UnifiedCAPI, CppGradients, - ::testing::Combine(::testing::Values("graphdef"), + ::testing::Combine(::testing::Values("graphdef", "mlir"), /*tfrt*/ ::testing::Values(false), /*executing_eagerly*/ ::testing::Values(true, false))); #endif diff --git a/tensorflow/c/eager/mnist_gradients_testutil.cc b/tensorflow/c/eager/mnist_gradients_testutil.cc index 4b2c87c678d..6688d9d4e75 100644 --- a/tensorflow/c/eager/mnist_gradients_testutil.cc +++ b/tensorflow/c/eager/mnist_gradients_testutil.cc @@ -24,136 +24,19 @@ limitations under the License. #include "tensorflow/c/eager/c_api_unified_experimental_internal.h" #include "tensorflow/c/eager/gradients.h" #include "tensorflow/c/eager/gradients_internal.h" +#include "tensorflow/c/eager/gradients_util.h" +#include "tensorflow/c/experimental/gradients/tape/tape_context.h" #include "tensorflow/c/experimental/ops/array_ops.h" #include "tensorflow/c/experimental/ops/math_ops.h" #include "tensorflow/c/experimental/ops/nn_ops.h" -#include "tensorflow/c/tf_status_helper.h" -#include "tensorflow/c/tf_tensor.h" #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" + +namespace tensorflow { +namespace gradients { +namespace internal { + using std::vector; -using tracing::TracingOperation; - -// ========================== Tape Ops ============================== - -// Computes `inputs[0] + inputs[1]` and records it on the tape. -Status Add(AbstractContext* ctx, Tape* tape, - absl::Span inputs, - absl::Span outputs, - const GradientRegistry& registry) { - AbstractOperationPtr add_op(ctx->CreateOperation()); - ForwardOperation forward_op; - forward_op.ctx = ctx; - TF_RETURN_IF_ERROR( - Reset(add_op.get(), "Add", /*raw_device_name=*/nullptr, &forward_op)); - if (isa(add_op.get())) { - TF_RETURN_IF_ERROR( - dyn_cast(add_op.get())->SetOpName("my_add")); - } - TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[0], &forward_op)); - TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[1], &forward_op)); - int num_retvals = 1; - return Execute(add_op.get(), ctx, outputs, &num_retvals, &forward_op, tape, - registry); -} - -// Computes `inputs[0] * inputs[1]` for matrices and records it on the tape. -Status MatMul(AbstractContext* ctx, Tape* tape, - absl::Span inputs, - absl::Span outputs, const char* name, - bool transpose_a, bool transpose_b, - const GradientRegistry& registry) { - AbstractOperationPtr matmul_op(ctx->CreateOperation()); - ForwardOperation forward_op; - forward_op.ctx = ctx; - TF_RETURN_IF_ERROR(Reset(matmul_op.get(), "MatMul", - /*raw_device_name=*/nullptr, &forward_op)); - if (isa(matmul_op.get())) { - TF_RETURN_IF_ERROR( - dyn_cast(matmul_op.get())->SetOpName(name)); - } - - TF_RETURN_IF_ERROR(AddInput(matmul_op.get(), inputs[0], &forward_op)); - TF_RETURN_IF_ERROR(AddInput(matmul_op.get(), inputs[1], &forward_op)); - TF_RETURN_IF_ERROR(tensorflow::gradients::internal::SetAttrBool( - matmul_op.get(), "transpose_a", transpose_a, &forward_op)); - TF_RETURN_IF_ERROR(tensorflow::gradients::internal::SetAttrBool( - matmul_op.get(), "transpose_b", transpose_b, &forward_op)); - - int num_retvals = 1; - return Execute(matmul_op.get(), ctx, outputs, &num_retvals, &forward_op, tape, - registry); -} - -Status Mul(AbstractContext* ctx, Tape* tape, - absl::Span inputs, - absl::Span outputs, const char* name, - const GradientRegistry& registry) { - AbstractOperationPtr mul_op(ctx->CreateOperation()); - ForwardOperation forward_op; - forward_op.ctx = ctx; - TF_RETURN_IF_ERROR( - Reset(mul_op.get(), "Mul", /*raw_device_name=*/nullptr, &forward_op)); - if (isa(mul_op.get())) { - TF_RETURN_IF_ERROR( - dyn_cast(mul_op.get())->SetOpName(name)); - } - - TF_RETURN_IF_ERROR(AddInput(mul_op.get(), inputs[0], &forward_op)); - TF_RETURN_IF_ERROR(AddInput(mul_op.get(), inputs[1], &forward_op)); - - int num_retvals = 1; - return Execute(mul_op.get(), ctx, outputs, &num_retvals, &forward_op, tape, - registry); -} - -// Computes `Relu(inputs[0])` and records it on the tape. -Status Relu(AbstractContext* ctx, Tape* tape, - absl::Span inputs, - absl::Span outputs, const char* name, - const GradientRegistry& registry) { - AbstractOperationPtr relu_op(ctx->CreateOperation()); - ForwardOperation forward_op; - forward_op.ctx = ctx; - TF_RETURN_IF_ERROR( - Reset(relu_op.get(), "Relu", /*raw_device_name=*/nullptr, &forward_op)); - if (isa(relu_op.get())) { - TF_RETURN_IF_ERROR( - dyn_cast(relu_op.get())->SetOpName(name)); - } - TF_RETURN_IF_ERROR(AddInput(relu_op.get(), inputs[0], &forward_op)); - int num_retvals = 1; - return Execute(relu_op.get(), ctx, outputs, &num_retvals, &forward_op, tape, - registry); -} - -// Computes `SoftmaxLoss(scores, labels)` for matrices and records it on the -// tape. -Status SparseSoftmaxCrossEntropyLoss( - AbstractContext* ctx, Tape* tape, - absl::Span inputs, - absl::Span outputs, const char* name, - const GradientRegistry& registry) { - AbstractTensorHandle* scores = inputs[0]; - AbstractTensorHandle* labels = inputs[1]; - - AbstractOperationPtr sm_op(ctx->CreateOperation()); - ForwardOperation forward_op; - forward_op.ctx = ctx; - TF_RETURN_IF_ERROR(Reset(sm_op.get(), "SparseSoftmaxCrossEntropyWithLogits", - /*raw_device_name=*/nullptr, &forward_op)); - if (isa(sm_op.get())) { - TF_RETURN_IF_ERROR( - dyn_cast(sm_op.get())->SetOpName(name)); - } - - TF_RETURN_IF_ERROR(AddInput(sm_op.get(), scores, &forward_op)); - TF_RETURN_IF_ERROR(AddInput(sm_op.get(), labels, &forward_op)); - - int num_retvals = 2; // returns loss values and backprop - return Execute(sm_op.get(), ctx, outputs, &num_retvals, &forward_op, tape, - registry); -} //===================== Test Models to run ========================= @@ -169,8 +52,9 @@ Status AddGradModel(AbstractContext* ctx, tape->Watch(ToId(inputs[0])); // Watch x. tape->Watch(ToId(inputs[1])); // Watch y. std::vector add_outputs(1); - TF_RETURN_IF_ERROR(Add(ctx, tape, inputs, absl::MakeSpan(add_outputs), - registry)); // Compute x+y. + AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); + TF_RETURN_IF_ERROR( + ops::Add(tape_ctx.get(), inputs, absl::MakeSpan(add_outputs), "Add")); std::unordered_map source_tensors_that_are_targets; @@ -202,9 +86,11 @@ Status MatMulGradModel(AbstractContext* ctx, tape->Watch(ToId(inputs[0])); // Watch x. tape->Watch(ToId(inputs[1])); // Watch y. vector mm_outputs(1); - TF_RETURN_IF_ERROR(MatMul(ctx, tape, inputs, absl::MakeSpan(mm_outputs), - "matmul0", /*transpose_a=*/false, - /*transpose_b=*/false, registry)); // Compute x*y. + AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); + TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), inputs, + absl::MakeSpan(mm_outputs), "matmul0", + /*transpose_a=*/false, + /*transpose_b=*/false)); // Compute x*y. std::unordered_map source_tensors_that_are_targets; @@ -238,8 +124,9 @@ Status MNISTForwardModel(AbstractContext* ctx, * hidden_layer = tf.nn.relu(mm_out_1) * scores = tf.matmul(hidden_layer,W2) * softmax = - * tf.nn.sparse_softmax_cross_entropy_with_logits(scores,y_labels) return - * scores, softmax + * tf.nn.sparse_softmax_cross_entropy_with_logits(scores, + * y_labels) + * return scores, softmax * * Use this convention for inputs: * @@ -257,24 +144,27 @@ Status MNISTForwardModel(AbstractContext* ctx, tape->Watch(ToId(W2)); // Watch W2. vector temp_outputs(1); - TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs), - "matmul0", /*transpose_a=*/false, - /*transpose_b=*/false, registry)); // Compute X*W1 + AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); + TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1}, + absl::MakeSpan(temp_outputs), "matmul0", + /*transpose_a=*/false, + /*transpose_b=*/false)); // Compute X*W1 - TF_RETURN_IF_ERROR(Relu(ctx, tape, {temp_outputs[0]}, - absl::MakeSpan(temp_outputs), "relu", - registry)); // Compute Relu(X*W1) + TF_RETURN_IF_ERROR(ops::Relu(tape_ctx.get(), {temp_outputs[0]}, + absl::MakeSpan(temp_outputs), + "relu")); // Compute Relu(X*W1) - TF_RETURN_IF_ERROR(MatMul(ctx, tape, {temp_outputs[0], W2}, - absl::MakeSpan(temp_outputs), "matmul1", - /*transpose_a=*/false, /*transpose_b=*/false, - registry)); // Compute W2*Relu(X*W1) + TF_RETURN_IF_ERROR(ops::MatMul( + tape_ctx.get(), {temp_outputs[0], W2}, absl::MakeSpan(temp_outputs), + "matmul1", + /*transpose_a=*/false, /*transpose_b=*/false)); // Compute W2*Relu(X*W1) AbstractTensorHandle* scores = temp_outputs[0]; - TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyLoss( - ctx, tape, {scores, y_labels}, absl::MakeSpan(temp_outputs), - "softmax_loss", registry)); // Compute Softmax(Scores,labels) + temp_outputs.resize(2); + TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits( + tape_ctx.get(), {scores, y_labels}, absl::MakeSpan(temp_outputs), + "softmax_loss")); // Compute Softmax(Scores,labels) AbstractTensorHandle* loss_vals = temp_outputs[0]; @@ -297,9 +187,11 @@ Status MatMulTransposeModel(AbstractContext* ctx, tape->Watch(ToId(W1)); vector temp_outputs(1); - TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs), - "matmul0", /*transpose_a=*/true, - /*transpose_b=*/false, registry)); // Compute X*W1 + AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); + TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1}, + absl::MakeSpan(temp_outputs), "matmul0", + /*transpose_a=*/true, + /*transpose_b=*/false)); // Compute X*W1 outputs[0] = temp_outputs[0]; @@ -315,8 +207,10 @@ Status ReluGradModel(AbstractContext* ctx, auto tape = new Tape(/*persistent=*/false); tape->Watch(ToId(inputs[0])); // Watch X vector relu_outputs(1); - TF_RETURN_IF_ERROR(Relu(ctx, tape, inputs, absl::MakeSpan(relu_outputs), - "relu0", registry)); // Relu(X) + AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); + TF_RETURN_IF_ERROR(ops::Relu(tape_ctx.get(), inputs, + absl::MakeSpan(relu_outputs), + "relu0")); // Relu(X) std::unordered_map source_tensors_that_are_targets; @@ -346,8 +240,9 @@ Status SoftmaxLossGradModel(AbstractContext* ctx, tape->Watch(ToId(inputs[0])); // Watch scores. tape->Watch(ToId(inputs[1])); // Watch labels. vector sm_outputs(2); - TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyLoss( - ctx, tape, inputs, absl::MakeSpan(sm_outputs), "softmax0", registry)); + AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); + TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits( + tape_ctx.get(), inputs, absl::MakeSpan(sm_outputs), "softmax0")); std::unordered_map source_tensors_that_are_targets; @@ -381,29 +276,30 @@ Status MNISTGradModel(AbstractContext* ctx, tape->Watch(ToId(W1)); // Watch W1. tape->Watch(ToId(W2)); // Watch W1. vector temp_outputs(1); - TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs), - "matmul0", /*transpose_a=*/false, - /*transpose_b=*/false, registry)); // Compute X*W1 + AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); + TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1}, + absl::MakeSpan(temp_outputs), "matmul0", + /*transpose_a=*/false, + /*transpose_b=*/false)); // Compute X*W1 AbstractTensorHandle* mm = temp_outputs[0]; - TF_RETURN_IF_ERROR(Relu(ctx, tape, {mm}, - absl::MakeSpan(temp_outputs), // Relu(X*W1) - "relu0", registry)); + TF_RETURN_IF_ERROR(ops::Relu(tape_ctx.get(), {mm}, + absl::MakeSpan(temp_outputs), // Relu(X*W1) + "relu0")); AbstractTensorHandle* hidden = temp_outputs[0]; - TF_RETURN_IF_ERROR(MatMul(ctx, tape, {hidden, W2}, - absl::MakeSpan(temp_outputs), "matmul1", - /*transpose_a=*/false, /*transpose_b=*/false, - registry)); // W2*Relu(X*W1) + TF_RETURN_IF_ERROR(ops::MatMul( + tape_ctx.get(), {hidden, W2}, absl::MakeSpan(temp_outputs), "matmul1", + /*transpose_a=*/false, /*transpose_b=*/false)); // W2*Relu(X*W1) AbstractTensorHandle* scores = temp_outputs[0]; temp_outputs.resize(2); - TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyLoss( - ctx, tape, {scores, y_labels}, absl::MakeSpan(temp_outputs), - "softmaxloss", registry)); // W2*Relu(X*W1) + TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits( + tape_ctx.get(), {scores, y_labels}, absl::MakeSpan(temp_outputs), + "softmaxloss")); // W2*Relu(X*W1) AbstractTensorHandle* loss = temp_outputs[0]; @@ -440,8 +336,10 @@ Status ScalarMulModel(AbstractContext* ctx, auto tape = new Tape(/*persistent=*/false); vector temp_outputs(1); - TF_RETURN_IF_ERROR(Mul(ctx, tape, {eta, A}, absl::MakeSpan(temp_outputs), - "scalarMul0", registry)); // Compute eta*A + AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); + TF_RETURN_IF_ERROR(ops::Mul(tape_ctx.get(), {eta, A}, + absl::MakeSpan(temp_outputs), + "scalarMul0")); // Compute eta*A outputs[0] = temp_outputs[0]; @@ -449,146 +347,69 @@ Status ScalarMulModel(AbstractContext* ctx, return Status::OK(); } +Status MatMulModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry) { + AbstractTensorHandle* X = inputs[0]; + AbstractTensorHandle* W1 = inputs[1]; + + TapeVSpace vspace(ctx); + auto tape = new Tape(/*persistent=*/false); + std::vector temp_outputs(1); + AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); + TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1}, + absl::MakeSpan(temp_outputs), "matmul0", + /*transpose_a=*/false, + /*transpose_b=*/false)); // Compute X*W1 + + outputs[0] = temp_outputs[0]; + delete tape; + return Status::OK(); +} + +Status MulModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry) { + AbstractTensorHandle* x = inputs[0]; + AbstractTensorHandle* y = inputs[1]; + + TapeVSpace vspace(ctx); + auto tape = new Tape(/*persistent=*/false); + std::vector temp_outputs(1); + AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); + TF_RETURN_IF_ERROR(ops::Mul(tape_ctx.get(), {x, y}, + absl::MakeSpan(temp_outputs), + "mul0")); // Compute x*y + + outputs[0] = temp_outputs[0]; + delete tape; + return Status::OK(); +} + +Status SoftmaxModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry) { + AbstractTensorHandle* x = inputs[0]; + AbstractTensorHandle* labels = inputs[1]; + + TapeVSpace vspace(ctx); + auto tape = new Tape(/*persistent=*/false); + std::vector temp_outputs(2); + AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); + TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits( + tape_ctx.get(), {x, labels}, absl::MakeSpan(temp_outputs), "sm_loss")); + + outputs[0] = temp_outputs[0]; // loss values + + delete tape; + return Status::OK(); +} + // ============================= End Models ================================ -Status UpdateWeights(AbstractContext* ctx, vector& grads, - vector& weights, - AbstractTensorHandle* learning_rate) { - /* Update weights one by one using gradient update rule: - * - * w -= lr*grad[w] - * - * NOTE: assuming learning rate is positive - */ - - Status s; - int num_grads = grads.size(); - vector temp_outputs(1); - std::string update_str; - - // Negate learning rate for gradient descent - TF_RETURN_IF_ERROR(ops::Neg(ctx, {learning_rate}, - absl::MakeSpan(temp_outputs), - "neg_lr")); // Compute -lr - learning_rate = temp_outputs[0]; - - for (int i = 0; i < num_grads; i++) { - // Compute dW = -lr * grad(w[i]) - update_str = "update_mul_" + std::to_string(i); - s = ops::Mul(ctx, {learning_rate, grads[i]}, absl::MakeSpan(temp_outputs), - update_str.c_str()); - - AbstractTensorHandle* dW = temp_outputs[0]; - - // Compute temp = weights[i] + dW - update_str = "update_add_" + std::to_string(i); - s = ops::Add(ctx, {weights[i], dW}, absl::MakeSpan(temp_outputs), - update_str.c_str()); - - // Update the weights - weights[i] = temp_outputs[0]; - } - - return Status::OK(); -} - -AbstractContext* BuildFunction(const char* fn_name) { - std::unique_ptr status( - TF_NewStatus(), TF_DeleteStatus); - TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name, status.get()); - return unwrap(graph_ctx); -} - -Status CreateParamsForInputs(AbstractContext* ctx, - absl::Span inputs, - vector* params) { - tracing::TracingTensorHandle* handle = nullptr; - for (auto input : inputs) { - TF_RETURN_IF_ERROR(dyn_cast(ctx)->AddParameter( - input->DataType(), &handle)); - params->emplace_back(handle); - } - return Status::OK(); -} - -Status RunModel(Model model, AbstractContext* ctx, - absl::Span inputs, - absl::Span outputs, bool use_function, - const GradientRegistry& registry) { - if (use_function) { - const char* fn_name = "test_fn"; - std::unique_ptr scoped_func; - // Returning null tensors from a tf.function is not supported, so we keep - // track of indices in the model's outputs are nullptr in this set. - // The FunctionDef only outputs the non-null tensors. We later pad the - // function op outputs to have nullptrs at the `null_indices`. - absl::flat_hash_set null_indices; - { - AbstractContextPtr func_ctx(BuildFunction(fn_name)); - vector func_inputs; - func_inputs.reserve(inputs.size()); - TF_RETURN_IF_ERROR( - CreateParamsForInputs(func_ctx.get(), inputs, &func_inputs)); - vector model_outputs; - model_outputs.resize(outputs.size()); - TF_RETURN_IF_ERROR(model(func_ctx.get(), absl::MakeSpan(func_inputs), - absl::MakeSpan(model_outputs), registry)); - for (auto func_input : func_inputs) { - func_input->Unref(); - } - AbstractFunction* func = nullptr; - OutputList output_list; - output_list.expected_num_outputs = 0; - output_list.outputs.reserve(outputs.size()); - for (int i = 0; i < model_outputs.size(); i++) { - if (model_outputs[i]) { - output_list.outputs.emplace_back(model_outputs[i]); - output_list.expected_num_outputs += 1; - } else { - null_indices.insert(i); - } - } - TF_RETURN_IF_ERROR(dyn_cast(func_ctx.get()) - ->Finalize(&output_list, &func)); - scoped_func.reset(func); - for (auto output : output_list.outputs) { - output->Unref(); - } - TF_RETURN_IF_ERROR(ctx->RegisterFunction(func)); - } - - AbstractOperationPtr fn_op(ctx->CreateOperation()); - TF_RETURN_IF_ERROR(fn_op->Reset(fn_name, /*raw_device_name=*/nullptr)); - for (auto input : inputs) { - TF_RETURN_IF_ERROR(fn_op->AddInput(input)); - } - int retvals = outputs.size() - null_indices.size(); - vector fn_outputs(retvals); - TF_RETURN_IF_ERROR(fn_op->Execute( - absl::Span(fn_outputs.data(), fn_outputs.size()), - &retvals)); - int skipped_indices = 0; - for (int i = 0; i < outputs.size(); i++) { - if (!null_indices.contains(i)) { - outputs[i] = fn_outputs[i - skipped_indices]; - } else { - skipped_indices += 1; - } - } - TF_RETURN_IF_ERROR(ctx->RemoveFunction(fn_name)); - return Status::OK(); - } else { - return model(ctx, inputs, outputs, registry); - } -} - -Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx) { - std::unique_ptr status( - TF_NewStatus(), TF_DeleteStatus); - TFE_ContextOptions* opts = TFE_NewContextOptions(); - TFE_ContextOptionsSetTfrt(opts, use_tfrt); - *ctx = unwrap(TF_NewEagerExecutionContext(opts, status.get())); - TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get())); - TFE_DeleteContextOptions(opts); - return Status::OK(); -} +} // namespace internal +} // namespace gradients +} // namespace tensorflow diff --git a/tensorflow/c/eager/mnist_gradients_testutil.h b/tensorflow/c/eager/mnist_gradients_testutil.h index b6de8ff6788..b173446ac9b 100644 --- a/tensorflow/c/eager/mnist_gradients_testutil.h +++ b/tensorflow/c/eager/mnist_gradients_testutil.h @@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_MNIST_GRADIENTS_TESTUTIL_H_ +#define TENSORFLOW_C_EAGER_MNIST_GRADIENTS_TESTUTIL_H_ #include #include "absl/types/span.h" @@ -24,50 +26,13 @@ limitations under the License. #include "tensorflow/c/experimental/ops/array_ops.h" #include "tensorflow/c/experimental/ops/math_ops.h" #include "tensorflow/c/experimental/ops/nn_ops.h" -#include "tensorflow/c/tf_status_helper.h" -#include "tensorflow/c/tf_tensor.h" #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" +#include "tensorflow/core/platform/status.h" -using namespace tensorflow; -using namespace tensorflow::gradients; -using namespace tensorflow::gradients::internal; -// ========================== Tape Ops ============================== - -// Computes `inputs[0] + inputs[1]` and records it on the tape. -Status Add(AbstractContext* ctx, Tape* tape, - absl::Span inputs, - absl::Span outputs, - const GradientRegistry& registry); - -// Computes `inputs[0] * inputs[1]` for matrices and records it on the tape. -Status MatMul(AbstractContext* ctx, Tape* tape, - absl::Span inputs, - absl::Span outputs, const char* name, - bool transpose_a, bool transpose_b, - const GradientRegistry& registry); - -// Computes `inputs[0] * inputs[1]` and records it on the tape. -Status Mul(AbstractContext* ctx, Tape* tape, - absl::Span inputs, - absl::Span outputs, const char* name, - const GradientRegistry& registry); - -// Computes `Relu(inputs[0])` and records it on the tape. -Status Relu(AbstractContext* ctx, Tape* tape, - absl::Span inputs, - absl::Span outputs, const char* name, - const GradientRegistry& registry); - -// Computes `SoftmaxLoss(scores, labels)` for matrices and records it on the -// tape. -Status SparseSoftmaxCrossEntropyLoss( - AbstractContext* ctx, Tape* tape, - absl::Span inputs, - absl::Span outputs, const char* name, - const GradientRegistry& registry); - -// ====================== End Tape Ops ============================ +namespace tensorflow { +namespace gradients { +namespace internal { // Computes // y = inputs[0] + inputs[1] @@ -121,26 +86,23 @@ Status ScalarMulModel(AbstractContext* ctx, absl::Span outputs, const GradientRegistry& registry); -// Updates the weights for a neural network given incoming grads and learning -// rate -Status UpdateWeights(AbstractContext* ctx, - std::vector& grads, - std::vector& weights, - AbstractTensorHandle* learning_rate); +Status MatMulModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry); -AbstractContext* BuildFunction(const char* fn_name); - -Status CreateParamsForInputs(AbstractContext* ctx, - absl::Span inputs, - std::vector* params); - -using Model = std::function, - absl::Span, const GradientRegistry&)>; - -Status RunModel(Model model, AbstractContext* ctx, +Status MulModel(AbstractContext* ctx, absl::Span inputs, - absl::Span outputs, bool use_function, + absl::Span outputs, const GradientRegistry& registry); -Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx); +Status SoftmaxModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry); + +} // namespace internal +} // namespace gradients +} // namespace tensorflow + +#endif // TENSORFLOW_C_EAGER_MNIST_GRADIENTS_TESTUTIL_H_ diff --git a/tensorflow/c/eager/parallel_device/BUILD b/tensorflow/c/eager/parallel_device/BUILD index df5504adce2..473ab503834 100644 --- a/tensorflow/c/eager/parallel_device/BUILD +++ b/tensorflow/c/eager/parallel_device/BUILD @@ -1,3 +1,5 @@ +load("//tensorflow:tensorflow.bzl", "filegroup") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load( "//tensorflow:tensorflow.bzl", "tf_cc_test", @@ -103,7 +105,6 @@ cc_library( hdrs = ["parallel_device_testlib.h"], deps = [ ":parallel_device", - ":parallel_device_ops", "//tensorflow/c:c_api", "//tensorflow/c:c_api_experimental", "//tensorflow/c/eager:c_api", @@ -118,7 +119,6 @@ tf_cc_test( srcs = ["parallel_device_test.cc"], deps = [ ":parallel_device", - ":parallel_device_ops", ":parallel_device_testlib", "//tensorflow/c:c_api", "//tensorflow/c:c_api_experimental", @@ -138,7 +138,6 @@ tf_cc_test( args = ["--heap_check=local"], deps = [ ":parallel_device", - ":parallel_device_ops", ":parallel_device_testlib", "//tensorflow/c:c_api", "//tensorflow/c:c_api_experimental", @@ -150,19 +149,3 @@ tf_cc_test( "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", ], ) - -# Note: ParallelDevice-specific ops are experimental and not currently linked in -# to TensorFlow by default, just used in a few tests. -filegroup( - name = "parallel_device_ops_srcs", - srcs = ["parallel_device_ops.cc"], - visibility = ["//tensorflow/python/distribute/parallel_device:__pkg__"], -) - -cc_library( - name = "parallel_device_ops", - srcs = [":parallel_device_ops_srcs"], - visibility = ["//tensorflow:internal"], - deps = ["//tensorflow/core:framework"], - alwayslink = 1, -) diff --git a/tensorflow/c/eager/parallel_device/parallel_device.cc b/tensorflow/c/eager/parallel_device/parallel_device.cc index d0e9f351478..41bde23448b 100644 --- a/tensorflow/c/eager/parallel_device/parallel_device.cc +++ b/tensorflow/c/eager/parallel_device/parallel_device.cc @@ -136,13 +136,6 @@ absl::optional> ExecuteWithSpecialOps( } result.emplace(std::move(outputs)); return result; - } else if (operation_name == std::string("DeviceID")) { - std::vector result_content; - result_content.reserve(1); - result_content.push_back(parallel_device.DeviceIDs(context, status)); - if (TF_GetCode(status) != TF_OK) return result; - result.emplace(std::move(result_content)); - return result; } std::vector parallel_inputs; std::vector> implicitly_broadcast_tensors; @@ -255,28 +248,44 @@ TFE_TensorHandle* CopyTensorFromParallelDevice(TFE_Context* context, // Since this function is used to satisfy the TFE_CustomDevice C API, // device_info is passed in using a C-style generic. It must always be a // ParallelDevice. -void ParallelDeviceExecute(TFE_Context* context, int num_inputs, - TFE_TensorHandle** inputs, - const char* operation_name, - const TFE_OpAttrs* attributes, int* num_outputs, +void ParallelDeviceExecute(const TFE_Op* original_op, int* num_outputs, TFE_TensorHandle** outputs, TF_Status* status, void* device_info) { + const char* requested_placement = TFE_OpGetDevice(original_op, status); + if (*requested_placement == '\0') { + TF_SetStatus( + status, TF_INTERNAL, + "Ops must be placed on the parallel device explicitly, or their inputs " + "first un-packed. Got an un-placed op with an input placed on the " + "parallel device."); + return; + } + TFE_Context* context = TFE_OpGetContext(original_op, status); + if (TF_GetCode(status) != TF_OK) return; + const char* operation_name = TFE_OpGetName(original_op, status); + if (TF_GetCode(status) != TF_OK) return; + const TFE_OpAttrs* attributes = TFE_OpGetAttrs(original_op); + NamedParallelDevice* named_device = reinterpret_cast(device_info); std::vector typed_inputs; + int num_inputs = TFE_OpGetFlatInputCount(original_op, status); + if (TF_GetCode(status) != TF_OK) return; typed_inputs.reserve(num_inputs); for (int i = 0; i < num_inputs; ++i) { + TFE_TensorHandle* input = TFE_OpGetFlatInput(original_op, i, status); + if (TF_GetCode(status) != TF_OK) return; const char* tensor_handle_device = - TFE_TensorHandleDeviceName(inputs[i], status); + TFE_TensorHandleDeviceName(input, status); if (TF_GetCode(status) != TF_OK) return; if (named_device->name() == tensor_handle_device) { // We assume that any tensors already placed on this device are // ParallelTensors. typed_inputs.emplace_back(reinterpret_cast( - TFE_TensorHandleDevicePointer(inputs[i], status))); + TFE_TensorHandleDevicePointer(input, status))); if (TF_GetCode(status) != TF_OK) return; } else { - typed_inputs.emplace_back(inputs[i]); + typed_inputs.emplace_back(input); } } diff --git a/tensorflow/c/eager/parallel_device/parallel_device_lib.cc b/tensorflow/c/eager/parallel_device/parallel_device_lib.cc index e270bfcbb80..095f33ff303 100644 --- a/tensorflow/c/eager/parallel_device/parallel_device_lib.cc +++ b/tensorflow/c/eager/parallel_device/parallel_device_lib.cc @@ -58,7 +58,7 @@ using ExecutorPtr = std::unique_ptr; class DeviceThread { public: // Starts a background thread waiting for `StartExecute`. - explicit DeviceThread(const std::string& device) + explicit DeviceThread(const std::string& device, const bool is_async) : status_(TF_NewStatus()), device_(device), // If the context's default exector is set to async, re-using that in @@ -67,7 +67,7 @@ class DeviceThread { // // TODO(allenl): We should have an async API that works with the // parallel device. - executor_(TFE_NewExecutor(/*is_async=*/false)), + executor_(TFE_NewExecutor(is_async)), op_(nullptr), thread_(tensorflow::Env::Default()->StartThread( tensorflow::ThreadOptions(), "parallel_device_execute", @@ -236,12 +236,13 @@ void DeviceThread::Execute(TFE_Context* context, const char* operation_name, } } -ParallelDevice::ParallelDevice(const std::vector& devices) +ParallelDevice::ParallelDevice(const std::vector& devices, + const bool is_async) : underlying_devices_(devices) { device_threads_.reserve(devices.size()); for (int device_index = 0; device_index < devices.size(); ++device_index) { device_threads_.emplace_back( - new DeviceThread(devices[device_index].c_str())); + new DeviceThread(devices[device_index].c_str(), is_async)); } } diff --git a/tensorflow/c/eager/parallel_device/parallel_device_lib.h b/tensorflow/c/eager/parallel_device/parallel_device_lib.h index b3dc47ab088..1bb9ce0f663 100644 --- a/tensorflow/c/eager/parallel_device/parallel_device_lib.h +++ b/tensorflow/c/eager/parallel_device/parallel_device_lib.h @@ -49,7 +49,10 @@ class DeviceThread; // placed on each underlying device. class ParallelDevice { public: - explicit ParallelDevice(const std::vector& devices); + // Eager async execution is only supported when remote eager is not in use + // (b/157523095). + explicit ParallelDevice(const std::vector& devices, + const bool is_async = false); ~ParallelDevice(); diff --git a/tensorflow/c/eager/parallel_device/parallel_device_testlib.cc b/tensorflow/c/eager/parallel_device/parallel_device_testlib.cc index 828dcbae093..67bc596b180 100644 --- a/tensorflow/c/eager/parallel_device/parallel_device_testlib.cc +++ b/tensorflow/c/eager/parallel_device/parallel_device_testlib.cc @@ -279,30 +279,4 @@ void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device, TFE_TensorHandleBackingDeviceName(components[1].get(), status.get()); ASSERT_EQ(underlying_devices[1], second_device); } - // Compute the device ID twice and verify the result - for (int i = 0; i < 2; ++i) { - std::unique_ptr op( - TFE_NewOp(context, "DeviceID", status.get()), TFE_DeleteOp); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); - TFE_OpSetDevice(op.get(), device_name, status.get()); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); - - TFE_TensorHandle* result_handle; - int num_retvals = 1; - TFE_Execute(op.get(), &result_handle, &num_retvals, status.get()); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); - std::array components; - ExtractPerDeviceValues(context, result_handle, &components, status.get()); - TFE_DeleteTensorHandle(result_handle); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); - - ExpectScalarEq(components[0].get(), 0); - ExpectScalarEq(components[1].get(), 1); - std::string first_device = - TFE_TensorHandleBackingDeviceName(components[0].get(), status.get()); - ASSERT_EQ(underlying_devices[0], first_device); - std::string second_device = - TFE_TensorHandleBackingDeviceName(components[1].get(), status.get()); - ASSERT_EQ(underlying_devices[1], second_device); - } } diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h index fcebe973500..efab4dfbeb2 100644 --- a/tensorflow/c/eager/tape.h +++ b/tensorflow/c/eager/tape.h @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { @@ -98,6 +99,10 @@ class VSpace { gtl::ArraySlice output_gradients, std::vector* result) const = 0; + // Builds a tensor filled with ones with the same shape and dtype as `t`. + virtual Status BuildOnesLike(const TapeTensor& t, + Gradient** result) const = 0; + // Looks up the ID of a Gradient. virtual int64 TensorId(Gradient* tensor) const = 0; @@ -121,7 +126,7 @@ class GradientTape { // functions (and hence the tensors they keep alive). Instead, everything // is deleted in ~GradientTape. Persistent GradientTapes are useful when // users want to compute multiple gradients over the same tape. - GradientTape(bool persistent) : persistent_(persistent) {} + explicit GradientTape(bool persistent) : persistent_(persistent) {} ~GradientTape() { for (const auto& pair : op_tape_) { pair.second.backward_function_deleter(pair.second.backward_function); @@ -595,8 +600,10 @@ Status InitialGradients( for (int j = 0; j < op_it->second.output_tensor_info.size(); ++j) { if (op_it->second.output_tensor_info[j].GetID() == id) { found = true; - (*result)[id].push_back( - op_it->second.output_tensor_info[j].OnesLike()); + Gradient* ones_like = nullptr; + TF_RETURN_IF_ERROR(vspace.BuildOnesLike( + op_it->second.output_tensor_info[j], &ones_like)); + (*result)[id].push_back(ones_like); break; } } @@ -611,7 +618,10 @@ Status InitialGradients( // target is also a source. auto source_tensor = sources_that_are_targets.find(id); if (source_tensor != sources_that_are_targets.end()) { - (*result)[id].push_back(source_tensor->second.OnesLike()); + Gradient* ones_like = nullptr; + TF_RETURN_IF_ERROR( + vspace.BuildOnesLike(source_tensor->second, &ones_like)); + (*result)[id].push_back(ones_like); } } } else { @@ -934,7 +944,7 @@ ForwardAccumulator::ForwardpropFromTape( // TODO(allenl): Figure out why using zeros_like everywhere causes issues // for some gradient functions and if there's another way to work around // it (e.g. conds instead of ifs). The value shouldn't really matter. - aid = output_tensor.OnesLike(); + TF_RETURN_IF_ERROR(vspace_.BuildOnesLike(output_tensor, &aid)); } if (TF_PREDICT_FALSE(aid == nullptr)) { return tensorflow::errors::Internal( diff --git a/tensorflow/c/eager/tracing_utils.cc b/tensorflow/c/eager/tracing_utils.cc new file mode 100644 index 00000000000..8eec4bc7d9a --- /dev/null +++ b/tensorflow/c/eager/tracing_utils.cc @@ -0,0 +1,37 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/c/eager/tracing_utils.h" + +#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" +#include "tensorflow/c/experimental/gradients/tape/tape_operation.h" +#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" +#include "tensorflow/core/platform/errors.h" + +namespace tensorflow { +namespace tracing { + +Status MaybeSetOpName(AbstractOperation* op, const char* op_name) { + if (isa(op)) { + TF_RETURN_IF_ERROR(dyn_cast(op)->SetOpName(op_name)); + } + if (isa(op)) { + TF_RETURN_IF_ERROR(MaybeSetOpName( + dyn_cast(op)->GetBackingOperation(), + op_name)); + } + return Status::OK(); +} +} // namespace tracing +} // namespace tensorflow diff --git a/tensorflow/c/eager/tracing_utils.h b/tensorflow/c/eager/tracing_utils.h new file mode 100644 index 00000000000..e2c8f9b28ec --- /dev/null +++ b/tensorflow/c/eager/tracing_utils.h @@ -0,0 +1,26 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_UTILS_H_ +#define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_UTILS_H_ + +#include "tensorflow/c/eager/abstract_operation.h" + +namespace tensorflow { +namespace tracing { +Status MaybeSetOpName(AbstractOperation*, const char* op_name); +} // namespace tracing +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_UTILS_H_ diff --git a/tensorflow/c/experimental/filesystem/BUILD b/tensorflow/c/experimental/filesystem/BUILD index 061fdbd893b..c05c7dc3f7e 100644 --- a/tensorflow/c/experimental/filesystem/BUILD +++ b/tensorflow/c/experimental/filesystem/BUILD @@ -1,3 +1,5 @@ +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") + # Experimental filesystem C APIs for TensorFlow. # Will be moved in proper place once all filesystems are converted to the # modular framework. diff --git a/tensorflow/c/experimental/filesystem/plugins/gcs/BUILD b/tensorflow/c/experimental/filesystem/plugins/gcs/BUILD index 68875d61e47..0fc9f260b21 100644 --- a/tensorflow/c/experimental/filesystem/plugins/gcs/BUILD +++ b/tensorflow/c/experimental/filesystem/plugins/gcs/BUILD @@ -1,3 +1,5 @@ +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") + # Experimental gcs filesystem plugin. load("//tensorflow:tensorflow.bzl", "get_win_copts", "tf_cc_shared_object", "tf_cc_test") @@ -29,6 +31,7 @@ cc_library( ":gcs_helper", ":ram_file_block_cache", "//tensorflow/c:env", + "//tensorflow/c:logging", "//tensorflow/c:tf_status", "//tensorflow/c/experimental/filesystem:filesystem_interface", "@com_github_googlecloudplatform_google_cloud_cpp//:storage_client", @@ -59,6 +62,7 @@ cc_library( deps = [ ":cleanup", "//tensorflow/c:env", + "//tensorflow/c:logging", "//tensorflow/c:tf_status", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/synchronization", diff --git a/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.cc b/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.cc index e01af918100..8cd8ad7ca81 100644 --- a/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.cc +++ b/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.cc @@ -23,6 +23,7 @@ limitations under the License. #include "google/cloud/storage/client.h" #include "tensorflow/c/env.h" #include "tensorflow/c/experimental/filesystem/plugins/gcs/gcs_helper.h" +#include "tensorflow/c/logging.h" #include "tensorflow/c/tf_status.h" // Implementation of a filesystem for GCS environments. @@ -120,20 +121,20 @@ static int64_t LoadBufferFromGCS(const std::string& path, size_t offset, return -1; } int64_t read; - if (!absl::SimpleAtoi(stream.headers().find("content-length")->second, - &read)) { + auto content_length = stream.headers().find("content-length"); + if (content_length == stream.headers().end()) { // When we read a file with offset that is bigger than the actual file size. // GCS will return an empty header (e.g no `content-length` header). In this // case, we will set read to `0` and continue. - if (TF_GetCode(status) == TF_OUT_OF_RANGE) { - read = 0; - } else { - TF_SetStatus(status, TF_UNKNOWN, "Could not get content-length header"); - return -1; - } + read = 0; + } else if (!absl::SimpleAtoi(content_length->second, &read)) { + TF_SetStatus(status, TF_UNKNOWN, "Could not get content-length header"); + return -1; } // `TF_OUT_OF_RANGE` isn't considered as an error. So we clear it here. TF_SetStatus(status, TF_OK, ""); + TF_VLog(1, "Successful read of %s @ %u of size: %u", path.c_str(), offset, + read); stream.read(buffer, read); read = stream.gcount(); if (read < buffer_size) { @@ -146,6 +147,8 @@ static int64_t LoadBufferFromGCS(const std::string& path, size_t offset, path, " @ ", offset) .c_str()); } + TF_VLog(2, "Successful integrity check for: %s @ %u", path.c_str(), + offset); } } return read; @@ -259,7 +262,8 @@ static void SyncImpl(const std::string& bucket, const std::string& object, if (*offset == -1 || *offset == 0) { // UploadFile will automatically switch to resumable upload based on Client // configuration. - auto metadata = gcs_client->UploadFile(outfile->getName(), bucket, object); + auto metadata = gcs_client->UploadFile(outfile->getName(), bucket, object, + gcs::Fields("size")); if (!metadata) { TF_SetStatusFromGCSStatus(metadata.status(), status); return; @@ -278,15 +282,18 @@ static void SyncImpl(const std::string& bucket, const std::string& object, } else { std::string temporary_object = gcs::CreateRandomPrefixName("tf_writable_file_gcs"); - auto metadata = - gcs_client->UploadFile(outfile->getName(), bucket, temporary_object); + auto metadata = gcs_client->UploadFile(outfile->getName(), bucket, + temporary_object, gcs::Fields("")); if (!metadata) { TF_SetStatusFromGCSStatus(metadata.status(), status); return; } + TF_VLog(3, "AppendObject: gs://%s/%s to gs://%s/%s", bucket.c_str(), + temporary_object.c_str(), bucket.c_str(), object.c_str()); const std::vector source_objects = { {object, {}, {}}, {temporary_object, {}, {}}}; - metadata = gcs_client->ComposeObject(bucket, source_objects, object); + metadata = gcs_client->ComposeObject(bucket, source_objects, object, + gcs::Fields("size")); if (!metadata) { TF_SetStatusFromGCSStatus(metadata.status(), status); return; @@ -321,6 +328,8 @@ void Append(const TF_WritableFile* file, const char* buffer, size_t n, "The internal temporary file is not writable."); return; } + TF_VLog(3, "Append: gs://%s/%s size %u", gcs_file->bucket.c_str(), + gcs_file->object.c_str(), n); gcs_file->sync_need = true; gcs_file->outfile.write(buffer, n); if (!gcs_file->outfile) @@ -346,6 +355,8 @@ int64_t Tell(const TF_WritableFile* file, TF_Status* status) { void Flush(const TF_WritableFile* file, TF_Status* status) { auto gcs_file = static_cast(file->plugin_file); if (gcs_file->sync_need) { + TF_VLog(3, "Flush started: gs://%s/%s", gcs_file->bucket.c_str(), + gcs_file->object.c_str()); if (!gcs_file->outfile) { TF_SetStatus(status, TF_INTERNAL, "Could not append to the internal temporary file."); @@ -353,6 +364,8 @@ void Flush(const TF_WritableFile* file, TF_Status* status) { } SyncImpl(gcs_file->bucket, gcs_file->object, &gcs_file->offset, &gcs_file->outfile, gcs_file->gcs_client, status); + TF_VLog(3, "Flush finished: gs://%s/%s", gcs_file->bucket.c_str(), + gcs_file->object.c_str()); if (TF_GetCode(status) != TF_OK) return; gcs_file->sync_need = false; } else { @@ -361,11 +374,16 @@ void Flush(const TF_WritableFile* file, TF_Status* status) { } void Sync(const TF_WritableFile* file, TF_Status* status) { + auto gcs_file = static_cast(file->plugin_file); + TF_VLog(3, "Sync: gs://%s/%s", gcs_file->bucket.c_str(), + gcs_file->object.c_str()); Flush(file, status); } void Close(const TF_WritableFile* file, TF_Status* status) { auto gcs_file = static_cast(file->plugin_file); + TF_VLog(3, "Close: gs://%s/%s", gcs_file->bucket.c_str(), + gcs_file->object.c_str()); if (gcs_file->sync_need) { Flush(file, status); } @@ -428,6 +446,8 @@ GCSFile::GCSFile(google::cloud::storage::Client&& gcs_client) if (absl::SimpleAtoi(std::getenv(kMaxStaleness), &value)) { max_staleness = value; } + TF_VLog(1, "GCS cache max size = %u ; block size = %u ; max staleness = %u", + max_bytes, block_size, max_staleness); file_block_cache = std::make_unique( block_size, max_bytes, max_staleness, @@ -504,13 +524,18 @@ void Cleanup(TF_Filesystem* filesystem) { static void UncachedStatForObject(const std::string& bucket, const std::string& object, GcsFileStat* stat, gcs::Client* gcs_client, TF_Status* status) { - auto metadata = gcs_client->GetObjectMetadata(bucket, object); + auto metadata = gcs_client->GetObjectMetadata( + bucket, object, gcs::Fields("generation,size,timeStorageClassUpdated")); if (!metadata) return TF_SetStatusFromGCSStatus(metadata.status(), status); stat->generation_number = metadata->generation(); stat->base.length = metadata->size(); stat->base.mtime_nsec = metadata->time_storage_class_updated().time_since_epoch().count(); stat->base.is_directory = object.back() == '/'; + TF_VLog(1, + "Stat of: gs://%s/%s -- length: %u generation: %u; mtime_nsec: %u;", + bucket.c_str(), object.c_str(), stat->base.length, + stat->generation_number, stat->base.mtime_nsec); return TF_SetStatus(status, TF_OK, ""); } @@ -545,9 +570,10 @@ void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path, if (TF_GetCode(status) != TF_OK) return -1; if (!gcs_file->file_block_cache->ValidateAndUpdateFileSignature( path, stat.generation_number)) { - std::cout - << "File signature has been changed. Refreshing the cache. Path: " - << path; + TF_VLog( + 1, + "File signature has been changed. Refreshing the cache. Path: %s", + path.c_str()); } read = gcs_file->file_block_cache->Read(path, offset, n, buffer, status); } else { @@ -579,6 +605,7 @@ void NewWritableFile(const TF_Filesystem* filesystem, const char* path, (gcs_file->compose ? 0 : -1)}); // We are responsible for freeing the pointer returned by TF_GetTempFileName free(temp_file_name); + TF_VLog(3, "GcsWritableFile: %s", path); TF_SetStatus(status, TF_OK, ""); } @@ -608,7 +635,8 @@ void NewAppendableFile(const TF_Filesystem* filesystem, const char* path, } else { // If compose is true, we do not download anything. // Instead we only check if this file exists on server or not. - auto metadata = gcs_file->gcs_client.GetObjectMetadata(bucket, object); + auto metadata = gcs_file->gcs_client.GetObjectMetadata(bucket, object, + gcs::Fields("size")); TF_SetStatusFromGCSStatus(metadata.status(), status); if (TF_GetCode(status) == TF_OK) { file->plugin_file = new tf_writable_file::GCSFile( @@ -624,7 +652,8 @@ void NewAppendableFile(const TF_Filesystem* filesystem, const char* path, return; } } - + TF_VLog(3, "GcsWritableFile: %s with existing file %s", path, + temp_file_name.c_str()); TF_SetStatus(status, TF_OK, ""); } @@ -639,7 +668,8 @@ void NewReadOnlyMemoryRegionFromFile(const TF_Filesystem* filesystem, if (TF_GetCode(status) != TF_OK) return; auto gcs_file = static_cast(filesystem->plugin_filesystem); - auto metadata = gcs_file->gcs_client.GetObjectMetadata(bucket, object); + auto metadata = gcs_file->gcs_client.GetObjectMetadata(bucket, object, + gcs::Fields("size")); if (!metadata) { TF_SetStatusFromGCSStatus(metadata.status(), status); return; @@ -670,7 +700,8 @@ static void StatForObject(GCSFile* gcs_file, const std::string& path, if (object.empty()) return TF_SetStatus( status, TF_INVALID_ARGUMENT, - ("'object' must be a non-empty string. (File: " + path + ")").c_str()); + absl::StrCat("'object' must be a non-empty string. (File: ", path, ")") + .c_str()); TF_SetStatus(status, TF_OK, ""); gcs_file->stat_cache->LookupOrCompute( path, stat, @@ -698,7 +729,8 @@ static bool ObjectExists(GCSFile* gcs_file, const std::string& path, static bool BucketExists(GCSFile* gcs_file, const std::string& bucket, TF_Status* status) { - auto metadata = gcs_file->gcs_client.GetBucketMetadata(bucket); + auto metadata = + gcs_file->gcs_client.GetBucketMetadata(bucket, gcs::Fields("")); TF_SetStatusFromGCSStatus(metadata.status(), status); if (TF_GetCode(status) != TF_OK && TF_GetCode(status) != TF_NOT_FOUND) return false; @@ -721,7 +753,8 @@ static std::vector GetChildrenBounded( std::string delimiter = recursive ? "" : "/"; for (auto&& item : gcs_file->gcs_client.ListObjectsAndPrefixes( - bucket, gcs::Prefix(prefix), gcs::Delimiter(delimiter))) { + bucket, gcs::Prefix(prefix), gcs::Delimiter(delimiter), + gcs::Fields("items(name),prefixes"))) { if (count == max_results) { TF_SetStatus(status, TF_OK, ""); return result; @@ -737,8 +770,8 @@ static std::vector GetChildrenBounded( auto pos = children.find(prefix); if (pos != 0) { TF_SetStatus(status, TF_INTERNAL, - ("Unexpected response: the returned file name " + children + - " doesn't match the prefix " + prefix) + absl::StrCat("Unexpected response: the returned file name ", + children, " doesn't match the prefix ", prefix) .c_str()); return result; } @@ -812,6 +845,10 @@ void CreateDir(const TF_Filesystem* filesystem, const char* path, TF_Status* status) { std::string dir = path; MaybeAppendSlash(&dir); + TF_VLog(3, + "CreateDir: creating directory with path: %s and " + "path_with_slash: %s", + path, dir.c_str()); std::string bucket, object; ParseGCSPath(dir, true, &bucket, &object, status); if (TF_GetCode(status) != TF_OK) return; @@ -821,19 +858,23 @@ void CreateDir(const TF_Filesystem* filesystem, const char* path, if (TF_GetCode(status) != TF_OK) return; if (!is_directory) TF_SetStatus(status, TF_NOT_FOUND, - ("The specified bucket " + dir + " was not found.").c_str()); + absl::StrCat("The specified bucket ", dir, " was not found.") + .c_str()); return; } PathExists(filesystem, dir.c_str(), status); - if (TF_GetCode(status) == TF_OK) + if (TF_GetCode(status) == TF_OK) { + // Use the original name for a correct error here. + TF_VLog(3, "CreateDir: directory already exists, not uploading %s", path); return TF_SetStatus(status, TF_ALREADY_EXISTS, path); + } auto metadata = gcs_file->gcs_client.InsertObject( bucket, object, "", // Adding this parameter means HTTP_CODE_PRECONDITION_FAILED // will be returned if the object already exists, so avoid reuploading. - gcs::IfGenerationMatch(0)); + gcs::IfGenerationMatch(0), gcs::Fields("")); TF_SetStatusFromGCSStatus(metadata.status(), status); if (TF_GetCode(status) == TF_FAILED_PRECONDITION) TF_SetStatus(status, TF_ALREADY_EXISTS, path); @@ -891,7 +932,8 @@ void CopyFile(const TF_Filesystem* filesystem, const char* src, const char* dst, auto gcs_file = static_cast(filesystem->plugin_filesystem); auto metadata = gcs_file->gcs_client.RewriteObjectBlocking( - bucket_src, object_src, bucket_dst, object_dst); + bucket_src, object_src, bucket_dst, object_dst, + gcs::Fields("done,rewriteToken")); TF_SetStatusFromGCSStatus(metadata.status(), status); } @@ -908,7 +950,8 @@ bool IsDirectory(const TF_Filesystem* filesystem, const char* path, if (!result) TF_SetStatus( status, TF_NOT_FOUND, - ("The specified bucket gs://" + bucket + " was not found.").c_str()); + absl::StrCat("The specified bucket gs://", bucket, " was not found.") + .c_str()); return result; } @@ -933,6 +976,7 @@ bool IsDirectory(const TF_Filesystem* filesystem, const char* path, static void RenameObject(const TF_Filesystem* filesystem, const std::string& src, const std::string& dst, TF_Status* status) { + TF_VLog(3, "RenameObject: started %s to %s", src.c_str(), dst.c_str()); std::string bucket_src, object_src; ParseGCSPath(src, false, &bucket_src, &object_src, status); if (TF_GetCode(status) != TF_OK) return; @@ -943,9 +987,11 @@ static void RenameObject(const TF_Filesystem* filesystem, auto gcs_file = static_cast(filesystem->plugin_filesystem); auto metadata = gcs_file->gcs_client.RewriteObjectBlocking( - bucket_src, object_src, bucket_dst, object_dst); + bucket_src, object_src, bucket_dst, object_dst, + gcs::Fields("done,rewriteToken")); TF_SetStatusFromGCSStatus(metadata.status(), status); if (TF_GetCode(status) != TF_OK) return; + TF_VLog(3, "RenameObject: finished %s to %s", src.c_str(), dst.c_str()); ClearFileCaches(gcs_file, dst); DeleteFile(filesystem, src.c_str(), status); @@ -954,8 +1000,10 @@ static void RenameObject(const TF_Filesystem* filesystem, void RenameFile(const TF_Filesystem* filesystem, const char* src, const char* dst, TF_Status* status) { if (!IsDirectory(filesystem, src, status)) { - if (TF_GetCode(status) == TF_FAILED_PRECONDITION) + if (TF_GetCode(status) == TF_FAILED_PRECONDITION) { + TF_SetStatus(status, TF_OK, ""); RenameObject(filesystem, src, dst, status); + } return; } @@ -1032,7 +1080,8 @@ void Stat(const TF_Filesystem* filesystem, const char* path, auto gcs_file = static_cast(filesystem->plugin_filesystem); if (object.empty()) { - auto bucket_metadata = gcs_file->gcs_client.GetBucketMetadata(bucket); + auto bucket_metadata = + gcs_file->gcs_client.GetBucketMetadata(bucket, gcs::Fields("")); TF_SetStatusFromGCSStatus(bucket_metadata.status(), status); if (TF_GetCode(status) == TF_OK) { stats->is_directory = true; @@ -1047,8 +1096,9 @@ void Stat(const TF_Filesystem* filesystem, const char* path, stats->mtime_nsec = 0; return TF_SetStatus(status, TF_OK, ""); } - if (TF_GetCode(status) == TF_OK) { - auto metadata = gcs_file->gcs_client.GetObjectMetadata(bucket, object); + if (TF_GetCode(status) == TF_FAILED_PRECONDITION) { + auto metadata = gcs_file->gcs_client.GetObjectMetadata( + bucket, object, gcs::Fields("size,timeStorageClassUpdated")); if (metadata) { stats->is_directory = false; stats->length = metadata.value().size(); @@ -1061,6 +1111,18 @@ void Stat(const TF_Filesystem* filesystem, const char* path, } } +int64_t GetFileSize(const TF_Filesystem* filesystem, const char* path, + TF_Status* status) { + // Only validate the name. + std::string bucket, object; + ParseGCSPath(path, false, &bucket, &object, status); + if (TF_GetCode(status) != TF_OK) return -1; + + TF_FileStatistics stat; + Stat(filesystem, path, &stat, status); + return stat.length; +} + static char* TranslateName(const TF_Filesystem* filesystem, const char* uri) { return strdup(uri); } diff --git a/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.h b/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.h index 973ce9e9dc2..5612d004d82 100644 --- a/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.h +++ b/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.h @@ -87,6 +87,24 @@ void NewReadOnlyMemoryRegionFromFile(const TF_Filesystem* filesystem, const char* path, TF_ReadOnlyMemoryRegion* region, TF_Status* status); +int64_t GetFileSize(const TF_Filesystem* filesystem, const char* path, + TF_Status* status); +void PathExists(const TF_Filesystem* filesystem, const char* path, + TF_Status* status); +void CreateDir(const TF_Filesystem* filesystem, const char* path, + TF_Status* status); +int GetChildren(const TF_Filesystem* filesystem, const char* path, + char*** entries, TF_Status* status); +void DeleteFile(const TF_Filesystem* filesystem, const char* path, + TF_Status* status); +void Stat(const TF_Filesystem* filesystem, const char* path, + TF_FileStatistics* stats, TF_Status* status); +void DeleteDir(const TF_Filesystem* filesystem, const char* path, + TF_Status* status); +void CopyFile(const TF_Filesystem* filesystem, const char* src, const char* dst, + TF_Status* status); +void RenameFile(const TF_Filesystem* filesystem, const char* src, + const char* dst, TF_Status* status); } // namespace tf_gcs_filesystem #endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_FILESYSTEM_H_ diff --git a/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem_test.cc b/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem_test.cc index 82c4e4b8705..e15921335ab 100644 --- a/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem_test.cc +++ b/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/platform/test.h" #define ASSERT_TF_OK(x) ASSERT_EQ(TF_OK, TF_GetCode(x)) << TF_Message(x) +#define EXPECT_TF_OK(x) EXPECT_EQ(TF_OK, TF_GetCode(x)) << TF_Message(x) static const char* content = "abcdefghijklmnopqrstuvwxyz1234567890"; // We will work with content_view instead of content. @@ -94,6 +95,70 @@ class GCSFilesystemTest : public ::testing::Test { return translated_name; } + std::unique_ptr + GetWriter() { + std::unique_ptr writer( + new TF_WritableFile, [](TF_WritableFile* file) { + if (file != nullptr) { + if (file->plugin_file != nullptr) tf_writable_file::Cleanup(file); + delete file; + } + }); + writer->plugin_file = nullptr; + return writer; + } + + std::unique_ptr + GetReader() { + std::unique_ptr + reader(new TF_RandomAccessFile, [](TF_RandomAccessFile* file) { + if (file != nullptr) { + if (file->plugin_file != nullptr) + tf_random_access_file::Cleanup(file); + delete file; + } + }); + reader->plugin_file = nullptr; + return reader; + } + + void WriteString(const std::string& path, const std::string& content) { + auto writer = GetWriter(); + tf_gcs_filesystem::NewWritableFile(filesystem_, path.c_str(), writer.get(), + status_); + if (TF_GetCode(status_) != TF_OK) return; + tf_writable_file::Append(writer.get(), content.c_str(), content.length(), + status_); + if (TF_GetCode(status_) != TF_OK) return; + tf_writable_file::Close(writer.get(), status_); + if (TF_GetCode(status_) != TF_OK) return; + } + + std::string ReadAll(const std::string& path) { + auto reader = GetReader(); + tf_gcs_filesystem::NewRandomAccessFile(filesystem_, path.c_str(), + reader.get(), status_); + if (TF_GetCode(status_) != TF_OK) return ""; + + auto file_size = + tf_gcs_filesystem::GetFileSize(filesystem_, path.c_str(), status_); + if (TF_GetCode(status_) != TF_OK) return ""; + + std::string content; + content.resize(file_size); + auto read = tf_random_access_file::Read(reader.get(), 0, file_size, + &content[0], status_); + if (TF_GetCode(status_) != TF_OK) return ""; + if (read >= 0) content.resize(read); + if (file_size != content.size()) + TF_SetStatus( + status_, TF_DATA_LOSS, + std::string("expected " + std::to_string(file_size) + " got " + + std::to_string(content.size()) + " bytes") + .c_str()); + return content; + } + protected: TF_Filesystem* filesystem_; TF_Status* status_; @@ -326,6 +391,145 @@ TEST_F(GCSFilesystemTest, ReadOnlyMemoryRegion) { delete region; } +TEST_F(GCSFilesystemTest, PathExists) { + tf_gcs_filesystem::Init(filesystem_, status_); + ASSERT_TF_OK(status_); + const std::string path = GetURIForPath("PathExists"); + tf_gcs_filesystem::PathExists(filesystem_, path.c_str(), status_); + EXPECT_EQ(TF_NOT_FOUND, TF_GetCode(status_)) << TF_Message(status_); + TF_SetStatus(status_, TF_OK, ""); + WriteString(path, "test"); + ASSERT_TF_OK(status_); + tf_gcs_filesystem::PathExists(filesystem_, path.c_str(), status_); + EXPECT_TF_OK(status_); +} + +TEST_F(GCSFilesystemTest, GetChildren) { + tf_gcs_filesystem::Init(filesystem_, status_); + ASSERT_TF_OK(status_); + const std::string base = GetURIForPath("GetChildren"); + tf_gcs_filesystem::CreateDir(filesystem_, base.c_str(), status_); + EXPECT_TF_OK(status_); + + const std::string file = io::JoinPath(base, "TestFile.csv"); + WriteString(file, "test"); + EXPECT_TF_OK(status_); + + const std::string subdir = io::JoinPath(base, "SubDir"); + tf_gcs_filesystem::CreateDir(filesystem_, subdir.c_str(), status_); + EXPECT_TF_OK(status_); + const std::string subfile = io::JoinPath(subdir, "TestSubFile.csv"); + WriteString(subfile, "test"); + EXPECT_TF_OK(status_); + + char** entries; + auto num_entries = tf_gcs_filesystem::GetChildren(filesystem_, base.c_str(), + &entries, status_); + EXPECT_TF_OK(status_); + + std::vector childrens; + for (int i = 0; i < num_entries; ++i) { + childrens.push_back(entries[i]); + } + std::sort(childrens.begin(), childrens.end()); + EXPECT_EQ(std::vector({"SubDir/", "TestFile.csv"}), childrens); +} + +TEST_F(GCSFilesystemTest, DeleteFile) { + tf_gcs_filesystem::Init(filesystem_, status_); + ASSERT_TF_OK(status_); + const std::string path = GetURIForPath("DeleteFile"); + WriteString(path, "test"); + ASSERT_TF_OK(status_); + tf_gcs_filesystem::DeleteFile(filesystem_, path.c_str(), status_); + EXPECT_TF_OK(status_); + tf_gcs_filesystem::PathExists(filesystem_, path.c_str(), status_); + EXPECT_EQ(TF_GetCode(status_), TF_NOT_FOUND); +} + +TEST_F(GCSFilesystemTest, CreateDir) { + tf_gcs_filesystem::Init(filesystem_, status_); + ASSERT_TF_OK(status_); + const std::string dir = GetURIForPath("CreateDir"); + tf_gcs_filesystem::CreateDir(filesystem_, dir.c_str(), status_); + EXPECT_TF_OK(status_); + + TF_FileStatistics stat; + tf_gcs_filesystem::Stat(filesystem_, dir.c_str(), &stat, status_); + EXPECT_TF_OK(status_); + EXPECT_TRUE(stat.is_directory); +} + +TEST_F(GCSFilesystemTest, DeleteDir) { + tf_gcs_filesystem::Init(filesystem_, status_); + ASSERT_TF_OK(status_); + const std::string dir = GetURIForPath("DeleteDir"); + const std::string file = io::JoinPath(dir, "DeleteDirFile.csv"); + WriteString(file, "test"); + ASSERT_TF_OK(status_); + tf_gcs_filesystem::DeleteDir(filesystem_, dir.c_str(), status_); + EXPECT_EQ(TF_GetCode(status_), TF_FAILED_PRECONDITION); + + TF_SetStatus(status_, TF_OK, ""); + tf_gcs_filesystem::DeleteFile(filesystem_, file.c_str(), status_); + EXPECT_TF_OK(status_); + tf_gcs_filesystem::DeleteDir(filesystem_, dir.c_str(), status_); + EXPECT_TF_OK(status_); + TF_FileStatistics stat; + tf_gcs_filesystem::Stat(filesystem_, dir.c_str(), &stat, status_); + EXPECT_EQ(TF_GetCode(status_), TF_NOT_FOUND) << TF_Message(status_); +} + +TEST_F(GCSFilesystemTest, StatFile) { + tf_gcs_filesystem::Init(filesystem_, status_); + ASSERT_TF_OK(status_); + const std::string path = GetURIForPath("StatFile"); + WriteString(path, "test"); + ASSERT_TF_OK(status_); + + TF_FileStatistics stat; + tf_gcs_filesystem::Stat(filesystem_, path.c_str(), &stat, status_); + EXPECT_TF_OK(status_); + EXPECT_EQ(4, stat.length); + EXPECT_FALSE(stat.is_directory); +} + +TEST_F(GCSFilesystemTest, RenameFile) { + tf_gcs_filesystem::Init(filesystem_, status_); + ASSERT_TF_OK(status_); + const std::string src = GetURIForPath("RenameFileSrc"); + const std::string dst = GetURIForPath("RenameFileDst"); + WriteString(src, "test"); + ASSERT_TF_OK(status_); + + tf_gcs_filesystem::RenameFile(filesystem_, src.c_str(), dst.c_str(), status_); + EXPECT_TF_OK(status_); + auto result = ReadAll(dst); + EXPECT_TF_OK(status_); + EXPECT_EQ("test", result); +} + +TEST_F(GCSFilesystemTest, RenameFileOverwrite) { + tf_gcs_filesystem::Init(filesystem_, status_); + ASSERT_TF_OK(status_); + const std::string src = GetURIForPath("RenameFileOverwriteSrc"); + const std::string dst = GetURIForPath("RenameFileOverwriteDst"); + + WriteString(src, "test_old"); + ASSERT_TF_OK(status_); + WriteString(dst, "test_new"); + ASSERT_TF_OK(status_); + + tf_gcs_filesystem::PathExists(filesystem_, dst.c_str(), status_); + EXPECT_TF_OK(status_); + tf_gcs_filesystem::RenameFile(filesystem_, src.c_str(), dst.c_str(), status_); + EXPECT_TF_OK(status_); + + auto result = ReadAll(dst); + EXPECT_TF_OK(status_); + EXPECT_EQ("test_old", result); +} + // These tests below are ported from // `//tensorflow/core/platform/cloud:gcs_file_system_test` TEST_F(GCSFilesystemTest, NewRandomAccessFile_NoBlockCache) { diff --git a/tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache.h b/tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache.h index 2abfb6f924b..72659a97d42 100644 --- a/tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache.h +++ b/tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache.h @@ -28,6 +28,7 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "absl/synchronization/notification.h" #include "tensorflow/c/env.h" +#include "tensorflow/c/logging.h" #include "tensorflow/c/tf_status.h" namespace tf_gcs_filesystem { @@ -65,8 +66,8 @@ class RamFileBlockCache { pruning_thread_.reset( TF_StartThread(&thread_options, "TF_prune_FBC", PruneThread, this)); } - std::cout << "GCS file block cache is " - << (IsCacheEnabled() ? "enabled" : "disabled") << ".\n"; + TF_VLog(1, "GCS file block cache is %s.\n", + (IsCacheEnabled() ? "enabled" : "disabled")); } ~RamFileBlockCache() { diff --git a/tensorflow/c/experimental/filesystem/plugins/hadoop/BUILD b/tensorflow/c/experimental/filesystem/plugins/hadoop/BUILD index 51ffd709f3d..765c4e5f06e 100644 --- a/tensorflow/c/experimental/filesystem/plugins/hadoop/BUILD +++ b/tensorflow/c/experimental/filesystem/plugins/hadoop/BUILD @@ -1,5 +1,7 @@ +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") + # Experimental hadoop filesystem plugin. -load("//tensorflow:tensorflow.bzl", "get_win_copts", "tf_cc_shared_object") +load("//tensorflow:tensorflow.bzl", "get_win_copts", "tf_cc_shared_object", "tf_cc_test") package( licenses = ["notice"], # Apache 2.0 @@ -20,12 +22,14 @@ cc_library( name = "hadoop_filesystem_impl", srcs = ["hadoop_filesystem.cc"], hdrs = ["hadoop_filesystem.h"], + compatible_with = [], copts = select({ "//conditions:default": [], "//tensorflow:windows": get_win_copts(), }), deps = [ "//tensorflow/c:env", + "//tensorflow/c:logging", "//tensorflow/c:tf_status", "//tensorflow/c/experimental/filesystem:filesystem_interface", "//third_party/hadoop:hdfs", @@ -33,3 +37,38 @@ cc_library( "@com_google_absl//absl/synchronization", ], ) + +# This test is set to manual because it requires downloading the Hadoop +# distribution to run. To run this test: +# 1. Ensure $JAVA_HOME is set to the location of a JDK 8 installation. +# 2. Download the binary Hadoop distribution from: +# http://hadoop.apache.org/releases.html +# 3. Extract the Hadoop distribution and run: +# source libexec/hadoop-config.sh +# 4. Optionally set up HDFS cluster configurations (optionally Kerberos) within +# $HADOOP_HDFS_HOME/etc/hadoop if you want to test against real +# distributed HDFS cluster +# 5. bazel test \ +# --test_env=LD_LIBRARY_PATH=$JAVA_HOME/jre/lib/amd64/server \ +# --test_env=HADOOP_HDFS_HOME=$HADOOP_HDFS_HOME \ +# --test_env=CLASSPATH=$($HADOOP_HDFS_HOME/bin/hadoop classpath --glob) \ +# :hadoop_file_system_test +# To test against the real distributed cluster, add the following option for +# bazel test: +# --test_env=HADOOP_TEST_TMPDIR=hdfs://cluster/test/tmp/dir +tf_cc_test( + name = "hadoop_filesystem_test", + srcs = [ + "hadoop_filesystem_test.cc", + ], + tags = [ + "manual", + "notap", + ], + deps = [ + ":hadoop_filesystem_impl", + "//tensorflow/core/platform:path", + "//tensorflow/core/platform:stacktrace_handler", + "//tensorflow/core/platform:test", + ], +) diff --git a/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem.cc b/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem.cc index e53e3d0bcc5..5ff28e4229a 100644 --- a/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem.cc +++ b/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem.cc @@ -22,11 +22,10 @@ limitations under the License. #include #include -#include "absl/synchronization/mutex.h" #include "tensorflow/c/env.h" #include "tensorflow/c/experimental/filesystem/filesystem_interface.h" +#include "tensorflow/c/logging.h" #include "tensorflow/c/tf_status.h" -#include "third_party/hadoop/hdfs.h" // Implementation of a filesystem for HADOOP environments. // This filesystem will support `hdfs://`, `viewfs://` and `har://` URI schemes. @@ -37,11 +36,17 @@ static void plugin_memory_free(void* ptr) { free(ptr); } void ParseHadoopPath(const std::string& fname, std::string* scheme, std::string* namenode, std::string* path) { size_t scheme_end = fname.find("://") + 2; - *scheme = fname.substr(0, scheme_end + 1); + // We don't want `://` in scheme. + *scheme = fname.substr(0, scheme_end - 2); size_t nn_end = fname.find("/", scheme_end + 1); - if (nn_end == std::string::npos) return; + if (nn_end == std::string::npos) { + *namenode = fname.substr(scheme_end + 1); + *path = ""; + return; + } *namenode = fname.substr(scheme_end + 1, nn_end - scheme_end - 1); - *path = fname.substr(nn_end + 1); + // We keep `/` in path. + *path = fname.substr(nn_end); } void SplitArchiveNameAndPath(std::string* path, std::string* nn, @@ -54,7 +59,7 @@ void SplitArchiveNameAndPath(std::string* path, std::string* nn, } // Case of hadoop archive. Namenode is the path to the archive. std::ostringstream namenodestream; - namenodestream << "har://" << nn + namenodestream << "har://" << *nn << path->substr(0, index_end_archive_name + 4); *nn = namenodestream.str(); path->erase(0, index_end_archive_name + 4); @@ -143,15 +148,20 @@ class LibHDFS { char* hdfs_home = getenv("HADOOP_HDFS_HOME"); if (hdfs_home != nullptr) { auto JoinPath = [](std::string home, std::string lib) { +#if defined(_WIN32) + if (home.back() != '\\') home.push_back('\\'); + return home + "lib\\native\\" + lib; +#else if (home.back() != '/') home.push_back('/'); return home + "lib/native/" + lib; +#endif }; std::string path = JoinPath(hdfs_home, kLibHdfsDso); TryLoadAndBind(path.c_str(), &handle_, status); if (TF_GetCode(status) == TF_OK) { return; } else { - std::cerr << "HadoopFileSystem load error: " << TF_Message(status); + TF_Log(TF_FATAL, "HadoopFileSystem load error: %s", TF_Message(status)); } } @@ -163,13 +173,15 @@ class LibHDFS { void* handle_; }; -// We rely on HDFS connection caching here. The HDFS client calls -// org.apache.hadoop.fs.FileSystem.get(), which caches the connection -// internally. -hdfsFS Connect(LibHDFS* libhdfs, const std::string& path, TF_Status* status) { +// We implement connection caching in Tensorflow, which can significantly +// improve performance. Fixes #43187 +hdfsFS Connect(tf_hadoop_filesystem::HadoopFile* hadoop_file, + const std::string& path, TF_Status* status) { + auto libhdfs = hadoop_file->libhdfs; std::string scheme, namenode, hdfs_path; ParseHadoopPath(path, &scheme, &namenode, &hdfs_path); + std::string cacheKey(scheme); hdfsBuilder* builder = libhdfs->hdfsNewBuilder(); if (scheme == "file") { libhdfs->hdfsBuilderSetNameNode(builder, nullptr); @@ -194,15 +206,24 @@ hdfsFS Connect(LibHDFS* libhdfs, const std::string& path, TF_Status* status) { SplitArchiveNameAndPath(&path_har, &namenode, status); if (TF_GetCode(status) != TF_OK) return nullptr; libhdfs->hdfsBuilderSetNameNode(builder, namenode.c_str()); + cacheKey += namenode; } else { libhdfs->hdfsBuilderSetNameNode( builder, namenode.empty() ? "default" : namenode.c_str()); + cacheKey += namenode; } - auto fs = libhdfs->hdfsBuilderConnect(builder); - if (fs == nullptr) - TF_SetStatusFromIOError(status, TF_NOT_FOUND, strerror(errno)); - else - TF_SetStatus(status, TF_OK, ""); + absl::MutexLock l(&hadoop_file->connection_cache_lock); + if (hadoop_file->connection_cache.find(cacheKey) == + hadoop_file->connection_cache.end()) { + auto cacheFs = libhdfs->hdfsBuilderConnect(builder); + if (cacheFs == nullptr) { + TF_SetStatusFromIOError(status, TF_NOT_FOUND, strerror(errno)); + return cacheFs; + } + hadoop_file->connection_cache[cacheKey] = cacheFs; + } + auto fs = hadoop_file->connection_cache[cacheKey]; + TF_SetStatus(status, TF_OK, ""); return fs; } @@ -216,6 +237,7 @@ typedef struct HDFSFile { LibHDFS* libhdfs; absl::Mutex mu; hdfsFile handle ABSL_GUARDED_BY(mu); + bool disable_eof_retried; HDFSFile(std::string path, std::string hdfs_path, hdfsFS fs, LibHDFS* libhdfs, hdfsFile handle) : path(std::move(path)), @@ -223,7 +245,15 @@ typedef struct HDFSFile { fs(fs), libhdfs(libhdfs), mu(), - handle(handle) {} + handle(handle) { + const char* disable_eof_retried_str = + getenv("HDFS_DISABLE_READ_EOF_RETRIED"); + if (disable_eof_retried_str && disable_eof_retried_str[0] == '1') { + disable_eof_retried = true; + } else { + disable_eof_retried = false; + } + } } HDFSFile; void Cleanup(TF_RandomAccessFile* file) { @@ -247,8 +277,12 @@ int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n, char* dst = buffer; bool eof_retried = false; - int64_t r = 0; - while (TF_GetCode(status) == TF_OK && !eof_retried) { + if (hdfs_file->disable_eof_retried) { + // eof_retried = true, avoid calling hdfsOpenFile in Read, Fixes #42597 + eof_retried = true; + } + int64_t read = 0; + while (TF_GetCode(status) == TF_OK && n > 0) { // We lock inside the loop rather than outside so we don't block other // concurrent readers. absl::MutexLock l(&hdfs_file->mu); @@ -257,12 +291,13 @@ int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n, // of int32. -2 offset can avoid JVM OutOfMemoryError. size_t read_n = (std::min)(n, static_cast(std::numeric_limits::max() - 2)); - r = libhdfs->hdfsPread(fs, handle, static_cast(offset), dst, - static_cast(read_n)); + int64_t r = libhdfs->hdfsPread(fs, handle, static_cast(offset), + dst, static_cast(read_n)); if (r > 0) { dst += r; n -= r; offset += r; + read += r; } else if (!eof_retried && r == 0) { // Always reopen the file upon reaching EOF to see if there's more data. // If writers are streaming contents while others are concurrently @@ -274,11 +309,13 @@ int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n, TF_SetStatusFromIOError(status, errno, path); return -1; } - handle = libhdfs->hdfsOpenFile(fs, hdfs_path, O_RDONLY, 0, 0, 0); - if (handle == nullptr) { + hdfs_file->handle = + libhdfs->hdfsOpenFile(fs, hdfs_path, O_RDONLY, 0, 0, 0); + if (hdfs_file->handle == nullptr) { TF_SetStatusFromIOError(status, errno, path); return -1; } + handle = hdfs_file->handle; eof_retried = true; } else if (eof_retried && r == 0) { TF_SetStatus(status, TF_OUT_OF_RANGE, "Read less bytes than requested"); @@ -288,7 +325,7 @@ int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n, TF_SetStatusFromIOError(status, errno, path); } } - return r; + return read; } } // namespace tf_random_access_file @@ -308,7 +345,7 @@ typedef struct HDFSFile { handle(handle) {} } HDFSFile; -static void Cleanup(TF_WritableFile* file) { +void Cleanup(TF_WritableFile* file) { auto hdfs_file = static_cast(file->plugin_file); hdfs_file->libhdfs->hdfsCloseFile(hdfs_file->fs, hdfs_file->handle); hdfs_file->fs = nullptr; @@ -387,30 +424,36 @@ void Close(const TF_WritableFile* file, TF_Status* status) { // SECTION 3. Implementation for `TF_ReadOnlyMemoryRegion` // ---------------------------------------------------------------------------- namespace tf_read_only_memory_region { - -// TODO(vnvo2409): Implement later - +// Hadoop doesn't support Readonly Memory Region } // namespace tf_read_only_memory_region // SECTION 4. Implementation for `TF_Filesystem`, the actual filesystem // ---------------------------------------------------------------------------- namespace tf_hadoop_filesystem { +HadoopFile::HadoopFile(TF_Status* status) + : libhdfs(new LibHDFS(status)), + connection_cache_lock(), + connection_cache() {} + void Init(TF_Filesystem* filesystem, TF_Status* status) { - filesystem->plugin_filesystem = new LibHDFS(status); + filesystem->plugin_filesystem = new HadoopFile(status); if (TF_GetCode(status) != TF_OK) return; TF_SetStatus(status, TF_OK, ""); } void Cleanup(TF_Filesystem* filesystem) { - auto libhdfs = static_cast(filesystem->plugin_filesystem); + auto hadoop_file = static_cast(filesystem->plugin_filesystem); + auto libhdfs = hadoop_file->libhdfs; delete libhdfs; + delete hadoop_file; } void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path, TF_RandomAccessFile* file, TF_Status* status) { - auto libhdfs = static_cast(filesystem->plugin_filesystem); - auto fs = Connect(libhdfs, path, status); + auto hadoop_file = static_cast(filesystem->plugin_filesystem); + auto libhdfs = hadoop_file->libhdfs; + auto fs = Connect(hadoop_file, path, status); if (TF_GetCode(status) != TF_OK) return; std::string scheme, namenode, hdfs_path; @@ -426,8 +469,27 @@ void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path, void NewWritableFile(const TF_Filesystem* filesystem, const char* path, TF_WritableFile* file, TF_Status* status) { - auto libhdfs = static_cast(filesystem->plugin_filesystem); - auto fs = Connect(libhdfs, path, status); + auto hadoop_file = static_cast(filesystem->plugin_filesystem); + auto libhdfs = hadoop_file->libhdfs; + auto fs = Connect(hadoop_file, path, status); + if (TF_GetCode(status) != TF_OK) return; + + std::string scheme, namenode, hdfs_path; + ParseHadoopPath(path, &scheme, &namenode, &hdfs_path); + + auto handle = libhdfs->hdfsOpenFile(fs, hdfs_path.c_str(), O_WRONLY, 0, 0, 0); + if (handle == nullptr) return TF_SetStatusFromIOError(status, errno, path); + + file->plugin_file = + new tf_writable_file::HDFSFile(hdfs_path, fs, libhdfs, handle); + TF_SetStatus(status, TF_OK, ""); +} + +void NewAppendableFile(const TF_Filesystem* filesystem, const char* path, + TF_WritableFile* file, TF_Status* status) { + auto hadoop_file = static_cast(filesystem->plugin_filesystem); + auto libhdfs = hadoop_file->libhdfs; + auto fs = Connect(hadoop_file, path, status); if (TF_GetCode(status) != TF_OK) return; std::string scheme, namenode, hdfs_path; @@ -458,8 +520,9 @@ void NewReadOnlyMemoryRegionFromFile(const TF_Filesystem* filesystem, void PathExists(const TF_Filesystem* filesystem, const char* path, TF_Status* status) { - auto libhdfs = static_cast(filesystem->plugin_filesystem); - auto fs = Connect(libhdfs, path, status); + auto hadoop_file = static_cast(filesystem->plugin_filesystem); + auto libhdfs = hadoop_file->libhdfs; + auto fs = Connect(hadoop_file, path, status); if (TF_GetCode(status) != TF_OK) return; std::string scheme, namenode, hdfs_path; @@ -474,8 +537,9 @@ void PathExists(const TF_Filesystem* filesystem, const char* path, void Stat(const TF_Filesystem* filesystem, const char* path, TF_FileStatistics* stats, TF_Status* status) { - auto libhdfs = static_cast(filesystem->plugin_filesystem); - auto fs = Connect(libhdfs, path, status); + auto hadoop_file = static_cast(filesystem->plugin_filesystem); + auto libhdfs = hadoop_file->libhdfs; + auto fs = Connect(hadoop_file, path, status); if (TF_GetCode(status) != TF_OK) return; std::string scheme, namenode, hdfs_path; @@ -493,8 +557,9 @@ void Stat(const TF_Filesystem* filesystem, const char* path, int64_t GetFileSize(const TF_Filesystem* filesystem, const char* path, TF_Status* status) { - auto libhdfs = static_cast(filesystem->plugin_filesystem); - auto fs = Connect(libhdfs, path, status); + auto hadoop_file = static_cast(filesystem->plugin_filesystem); + auto libhdfs = hadoop_file->libhdfs; + auto fs = Connect(hadoop_file, path, status); if (TF_GetCode(status) != TF_OK) return -1; std::string scheme, namenode, hdfs_path; @@ -514,8 +579,9 @@ int64_t GetFileSize(const TF_Filesystem* filesystem, const char* path, void DeleteFile(const TF_Filesystem* filesystem, const char* path, TF_Status* status) { - auto libhdfs = static_cast(filesystem->plugin_filesystem); - auto fs = Connect(libhdfs, path, status); + auto hadoop_file = static_cast(filesystem->plugin_filesystem); + auto libhdfs = hadoop_file->libhdfs; + auto fs = Connect(hadoop_file, path, status); if (TF_GetCode(status) != TF_OK) return; std::string scheme, namenode, hdfs_path; @@ -529,8 +595,9 @@ void DeleteFile(const TF_Filesystem* filesystem, const char* path, void CreateDir(const TF_Filesystem* filesystem, const char* path, TF_Status* status) { - auto libhdfs = static_cast(filesystem->plugin_filesystem); - auto fs = Connect(libhdfs, path, status); + auto hadoop_file = static_cast(filesystem->plugin_filesystem); + auto libhdfs = hadoop_file->libhdfs; + auto fs = Connect(hadoop_file, path, status); if (TF_GetCode(status) != TF_OK) return; std::string scheme, namenode, hdfs_path; @@ -544,8 +611,9 @@ void CreateDir(const TF_Filesystem* filesystem, const char* path, void DeleteDir(const TF_Filesystem* filesystem, const char* path, TF_Status* status) { - auto libhdfs = static_cast(filesystem->plugin_filesystem); - auto fs = Connect(libhdfs, path, status); + auto hadoop_file = static_cast(filesystem->plugin_filesystem); + auto libhdfs = hadoop_file->libhdfs; + auto fs = Connect(hadoop_file, path, status); if (TF_GetCode(status) != TF_OK) return; std::string scheme, namenode, hdfs_path; @@ -580,8 +648,9 @@ void DeleteDir(const TF_Filesystem* filesystem, const char* path, void RenameFile(const TF_Filesystem* filesystem, const char* src, const char* dst, TF_Status* status) { - auto libhdfs = static_cast(filesystem->plugin_filesystem); - auto fs = Connect(libhdfs, src, status); + auto hadoop_file = static_cast(filesystem->plugin_filesystem); + auto libhdfs = hadoop_file->libhdfs; + auto fs = Connect(hadoop_file, src, status); if (TF_GetCode(status) != TF_OK) return; std::string scheme, namenode, hdfs_path_src, hdfs_path_dst; @@ -601,8 +670,9 @@ void RenameFile(const TF_Filesystem* filesystem, const char* src, int GetChildren(const TF_Filesystem* filesystem, const char* path, char*** entries, TF_Status* status) { - auto libhdfs = static_cast(filesystem->plugin_filesystem); - auto fs = Connect(libhdfs, path, status); + auto hadoop_file = static_cast(filesystem->plugin_filesystem); + auto libhdfs = hadoop_file->libhdfs; + auto fs = Connect(hadoop_file, path, status); if (TF_GetCode(status) != TF_OK) return -1; std::string scheme, namenode, hdfs_path; @@ -638,7 +708,9 @@ int GetChildren(const TF_Filesystem* filesystem, const char* path, return num_entries; } -// TODO(vnvo2409): Implement later +static char* TranslateName(const TF_Filesystem* filesystem, const char* uri) { + return strdup(uri); +} } // namespace tf_hadoop_filesystem @@ -646,6 +718,42 @@ static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops, const char* uri) { TF_SetFilesystemVersionMetadata(ops); ops->scheme = strdup(uri); + + ops->random_access_file_ops = static_cast( + plugin_memory_allocate(TF_RANDOM_ACCESS_FILE_OPS_SIZE)); + ops->random_access_file_ops->cleanup = tf_random_access_file::Cleanup; + ops->random_access_file_ops->read = tf_random_access_file::Read; + + ops->writable_file_ops = static_cast( + plugin_memory_allocate(TF_WRITABLE_FILE_OPS_SIZE)); + ops->writable_file_ops->cleanup = tf_writable_file::Cleanup; + ops->writable_file_ops->append = tf_writable_file::Append; + ops->writable_file_ops->tell = tf_writable_file::Tell; + ops->writable_file_ops->flush = tf_writable_file::Flush; + ops->writable_file_ops->sync = tf_writable_file::Sync; + ops->writable_file_ops->close = tf_writable_file::Close; + + ops->filesystem_ops = static_cast( + plugin_memory_allocate(TF_FILESYSTEM_OPS_SIZE)); + ops->filesystem_ops->init = tf_hadoop_filesystem::Init; + ops->filesystem_ops->cleanup = tf_hadoop_filesystem::Cleanup; + ops->filesystem_ops->new_random_access_file = + tf_hadoop_filesystem::NewRandomAccessFile; + ops->filesystem_ops->new_writable_file = + tf_hadoop_filesystem::NewWritableFile; + ops->filesystem_ops->new_appendable_file = + tf_hadoop_filesystem::NewAppendableFile; + ops->filesystem_ops->new_read_only_memory_region_from_file = + tf_hadoop_filesystem::NewReadOnlyMemoryRegionFromFile; + ops->filesystem_ops->path_exists = tf_hadoop_filesystem::PathExists; + ops->filesystem_ops->stat = tf_hadoop_filesystem::Stat; + ops->filesystem_ops->get_file_size = tf_hadoop_filesystem::GetFileSize; + ops->filesystem_ops->delete_file = tf_hadoop_filesystem::DeleteFile; + ops->filesystem_ops->create_dir = tf_hadoop_filesystem::CreateDir; + ops->filesystem_ops->delete_dir = tf_hadoop_filesystem::DeleteDir; + ops->filesystem_ops->rename_file = tf_hadoop_filesystem::RenameFile; + ops->filesystem_ops->get_children = tf_hadoop_filesystem::GetChildren; + ops->filesystem_ops->translate_name = tf_hadoop_filesystem::TranslateName; } void TF_InitPlugin(TF_FilesystemPluginInfo* info) { diff --git a/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem.h b/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem.h index 850cefe0231..06b91a68123 100644 --- a/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem.h +++ b/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem.h @@ -15,7 +15,73 @@ limitations under the License. #ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_HADOOP_HADOOP_FILESYSTEM_H_ #define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_HADOOP_HADOOP_FILESYSTEM_H_ +#include +#include + +#include "absl/synchronization/mutex.h" #include "tensorflow/c/experimental/filesystem/filesystem_interface.h" #include "tensorflow/c/tf_status.h" +#include "third_party/hadoop/hdfs.h" + +void ParseHadoopPath(const std::string& fname, std::string* scheme, + std::string* namenode, std::string* path); +void SplitArchiveNameAndPath(std::string* path, std::string* nn, + TF_Status* status); +class LibHDFS; + +namespace tf_random_access_file { +void Cleanup(TF_RandomAccessFile* file); +int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n, + char* buffer, TF_Status* status); +} // namespace tf_random_access_file + +namespace tf_writable_file { +void Cleanup(TF_WritableFile* file); +void Append(const TF_WritableFile* file, const char* buffer, size_t n, + TF_Status* status); +int64_t Tell(const TF_WritableFile* file, TF_Status* status); +void Sync(const TF_WritableFile* file, TF_Status* status); +void Flush(const TF_WritableFile* file, TF_Status* status); +void Close(const TF_WritableFile* file, TF_Status* status); +} // namespace tf_writable_file + +namespace tf_hadoop_filesystem { +typedef struct HadoopFile { + LibHDFS* libhdfs; + absl::Mutex connection_cache_lock; + std::map connection_cache + ABSL_GUARDED_BY(connection_cache_lock); + HadoopFile(TF_Status* status); +} HadoopFile; + +void Init(TF_Filesystem* filesystem, TF_Status* status); +void Cleanup(TF_Filesystem* filesystem); +void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path, + TF_RandomAccessFile* file, TF_Status* status); +void NewWritableFile(const TF_Filesystem* filesystem, const char* path, + TF_WritableFile* file, TF_Status* status); +void NewAppendableFile(const TF_Filesystem* filesystem, const char* path, + TF_WritableFile* file, TF_Status* status); +void NewReadOnlyMemoryRegionFromFile(const TF_Filesystem* filesystem, + const char* path, + TF_ReadOnlyMemoryRegion* region, + TF_Status* status); +void PathExists(const TF_Filesystem* filesystem, const char* path, + TF_Status* status); +void Stat(const TF_Filesystem* filesystem, const char* path, + TF_FileStatistics* stats, TF_Status* status); +int64_t GetFileSize(const TF_Filesystem* filesystem, const char* path, + TF_Status* status); +void DeleteFile(const TF_Filesystem* filesystem, const char* path, + TF_Status* status); +void CreateDir(const TF_Filesystem* filesystem, const char* path, + TF_Status* status); +void DeleteDir(const TF_Filesystem* filesystem, const char* path, + TF_Status* status); +void RenameFile(const TF_Filesystem* filesystem, const char* src, + const char* dst, TF_Status* status); +int GetChildren(const TF_Filesystem* filesystem, const char* path, + char*** entries, TF_Status* status); +} // namespace tf_hadoop_filesystem #endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_HADOOP_HADOOP_FILESYSTEM_H_ diff --git a/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem_test.cc b/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem_test.cc new file mode 100644 index 00000000000..df85ba9e4dd --- /dev/null +++ b/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem_test.cc @@ -0,0 +1,460 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem.h" + +#include "tensorflow/core/platform/path.h" +#include "tensorflow/core/platform/stacktrace_handler.h" +#include "tensorflow/core/platform/test.h" +#include "third_party/hadoop/hdfs.h" + +#define ASSERT_TF_OK(x) ASSERT_EQ(TF_OK, TF_GetCode(x)) << TF_Message(x) +#define EXPECT_TF_OK(x) EXPECT_EQ(TF_OK, TF_GetCode(x)) << TF_Message(x) + +namespace tensorflow { +namespace { + +class HadoopFileSystemTest : public ::testing::Test { + public: + void SetUp() override { + status_ = TF_NewStatus(); + filesystem_ = new TF_Filesystem; + tf_hadoop_filesystem::Init(filesystem_, status_); + ASSERT_TF_OK(status_) << "Could not initialize filesystem. " + << TF_Message(status_); + } + void TearDown() override { + TF_DeleteStatus(status_); + tf_hadoop_filesystem::Cleanup(filesystem_); + delete filesystem_; + } + + std::string TmpDir(const std::string& path) { + char* test_dir = getenv("HADOOP_TEST_TMPDIR"); + if (test_dir != nullptr) { + return io::JoinPath(std::string(test_dir), path); + } else { + return "file://" + io::JoinPath(testing::TmpDir(), path); + } + } + + std::unique_ptr + GetWriter() { + std::unique_ptr writer( + new TF_WritableFile, [](TF_WritableFile* file) { + if (file != nullptr) { + if (file->plugin_file != nullptr) tf_writable_file::Cleanup(file); + delete file; + } + }); + writer->plugin_file = nullptr; + return writer; + } + + std::unique_ptr + GetReader() { + std::unique_ptr + reader(new TF_RandomAccessFile, [](TF_RandomAccessFile* file) { + if (file != nullptr) { + if (file->plugin_file != nullptr) + tf_random_access_file::Cleanup(file); + delete file; + } + }); + reader->plugin_file = nullptr; + return reader; + } + + void WriteString(const std::string& path, const std::string& content) { + auto writer = GetWriter(); + tf_hadoop_filesystem::NewWritableFile(filesystem_, path.c_str(), + writer.get(), status_); + if (TF_GetCode(status_) != TF_OK) return; + tf_writable_file::Append(writer.get(), content.c_str(), content.length(), + status_); + if (TF_GetCode(status_) != TF_OK) return; + tf_writable_file::Close(writer.get(), status_); + if (TF_GetCode(status_) != TF_OK) return; + } + + std::string ReadAll(const std::string& path) { + auto reader = GetReader(); + tf_hadoop_filesystem::NewRandomAccessFile(filesystem_, path.c_str(), + reader.get(), status_); + if (TF_GetCode(status_) != TF_OK) return ""; + + auto file_size = + tf_hadoop_filesystem::GetFileSize(filesystem_, path.c_str(), status_); + if (TF_GetCode(status_) != TF_OK) return ""; + + std::string content; + content.resize(file_size); + auto read = tf_random_access_file::Read(reader.get(), 0, file_size, + &content[0], status_); + if (TF_GetCode(status_) != TF_OK) return ""; + if (read >= 0) content.resize(read); + if (file_size != content.size()) + TF_SetStatus( + status_, TF_DATA_LOSS, + std::string("expected " + std::to_string(file_size) + " got " + + std::to_string(content.size()) + " bytes") + .c_str()); + return content; + } + + protected: + TF_Filesystem* filesystem_; + TF_Status* status_; +}; + +TEST_F(HadoopFileSystemTest, RandomAccessFile) { + const std::string path = TmpDir("RandomAccessFile"); + const std::string content = "abcdefghijklmn"; + + WriteString(path, content); + ASSERT_TF_OK(status_); + + auto reader = GetReader(); + tf_hadoop_filesystem::NewRandomAccessFile(filesystem_, path.c_str(), + reader.get(), status_); + EXPECT_TF_OK(status_); + + std::string result; + result.resize(content.size()); + auto read = tf_random_access_file::Read(reader.get(), 0, content.size(), + &result[0], status_); + result.resize(read); + EXPECT_TF_OK(status_); + EXPECT_EQ(content.size(), result.size()); + EXPECT_EQ(content, result); + + result.clear(); + result.resize(4); + read = tf_random_access_file::Read(reader.get(), 2, 4, &result[0], status_); + result.resize(read); + EXPECT_TF_OK(status_); + EXPECT_EQ(4, result.size()); + EXPECT_EQ(content.substr(2, 4), result); +} + +TEST_F(HadoopFileSystemTest, WritableFile) { + auto writer = GetWriter(); + const std::string path = TmpDir("WritableFile"); + tf_hadoop_filesystem::NewWritableFile(filesystem_, path.c_str(), writer.get(), + status_); + EXPECT_TF_OK(status_); + tf_writable_file::Append(writer.get(), "content1,", strlen("content1,"), + status_); + EXPECT_TF_OK(status_); + auto pos = tf_writable_file::Tell(writer.get(), status_); + EXPECT_TF_OK(status_); + EXPECT_EQ(pos, 9); + + tf_writable_file::Append(writer.get(), "content2", strlen("content2"), + status_); + EXPECT_TF_OK(status_); + tf_writable_file::Flush(writer.get(), status_); + EXPECT_TF_OK(status_); + tf_writable_file::Sync(writer.get(), status_); + EXPECT_TF_OK(status_); + tf_writable_file::Close(writer.get(), status_); + EXPECT_TF_OK(status_); + + auto content = ReadAll(path); + EXPECT_TF_OK(status_); + EXPECT_EQ("content1,content2", content); +} + +TEST_F(HadoopFileSystemTest, PathExists) { + const std::string path = TmpDir("PathExists"); + tf_hadoop_filesystem::PathExists(filesystem_, path.c_str(), status_); + EXPECT_EQ(TF_NOT_FOUND, TF_GetCode(status_)) << TF_Message(status_); + TF_SetStatus(status_, TF_OK, ""); + WriteString(path, "test"); + ASSERT_TF_OK(status_); + tf_hadoop_filesystem::PathExists(filesystem_, path.c_str(), status_); + EXPECT_TF_OK(status_); +} + +TEST_F(HadoopFileSystemTest, GetChildren) { + const std::string base = TmpDir("GetChildren"); + tf_hadoop_filesystem::CreateDir(filesystem_, base.c_str(), status_); + EXPECT_TF_OK(status_); + + const std::string file = io::JoinPath(base, "TestFile.csv"); + WriteString(file, "test"); + EXPECT_TF_OK(status_); + + const std::string subdir = io::JoinPath(base, "SubDir"); + tf_hadoop_filesystem::CreateDir(filesystem_, subdir.c_str(), status_); + EXPECT_TF_OK(status_); + const std::string subfile = io::JoinPath(subdir, "TestSubFile.csv"); + WriteString(subfile, "test"); + EXPECT_TF_OK(status_); + + char** entries; + auto num_entries = tf_hadoop_filesystem::GetChildren( + filesystem_, base.c_str(), &entries, status_); + EXPECT_TF_OK(status_); + + std::vector childrens; + for (int i = 0; i < num_entries; ++i) { + childrens.push_back(entries[i]); + } + std::sort(childrens.begin(), childrens.end()); + EXPECT_EQ(std::vector({"SubDir", "TestFile.csv"}), childrens); +} + +TEST_F(HadoopFileSystemTest, DeleteFile) { + const std::string path = TmpDir("DeleteFile"); + WriteString(path, "test"); + ASSERT_TF_OK(status_); + tf_hadoop_filesystem::DeleteFile(filesystem_, path.c_str(), status_); + EXPECT_TF_OK(status_); +} + +TEST_F(HadoopFileSystemTest, GetFileSize) { + const std::string path = TmpDir("GetFileSize"); + WriteString(path, "test"); + ASSERT_TF_OK(status_); + auto file_size = + tf_hadoop_filesystem::GetFileSize(filesystem_, path.c_str(), status_); + EXPECT_TF_OK(status_); + EXPECT_EQ(4, file_size); +} + +TEST_F(HadoopFileSystemTest, CreateDirStat) { + const std::string path = TmpDir("CreateDirStat"); + tf_hadoop_filesystem::CreateDir(filesystem_, path.c_str(), status_); + EXPECT_TF_OK(status_); + TF_FileStatistics stat; + tf_hadoop_filesystem::Stat(filesystem_, path.c_str(), &stat, status_); + EXPECT_TF_OK(status_); + EXPECT_TRUE(stat.is_directory); +} + +TEST_F(HadoopFileSystemTest, DeleteDir) { + const std::string path = TmpDir("DeleteDir"); + tf_hadoop_filesystem::DeleteDir(filesystem_, path.c_str(), status_); + EXPECT_NE(TF_GetCode(status_), TF_OK); + tf_hadoop_filesystem::CreateDir(filesystem_, path.c_str(), status_); + EXPECT_TF_OK(status_); + tf_hadoop_filesystem::DeleteDir(filesystem_, path.c_str(), status_); + EXPECT_TF_OK(status_); + TF_FileStatistics stat; + tf_hadoop_filesystem::Stat(filesystem_, path.c_str(), &stat, status_); + EXPECT_NE(TF_GetCode(status_), TF_OK); +} + +TEST_F(HadoopFileSystemTest, RenameFile) { + const std::string src = TmpDir("RenameFileSrc"); + const std::string dst = TmpDir("RenameFileDst"); + WriteString(src, "test"); + ASSERT_TF_OK(status_); + + tf_hadoop_filesystem::RenameFile(filesystem_, src.c_str(), dst.c_str(), + status_); + EXPECT_TF_OK(status_); + auto result = ReadAll(dst); + EXPECT_TF_OK(status_); + EXPECT_EQ("test", result); +} + +TEST_F(HadoopFileSystemTest, RenameFileOverwrite) { + const std::string src = TmpDir("RenameFileOverwriteSrc"); + const std::string dst = TmpDir("RenameFileOverwriteDst"); + + WriteString(src, "test_old"); + ASSERT_TF_OK(status_); + WriteString(dst, "test_new"); + ASSERT_TF_OK(status_); + + tf_hadoop_filesystem::PathExists(filesystem_, dst.c_str(), status_); + EXPECT_TF_OK(status_); + tf_hadoop_filesystem::RenameFile(filesystem_, src.c_str(), dst.c_str(), + status_); + EXPECT_TF_OK(status_); + + auto result = ReadAll(dst); + EXPECT_TF_OK(status_); + EXPECT_EQ("test_old", result); +} + +TEST_F(HadoopFileSystemTest, StatFile) { + const std::string path = TmpDir("StatFile"); + WriteString(path, "test"); + ASSERT_TF_OK(status_); + TF_FileStatistics stat; + tf_hadoop_filesystem::Stat(filesystem_, path.c_str(), &stat, status_); + EXPECT_TF_OK(status_); + EXPECT_EQ(4, stat.length); + EXPECT_FALSE(stat.is_directory); +} + +TEST_F(HadoopFileSystemTest, WriteWhileReading) { + const std::string path = TmpDir("WriteWhileReading"); + // Skip the test if we're not testing on HDFS. Hadoop's local filesystem + // implementation makes no guarantees that writable files are readable while + // being written. + if (path.find_first_of("hdfs://") != 0) GTEST_SKIP(); + + auto writer = GetWriter(); + tf_hadoop_filesystem::NewWritableFile(filesystem_, path.c_str(), writer.get(), + status_); + EXPECT_TF_OK(status_); + + const std::string content1 = "content1"; + tf_writable_file::Append(writer.get(), content1.c_str(), content1.size(), + status_); + EXPECT_TF_OK(status_); + tf_writable_file::Flush(writer.get(), status_); + EXPECT_TF_OK(status_); + + auto reader = GetReader(); + tf_hadoop_filesystem::NewRandomAccessFile(filesystem_, path.c_str(), + reader.get(), status_); + EXPECT_TF_OK(status_); + + std::string result; + result.resize(content1.size()); + auto read = tf_random_access_file::Read(reader.get(), 0, content1.size(), + &result[0], status_); + result.resize(read); + EXPECT_TF_OK(status_); + EXPECT_EQ(content1, result); + + const std::string content2 = "content2"; + tf_writable_file::Append(writer.get(), content2.c_str(), content2.size(), + status_); + EXPECT_TF_OK(status_); + tf_writable_file::Flush(writer.get(), status_); + EXPECT_TF_OK(status_); + + result.resize(content2.size()); + read = tf_random_access_file::Read(reader.get(), content1.size(), + content2.size(), &result[0], status_); + result.resize(read); + EXPECT_TF_OK(status_); + EXPECT_EQ(content2, result); + + tf_writable_file::Close(writer.get(), status_); + EXPECT_TF_OK(status_); +} + +TEST_F(HadoopFileSystemTest, ReadWhileOverwriting) { + static char set_disable_var[] = "HDFS_DISABLE_READ_EOF_RETRIED=1"; + putenv(set_disable_var); + + const std::string path = TmpDir("ReadWhileOverwriting"); + if (path.find_first_of("hdfs://") != 0) GTEST_SKIP(); + + const string content1 = "content1"; + WriteString(path, content1); + ASSERT_TF_OK(status_); + + auto reader = GetReader(); + tf_hadoop_filesystem::NewRandomAccessFile(filesystem_, path.c_str(), + reader.get(), status_); + EXPECT_TF_OK(status_); + + std::string result; + result.resize(content1.size()); + auto read = tf_random_access_file::Read(reader.get(), 0, content1.size(), + &result[0], status_); + result.resize(read); + EXPECT_TF_OK(status_); + EXPECT_EQ(content1, result); + + tf_hadoop_filesystem::DeleteFile(filesystem_, path.c_str(), status_); + EXPECT_TF_OK(status_); + + string content2 = "overwrite"; + WriteString(path, content1 + content2); + ASSERT_TF_OK(status_); + + result.resize(content2.size()); + read = tf_random_access_file::Read(reader.get(), content1.size(), + content2.size(), &result[0], status_); + result.resize(read); + EXPECT_TF_OK(status_); + EXPECT_EQ(0, result.size()); + + static char set_enable_var[] = "HDFS_DISABLE_READ_EOF_RETRIED=0"; + putenv(set_enable_var); +} + +TEST_F(HadoopFileSystemTest, HarSplit) { + const std::string har_path = + "har://hdfs-root/user/j.doe/my_archive.har/dir0/dir1/file.txt"; + std::string scheme, namenode, path; + ParseHadoopPath(har_path, &scheme, &namenode, &path); + EXPECT_EQ("har", scheme); + EXPECT_EQ("hdfs-root", namenode); + EXPECT_EQ("/user/j.doe/my_archive.har/dir0/dir1/file.txt", path); + SplitArchiveNameAndPath(&path, &namenode, status_); + EXPECT_TF_OK(status_); + EXPECT_EQ("har://hdfs-root/user/j.doe/my_archive.har", namenode); + EXPECT_EQ("/dir0/dir1/file.txt", path); +} + +TEST_F(HadoopFileSystemTest, NoHarExtension) { + const std::string har_path = + "har://hdfs-root/user/j.doe/my_archive/dir0/dir1/file.txt"; + std::string scheme, namenode, path; + ParseHadoopPath(har_path, &scheme, &namenode, &path); + EXPECT_EQ("har", scheme); + EXPECT_EQ("hdfs-root", namenode); + EXPECT_EQ("/user/j.doe/my_archive/dir0/dir1/file.txt", path); + SplitArchiveNameAndPath(&path, &namenode, status_); + EXPECT_EQ(TF_GetCode(status_), TF_INVALID_ARGUMENT) << TF_Message(status_); +} + +TEST_F(HadoopFileSystemTest, HarRootPath) { + const std::string har_path = "har://hdfs-root/user/j.doe/my_archive.har"; + std::string scheme, namenode, path; + ParseHadoopPath(har_path, &scheme, &namenode, &path); + EXPECT_EQ("har", scheme); + EXPECT_EQ("hdfs-root", namenode); + EXPECT_EQ("/user/j.doe/my_archive.har", path); + SplitArchiveNameAndPath(&path, &namenode, status_); + EXPECT_TF_OK(status_); + EXPECT_EQ("har://hdfs-root/user/j.doe/my_archive.har", namenode); + EXPECT_EQ("/", path); +} + +TEST_F(HadoopFileSystemTest, WriteLargeFile) { + if (std::getenv("HADOOP_TEST_LARGE_FILE") != "1") GTEST_SKIP(); + const std::string path = TmpDir("WriteLargeFile"); + const size_t file_size = + static_cast(std::numeric_limits::max()) + 1024; + // Fake a test string. + std::string source(file_size, {}); + for (size_t i = 0; i < file_size; ++i) source[i] = (i % 128); + WriteString(path, source); + ASSERT_TF_OK(status_); + auto result = ReadAll(path); + EXPECT_TF_OK(status_); + EXPECT_EQ(source, result); +} +// NewAppendableFile() is not testable. Local filesystem maps to +// ChecksumFileSystem in Hadoop, where appending is an unsupported operation. + +} // namespace +} // namespace tensorflow + +GTEST_API_ int main(int argc, char** argv) { + tensorflow::testing::InstallStacktraceHandler(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/c/experimental/filesystem/plugins/posix/BUILD b/tensorflow/c/experimental/filesystem/plugins/posix/BUILD index 3afe114b5a6..b87dddb96a1 100644 --- a/tensorflow/c/experimental/filesystem/plugins/posix/BUILD +++ b/tensorflow/c/experimental/filesystem/plugins/posix/BUILD @@ -1,3 +1,5 @@ +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") + # Experimental posix filesystem plugin. load("//tensorflow:tensorflow.bzl", "tf_cc_shared_object") diff --git a/tensorflow/c/experimental/filesystem/plugins/windows/BUILD b/tensorflow/c/experimental/filesystem/plugins/windows/BUILD index b845d1e3616..02f7b5ba706 100644 --- a/tensorflow/c/experimental/filesystem/plugins/windows/BUILD +++ b/tensorflow/c/experimental/filesystem/plugins/windows/BUILD @@ -1,3 +1,5 @@ +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") + # Experimental windows filesystem plugin. load("//tensorflow:tensorflow.bzl", "get_win_copts", "tf_cc_shared_object") diff --git a/tensorflow/c/experimental/gradients/BUILD b/tensorflow/c/experimental/gradients/BUILD index 36a3251def7..e8a50e32216 100644 --- a/tensorflow/c/experimental/gradients/BUILD +++ b/tensorflow/c/experimental/gradients/BUILD @@ -1,3 +1,6 @@ +load("//tensorflow:tensorflow.bzl", "filegroup") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") + # Library of gradient functions. package( licenses = ["notice"], # Apache 2.0 @@ -16,7 +19,7 @@ cc_library( "//tensorflow/c/eager:abstract_operation", "//tensorflow/c/eager:abstract_tensor_handle", "//tensorflow/c/eager:c_api_unified_internal", - "//tensorflow/c/eager:gradients", + "//tensorflow/c/eager:gradients_internal", "//tensorflow/core/lib/llvm_rtti", ], ) @@ -31,14 +34,11 @@ cc_library( "//tensorflow:internal", ], deps = [ - "//tensorflow/c/eager:abstract_operation", "//tensorflow/c/eager:abstract_tensor_handle", - "//tensorflow/c/eager:c_api_unified_internal", - "//tensorflow/c/eager:gradients", + "//tensorflow/c/eager:gradients_internal", "//tensorflow/c/experimental/ops:array_ops", "//tensorflow/c/experimental/ops:math_ops", "//tensorflow/c/experimental/ops:nn_ops", - "//tensorflow/core/lib/llvm_rtti", ], ) @@ -52,13 +52,46 @@ cc_library( "//tensorflow:internal", ], deps = [ - "//tensorflow/c/eager:abstract_operation", "//tensorflow/c/eager:abstract_tensor_handle", - "//tensorflow/c/eager:c_api_unified_internal", - "//tensorflow/c/eager:gradients", + "//tensorflow/c/eager:gradients_internal", + "//tensorflow/c/eager:immediate_execution_context", + "//tensorflow/c/eager:immediate_execution_tensor_handle", "//tensorflow/c/experimental/ops:array_ops", "//tensorflow/c/experimental/ops:math_ops", "//tensorflow/c/experimental/ops:nn_ops", "//tensorflow/core/lib/llvm_rtti", + "//tensorflow/core/platform:errors", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "gradients", + hdrs = [ + "array_grad.h", + "math_grad.h", + "nn_grad.h", + ], + visibility = [ + "//tensorflow:internal", + ], + deps = [ + ":array_grad", + ":math_grad", + ":nn_grad", + "//tensorflow/c/eager:gradients_internal", + ], +) + +filegroup( + name = "pywrap_required_hdrs", + srcs = [ + "array_grad.h", + "math_grad.h", + "nn_grad.h", + ], + visibility = [ + "//tensorflow/core:__pkg__", + "//tensorflow/python:__pkg__", ], ) diff --git a/tensorflow/c/experimental/gradients/math_grad.cc b/tensorflow/c/experimental/gradients/math_grad.cc index 3537b30c597..5551642127d 100644 --- a/tensorflow/c/experimental/gradients/math_grad.cc +++ b/tensorflow/c/experimental/gradients/math_grad.cc @@ -22,10 +22,10 @@ limitations under the License. using std::vector; using tensorflow::ops::Conj; -using tensorflow::ops::Identity; using tensorflow::ops::MatMul; using tensorflow::ops::Mul; -using tensorflow::ops::ZerosLike; +using tensorflow::ops::Neg; +using tensorflow::ops::SqrtGrad; namespace tensorflow { namespace gradients { @@ -36,21 +36,14 @@ class AddGradientFunction : public GradientFunction { Status Compute(Context* ctx, const IncomingGradients& grad_inputs, vector* grad_outputs) override { grad_outputs->resize(2); - vector identity_outputs(1); - // TODO(b/145674566): Handle name unification in tracing code. // TODO(b/161805092): Support broadcasting. - std::string name = "Identity_A"; - TF_RETURN_IF_ERROR(ops::Identity(ctx->ctx, {grad_inputs[0]}, - absl::MakeSpan(identity_outputs), - name.c_str())); - (*grad_outputs)[0] = identity_outputs[0]; + DCHECK(grad_inputs[0]); + (*grad_outputs)[0] = grad_inputs[0]; + (*grad_outputs)[1] = grad_inputs[0]; - name = "Identity_B"; - TF_RETURN_IF_ERROR(ops::Identity(ctx->ctx, {grad_inputs[0]}, - absl::MakeSpan(identity_outputs), - name.c_str())); - (*grad_outputs)[1] = identity_outputs[0]; + (*grad_outputs)[0]->Ref(); + (*grad_outputs)[1]->Ref(); return Status::OK(); } ~AddGradientFunction() override {} @@ -81,6 +74,25 @@ class ExpGradientFunction : public GradientFunction { AbstractTensorHandlePtr exp_; }; +class SqrtGradientFunction : public GradientFunction { + public: + explicit SqrtGradientFunction(AbstractTensorHandle* sqrt) : sqrt_(sqrt) { + sqrt->Ref(); + } + Status Compute(Context* ctx, const IncomingGradients& grad_inputs, + vector* grad_outputs) override { + std::string name = "Sqrt_Grad"; + grad_outputs->resize(1); + TF_RETURN_IF_ERROR(SqrtGrad(ctx->ctx, {sqrt_.get(), grad_inputs[0]}, + absl::MakeSpan(*grad_outputs), name.c_str())); + return Status::OK(); + } + ~SqrtGradientFunction() override {} + + private: + AbstractTensorHandlePtr sqrt_; +}; + class MatMulGradientFunction : public GradientFunction { public: explicit MatMulGradientFunction(vector f_inputs, @@ -190,6 +202,56 @@ class MatMulGradientFunction : public GradientFunction { AttrBuilder forward_attrs; }; +class NegGradientFunction : public GradientFunction { + public: + Status Compute(Context* ctx, const IncomingGradients& grad_inputs, + vector* grad_outputs) override { + /* Given upstream grad U and a Neg op Y = -X, the gradients are: + * + * dX = -U + * + */ + + grad_outputs->resize(1); + std::string name = "Neg_Grad"; + TF_RETURN_IF_ERROR(ops::Neg(ctx->ctx, {grad_inputs[0]}, + absl::MakeSpan(*grad_outputs), name.c_str())); + return Status::OK(); + } + ~NegGradientFunction() override {} +}; + +class SubGradientFunction : public GradientFunction { + public: + Status Compute(Context* ctx, const IncomingGradients& grad_inputs, + vector* grad_outputs) override { + /* Given upstream grad U and a Sub op A-B, the gradients are: + * + * dA = U + * dB = -U + * + */ + + grad_outputs->resize(2); + + // Grad for A + DCHECK(grad_inputs[0]); + (*grad_outputs)[0] = grad_inputs[0]; + (*grad_outputs)[0]->Ref(); + + // Grad for B + // negate the upstream grad + std::vector neg_outputs(1); + std::string name = "Neg_Sub_Grad_B"; + TF_RETURN_IF_ERROR(ops::Neg(ctx->ctx, {grad_inputs[0]}, + absl::MakeSpan(neg_outputs), name.c_str())); + (*grad_outputs)[1] = neg_outputs[0]; + + return Status::OK(); + } + ~SubGradientFunction() override {} +}; + } // namespace BackwardFunction* AddRegisterer(const ForwardOperation& op) { @@ -219,5 +281,32 @@ BackwardFunction* MatMulRegisterer(const ForwardOperation& op) { return new BackwardFunction(gradient_function, default_gradients); } +BackwardFunction* SqrtRegisterer(const ForwardOperation& op) { + auto gradient_function = new SqrtGradientFunction(op.outputs[0]); + // For ops with a single output, the gradient function is not called if there + // is no incoming gradient. So we do not need to worry about creating zeros + // grads in this case. + auto default_gradients = new PassThroughDefaultGradients(op); + return new BackwardFunction(gradient_function, default_gradients); +} + +BackwardFunction* NegRegisterer(const ForwardOperation& op) { + auto gradient_function = new NegGradientFunction; + // For ops with a single output, the gradient function is not called if there + // is no incoming gradient. So we do not need to worry about creating zeros + // grads in this case. + auto default_gradients = new PassThroughDefaultGradients(op); + return new BackwardFunction(gradient_function, default_gradients); +} + +BackwardFunction* SubRegisterer(const ForwardOperation& op) { + // For ops with a single output, the gradient function is not called if there + // is no incoming gradient. So we do not need to worry about creating zeros + // grads in this case. + auto gradient_function = new SubGradientFunction; + auto default_gradients = new PassThroughDefaultGradients(op); + return new BackwardFunction(gradient_function, default_gradients); +} + } // namespace gradients } // namespace tensorflow diff --git a/tensorflow/c/experimental/gradients/math_grad.h b/tensorflow/c/experimental/gradients/math_grad.h index 205419e1201..756c5f84153 100644 --- a/tensorflow/c/experimental/gradients/math_grad.h +++ b/tensorflow/c/experimental/gradients/math_grad.h @@ -19,10 +19,15 @@ limitations under the License. namespace tensorflow { namespace gradients { + BackwardFunction* AddRegisterer(const ForwardOperation& op); BackwardFunction* ExpRegisterer(const ForwardOperation& op); BackwardFunction* MatMulRegisterer(const ForwardOperation& op); +BackwardFunction* SqrtRegisterer(const ForwardOperation& op); +BackwardFunction* NegRegisterer(const ForwardOperation& op); +BackwardFunction* SubRegisterer(const ForwardOperation& op); + } // namespace gradients } // namespace tensorflow -#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_MATH_GRAD_H_ \ No newline at end of file +#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_MATH_GRAD_H_ diff --git a/tensorflow/c/experimental/gradients/nn_grad.cc b/tensorflow/c/experimental/gradients/nn_grad.cc index 3da1e0dc153..64532c8ffc0 100644 --- a/tensorflow/c/experimental/gradients/nn_grad.cc +++ b/tensorflow/c/experimental/gradients/nn_grad.cc @@ -14,17 +14,19 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/c/experimental/gradients/nn_grad.h" +#include "absl/types/span.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/immediate_execution_context.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" #include "tensorflow/c/experimental/ops/array_ops.h" #include "tensorflow/c/experimental/ops/math_ops.h" #include "tensorflow/c/experimental/ops/nn_ops.h" +#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" +#include "tensorflow/core/platform/errors.h" using std::vector; -using tensorflow::ops::Conj; -using tensorflow::ops::Identity; using tensorflow::ops::Mul; using tensorflow::ops::ReluGrad; -using tensorflow::ops::SparseSoftmaxCrossEntropyLoss; -using tensorflow::ops::ZerosLike; namespace tensorflow { namespace gradients { @@ -58,9 +60,31 @@ class ReluGradientFunction : public GradientFunction { vector forward_outputs; }; -class SparseSoftmaxCrossEntropyLossGradientFunction : public GradientFunction { +Status BroadcastMul(AbstractContext* ctx, AbstractTensorHandle* vec, + AbstractTensorHandle* mat, + absl::Span outputs) { + if (!isa(ctx)) { + // TODO(b/168850692): Fix this. + return errors::Unimplemented( + "BroadcastMul is not supported in tracing mode yet."); + } + auto imm_ctx = dyn_cast(ctx); + AbstractTensorPtr minus_1(imm_ctx->CreateInt32Scalar(-1)); + ImmediateTensorHandlePtr dim(imm_ctx->CreateLocalHandle(minus_1.get())); + vector expand_dims_outputs(1); + TF_RETURN_IF_ERROR(ops::ExpandDims(ctx, {vec, dim.get()}, + absl::MakeSpan(expand_dims_outputs), + "ExpandDims")); + TF_RETURN_IF_ERROR( + ops::Mul(ctx, {expand_dims_outputs[0], mat}, outputs, "Mul")); + expand_dims_outputs[0]->Unref(); + return Status::OK(); +} + +class SparseSoftmaxCrossEntropyWithLogitsGradientFunction + : public GradientFunction { public: - explicit SparseSoftmaxCrossEntropyLossGradientFunction( + explicit SparseSoftmaxCrossEntropyWithLogitsGradientFunction( vector f_outputs) : forward_outputs(f_outputs) {} @@ -69,12 +93,10 @@ class SparseSoftmaxCrossEntropyLossGradientFunction : public GradientFunction { grad_outputs->resize(2); // Grad for Softmax Input - std::string name = "Mul_Softmax_Grad"; vector mul_outputs(1); - TF_RETURN_IF_ERROR( - ops::Mul(ctx->ctx, {grad_inputs[0], forward_outputs[1]}, - absl::MakeSpan(mul_outputs), - name.c_str())); // upstream_grad * local softmax grad + TF_RETURN_IF_ERROR(BroadcastMul( + ctx->ctx, grad_inputs[0], forward_outputs[1], + absl::MakeSpan(mul_outputs))); // upstream_grad * local softmax grad (*grad_outputs)[0] = mul_outputs[0]; // Grad for labels is null @@ -82,7 +104,7 @@ class SparseSoftmaxCrossEntropyLossGradientFunction : public GradientFunction { return Status::OK(); } - ~SparseSoftmaxCrossEntropyLossGradientFunction() override {} + ~SparseSoftmaxCrossEntropyWithLogitsGradientFunction() override {} private: vector forward_outputs; @@ -99,10 +121,10 @@ BackwardFunction* ReluRegisterer(const ForwardOperation& op) { return new BackwardFunction(gradient_function, default_gradients); } -BackwardFunction* SparseSoftmaxCrossEntropyLossRegisterer( +BackwardFunction* SparseSoftmaxCrossEntropyWithLogitsRegisterer( const ForwardOperation& op) { auto gradient_function = - new SparseSoftmaxCrossEntropyLossGradientFunction(op.outputs); + new SparseSoftmaxCrossEntropyWithLogitsGradientFunction(op.outputs); auto default_gradients = new PassThroughDefaultGradients(op); return new BackwardFunction(gradient_function, default_gradients); } diff --git a/tensorflow/c/experimental/gradients/nn_grad.h b/tensorflow/c/experimental/gradients/nn_grad.h index d002725847f..034f20d7325 100644 --- a/tensorflow/c/experimental/gradients/nn_grad.h +++ b/tensorflow/c/experimental/gradients/nn_grad.h @@ -20,9 +20,9 @@ limitations under the License. namespace tensorflow { namespace gradients { BackwardFunction* ReluRegisterer(const ForwardOperation& op); -BackwardFunction* SparseSoftmaxCrossEntropyLossRegisterer( +BackwardFunction* SparseSoftmaxCrossEntropyWithLogitsRegisterer( const ForwardOperation& op); } // namespace gradients } // namespace tensorflow -#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_NN_GRAD_H_ \ No newline at end of file +#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_NN_GRAD_H_ diff --git a/tensorflow/c/experimental/gradients/tape/BUILD b/tensorflow/c/experimental/gradients/tape/BUILD new file mode 100644 index 00000000000..bada49ea669 --- /dev/null +++ b/tensorflow/c/experimental/gradients/tape/BUILD @@ -0,0 +1,66 @@ +# A tape built on top of unified execution APIs. +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") + +package( + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "tape_context", + srcs = ["tape_context.cc"], + hdrs = [ + "tape_context.h", + ], + visibility = [ + "//tensorflow:internal", + ], + deps = [ + ":tape_operation", + "//tensorflow/c/eager:abstract_context", + "//tensorflow/c/eager:abstract_function", + "//tensorflow/c/eager:abstract_operation", + ], +) + +cc_library( + name = "tape_operation", + srcs = ["tape_operation.cc"], + hdrs = [ + "tape_operation.h", + ], + visibility = [ + "//tensorflow:internal", + ], + deps = [ + "//tensorflow/c/eager:abstract_context", + "//tensorflow/c/eager:abstract_function", + "//tensorflow/c/eager:abstract_operation", + "//tensorflow/c/eager:gradients_internal", + ], +) + +cc_library( + name = "tape", + hdrs = [ + "tape_context.h", + "tape_operation.h", + ], + visibility = [ + "//tensorflow:internal", + ], + deps = [ + ":tape_context", + ":tape_operation", + ], +) + +filegroup( + name = "pywrap_required_hdrs", + srcs = [ + "tape_context.h", + "tape_operation.h", + ], + visibility = [ + "//tensorflow:internal", + ], +) diff --git a/tensorflow/c/experimental/gradients/tape/tape_context.cc b/tensorflow/c/experimental/gradients/tape/tape_context.cc new file mode 100644 index 00000000000..1fa1a3f24f1 --- /dev/null +++ b/tensorflow/c/experimental/gradients/tape/tape_context.cc @@ -0,0 +1,47 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/c/experimental/gradients/tape/tape_context.h" + +#include "tensorflow/c/eager/abstract_context.h" +#include "tensorflow/c/experimental/gradients/tape/tape_operation.h" + +namespace tensorflow { +namespace gradients { +TapeContext::TapeContext(AbstractContext* c, Tape* tape, + const GradientRegistry& registry) + : AbstractContext(kTape), parent_ctx_(c), tape_(tape), registry_(registry) { + // TODO(srbs): Make AbstractContext ref counted. + // parent_ctx_->Ref(); +} +void TapeContext::Release() { + // TODO(srbs): Change to Unref() + delete this; +} +TapeContext::~TapeContext() { + // TODO(srbs): Make AbstractContext ref counted. + // parent_ctx_->Unref(); +} +TapeOperation* TapeContext::CreateOperation() { + return new TapeOperation(parent_ctx_->CreateOperation(), tape_, registry_); +} +Status TapeContext::RegisterFunction(AbstractFunction* f) { + return parent_ctx_->RegisterFunction(f); +} +Status TapeContext::RemoveFunction(const string& func) { + return parent_ctx_->RemoveFunction(func); +} + +} // namespace gradients +} // namespace tensorflow diff --git a/tensorflow/c/experimental/gradients/tape/tape_context.h b/tensorflow/c/experimental/gradients/tape/tape_context.h new file mode 100644 index 00000000000..291053226fb --- /dev/null +++ b/tensorflow/c/experimental/gradients/tape/tape_context.h @@ -0,0 +1,44 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_TAPE_CONTEXT_H_ +#define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_TAPE_CONTEXT_H_ + +#include "tensorflow/c/eager/abstract_context.h" +#include "tensorflow/c/experimental/gradients/tape/tape_operation.h" + +namespace tensorflow { +namespace gradients { +class TapeContext : public AbstractContext { + public: + explicit TapeContext(AbstractContext*, Tape*, const GradientRegistry&); + void Release() override; + TapeOperation* CreateOperation() override; + Status RegisterFunction(AbstractFunction*) override; + Status RemoveFunction(const string& func) override; + // For LLVM style RTTI. + static bool classof(const AbstractContext* ptr) { + return ptr->getKind() == kTape; + } + ~TapeContext() override; + + private: + AbstractContext* parent_ctx_; // Not owned. + Tape* tape_; + const GradientRegistry& registry_; +}; +} // namespace gradients +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_TAPE_CONTEXT_H_ diff --git a/tensorflow/c/experimental/gradients/tape/tape_operation.cc b/tensorflow/c/experimental/gradients/tape/tape_operation.cc new file mode 100644 index 00000000000..0b247d08f6c --- /dev/null +++ b/tensorflow/c/experimental/gradients/tape/tape_operation.cc @@ -0,0 +1,238 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/c/experimental/gradients/tape/tape_operation.h" + +#include "tensorflow/c/eager/abstract_context.h" +#include "tensorflow/c/eager/gradients.h" + +namespace tensorflow { +namespace gradients { +TapeOperation::TapeOperation(AbstractOperation* parent_op, Tape* tape, + const GradientRegistry& registry) + : AbstractOperation(kTape), + parent_op_(parent_op), + tape_(tape), + registry_(registry) { + // TODO(srbs): Make AbstractOperation RefCounted. + // parent_op_->Ref(); +} +void TapeOperation::Release() { + // TODO(srbs): Change to Unref(). + delete this; +} +TapeOperation::~TapeOperation() { + // TODO(srbs): Make AbstractOperation RefCounted. + // parent_op->Unref(); +} +Status TapeOperation::Reset(const char* op, const char* raw_device_name) { + forward_op_.op_name = op; + forward_op_.attrs.Reset(op); + forward_op_.inputs.clear(); + forward_op_.outputs.clear(); + return parent_op_->Reset(op, raw_device_name); +} +const string& TapeOperation::Name() const { return parent_op_->Name(); } +const string& TapeOperation::DeviceName() const { + return parent_op_->DeviceName(); +} +Status TapeOperation::SetDeviceName(const char* name) { + return parent_op_->SetDeviceName(name); +} +Status TapeOperation::AddInput(AbstractTensorHandle* input) { + TF_RETURN_IF_ERROR(parent_op_->AddInput(input)); + forward_op_.inputs.push_back(input); + return Status::OK(); +} +Status TapeOperation::AddInputList( + absl::Span inputs) { + TF_RETURN_IF_ERROR(parent_op_->AddInputList(inputs)); + for (auto input : inputs) { + forward_op_.inputs.push_back(input); + } + return Status::OK(); +} +Status TapeOperation::SetAttrString(const char* attr_name, const char* data, + size_t length) { + forward_op_.attrs.Set(attr_name, StringPiece(data, length)); + return parent_op_->SetAttrString(attr_name, data, length); +} +Status TapeOperation::SetAttrInt(const char* attr_name, int64_t value) { + forward_op_.attrs.Set(attr_name, static_cast(value)); + return parent_op_->SetAttrInt(attr_name, value); +} +Status TapeOperation::SetAttrFloat(const char* attr_name, float value) { + forward_op_.attrs.Set(attr_name, value); + return parent_op_->SetAttrFloat(attr_name, value); +} +Status TapeOperation::SetAttrBool(const char* attr_name, bool value) { + forward_op_.attrs.Set(attr_name, value); + return parent_op_->SetAttrBool(attr_name, value); +} +Status TapeOperation::SetAttrType(const char* attr_name, DataType value) { + forward_op_.attrs.Set(attr_name, value); + return parent_op_->SetAttrType(attr_name, value); +} +Status TapeOperation::SetAttrShape(const char* attr_name, const int64_t* dims, + const int num_dims) { + if (num_dims > TensorShape::MaxDimensions()) { + return errors::InvalidArgument("Value specified for `", attr_name, "` has ", + num_dims, + " dimensions which is over the limit of ", + TensorShape::MaxDimensions(), "."); + } + TensorShapeProto proto; + if (num_dims < 0) { + proto.set_unknown_rank(true); + } else { + for (int d = 0; d < num_dims; ++d) { + proto.add_dim()->set_size(dims[d]); + } + } + + forward_op_.attrs.Set(attr_name, proto); + return parent_op_->SetAttrShape(attr_name, dims, num_dims); +} +Status TapeOperation::SetAttrFunction(const char* attr_name, + const AbstractOperation* value) { + return tensorflow::errors::Unimplemented( + "SetAttrFunction has not been implemented yet."); +} +Status TapeOperation::SetAttrFunctionName(const char* attr_name, + const char* value, size_t length) { + return tensorflow::errors::Unimplemented( + "SetAttrFunctionName has not been implemented " + "yet."); +} +Status TapeOperation::SetAttrTensor(const char* attr_name, + AbstractTensorInterface* tensor) { + return tensorflow::errors::Unimplemented( + "SetAttrTensor has not been implemented yet."); +} +Status TapeOperation::SetAttrStringList(const char* attr_name, + const void* const* values, + const size_t* lengths, int num_values) { + std::vector v(num_values); + for (int i = 0; i < num_values; ++i) { + v[i] = StringPiece(static_cast(values[i]), lengths[i]); + } + forward_op_.attrs.Set(attr_name, v); + return parent_op_->SetAttrStringList(attr_name, values, lengths, num_values); +} +Status TapeOperation::SetAttrFloatList(const char* attr_name, + const float* values, int num_values) { + forward_op_.attrs.Set(attr_name, + gtl::ArraySlice(values, num_values)); + return parent_op_->SetAttrFloatList(attr_name, values, num_values); +} +Status TapeOperation::SetAttrIntList(const char* attr_name, + const int64_t* values, int num_values) { + forward_op_.attrs.Set( + attr_name, gtl::ArraySlice( + reinterpret_cast(values), num_values)); + return parent_op_->SetAttrIntList(attr_name, values, num_values); +} +Status TapeOperation::SetAttrTypeList(const char* attr_name, + const DataType* values, int num_values) { + forward_op_.attrs.Set(attr_name, + gtl::ArraySlice(values, num_values)); + return parent_op_->SetAttrTypeList(attr_name, values, num_values); +} +Status TapeOperation::SetAttrBoolList(const char* attr_name, + const unsigned char* values, + int num_values) { + std::unique_ptr b(new bool[num_values]); + for (int i = 0; i < num_values; ++i) { + b[i] = values[i]; + } + forward_op_.attrs.Set(attr_name, + gtl::ArraySlice(b.get(), num_values)); + return parent_op_->SetAttrBoolList(attr_name, values, num_values); +} +Status TapeOperation::SetAttrShapeList(const char* attr_name, + const int64_t** dims, + const int* num_dims, int num_values) { + std::unique_ptr proto(new TensorShapeProto[num_values]); + for (int i = 0; i < num_values; ++i) { + const auto num_dims_i = num_dims[i]; + + if (num_dims_i > TensorShape::MaxDimensions()) { + return errors::InvalidArgument( + strings::StrCat("Value specified for `", attr_name, "` has ", + num_dims_i, " dimensions which is over the limit of ", + TensorShape::MaxDimensions(), ".")); + } + if (num_dims_i < 0) { + proto[i].set_unknown_rank(true); + } else { + const int64_t* dims_i = dims[i]; + auto proto_i = &proto[i]; + for (int d = 0; d < num_dims_i; ++d) { + proto_i->add_dim()->set_size(dims_i[d]); + } + } + } + forward_op_.attrs.Set( + attr_name, gtl::ArraySlice(proto.get(), num_values)); + return parent_op_->SetAttrShapeList(attr_name, dims, num_dims, num_values); +} +Status TapeOperation::SetAttrFunctionList( + const char* attr_name, absl::Span values) { + return tensorflow::errors::Unimplemented( + "SetAttrFunctionList has not been " + "implemented yet."); +} +AbstractOperation* TapeOperation::GetBackingOperation() { return parent_op_; } +Status TapeOperation::Execute(absl::Span retvals, + int* num_retvals) { + TF_RETURN_IF_ERROR(parent_op_->Execute(retvals, num_retvals)); + std::vector input_ids(forward_op_.inputs.size()); + std::vector input_dtypes(forward_op_.inputs.size()); + for (int i = 0; i < forward_op_.inputs.size(); i++) { + input_ids[i] = ToId(forward_op_.inputs[i]); + input_dtypes[i] = forward_op_.inputs[i]->DataType(); + } + for (int i = 0; i < *num_retvals; i++) { + // TODO(srbs): Manage refcount of ForwardOperation's inputs/outputs. + forward_op_.outputs.push_back(retvals[i]); + } + // TODO(b/166669239): This is needed to support AttrBuilder::Get for string + // attributes. Number type attrs and DataType attrs work fine without this. + // Consider getting rid of this and making the behavior between number types + // and string consistent. + forward_op_.attrs.BuildNodeDef(); + std::vector tape_tensors; + for (auto t : retvals) { + tape_tensors.push_back(TapeTensor(t)); + } + tape_->RecordOperation( + parent_op_->Name(), tape_tensors, input_ids, input_dtypes, + [this]() -> BackwardFunction* { + std::unique_ptr backward_fn; + Status s = registry_.Lookup(forward_op_, &backward_fn); + if (!s.ok()) { + return nullptr; + } + return backward_fn.release(); + }, + [](BackwardFunction* ptr) { + if (ptr) { + delete ptr; + } + }); + return Status::OK(); +} + +} // namespace gradients +} // namespace tensorflow diff --git a/tensorflow/c/experimental/gradients/tape/tape_operation.h b/tensorflow/c/experimental/gradients/tape/tape_operation.h new file mode 100644 index 00000000000..b971176d9e7 --- /dev/null +++ b/tensorflow/c/experimental/gradients/tape/tape_operation.h @@ -0,0 +1,80 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_TAPE_OPERATION_H_ +#define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_TAPE_OPERATION_H_ + +#include "tensorflow/c/eager/abstract_operation.h" +#include "tensorflow/c/eager/gradients.h" + +namespace tensorflow { +namespace gradients { +class TapeOperation : public AbstractOperation { + public: + explicit TapeOperation(AbstractOperation*, Tape*, const GradientRegistry&); + void Release() override; + Status Reset(const char* op, const char* raw_device_name) override; + const string& Name() const override; + const string& DeviceName() const override; + Status SetDeviceName(const char* name) override; + Status AddInput(AbstractTensorHandle* input) override; + Status AddInputList(absl::Span inputs) override; + Status Execute(absl::Span retvals, + int* num_retvals) override; + Status SetAttrString(const char* attr_name, const char* data, + size_t length) override; + Status SetAttrInt(const char* attr_name, int64_t value) override; + Status SetAttrFloat(const char* attr_name, float value) override; + Status SetAttrBool(const char* attr_name, bool value) override; + Status SetAttrType(const char* attr_name, DataType value) override; + Status SetAttrShape(const char* attr_name, const int64_t* dims, + const int num_dims) override; + Status SetAttrFunction(const char* attr_name, + const AbstractOperation* value) override; + Status SetAttrFunctionName(const char* attr_name, const char* value, + size_t length) override; + Status SetAttrTensor(const char* attr_name, + AbstractTensorInterface* tensor) override; + Status SetAttrStringList(const char* attr_name, const void* const* values, + const size_t* lengths, int num_values) override; + Status SetAttrFloatList(const char* attr_name, const float* values, + int num_values) override; + Status SetAttrIntList(const char* attr_name, const int64_t* values, + int num_values) override; + Status SetAttrTypeList(const char* attr_name, const DataType* values, + int num_values) override; + Status SetAttrBoolList(const char* attr_name, const unsigned char* values, + int num_values) override; + Status SetAttrShapeList(const char* attr_name, const int64_t** dims, + const int* num_dims, int num_values) override; + Status SetAttrFunctionList( + const char* attr_name, + absl::Span values) override; + AbstractOperation* GetBackingOperation(); + // For LLVM style RTTI. + static bool classof(const AbstractOperation* ptr) { + return ptr->getKind() == kTape; + } + ~TapeOperation() override; + + private: + AbstractOperation* parent_op_; + ForwardOperation forward_op_; + Tape* tape_; + const GradientRegistry& registry_; +}; + +} // namespace gradients +} // namespace tensorflow +#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_TAPE_OPERATION_H_ diff --git a/tensorflow/c/experimental/ops/BUILD b/tensorflow/c/experimental/ops/BUILD index c5810bffa48..d311a0c26db 100644 --- a/tensorflow/c/experimental/ops/BUILD +++ b/tensorflow/c/experimental/ops/BUILD @@ -1,3 +1,6 @@ +load("//tensorflow:tensorflow.bzl", "filegroup") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") + # Experimental ops. These will eventually be replaced by machine-generated versions. package( licenses = ["notice"], # Apache 2.0 @@ -19,7 +22,7 @@ cc_library( "//tensorflow/c/eager:abstract_operation", "//tensorflow/c/eager:abstract_tensor_handle", "//tensorflow/c/eager:c_api_unified_internal", - "//tensorflow/core/lib/llvm_rtti", + "//tensorflow/c/eager:tracing_utils", "//tensorflow/core/platform:errors", ], ) @@ -40,8 +43,8 @@ cc_library( "//tensorflow/c/eager:abstract_context", "//tensorflow/c/eager:abstract_tensor_handle", "//tensorflow/c/eager:c_api_unified_internal", + "//tensorflow/c/eager:tracing_utils", "//tensorflow/core:framework", - "//tensorflow/core/lib/llvm_rtti", "//tensorflow/core/platform:errors", ], ) @@ -61,7 +64,41 @@ cc_library( "//tensorflow/c/eager:abstract_operation", "//tensorflow/c/eager:abstract_tensor_handle", "//tensorflow/c/eager:c_api_unified_internal", - "//tensorflow/core/lib/llvm_rtti", + "//tensorflow/c/eager:tracing_utils", "//tensorflow/core/platform:errors", ], ) + +cc_library( + name = "ops", + hdrs = [ + "array_ops.h", + "math_ops.h", + "nn_ops.h", + ], + visibility = [ + "//tensorflow:internal", + ], + deps = [ + ":array_ops", + ":math_ops", + ":nn_ops", + "//tensorflow/c/eager:abstract_context", + "//tensorflow/c/eager:abstract_operation", + "//tensorflow/c/eager:abstract_tensor_handle", + "//tensorflow/c/eager:c_api_unified_internal", + ], +) + +filegroup( + name = "pywrap_required_hdrs", + srcs = [ + "array_ops.h", + "math_ops.h", + "nn_ops.h", + ], + visibility = [ + "//tensorflow/core:__pkg__", + "//tensorflow/python:__pkg__", + ], +) diff --git a/tensorflow/c/experimental/ops/array_ops.cc b/tensorflow/c/experimental/ops/array_ops.cc index df0f4639fbd..debeba18edf 100644 --- a/tensorflow/c/experimental/ops/array_ops.cc +++ b/tensorflow/c/experimental/ops/array_ops.cc @@ -14,9 +14,11 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/c/experimental/ops/array_ops.h" -#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" +#include "tensorflow/c/eager/tracing_utils.h" #include "tensorflow/core/platform/errors.h" +using tensorflow::tracing::MaybeSetOpName; + namespace tensorflow { namespace ops { @@ -26,28 +28,58 @@ Status Identity(AbstractContext* ctx, AbstractOperationPtr identity_op(ctx->CreateOperation()); TF_RETURN_IF_ERROR( identity_op->Reset("Identity", /*raw_device_name=*/nullptr)); - if (isa(identity_op.get())) { - TF_RETURN_IF_ERROR(dyn_cast(identity_op.get()) - ->SetOpName(name)); - } + TF_RETURN_IF_ERROR(MaybeSetOpName(identity_op.get(), name)); TF_RETURN_IF_ERROR(identity_op->AddInput(inputs[0])); int num_retvals = 1; return identity_op->Execute(outputs, &num_retvals); } +Status IdentityN(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, const char* name) { + AbstractOperationPtr identity_n_op(ctx->CreateOperation()); + TF_RETURN_IF_ERROR( + identity_n_op->Reset("IdentityN", /*raw_device_name=*/nullptr)); + TF_RETURN_IF_ERROR(MaybeSetOpName(identity_n_op.get(), name)); + TF_RETURN_IF_ERROR(identity_n_op->AddInputList(inputs)); + int num_retvals = inputs.size(); + return identity_n_op->Execute(outputs, &num_retvals); +} + Status ZerosLike(AbstractContext* ctx, absl::Span inputs, absl::Span outputs, const char* name) { AbstractOperationPtr z_op(ctx->CreateOperation()); TF_RETURN_IF_ERROR(z_op->Reset("ZerosLike", /*raw_device_name=*/nullptr)); - if (isa(z_op.get())) { - TF_RETURN_IF_ERROR( - dyn_cast(z_op.get())->SetOpName(name)); - } + TF_RETURN_IF_ERROR(MaybeSetOpName(z_op.get(), name)); TF_RETURN_IF_ERROR(z_op->AddInput(inputs[0])); int num_retvals = 1; return z_op->Execute(outputs, &num_retvals); } +Status Shape(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, const char* name) { + AbstractOperationPtr shape_op(ctx->CreateOperation()); + TF_RETURN_IF_ERROR(shape_op->Reset("Shape", /*raw_device_name=*/nullptr)); + TF_RETURN_IF_ERROR(MaybeSetOpName(shape_op.get(), name)); + TF_RETURN_IF_ERROR(shape_op->AddInput(inputs[0])); // input + int num_retvals = 1; + TF_RETURN_IF_ERROR(shape_op->Execute(outputs, &num_retvals)); + return Status::OK(); +} + +Status ExpandDims(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, const char* name) { + AbstractOperationPtr op(ctx->CreateOperation()); + TF_RETURN_IF_ERROR(op->Reset("ExpandDims", /*raw_device_name=*/nullptr)); + TF_RETURN_IF_ERROR(MaybeSetOpName(op.get(), name)); + TF_RETURN_IF_ERROR(op->AddInput(inputs[0])); + TF_RETURN_IF_ERROR(op->AddInput(inputs[1])); + int num_retvals = 1; + return op->Execute(outputs, &num_retvals); +} + } // namespace ops } // namespace tensorflow diff --git a/tensorflow/c/experimental/ops/array_ops.h b/tensorflow/c/experimental/ops/array_ops.h index 8dc68db673f..f63412ed248 100644 --- a/tensorflow/c/experimental/ops/array_ops.h +++ b/tensorflow/c/experimental/ops/array_ops.h @@ -18,7 +18,6 @@ limitations under the License. #include "tensorflow/c/eager/abstract_context.h" #include "tensorflow/c/eager/abstract_operation.h" #include "tensorflow/c/eager/abstract_tensor_handle.h" -#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" namespace tensorflow { namespace ops { @@ -27,10 +26,22 @@ Status Identity(AbstractContext* ctx, absl::Span inputs, absl::Span outputs, const char* name); +Status IdentityN(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, const char* name); + Status ZerosLike(AbstractContext* ctx, absl::Span inputs, absl::Span outputs, const char* name); +Status Shape(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, const char* name); + +Status ExpandDims(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, const char* name); + } // namespace ops } // namespace tensorflow diff --git a/tensorflow/c/experimental/ops/math_ops.cc b/tensorflow/c/experimental/ops/math_ops.cc index 82c2f0e8169..20aab8a77d3 100644 --- a/tensorflow/c/experimental/ops/math_ops.cc +++ b/tensorflow/c/experimental/ops/math_ops.cc @@ -16,22 +16,21 @@ limitations under the License. #include "tensorflow/c/eager/abstract_context.h" #include "tensorflow/c/eager/abstract_tensor_handle.h" -#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" +#include "tensorflow/c/eager/tracing_utils.h" #include "tensorflow/c/experimental/ops/array_ops.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/platform/errors.h" + +using tensorflow::tracing::MaybeSetOpName; + namespace tensorflow { namespace ops { -using tensorflow::tracing::TracingOperation; Status Mul(AbstractContext* ctx, absl::Span inputs, absl::Span outputs, const char* name) { AbstractOperationPtr mul_op(ctx->CreateOperation()); TF_RETURN_IF_ERROR(mul_op->Reset("Mul", /*raw_device_name=*/nullptr)); - if (isa(mul_op.get())) { - TF_RETURN_IF_ERROR( - dyn_cast(mul_op.get())->SetOpName(name)); - } + TF_RETURN_IF_ERROR(MaybeSetOpName(mul_op.get(), name)); TF_RETURN_IF_ERROR(mul_op->AddInput(inputs[0])); TF_RETURN_IF_ERROR(mul_op->AddInput(inputs[1])); int num_retvals = 1; @@ -55,12 +54,7 @@ Status Add(AbstractContext* ctx, absl::Span inputs, absl::Span outputs, const char* name) { AbstractOperationPtr add_op(ctx->CreateOperation()); TF_RETURN_IF_ERROR(add_op->Reset("AddV2", /*raw_device_name=*/nullptr)); - - if (isa(add_op.get())) { - TF_RETURN_IF_ERROR( - dyn_cast(add_op.get())->SetOpName(name)); - } - + TF_RETURN_IF_ERROR(MaybeSetOpName(add_op.get(), name)); TF_RETURN_IF_ERROR(add_op->AddInput(inputs[0])); TF_RETURN_IF_ERROR(add_op->AddInput(inputs[1])); @@ -69,18 +63,26 @@ Status Add(AbstractContext* ctx, absl::Span inputs, return Status::OK(); } +Status Sub(AbstractContext* ctx, absl::Span inputs, + absl::Span outputs, const char* name) { + AbstractOperationPtr sub_op(ctx->CreateOperation()); + TF_RETURN_IF_ERROR(sub_op->Reset("Sub", /*raw_device_name=*/nullptr)); + TF_RETURN_IF_ERROR(MaybeSetOpName(sub_op.get(), name)); + TF_RETURN_IF_ERROR(sub_op->AddInput(inputs[0])); + TF_RETURN_IF_ERROR(sub_op->AddInput(inputs[1])); + + int num_retvals = 1; + TF_RETURN_IF_ERROR(sub_op->Execute(outputs, &num_retvals)); + return Status::OK(); +} + Status MatMul(AbstractContext* ctx, absl::Span inputs, absl::Span outputs, const char* name, bool transpose_a = false, bool transpose_b = false) { AbstractOperationPtr matmul_op(ctx->CreateOperation()); TF_RETURN_IF_ERROR(matmul_op->Reset("MatMul", /*raw_device_name=*/nullptr)); - - if (isa(matmul_op.get())) { - TF_RETURN_IF_ERROR( - dyn_cast(matmul_op.get())->SetOpName(name)); - } - + TF_RETURN_IF_ERROR(MaybeSetOpName(matmul_op.get(), name)); TF_RETURN_IF_ERROR(matmul_op->AddInput(inputs[0])); TF_RETURN_IF_ERROR(matmul_op->AddInput(inputs[1])); @@ -96,15 +98,79 @@ Status Neg(AbstractContext* ctx, absl::Span inputs, absl::Span outputs, const char* name) { AbstractOperationPtr neg_op(ctx->CreateOperation()); TF_RETURN_IF_ERROR(neg_op->Reset("Neg", /*raw_device_name=*/nullptr)); - if (isa(neg_op.get())) { - TF_RETURN_IF_ERROR( - dyn_cast(neg_op.get())->SetOpName(name)); - } + TF_RETURN_IF_ERROR(MaybeSetOpName(neg_op.get(), name)); TF_RETURN_IF_ERROR(neg_op->AddInput(inputs[0])); int num_retvals = 1; return neg_op->Execute(outputs, &num_retvals); } +Status Sum(AbstractContext* ctx, absl::Span inputs, + absl::Span outputs, const char* name) { + AbstractOperationPtr sum_op(ctx->CreateOperation()); + TF_RETURN_IF_ERROR(sum_op->Reset("Sum", /*raw_device_name=*/nullptr)); + TF_RETURN_IF_ERROR(MaybeSetOpName(sum_op.get(), name)); + TF_RETURN_IF_ERROR(sum_op->AddInput(inputs[0])); // input_vals + TF_RETURN_IF_ERROR(sum_op->AddInput(inputs[1])); // reduction_indices + + int num_retvals = 1; + TF_RETURN_IF_ERROR(sum_op->Execute(outputs, &num_retvals)); + return Status::OK(); +} + +Status DivNoNan(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, const char* name) { + AbstractOperationPtr div_op(ctx->CreateOperation()); + TF_RETURN_IF_ERROR(div_op->Reset("DivNoNan", /*raw_device_name=*/nullptr)); + TF_RETURN_IF_ERROR(MaybeSetOpName(div_op.get(), name)); + TF_RETURN_IF_ERROR(div_op->AddInput(inputs[0])); // x + TF_RETURN_IF_ERROR(div_op->AddInput(inputs[1])); // y + + int num_retvals = 1; + TF_RETURN_IF_ERROR(div_op->Execute( + outputs, &num_retvals)); // z = x / y, (z_i = 0 if y_i = 0) + return Status::OK(); +} + +Status Exp(AbstractContext* ctx, absl::Span inputs, + absl::Span outputs, const char* name) { + AbstractOperationPtr exp_op(ctx->CreateOperation()); + TF_RETURN_IF_ERROR(exp_op->Reset("Exp", /*raw_device_name=*/nullptr)); + TF_RETURN_IF_ERROR(MaybeSetOpName(exp_op.get(), name)); + TF_RETURN_IF_ERROR(exp_op->AddInput(inputs[0])); + + int num_retvals = 1; + return exp_op->Execute(outputs, &num_retvals); +} + +Status Sqrt(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, const char* name) { + AbstractOperationPtr sqrt_op(ctx->CreateOperation()); + TF_RETURN_IF_ERROR(sqrt_op->Reset("Sqrt", /*raw_device_name=*/nullptr)); + TF_RETURN_IF_ERROR(MaybeSetOpName(sqrt_op.get(), name)); + TF_RETURN_IF_ERROR(sqrt_op->AddInput(inputs[0])); + + int num_retvals = 1; + Status s = sqrt_op->Execute(outputs, &num_retvals); + return s; +} + +Status SqrtGrad(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, const char* name) { + AbstractOperationPtr sqrt_grad_op(ctx->CreateOperation()); + TF_RETURN_IF_ERROR( + sqrt_grad_op->Reset("SqrtGrad", /*raw_device_name=*/nullptr)); + TF_RETURN_IF_ERROR(MaybeSetOpName(sqrt_grad_op.get(), name)); + TF_RETURN_IF_ERROR(sqrt_grad_op->AddInput(inputs[0])); + TF_RETURN_IF_ERROR(sqrt_grad_op->AddInput(inputs[1])); + + int num_retvals = 1; + Status s = sqrt_grad_op->Execute(outputs, &num_retvals); + return s; +} + } // namespace ops } // namespace tensorflow diff --git a/tensorflow/c/experimental/ops/math_ops.h b/tensorflow/c/experimental/ops/math_ops.h index ed1e6c5b3d6..7051e38656f 100644 --- a/tensorflow/c/experimental/ops/math_ops.h +++ b/tensorflow/c/experimental/ops/math_ops.h @@ -22,18 +22,43 @@ namespace tensorflow { namespace ops { Status Mul(AbstractContext* ctx, absl::Span inputs, absl::Span outputs, const char* name); + Status Conj(AbstractContext* ctx, absl::Span inputs, absl::Span outputs, const char* name); + Status Add(AbstractContext* ctx, absl::Span inputs, absl::Span outputs, const char* name); + Status MatMul(AbstractContext* ctx, absl::Span inputs, absl::Span outputs, const char* name, bool transpose_a, bool transpose_b); + Status Neg(AbstractContext* ctx, absl::Span inputs, absl::Span outputs, const char* name); +Status Sum(AbstractContext* ctx, absl::Span inputs, + absl::Span outputs, const char* name); + +Status Sub(AbstractContext* ctx, absl::Span inputs, + absl::Span outputs, const char* name); + +Status DivNoNan(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, const char* name); + +Status Exp(AbstractContext* ctx, absl::Span inputs, + absl::Span outputs, const char* name); + +Status Sqrt(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, const char* name); + +Status SqrtGrad(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, const char* name); + } // namespace ops } // namespace tensorflow diff --git a/tensorflow/c/experimental/ops/nn_ops.cc b/tensorflow/c/experimental/ops/nn_ops.cc index 8f5f550bb8b..6a97dbf0939 100644 --- a/tensorflow/c/experimental/ops/nn_ops.cc +++ b/tensorflow/c/experimental/ops/nn_ops.cc @@ -15,24 +15,22 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/c/experimental/ops/nn_ops.h" +#include "tensorflow/c/eager/tracing_utils.h" #include "tensorflow/core/platform/errors.h" +using tensorflow::tracing::MaybeSetOpName; + namespace tensorflow { namespace ops { // Softmax Loss given scores and labels, used by the SoftMaxLossGradient -Status SparseSoftmaxCrossEntropyLoss( +Status SparseSoftmaxCrossEntropyWithLogits( AbstractContext* ctx, absl::Span inputs, absl::Span outputs, const char* name) { AbstractOperationPtr sm_loss_op(ctx->CreateOperation()); TF_RETURN_IF_ERROR(sm_loss_op->Reset("SparseSoftmaxCrossEntropyWithLogits", /*raw_device_name=*/nullptr)); - - if (isa(sm_loss_op.get())) { - TF_RETURN_IF_ERROR( - dyn_cast(sm_loss_op.get())->SetOpName(name)); - } - + TF_RETURN_IF_ERROR(MaybeSetOpName(sm_loss_op.get(), name)); TF_RETURN_IF_ERROR(sm_loss_op->AddInput(inputs[0])); // input scores TF_RETURN_IF_ERROR(sm_loss_op->AddInput(inputs[1])); // labels @@ -49,12 +47,7 @@ Status ReluGrad(AbstractContext* ctx, AbstractOperationPtr relugrad_op(ctx->CreateOperation()); TF_RETURN_IF_ERROR( relugrad_op->Reset("ReluGrad", /*raw_device_name=*/nullptr)); - - if (isa(relugrad_op.get())) { - TF_RETURN_IF_ERROR(dyn_cast(relugrad_op.get()) - ->SetOpName(name)); - } - + TF_RETURN_IF_ERROR(MaybeSetOpName(relugrad_op.get(), name)); TF_RETURN_IF_ERROR(relugrad_op->AddInput(inputs[0])); // upstream grads TF_RETURN_IF_ERROR(relugrad_op->AddInput(inputs[1])); // relu inputs @@ -63,5 +56,18 @@ Status ReluGrad(AbstractContext* ctx, return Status::OK(); } +Status Relu(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, const char* name) { + AbstractOperationPtr relu_op(ctx->CreateOperation()); + TF_RETURN_IF_ERROR(relu_op->Reset("Relu", /*raw_device_name=*/nullptr)); + TF_RETURN_IF_ERROR(MaybeSetOpName(relu_op.get(), name)); + TF_RETURN_IF_ERROR(relu_op->AddInput(inputs[0])); + + int num_retvals = 1; + TF_RETURN_IF_ERROR(relu_op->Execute(outputs, &num_retvals)); + return Status::OK(); +} + } // namespace ops } // namespace tensorflow diff --git a/tensorflow/c/experimental/ops/nn_ops.h b/tensorflow/c/experimental/ops/nn_ops.h index 3e618b00869..3c0e04579a1 100644 --- a/tensorflow/c/experimental/ops/nn_ops.h +++ b/tensorflow/c/experimental/ops/nn_ops.h @@ -18,12 +18,11 @@ limitations under the License. #include "tensorflow/c/eager/abstract_operation.h" #include "tensorflow/c/eager/abstract_tensor_handle.h" #include "tensorflow/c/eager/c_api_unified_experimental_internal.h" -#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" namespace tensorflow { namespace ops { -Status SparseSoftmaxCrossEntropyLoss( +Status SparseSoftmaxCrossEntropyWithLogits( AbstractContext* ctx, absl::Span inputs, absl::Span outputs, const char* name); @@ -31,6 +30,10 @@ Status ReluGrad(AbstractContext* ctx, absl::Span inputs, absl::Span outputs, const char* name); +Status Relu(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, const char* name); + } // namespace ops } // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/BUILD b/tensorflow/c/experimental/saved_model/core/BUILD index 2feb7c1b33e..4cf868e4714 100644 --- a/tensorflow/c/experimental/saved_model/core/BUILD +++ b/tensorflow/c/experimental/saved_model/core/BUILD @@ -1,3 +1,5 @@ +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") + # Experimental SavedModel C APIs for TensorFlow. See RFC # https://github.com/tensorflow/community/pull/207 # Targets in this directory are pure C++ "Classes" underlying the C API types @@ -62,13 +64,21 @@ cc_library( ":function_metadata", "//tensorflow/c:tf_tensor_internal", "//tensorflow/c/eager:immediate_execution_context", + "//tensorflow/c/experimental/saved_model/core/revived_types:asset", "//tensorflow/c/experimental/saved_model/core/revived_types:constant", + "//tensorflow/c/experimental/saved_model/core/revived_types:partially_revived_objects", + "//tensorflow/c/experimental/saved_model/core/revived_types:restored_resource_revival_state", "//tensorflow/c/experimental/saved_model/core/revived_types:tf_concrete_function", + "//tensorflow/c/experimental/saved_model/core/revived_types:tf_concrete_function_revival_state", + "//tensorflow/c/experimental/saved_model/core/revived_types:tf_signature_def_function_revival_state", "//tensorflow/c/experimental/saved_model/core/revived_types:variable", + "//tensorflow/cc/saved_model:loader_util", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", ], ) @@ -81,15 +91,24 @@ cc_library( ":signature_def_function_metadata", "//tensorflow/c/eager:immediate_execution_operation", "//tensorflow/c/eager:immediate_execution_tensor_handle", + "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/types:span", ], ) cc_library( name = "signature_def_function_metadata", + srcs = [ + "signature_def_function_metadata.cc", + ], hdrs = [ "signature_def_function_metadata.h", ], + deps = [ + ":tensor_spec", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], ) cc_library( @@ -138,11 +157,13 @@ cc_library( ":saved_model_api", ":saved_model_utils", ":signature_def_function", - "//tensorflow/c:tensor_interface", "//tensorflow/c/eager:immediate_execution_context", "//tensorflow/c/eager:immediate_execution_tensor_handle", "//tensorflow/c/experimental/saved_model/core/ops:restore_ops", "//tensorflow/c/experimental/saved_model/core/revived_types:constant", + "//tensorflow/c/experimental/saved_model/core/revived_types:flat_tensor_function", + "//tensorflow/c/experimental/saved_model/core/revived_types:partially_revived_objects", + "//tensorflow/c/experimental/saved_model/core/revived_types:revived_objects", "//tensorflow/c/experimental/saved_model/core/revived_types:tensorhandle_convertible", "//tensorflow/c/experimental/saved_model/core/revived_types:tf_concrete_function", "//tensorflow/c/experimental/saved_model/core/revived_types:variable", @@ -151,7 +172,6 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", - "//tensorflow/core/common_runtime/eager:tensor_handle", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", @@ -213,6 +233,7 @@ tf_cc_test( "//tensorflow/core/common_runtime/eager:context", "//tensorflow/core/common_runtime/eager:core", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", ], ) @@ -256,6 +277,20 @@ tf_cc_test( ], ) +cc_library( + name = "tensor_spec", + srcs = [ + "tensor_spec.cc", + ], + hdrs = [ + "tensor_spec.h", + ], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + ], +) + tf_cc_test( name = "tf_concrete_function_loading_test", srcs = [ diff --git a/tensorflow/c/experimental/saved_model/core/concrete_function.h b/tensorflow/c/experimental/saved_model/core/concrete_function.h index 934fa6d2bda..48a20ef7768 100644 --- a/tensorflow/c/experimental/saved_model/core/concrete_function.h +++ b/tensorflow/c/experimental/saved_model/core/concrete_function.h @@ -43,8 +43,8 @@ class ConcreteFunction { virtual ~ConcreteFunction() = default; // This method returns the "Call" Op used to execute the function. - virtual Status GetCallOp(absl::Span inputs, - ImmediateOpPtr* out) = 0; + virtual Status MakeCallOp(absl::Span inputs, + ImmediateOpPtr* out) const = 0; virtual const FunctionMetadata& GetFunctionMetadata() const = 0; }; diff --git a/tensorflow/c/experimental/saved_model/core/object_graph_traversal_test.cc b/tensorflow/c/experimental/saved_model/core/object_graph_traversal_test.cc index 1c70d40cada..d179d0de6b7 100644 --- a/tensorflow/c/experimental/saved_model/core/object_graph_traversal_test.cc +++ b/tensorflow/c/experimental/saved_model/core/object_graph_traversal_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "absl/strings/string_view.h" +#include "absl/types/optional.h" #include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/stringpiece.h" @@ -300,80 +301,70 @@ nodes { TEST(ObjectGraphTraversalTest, Success) { SavedObjectGraph object_graph = ParseSavedObjectGraph(kSingleChildFoo); - const SavedObject* obj = internal::FindNodeAtPath("foo", object_graph); - ASSERT_NE(nullptr, obj); - EXPECT_EQ(obj->kind_case(), SavedObject::kUserObject); - EXPECT_EQ(obj->user_object().identifier(), "_generic_user_object"); + absl::optional node = internal::FindNodeAtPath("foo", object_graph); + ASSERT_TRUE(node.has_value()); + EXPECT_EQ(*node, 1); } TEST(ObjectGraphTraversalTest, ObjectNotFound) { SavedObjectGraph object_graph = ParseSavedObjectGraph(kSingleChildFoo); - const SavedObject* obj = internal::FindNodeAtPath("bar", object_graph); - EXPECT_EQ(nullptr, obj); + absl::optional node = internal::FindNodeAtPath("bar", object_graph); + EXPECT_FALSE(node.has_value()); } TEST(ObjectGraphTraversalTest, CaseSensitiveMismatch) { SavedObjectGraph object_graph = ParseSavedObjectGraph(kSingleChildFoo); - const SavedObject* obj = internal::FindNodeAtPath("FOO", object_graph); - EXPECT_EQ(nullptr, obj); + absl::optional node = internal::FindNodeAtPath("FOO", object_graph); + EXPECT_FALSE(node.has_value()); } TEST(ObjectGraphTraversalTest, NestedObjectFound) { SavedObjectGraph object_graph = ParseSavedObjectGraph(kSingleChildFooWithFuncBar); - const SavedObject* obj = internal::FindNodeAtPath("foo.bar", object_graph); - ASSERT_NE(nullptr, obj); - EXPECT_EQ(obj->kind_case(), SavedObject::kFunction); - EXPECT_EQ(obj->function().concrete_functions_size(), 1); - EXPECT_EQ(obj->function().concrete_functions(0), "__inference_my_func_5"); + absl::optional node = internal::FindNodeAtPath("foo.bar", object_graph); + ASSERT_TRUE(node.has_value()); + EXPECT_EQ(*node, 2); } TEST(ObjectGraphTraversalTest, MultiplePathsAliasSameObject) { SavedObjectGraph object_graph = ParseSavedObjectGraph(kMultiplePathsToChild); - const SavedObject* foo_baz = + absl::optional foo_baz_node = internal::FindNodeAtPath("foo.baz", object_graph); - ASSERT_NE(nullptr, foo_baz); - EXPECT_EQ(foo_baz->kind_case(), SavedObject::kUserObject); - EXPECT_EQ(foo_baz->user_object().identifier(), "_generic_user_object"); + ASSERT_TRUE(foo_baz_node.has_value()); + EXPECT_EQ(*foo_baz_node, 4); - const SavedObject* bar_wombat = + absl::optional bar_wombat_node = internal::FindNodeAtPath("bar.wombat", object_graph); - ASSERT_NE(nullptr, bar_wombat); - EXPECT_EQ(bar_wombat->kind_case(), SavedObject::kUserObject); - EXPECT_EQ(bar_wombat->user_object().identifier(), "_generic_user_object"); + ASSERT_TRUE(bar_wombat_node.has_value()); + EXPECT_EQ(*bar_wombat_node, 4); - EXPECT_EQ(foo_baz, bar_wombat); + EXPECT_EQ(*foo_baz_node, *bar_wombat_node); } TEST(ObjectGraphTraversalTest, CyclesAreOK) { SavedObjectGraph object_graph = ParseSavedObjectGraph(kCycleBetweenParentAndChild); - const SavedObject* foo = internal::FindNodeAtPath("foo", object_graph); - ASSERT_NE(nullptr, foo); - EXPECT_EQ(foo->kind_case(), SavedObject::kUserObject); - EXPECT_EQ(foo->user_object().identifier(), "_generic_user_object"); + absl::optional foo = internal::FindNodeAtPath("foo", object_graph); + ASSERT_TRUE(foo.has_value()); + EXPECT_EQ(*foo, 1); - const SavedObject* foo_bar = + absl::optional foo_bar = internal::FindNodeAtPath("foo.bar", object_graph); - ASSERT_NE(nullptr, foo_bar); - EXPECT_EQ(foo_bar->kind_case(), SavedObject::kUserObject); - EXPECT_EQ(foo_bar->user_object().identifier(), "_generic_user_object"); + ASSERT_TRUE(foo_bar.has_value()); + EXPECT_EQ(*foo_bar, 3); - const SavedObject* foo_bar_parent = + absl::optional foo_bar_parent = internal::FindNodeAtPath("foo.bar.parent", object_graph); - ASSERT_NE(nullptr, foo_bar_parent); - EXPECT_EQ(foo_bar_parent->kind_case(), SavedObject::kUserObject); - EXPECT_EQ(foo_bar_parent->user_object().identifier(), "_generic_user_object"); + ASSERT_TRUE(foo_bar_parent.has_value()); + EXPECT_EQ(*foo_bar_parent, 1); - const SavedObject* foo_bar_parent_bar = + absl::optional foo_bar_parent_bar = internal::FindNodeAtPath("foo.bar.parent.bar", object_graph); - ASSERT_NE(nullptr, foo_bar_parent_bar); - EXPECT_EQ(foo_bar_parent_bar->kind_case(), SavedObject::kUserObject); - EXPECT_EQ(foo_bar_parent_bar->user_object().identifier(), - "_generic_user_object"); + ASSERT_TRUE(foo_bar_parent_bar.has_value()); + EXPECT_EQ(*foo_bar_parent_bar, 3); - EXPECT_EQ(foo, foo_bar_parent); - EXPECT_EQ(foo_bar, foo_bar_parent_bar); + EXPECT_EQ(*foo, *foo_bar_parent); + EXPECT_EQ(*foo_bar, *foo_bar_parent_bar); } } // namespace diff --git a/tensorflow/c/experimental/saved_model/core/ops/BUILD b/tensorflow/c/experimental/saved_model/core/ops/BUILD index 673ea1a80e2..549980b03e9 100644 --- a/tensorflow/c/experimental/saved_model/core/ops/BUILD +++ b/tensorflow/c/experimental/saved_model/core/ops/BUILD @@ -1,3 +1,5 @@ +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") + # This package contains written convenience helpers for Eager Operations # used by SavedModel. Once we autogenerate C++ Eager Op wrappers, we can remove these. load( diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/BUILD b/tensorflow/c/experimental/saved_model/core/revived_types/BUILD index 2b883618c87..ac168830a0e 100644 --- a/tensorflow/c/experimental/saved_model/core/revived_types/BUILD +++ b/tensorflow/c/experimental/saved_model/core/revived_types/BUILD @@ -1,3 +1,5 @@ +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") + # This package contains classes corresponding to Revived SavedObjectGraph types # used by SavedModel. See https://cs.opensource.google/tensorflow/tensorflow/+/c575e2ba93c442121d98d3f125d83fed1339924d:tensorflow/core/protobuf/saved_object_graph.proto;l=56-62 package( @@ -8,6 +10,25 @@ package( licenses = ["notice"], # Apache 2.0 ) +cc_library( + name = "asset", + srcs = [ + "asset.cc", + ], + hdrs = [ + "asset.h", + ], + deps = [ + ":tensorhandle_convertible", + "//tensorflow/c:tensor_interface", + "//tensorflow/c/eager:immediate_execution_context", + "//tensorflow/c/eager:immediate_execution_tensor_handle", + "//tensorflow/cc/saved_model:constants", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], +) + cc_library( name = "constant", srcs = [ @@ -28,6 +49,106 @@ cc_library( ], ) +cc_library( + name = "flat_tensor_function", + srcs = [ + "flat_tensor_function.cc", + ], + hdrs = [ + "flat_tensor_function.h", + ], + deps = [ + "//tensorflow/c/eager:abstract_tensor_handle", + "//tensorflow/c/eager:immediate_execution_context", + "//tensorflow/c/eager:immediate_execution_operation", + "//tensorflow/c/eager:immediate_execution_tensor_handle", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/common_runtime/eager:context", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "partially_revived_objects", + srcs = [ + "partially_revived_objects.cc", + ], + hdrs = [ + "partially_revived_objects.h", + ], + deps = [ + ":asset", + ":constant", + ":restored_resource", + ":restored_resource_revival_state", + ":revived_objects", + ":tf_concrete_function", + ":tf_concrete_function_revival_state", + ":tf_signature_def_function", + ":tf_signature_def_function_revival_state", + ":variable", + "//tensorflow/c/eager:abstract_tensor_handle", + "//tensorflow/c/eager:immediate_execution_context", + "//tensorflow/c/eager:immediate_execution_operation", + "//tensorflow/c/eager:immediate_execution_tensor_handle", + "//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata", + "//tensorflow/c/experimental/saved_model/core:tensor_spec", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/lib/llvm_rtti", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "restored_resource", + srcs = [ + "restored_resource.cc", + ], + hdrs = [ + "restored_resource.h", + ], + deps = [ + ":tensorhandle_convertible", + ":tf_concrete_function", + "//tensorflow/c/eager:abstract_tensor_handle", + "//tensorflow/c/eager:immediate_execution_context", + "//tensorflow/c/eager:immediate_execution_operation", + "//tensorflow/c/eager:immediate_execution_tensor_handle", + "//tensorflow/core:lib", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "restored_resource_revival_state", + hdrs = [ + "restored_resource_revival_state.h", + ], + deps = [ + ":tf_concrete_function_revival_state", + "//tensorflow/c/eager:immediate_execution_tensor_handle", + ], +) + +cc_library( + name = "revived_objects", + hdrs = [ + "revived_objects.h", + ], + deps = [ + ":asset", + ":constant", + ":restored_resource", + ":tf_concrete_function", + ":tf_signature_def_function", + ":variable", + "//tensorflow/core:lib", + ], +) + cc_library( name = "variable", srcs = [ @@ -45,6 +166,8 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/common_runtime/eager:context", + "//tensorflow/core/common_runtime/eager:tensor_handle", + "//tensorflow/core/lib/llvm_rtti", "@com_google_absl//absl/types:optional", ], ) @@ -68,7 +191,7 @@ cc_library( "tf_concrete_function.h", ], deps = [ - ":tensorhandle_convertible", + ":flat_tensor_function", "//tensorflow/c/eager:abstract_tensor_handle", "//tensorflow/c/eager:immediate_execution_context", "//tensorflow/c/eager:immediate_execution_operation", @@ -81,3 +204,55 @@ cc_library( "@com_google_absl//absl/types:span", ], ) + +cc_library( + name = "tf_concrete_function_revival_state", + hdrs = [ + "tf_concrete_function_revival_state.h", + ], + deps = [ + ":tf_concrete_function", + "//tensorflow/c/eager:immediate_execution_context", + "//tensorflow/c/eager:immediate_execution_tensor_handle", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/types:optional", + ], +) + +cc_library( + name = "tf_signature_def_function", + srcs = [ + "tf_signature_def_function.cc", + ], + hdrs = [ + "tf_signature_def_function.h", + ], + deps = [ + ":flat_tensor_function", + "//tensorflow/c/eager:abstract_tensor_handle", + "//tensorflow/c/eager:immediate_execution_context", + "//tensorflow/c/eager:immediate_execution_operation", + "//tensorflow/c/eager:immediate_execution_tensor_handle", + "//tensorflow/c/experimental/saved_model/core:signature_def_function", + "//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/common_runtime/eager:context", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "tf_signature_def_function_revival_state", + hdrs = [ + "tf_signature_def_function_revival_state.h", + ], + deps = [ + ":tf_signature_def_function", + "//tensorflow/c/eager:immediate_execution_tensor_handle", + "//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/types:optional", + ], +) diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/asset.cc b/tensorflow/c/experimental/saved_model/core/revived_types/asset.cc new file mode 100644 index 00000000000..5cc14d615f5 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/revived_types/asset.cc @@ -0,0 +1,49 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/experimental/saved_model/core/revived_types/asset.h" + +#include + +#include "tensorflow/c/eager/immediate_execution_context.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/c/tensor_interface.h" +#include "tensorflow/cc/saved_model/constants.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/path.h" + +namespace tensorflow { + +Asset::Asset(ImmediateTensorHandlePtr handle) + : TensorHandleConvertible(std::move(handle)) {} + +Status Asset::Create(ImmediateExecutionContext* ctx, + const std::string& saved_model_dir, + const std::string& asset_filename, + std::unique_ptr* output) { + std::string abs_path = + io::JoinPath(saved_model_dir, kSavedModelAssetsDirectory, asset_filename); + AbstractTensorPtr tensor(ctx->CreateStringScalar(abs_path)); + if (tensor.get() == nullptr) { + return errors::Internal( + "Failed to create scalar string tensor for Asset at path ", abs_path); + } + + ImmediateTensorHandlePtr handle(ctx->CreateLocalHandle(tensor.get())); + output->reset(new Asset(std::move(handle))); + return Status(); +} + +} // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/asset.h b/tensorflow/c/experimental/saved_model/core/revived_types/asset.h new file mode 100644 index 00000000000..c98bd9b5628 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/revived_types/asset.h @@ -0,0 +1,50 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_ASSET_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_ASSET_H_ + +#include + +#include "tensorflow/c/eager/immediate_execution_context.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h" +#include "tensorflow/c/tensor_interface.h" +#include "tensorflow/core/framework/tensor.pb.h" + +namespace tensorflow { + +class Asset : public TensorHandleConvertible { + public: + static Status Create(ImmediateExecutionContext* ctx, + const std::string& saved_model_dir, + const std::string& asset_filename, + std::unique_ptr* output); + + // Asset is movable, but not copyable. + Asset(Asset&& other) = default; + Asset& operator=(Asset&& other) = default; + + ~Asset() override = default; + + private: + explicit Asset(ImmediateTensorHandlePtr handle); + Asset(const Asset&) = delete; + Asset& operator=(const Asset&) = delete; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_ASSET_H_ diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.cc b/tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.cc new file mode 100644 index 00000000000..59f7306fedc --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.cc @@ -0,0 +1,91 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h" + +#include +#include + +#include "absl/types/span.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/immediate_execution_operation.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/core/common_runtime/eager/context.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/protobuf/saved_object_graph.pb.h" +#include "tensorflow/core/protobuf/struct.pb.h" + +namespace tensorflow { + +FlatTensorFunction::FlatTensorFunction( + const std::string& name, std::vector captures, + ImmediateExecutionContext* ctx) + : name_(name), captures_(std::move(captures)), ctx_(ctx) {} + +FlatTensorFunction::~FlatTensorFunction() { + Status status = ctx_->RemoveFunction(name_); + if (!status.ok()) { + LOG(ERROR) << "Failed to remove functiondef " << name_ << ". " + << status.error_message(); + } +} + +Status FlatTensorFunction::Create( + const FunctionDef* function_def, + std::vector captures, + ImmediateExecutionContext* ctx, std::unique_ptr* out) { + TF_RETURN_IF_ERROR(ctx->AddFunctionDef(*function_def)); + std::vector owned_captures; + owned_captures.reserve(captures.size()); + for (ImmediateExecutionTensorHandle* capture : captures) { + capture->Ref(); + owned_captures.push_back(ImmediateTensorHandlePtr(capture)); + } + + out->reset(new FlatTensorFunction(function_def->signature().name(), + std::move(owned_captures), ctx)); + return Status(); +} + +Status FlatTensorFunction::MakeCallOp( + absl::Span inputs, ImmediateOpPtr* out) const { + out->reset(ctx_->CreateOperation()); + // In eager mode, TF2 python executes functions by constructing an op with + // the name of the functiondef: + // https://github.com/tensorflow/tensorflow/blob/66668ec0ca432e2f38a575b814f45b6d299d01ed/tensorflow/python/eager/function.py#L545 + // In graph mode, we create a PartitionedCallOp instead: + // https://github.com/tensorflow/tensorflow/blob/66668ec0ca432e2f38a575b814f45b6d299d01ed/tensorflow/python/eager/function.py#L573 + + // TODO(bmzhao): After discussing with Allen, we should execute this via a + // PartitionedCallOp for compatibility with "tooling that assumes functions in + // graphs are PartitionedCallOps". + TF_RETURN_IF_ERROR((*out)->Reset(name_.c_str(), nullptr)); + + // Adding the user-provided inputs to the function. + TF_RETURN_IF_ERROR((*out)->AddInputList(inputs)); + + absl::Span captures( + reinterpret_cast(captures_.data()), + captures_.size()); + + // Adding the captures of the function. + TF_RETURN_IF_ERROR((*out)->AddInputList(captures)); + return Status(); +} + +} // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h b/tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h new file mode 100644 index 00000000000..a6769d323b4 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h @@ -0,0 +1,86 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_FLAT_TENSOR_FUNCTION_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_FLAT_TENSOR_FUNCTION_H_ + +#include +#include +#include +#include + +#include "tensorflow/c/eager/immediate_execution_context.h" +#include "tensorflow/c/eager/immediate_execution_operation.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/protobuf/saved_object_graph.pb.h" + +namespace tensorflow { + +// FlatTensorFunction models a TF2 eager runtime view of a callable function, +// taking + returning flat lists of tensors, including any captures. +// Effectively, it is a thin wrapper around a FunctionDef owned by the +// EagerContext, and any TensorHandle captures associated with the function. The +// MakeCallOp method handles the logic of marshaling captures after the user +// provided inputs automatically. +// Note(bmzhao): This class is mainly intended to house low-level reusable +// function logic between SignatureDefFunction and ConcreteFunction, which +// present higher level interfaces. This type does *not* hold any "function +// metadata". +class FlatTensorFunction { + public: + // Factory for creating a FlatTensorFunction. + // + // Params: + // function_def - The function_def associated with the created + // FlatTensorFunction. FlatTensorFunction will register this + // function_def with `ctx` on creation, and de-register it on + // destruction. function_def must be non-null, but + // otherwise has no lifetime requirements. + // captures - The captured TensorHandles associated with this + // FlatTensorFunction. FlatTensorFunction will participate in + // ownership of the handles (it explicitly increments the refcount + // of each handle, and will decrement them on destruction). + // ctx - A handle to the Tensorflow runtime. This MUST be non-null and + // outlive TFConcreteFunction. + // out - The output FlatTensorFunction. + static Status Create(const FunctionDef* function_def, + std::vector captures, + ImmediateExecutionContext* ctx, + std::unique_ptr* out); + + // This method creates a "Call" Op used to execute the function. + Status MakeCallOp(absl::Span inputs, + ImmediateOpPtr* out) const; + + ~FlatTensorFunction(); + + private: + FlatTensorFunction(const std::string& name, + std::vector captures, + ImmediateExecutionContext* ctx); + + FlatTensorFunction(const FlatTensorFunction&) = delete; + FlatTensorFunction& operator=(const FlatTensorFunction&) = delete; + + // Name of the FunctionDef corresponding to this TFConcreteFunction + std::string name_; + std::vector captures_; + ImmediateExecutionContext* ctx_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_FLAT_TENSOR_FUNCTION_H_ diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.cc b/tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.cc new file mode 100644 index 00000000000..1c615405644 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.cc @@ -0,0 +1,543 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.h" + +#include +#include +#include +#include + +#include "absl/types/span.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/immediate_execution_context.h" +#include "tensorflow/c/eager/immediate_execution_operation.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/restored_resource.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/restored_resource_revival_state.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/revived_objects.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function_revival_state.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function_revival_state.h" +#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h" +#include "tensorflow/c/experimental/saved_model/core/tensor_spec.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/stringpiece.h" +#include "tensorflow/core/protobuf/saved_object_graph.pb.h" +#include "tensorflow/core/protobuf/struct.pb.h" + +namespace tensorflow { + +namespace { + +using StructuredValueDictEntry = + protobuf::MapPair; + +using NamedParamMap = + gtl::FlatMap; + +Status AssertAllCreateResourceFunctionsHaveNoCaptures( + const PartiallyRevivedObjects& objects) { + for (const auto& id_and_resource : objects.restored_resources) { + int node_id = id_and_resource.first; + const RestoredResourceRevivalState& resource = id_and_resource.second; + const TFConcreteFunctionRevivalState* create_resource_fn = + resource.create_resource; + if (create_resource_fn == nullptr) { + return errors::FailedPrecondition( + "Resource at node ", node_id, + " did not have a create_resource() function"); + } + const SavedConcreteFunction* saved_create_resource_fn = + create_resource_fn->saved_concrete_func; + if (!saved_create_resource_fn->bound_inputs().empty()) { + // TODO(b/124045874): Support loading resource functions via a top sort + return errors::Unimplemented( + "Create Resource functions with captures are currently unsupported."); + } + } + return Status(); +} + +// Retrieves the TensorHandle associated with `node_id` from `obj_graph`, and +// set `*handle` to point to it. +Status TensorHandleFromNode(int node_id, const SavedObjectGraph& obj_graph, + const PartiallyRevivedObjects& objects, + ImmediateExecutionTensorHandle** handle) { + const SavedObject& node = obj_graph.nodes(node_id); + SavedObject::KindCase kind = node.kind_case(); + switch (kind) { + case SavedObject::kVariable: { + const auto& variables_iter = objects.variables.find(node_id); + if (variables_iter == objects.variables.end()) { + return errors::FailedPrecondition( + "Tried to convert node id ", node_id, + " of type variable to tensor but the variable wasn't initialized"); + } + *handle = variables_iter->second->handle(); + return Status(); + } + case SavedObject::kConstant: { + const auto& constants_iter = objects.constants.find(node_id); + if (constants_iter == objects.constants.end()) { + return errors::FailedPrecondition("Tried to convert node id ", node_id, + " of type constant to tensor but the " + "constant wasn't initialized"); + } + *handle = constants_iter->second->handle(); + return Status(); + } + case SavedObject::kAsset: { + const auto& assets_iter = objects.assets.find(node_id); + if (assets_iter == objects.assets.end()) { + return errors::FailedPrecondition( + "Tried to convert node id ", node_id, + " of type asset to tensor but the asset wasn't initialized"); + } + *handle = assets_iter->second->handle(); + return Status(); + } + case SavedObject::kResource: { + const auto& resource_iter = objects.restored_resources.find(node_id); + if (resource_iter == objects.restored_resources.end()) { + return errors::FailedPrecondition( + "Tried to convert node id ", node_id, + " of type Resource to tensor but the Resource wasn't initialized"); + } + const RestoredResourceRevivalState& resource = resource_iter->second; + if (resource.resource_handle == nullptr) { + return errors::FailedPrecondition( + "Resource with node id ", node_id, + " should have its resource_handle created, but was nullptr."); + } + *handle = resource.resource_handle.get(); + return Status(); + } + default: { + return errors::FailedPrecondition( + "Only objects of type variable, constant, asset, and resources have " + "capturable tensorhandles. Encountered object of kind ", + node.kind_case(), " at node id: ", node_id); + } + } +} + +std::vector SignatureDefParamsFromNamedParamMap( + const NamedParamMap& params) { + // The underlying functiondef associated with the SignatureDef has + // nest.flattened inputs and outputs, which are sorted by string key. + std::vector result; + result.reserve(params.size()); + for (const auto& named_param : params) { + result.push_back(SignatureDefParam(std::string(named_param.first), + TensorSpec(*named_param.second))); + } + std::sort(result.begin(), result.end(), + [](const SignatureDefParam& x, const SignatureDefParam& y) { + return x.name() < y.name(); + }); + + return result; +} + +// SignatureDefArgsFromInputs takes the "canonicalized_input_signature" +// field of a SavedConcreteFunction, ensures it conforms to the structure of +// tuple(tuple(), dict()), and "returns" a list of +// SignatureDefParams of the SignatureDefFunction's arguments. +Status SignatureDefArgsFromInputs( + const StructuredValue& canonicalized_input_signature, + std::vector* out) { + // Note(bmzhao): canonicalized_input_signature should be a tuple of + // (args, kwargs), where args is an empty tuple, and kwargs is a dictionary of + // string keys to TensorSpecs. + if (!canonicalized_input_signature.has_tuple_value()) { + return errors::FailedPrecondition( + "SignatureDefFunction's canonicalized_input_signature should be " + "of form tuple(tuple(), dict()), but was instead: \n", + canonicalized_input_signature.DebugString()); + } + + const TupleValue& args_kwargs_tuple = + canonicalized_input_signature.tuple_value(); + if (args_kwargs_tuple.values_size() != 2) { + return errors::FailedPrecondition( + "SignatureDefFunction's canonicalized_input_signature should be " + "a tuple of two elements (args, kwargs), but was instead: \n", + args_kwargs_tuple.DebugString()); + } + + const StructuredValue& args = args_kwargs_tuple.values(0); + if (!args.has_tuple_value() || !args.tuple_value().values().empty()) { + return errors::FailedPrecondition( + "SignatureDefFunction's canonicalized_input_signature's args" + "should be an empty tuple, but instead got: \n", + args.DebugString()); + } + + const StructuredValue& kwargs = args_kwargs_tuple.values(1); + if (!kwargs.has_dict_value()) { + return errors::FailedPrecondition( + "SignatureDefFunction's canonicalized_input_signature's kwargs" + "should be a dictionary, but instead got: \n", + kwargs.DebugString()); + } + + const DictValue& kwargs_dict = kwargs.dict_value(); + NamedParamMap result; + result.reserve(kwargs_dict.fields_size()); + + for (const auto& key_value : kwargs_dict.fields()) { + const std::string& key = key_value.first; + const StructuredValue& value = key_value.second; + if (!value.has_tensor_spec_value()) { + return errors::FailedPrecondition( + "SignatureDefFunction's canonicalized_input_signature's kwargs" + "dictionary contained a non-tensorspec value for key-value pair: \n", + "Key: ", key, "Value: \n", value.DebugString()); + } + result[key] = &value.tensor_spec_value(); + } + + *out = SignatureDefParamsFromNamedParamMap(result); + + return Status(); +} + +// SignatureDefReturnsFromOutputs takes the "output_signature" field of a +// SavedConcreteFunction, ensures it conforms to the structure of +// dict(), and "returns" a list of SignatureDefParams of the +// SignatureDefFunction's returns. +Status SignatureDefReturnsFromOutputs(const StructuredValue& output_signature, + std::vector* out) { + if (!output_signature.has_dict_value()) { + return errors::FailedPrecondition( + "SignatureDefFunction's output_signature must be a dictionary, but " + "instead got: ", + output_signature.DebugString()); + } + + const DictValue& output_dict = output_signature.dict_value(); + NamedParamMap result; + result.reserve(output_dict.fields_size()); + + for (const auto& key_value : output_dict.fields()) { + const std::string& key = key_value.first; + const StructuredValue& value = key_value.second; + if (!value.has_tensor_spec_value()) { + return errors::FailedPrecondition( + "SignatureDefFunction's output_signature dictionary contained a " + "non-tensorspec value for key-value pair: \n", + "Key: ", key, "Value: \n", value.DebugString()); + } + result[key] = &value.tensor_spec_value(); + } + *out = SignatureDefParamsFromNamedParamMap(result); + + return Status(); +} + +// The implementation takes advantage of the fact that SignatureDefFunction's +// "traced" Signature wrapper function always has inputs/outputs of dictionaries +// https://github.com/tensorflow/tensorflow/blob/53cdd5e87c423b195f33775753273286fd5a1a65/tensorflow/python/saved_model/signature_serialization.py#L119-L126 +// https://github.com/tensorflow/tensorflow/blob/53cdd5e87c423b195f33775753273286fd5a1a65/tensorflow/python/saved_model/signature_serialization.py#L153-L178 +// Additionally, we take advantage of the fact that the SignatureDefFunction's +// associated functiondef has lexicographically ordered inputs/outputs due to +// nest.flatten. +Status LoadSignatureDefFunctionMetadata( + const SavedConcreteFunction& saved_concrete_function, + SignatureDefFunctionMetadata* out) { + std::vector args; + TF_RETURN_IF_ERROR(SignatureDefArgsFromInputs( + saved_concrete_function.canonicalized_input_signature(), &args)); + + std::vector rets; + TF_RETURN_IF_ERROR(SignatureDefReturnsFromOutputs( + saved_concrete_function.output_signature(), &rets)); + + *out = SignatureDefFunctionMetadata(std::move(args), std::move(rets)); + return Status(); +} + +// This function finds the necessary captures, then forwards to the builder +// method +Status CreateConcreteFunction(ImmediateExecutionContext* ctx, + const TFConcreteFunctionRevivalState& builder, + const SavedObjectGraph& obj_graph, + const PartiallyRevivedObjects& objects, + std::unique_ptr* out) { + const auto& capture_node_ids = builder.saved_concrete_func->bound_inputs(); + std::vector captures; + captures.reserve(capture_node_ids.size()); + for (int capture_node_id : capture_node_ids) { + ImmediateExecutionTensorHandle* capture_handle; + TF_RETURN_IF_ERROR(TensorHandleFromNode(capture_node_id, obj_graph, objects, + &capture_handle)); + captures.push_back(capture_handle); + } + // TODO(bmzhao): Create Metadata here + return TFConcreteFunction::Create(/*function_def=*/builder.fdef, + /*captures=*/std::move(captures), + /*metadata=*/{}, + /*ctx=*/ctx, + /*out=*/out); +} + +Status CreateSignatureDefFunction( + ImmediateExecutionContext* ctx, + const TFSignatureDefFunctionRevivalState& builder, + const SavedObjectGraph& obj_graph, const PartiallyRevivedObjects& objects, + std::unique_ptr* out) { + const auto& capture_node_ids = builder.saved_concrete_func->bound_inputs(); + std::vector captures; + captures.reserve(capture_node_ids.size()); + for (int capture_node_id : capture_node_ids) { + ImmediateExecutionTensorHandle* capture_handle; + TF_RETURN_IF_ERROR(TensorHandleFromNode(capture_node_id, obj_graph, objects, + &capture_handle)); + captures.push_back(capture_handle); + } + + SignatureDefFunctionMetadata metadata; + TF_RETURN_IF_ERROR(LoadSignatureDefFunctionMetadata( + *builder.saved_concrete_func, &metadata)); + + return TFSignatureDefFunction::Create(/*function_def=*/builder.fdef, + /*captures=*/std::move(captures), + /*metadata=*/std::move(metadata), + /*ctx=*/ctx, + /*out=*/out); +} + +Status InitializeCreateResourceFunctions(ImmediateExecutionContext* ctx, + const SavedObjectGraph& obj_graph, + const PartiallyRevivedObjects& objects, + RevivedObjects* revived) { + for (const auto& id_and_resource : objects.restored_resources) { + const RestoredResourceRevivalState& resource = id_and_resource.second; + const TFConcreteFunctionRevivalState* create_resource_fn = + resource.create_resource; + + const SavedConcreteFunction* saved_create_resource_fn = + create_resource_fn->saved_concrete_func; + if (!saved_create_resource_fn->bound_inputs().empty()) { + // TODO(b/124045874): Load resource functions via a topological sort + return errors::Unimplemented( + "Create Resource functions with captures are currently unsupported."); + } + std::unique_ptr out; + TF_RETURN_IF_ERROR(CreateConcreteFunction(ctx, *create_resource_fn, + obj_graph, objects, &out)); + revived->concrete_functions[create_resource_fn->node_id] = std::move(out); + } + return Status(); +} + +Status InitializeAllFunctions(ImmediateExecutionContext* ctx, + const SavedObjectGraph& obj_graph, + const PartiallyRevivedObjects& objects, + RevivedObjects* revived) { + gtl::FlatMap>* destination_func_map = + &revived->concrete_functions; + gtl::FlatMap>* + destination_sig_map = &revived->signature_def_functions; + + for (const auto& id_and_func : objects.concrete_functions) { + int node_id = id_and_func.first; + const TFConcreteFunctionRevivalState& func = id_and_func.second; + + if (destination_func_map->find(node_id) != destination_func_map->end()) { + // The function has already been initialized in the destination_map, + // so we can skip this node. This can occur because we initialize + // CreateResource functions before calling this function. + continue; + } + + std::unique_ptr out; + TF_RETURN_IF_ERROR( + CreateConcreteFunction(ctx, func, obj_graph, objects, &out)); + (*destination_func_map)[node_id] = std::move(out); + } + + for (const auto& id_and_func : objects.signature_def_functions) { + int node_id = id_and_func.first; + const TFSignatureDefFunctionRevivalState& func = id_and_func.second; + + if (destination_sig_map->find(node_id) != destination_sig_map->end()) { + continue; + } + + std::unique_ptr out; + TF_RETURN_IF_ERROR( + CreateSignatureDefFunction(ctx, func, obj_graph, objects, &out)); + (*destination_sig_map)[node_id] = std::move(out); + } + + return Status(); +} + +Status CreateAllResourceHandles(ImmediateExecutionContext* ctx, + const SavedObjectGraph& obj_graph, + PartiallyRevivedObjects* objects, + RevivedObjects* revived) { + for (auto& id_and_resource : objects->restored_resources) { + RestoredResourceRevivalState& resource = id_and_resource.second; + int create_resource_fn_node = resource.create_resource->node_id; + const gtl::FlatMap>& + revived_functions = revived->concrete_functions; + + const auto& revived_functions_iter = + revived_functions.find(create_resource_fn_node); + if (revived_functions_iter == revived_functions.end()) { + return errors::FailedPrecondition( + "ConcreteFunction at node ", create_resource_fn_node, + " should have been initialized prior to being called."); + } + const TFConcreteFunction& create_resource_fn = + *revived_functions_iter->second; + ImmediateOpPtr function_op; + TF_RETURN_IF_ERROR(create_resource_fn.MakeCallOp({}, &function_op)); + TF_RETURN_IF_ERROR(function_op->SetDeviceName(resource.device.c_str())); + + AbstractTensorHandle* resource_handle = nullptr; + int num_retvals = 1; + TF_RETURN_IF_ERROR(function_op->Execute( + absl::MakeSpan(&resource_handle, num_retvals), &num_retvals)); + AbstractTensorHandlePtr owned_resource_handle(resource_handle); + if (!tensorflow::isa( + owned_resource_handle.get())) { + return errors::Internal("Unexpected tensor handle kind."); + } + ImmediateTensorHandlePtr result( + reinterpret_cast( + owned_resource_handle.release())); + resource.resource_handle = std::move(result); + } + return Status(); +} + +// Finds a ConcreteFunction with node id `node` in `objects`, and sets *out to +// point to it. If node doesn't exist in `objects`, out is untouched, and an +// error status is returned. +Status FindConcreteFunction(int node, RevivedObjects* objects, + TFConcreteFunction** out) { + auto func_iter = objects->concrete_functions.find(node); + if (func_iter == objects->concrete_functions.end()) { + return errors::FailedPrecondition( + "Failed to find ConcreteFunction with node id ", node, + " in revived objects"); + } + *out = func_iter->second.get(); + return Status(); +} + +Status BuildResources(ImmediateExecutionContext* ctx, + const SavedObjectGraph& obj_graph, + PartiallyRevivedObjects* objects, + RevivedObjects* revived) { + for (auto& id_and_resource : objects->restored_resources) { + int node_id = id_and_resource.first; + RestoredResourceRevivalState& resource_revival_state = + id_and_resource.second; + + TFConcreteFunction* create_resource = nullptr; + + // Check all the functions associated with the resource have already been + // initialized in `revived` + if (resource_revival_state.create_resource != nullptr) { + TF_RETURN_IF_ERROR( + FindConcreteFunction(resource_revival_state.create_resource->node_id, + revived, &create_resource)); + } + + TFConcreteFunction* initialize = nullptr; + if (resource_revival_state.initialize != nullptr) { + TF_RETURN_IF_ERROR(FindConcreteFunction( + resource_revival_state.initialize->node_id, revived, &initialize)); + } + + TFConcreteFunction* destroy_resource = nullptr; + if (resource_revival_state.destroy_resource != nullptr) { + TF_RETURN_IF_ERROR( + FindConcreteFunction(resource_revival_state.destroy_resource->node_id, + revived, &destroy_resource)); + } + + if (resource_revival_state.resource_handle == nullptr) { + return errors::FailedPrecondition("Resource at node id ", node_id, + " does not have a resource handle."); + } + + revived->restored_resources.emplace( + node_id, RestoredResource( + /*device=*/resource_revival_state.device, + /*create_resource=*/create_resource, + /*initialize=*/initialize, + /*destroy_resource=*/destroy_resource, + /*resource_handle=*/ + std::move(resource_revival_state.resource_handle))); + } + return Status(); +} + +} // namespace + +Status PartiallyRevivedObjects::Build(ImmediateExecutionContext* ctx, + const SavedObjectGraph& obj_graph, + RevivedObjects* revived) { + // Step 1: We would like to initialize all functions; this requires setting up + // their captured tensorhandles, which may come from variables, assets, + // constants, or resources. The first three are trivial; However, + // tensorhandles that correspond to resources must be created by invoking + // their "create_resource" function. + // https://github.com/tensorflow/tensorflow/blob/f19c6efb4a8ba60e2492eedc98ef5375abb39dc7/tensorflow/python/saved_model/load.py#L240 + // https://github.com/tensorflow/tensorflow/blob/f19c6efb4a8ba60e2492eedc98ef5375abb39dc7/tensorflow/python/training/tracking/tracking.py#L233 + // For now, we assert that all create_resource functions must have no + // captures. This aligns with the current behavior in python. + // https://github.com/tensorflow/tensorflow/blob/50eac986bf7a0ad12594e080f083181f277e0b49/tensorflow/python/saved_model/load.py#L152-L155 + // TODO(bmzhao): We should do a topological sort instead. + + // 1a. Make sure all CreateResource functions have no captures. + TF_RETURN_IF_ERROR(AssertAllCreateResourceFunctionsHaveNoCaptures(*this)); + + // 1b. Initialize all CreateResource functions, storing them in `revived` + TF_RETURN_IF_ERROR( + InitializeCreateResourceFunctions(ctx, obj_graph, *this, revived)); + + // 1c. Invoke all "CreateResource" functions and store their ResourceHandles + // https://github.com/tensorflow/tensorflow/blob/3b6b41b68a95dc70c26dc816b29d359bfb88c116/tensorflow/python/training/tracking/tracking.py#L241-L247 + // in *this->resources. + // TODO(bmzhao): Maybe store them separately, not in *this? + TF_RETURN_IF_ERROR(CreateAllResourceHandles(ctx, obj_graph, this, revived)); + + // 2. Initialize all the rest of the functions + TF_RETURN_IF_ERROR(InitializeAllFunctions(ctx, obj_graph, *this, revived)); + + // 3a. Move over all non-function, non-resource objects + revived->variables = std::move(variables); + revived->assets = std::move(assets); + revived->constants = std::move(constants); + revived->signatures_map = std::move(signatures_map); + + // 3b. Move over resources. + TF_RETURN_IF_ERROR(BuildResources(ctx, obj_graph, this, revived)); + + return Status(); +} + +} // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.h b/tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.h new file mode 100644 index 00000000000..78960b8c95f --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.h @@ -0,0 +1,62 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_PARTIALLY_REVIVED_OBJECTS_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_PARTIALLY_REVIVED_OBJECTS_H_ + +#include + +#include "tensorflow/c/eager/immediate_execution_context.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/asset.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/restored_resource_revival_state.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/revived_objects.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function_revival_state.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function_revival_state.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/protobuf/saved_object_graph.pb.h" + +namespace tensorflow { + +// Container for objects during the revival step in SavedModel's loading. +// Notably, resources and functions can be in a state where they reference +// other resources/functions that have not been constructed yet. We collect +// *all* objects in a partially valid state here, then properly initialize +// resources and functions. Implementation-wise, PartiallyRevivedObjects +// contains maps keyed by the node number of the SavedObjectGraph, and map to an +// object of the corresponding type. So, if node 2 in the object graph is a +// variable, PartiallyRevivedObjects.variables[2] exists, and corresponds to a +// tensorflow::Variable object. The only exception to this is the +// "signatures_map", which is keyed by the "signature" key +// (https://github.com/tensorflow/tensorflow/blob/372918decee7f558b3c194b04f77c20dcc679a31/tensorflow/core/protobuf/meta_graph.proto#L89), +// and maps to the SignatureDefFunction node in the SavedObjectGraph. +struct PartiallyRevivedObjects { + gtl::FlatMap> variables; + gtl::FlatMap> assets; + gtl::FlatMap> constants; + gtl::FlatMap concrete_functions; + gtl::FlatMap signature_def_functions; + gtl::FlatMap restored_resources; + gtl::FlatMap signatures_map; + + Status Build(ImmediateExecutionContext* ctx, + const SavedObjectGraph& obj_graph, RevivedObjects* revived); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_PARTIALLY_REVIVED_OBJECTS_H_ diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/restored_resource.cc b/tensorflow/c/experimental/saved_model/core/revived_types/restored_resource.cc new file mode 100644 index 00000000000..47860ce8b39 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/revived_types/restored_resource.cc @@ -0,0 +1,76 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/experimental/saved_model/core/revived_types/restored_resource.h" + +#include "absl/types/span.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/immediate_execution_context.h" +#include "tensorflow/c/eager/immediate_execution_operation.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +namespace { + +Status ExecuteNoArgDummyReturnFunction(TFConcreteFunction* func) { + ImmediateOpPtr function_op; + TF_RETURN_IF_ERROR(func->MakeCallOp({}, &function_op)); + + AbstractTensorHandle* dummy_output = nullptr; + int num_retvals = 1; + TF_RETURN_IF_ERROR(function_op->Execute( + absl::MakeSpan(&dummy_output, num_retvals), &num_retvals)); + AbstractTensorHandlePtr owned_dummy_output(dummy_output); + return Status(); +} + +} // namespace + +RestoredResource::RestoredResource(const std::string& device, + TFConcreteFunction* create_resource, + TFConcreteFunction* initialize, + TFConcreteFunction* destroy_resource, + ImmediateTensorHandlePtr resource_handle) + : TensorHandleConvertible(std::move(resource_handle)), + device_(device), + create_resource_(create_resource), + initialize_(initialize), + destroy_resource_(destroy_resource) {} + +Status RestoredResource::Initialize() const { + return ExecuteNoArgDummyReturnFunction(initialize_); +} + +RestoredResource::~RestoredResource() { + // Note(bmzhao): SavedModels saved before + // https://github.com/tensorflow/tensorflow/commit/3c806101f57768e479f8646e7518bbdff1632ca3 + // did not have their destroy_resource function saved, meaning they will + // leak resources. + if (destroy_resource_ != nullptr) { + Status status = ExecuteNoArgDummyReturnFunction(destroy_resource_); + if (!status.ok()) { + LOG(WARNING) + << "Failed executing destroy_resource function for RestoredResource: " + << status.error_message(); + } + } +} + +} // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/restored_resource.h b/tensorflow/c/experimental/saved_model/core/revived_types/restored_resource.h new file mode 100644 index 00000000000..7adbd563a6b --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/revived_types/restored_resource.h @@ -0,0 +1,87 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_RESTORED_RESOURCE_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_RESTORED_RESOURCE_H_ + +#include +#include + +#include "absl/types/optional.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h" + +namespace tensorflow { + +// RestoredResource represents a TF2 "Resource" object loaded from a savedmodel, +// analogous to the Python _RestoredResource object: +// https://github.com/tensorflow/tensorflow/blob/fda326e542ca67534e8411edb180e8760a4828b7/tensorflow/python/saved_model/load.py#L481 +// TF2 resource objects typically extend TrackableResource: +// https://github.com/tensorflow/tensorflow/blob/fda326e542ca67534e8411edb180e8760a4828b7/tensorflow/python/training/tracking/tracking.py#L285 +// and are expected to implement "_create_resource", "_initialize", and +// "_destroy_resource" functions: +// https://github.com/tensorflow/tensorflow/blob/139ba9c5284799beafdd1d7f895127cf00e7c48f/tensorflow/python/training/tracking/tracking.py#L262-L281 +class RestoredResource : TensorHandleConvertible { + public: + // Note(bmzhao): RestoredResource stores non-owning pointers to its associated + // functions because SavedModel internally owns all functions and objects in + // the RevivedObjects struct (which owns all functions). One alternative would + // be to have RevivedObjects store shared_ptr instead, and + // change RestoredResource's constructor take shared_ptr. + // To keep things simple, I've stuck to raw pointers for now. + // + // Params: + // device - The device string associated with the SavedResource + // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/protobuf/saved_object_graph.proto#L182 + // Conceptually, this is the same device used in CapturableResource: + // https://github.com/tensorflow/tensorflow/blob/568e2bef00f24af1159a0846abf67c099ca78a21/tensorflow/python/training/tracking/tracking.py#L222-L225 + // Implementation-wise, it is device used when invoking the + // create_resource function to produce the resource_handle + // associated with the object: + // https://github.com/tensorflow/tensorflow/blob/568e2bef00f24af1159a0846abf67c099ca78a21/tensorflow/python/training/tracking/tracking.py#L246-L247 + // create_resource - Non owning pointer to the create_resource function + // associated with this object. Must be NON-NULL. + // initialize - Non owning pointer to the initialize function associated with + // this object. Must be NON-NULL. + // destroy_resource - Non owning pointer to the destroy_resource function + // associated with this object. Ideally this should be + // NON-NULL, but in order to support models saved prior to + // https://github.com/tensorflow/tensorflow/commit/3c806101f57768e479f8646e7518bbdff1632ca3 + // we allow null here. This will, however, leak resources. + RestoredResource(const std::string& device, + TFConcreteFunction* create_resource, + TFConcreteFunction* initialize, + TFConcreteFunction* destroy_resource, + ImmediateTensorHandlePtr resource_handle); + + Status Initialize() const; + + // RestoredResource is movable, but not copyable. + RestoredResource(RestoredResource&& other) = default; + RestoredResource& operator=(RestoredResource&& other) = default; + + ~RestoredResource() override; + + private: + std::string device_; + TFConcreteFunction* create_resource_; + TFConcreteFunction* initialize_; + TFConcreteFunction* destroy_resource_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_RESTORED_RESOURCE_H_ diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/restored_resource_revival_state.h b/tensorflow/c/experimental/saved_model/core/revived_types/restored_resource_revival_state.h new file mode 100644 index 00000000000..48d00308cc1 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/revived_types/restored_resource_revival_state.h @@ -0,0 +1,38 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_RESTORED_RESOURCE_REVIVAL_STATE_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_RESTORED_RESOURCE_REVIVAL_STATE_H_ + +#include + +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function_revival_state.h" + +namespace tensorflow { + +// All "Resources" should have these 3 saved functions: +// https://github.com/tensorflow/tensorflow/blob/86dc281333d7d277ddc1882f2bca4b17e7ec40e5/tensorflow/python/training/tracking/tracking.py#L277-L281 +struct RestoredResourceRevivalState { + std::string device; + TFConcreteFunctionRevivalState* create_resource = nullptr; + TFConcreteFunctionRevivalState* initialize = nullptr; + TFConcreteFunctionRevivalState* destroy_resource = nullptr; + ImmediateTensorHandlePtr resource_handle = nullptr; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_RESTORED_RESOURCE_REVIVAL_STATE_H_ diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/revived_objects.h b/tensorflow/c/experimental/saved_model/core/revived_types/revived_objects.h new file mode 100644 index 00000000000..cc9be0b937d --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/revived_types/revived_objects.h @@ -0,0 +1,52 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_REVIVED_OBJECTS_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_REVIVED_OBJECTS_H_ + +#include +#include + +#include "tensorflow/c/experimental/saved_model/core/revived_types/asset.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/restored_resource.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h" +#include "tensorflow/core/lib/gtl/flatmap.h" + +namespace tensorflow { + +// RevivedObjects is mainly used as a container for all the "state" owned by +// SavedModel. It stores all non-"user object" nodes from a SavedModel +// (https://github.com/tensorflow/tensorflow/blob/568e2bef00f24af1159a0846abf67c099ca78a21/tensorflow/core/protobuf/saved_object_graph.proto#L57-L62) +// in a "fully constructed" state. It is effectively a strongly typed map, where +// each member is a map from the node id in the SavedObjectGraph's nodes +// (https://github.com/tensorflow/tensorflow/blob/568e2bef00f24af1159a0846abf67c099ca78a21/tensorflow/core/protobuf/saved_object_graph.proto#L25-L29) +// to the revived object of the corresponding type. +struct RevivedObjects { + gtl::FlatMap> variables; + gtl::FlatMap> assets; + gtl::FlatMap> constants; + gtl::FlatMap> concrete_functions; + gtl::FlatMap> + signature_def_functions; + gtl::FlatMap restored_resources; + gtl::FlatMap signatures_map; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_REVIVED_OBJECTS_H_ diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.cc b/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.cc index f734f9eca66..d9773a4520f 100644 --- a/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.cc +++ b/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/c/eager/abstract_tensor_handle.h" #include "tensorflow/c/eager/immediate_execution_operation.h" #include "tensorflow/c/eager/immediate_execution_tensor_handle.h" -#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h" #include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/platform/errors.h" @@ -33,32 +33,20 @@ limitations under the License. namespace tensorflow { -TFConcreteFunction::TFConcreteFunction( - const std::string& name, - std::vector captures, - FunctionMetadata metadata, ImmediateExecutionContext* ctx) - : name_(name), - captures_(std::move(captures)), - metadata_(std::move(metadata)), - ctx_(ctx) {} - -TFConcreteFunction::~TFConcreteFunction() { - Status status = ctx_->RemoveFunction(name_); - if (!status.ok()) { - LOG(ERROR) << "Failed to remove functiondef " << name_ << ". " - << status.error_message(); - } -} +TFConcreteFunction::TFConcreteFunction(std::unique_ptr func, + FunctionMetadata metadata) + : func_(std::move(func)), metadata_(std::move(metadata)) {} Status TFConcreteFunction::Create( const FunctionDef* function_def, std::vector captures, FunctionMetadata metadata, ImmediateExecutionContext* ctx, std::unique_ptr* out) { - TF_RETURN_IF_ERROR(ctx->AddFunctionDef(*function_def)); - out->reset(new TFConcreteFunction(function_def->signature().name(), - std::move(captures), std::move(metadata), - ctx)); + std::unique_ptr func; + TF_RETURN_IF_ERROR(FlatTensorFunction::Create( + function_def, std::move(captures), ctx, &func)); + + out->reset(new TFConcreteFunction(std::move(func), std::move(metadata))); return Status(); } @@ -66,30 +54,9 @@ const FunctionMetadata& TFConcreteFunction::GetFunctionMetadata() const { return metadata_; } -Status TFConcreteFunction::GetCallOp( - absl::Span inputs, ImmediateOpPtr* out) { - out->reset(ctx_->CreateOperation()); - // In eager mode, TF2 python executes functions by constructing an op with - // the name of the functiondef: - // https://github.com/tensorflow/tensorflow/blob/66668ec0ca432e2f38a575b814f45b6d299d01ed/tensorflow/python/eager/function.py#L545 - // In graph mode, we create a PartitionedCallOp instead: - // https://github.com/tensorflow/tensorflow/blob/66668ec0ca432e2f38a575b814f45b6d299d01ed/tensorflow/python/eager/function.py#L573 - - // TODO(bmzhao): After discussing with Allen, we should execute this via a - // PartitionedCallOp for compatibility with "tooling that assumes functions in - // graphs are PartitionedCallOps". - TF_RETURN_IF_ERROR((*out)->Reset(name_.c_str(), nullptr)); - - // Adding the user-provided inputs to the function. - TF_RETURN_IF_ERROR((*out)->AddInputList(inputs)); - - absl::Span captures( - reinterpret_cast(captures_.data()), - captures_.size()); - - // Adding the captures of the function. - TF_RETURN_IF_ERROR((*out)->AddInputList(captures)); - return Status(); +Status TFConcreteFunction::MakeCallOp( + absl::Span inputs, ImmediateOpPtr* out) const { + return func_->MakeCallOp(inputs, out); } } // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h b/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h index d38f3546f91..edc26f4d5aa 100644 --- a/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h +++ b/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h @@ -27,7 +27,7 @@ limitations under the License. #include "tensorflow/c/eager/immediate_execution_tensor_handle.h" #include "tensorflow/c/experimental/saved_model/core/concrete_function.h" #include "tensorflow/c/experimental/saved_model/core/function_metadata.h" -#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/protobuf/saved_object_graph.pb.h" @@ -58,26 +58,22 @@ class TFConcreteFunction : public ConcreteFunction { std::unique_ptr* out); // This method returns the "Call" Op used to execute the function. - Status GetCallOp(absl::Span inputs, - ImmediateOpPtr* out) override; + Status MakeCallOp(absl::Span inputs, + ImmediateOpPtr* out) const override; const FunctionMetadata& GetFunctionMetadata() const override; - ~TFConcreteFunction() override; + ~TFConcreteFunction() override = default; private: - TFConcreteFunction(const std::string& name, - std::vector captures, - FunctionMetadata metadata, ImmediateExecutionContext* ctx); + TFConcreteFunction(std::unique_ptr func, + FunctionMetadata metadata); TFConcreteFunction(const TFConcreteFunction&) = delete; TFConcreteFunction& operator=(const TFConcreteFunction&) = delete; - // Name of the FunctionDef corresponding to this TFConcreteFunction - std::string name_; - std::vector captures_; + std::unique_ptr func_; FunctionMetadata metadata_; - ImmediateExecutionContext* ctx_; }; } // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function_revival_state.h b/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function_revival_state.h new file mode 100644 index 00000000000..3dd7a6eecc4 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function_revival_state.h @@ -0,0 +1,61 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_TF_CONCRETE_FUNCTION_REVIVAL_STATE_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_TF_CONCRETE_FUNCTION_REVIVAL_STATE_H_ + +#include +#include + +#include "absl/types/optional.h" +#include "tensorflow/c/eager/immediate_execution_context.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/protobuf/saved_object_graph.pb.h" + +namespace tensorflow { + +// TFConcreteFunctionRevivalState wraps the state needed for building a +// TF_ConcreteFunction. This is mainly used in PartiallyRevivedObjects, which +// wraps partially constructed Function and Resource objects. +struct TFConcreteFunctionRevivalState { + // Index of the node in the SavedObjectGraph it was loaded from. + int node_id; + + // Pointer to the original functiondef. fdef_ is guaranteed to be + // non-null. + const FunctionDef* fdef; + + // TensorHandle captures for this funtion + std::vector captures; + + // SavedConcreteFunction contains much of the metadata of the expected "types" + // of the inputs and outputs of a function. + // Note(bmzhao): saved_concrete_func_ is guaranteed to be non-null. + const SavedConcreteFunction* saved_concrete_func; + + // This field is only present on TF2 ConcreteFunctions, and is useful for + // determining the original argument *names* of the function, (since the + // "canonicalized_input_signature" may append extra uniquifying integers). + // However, SavedBareConcreteFunctions do not have a FunctionSpec. + // Note(bmzhao): if function_spec_.has_value(), *function_spec_ is guaranteed + // to be non-null. + absl::optional function_spec; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_TF_CONCRETE_FUNCTION_REVIVAL_STATE_H_ diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.cc b/tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.cc new file mode 100644 index 00000000000..ab1745dcd47 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.cc @@ -0,0 +1,64 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.h" + +#include +#include + +#include "absl/types/span.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/immediate_execution_operation.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h" +#include "tensorflow/core/common_runtime/eager/context.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/protobuf/saved_object_graph.pb.h" +#include "tensorflow/core/protobuf/struct.pb.h" + +namespace tensorflow { + +TFSignatureDefFunction::TFSignatureDefFunction( + std::unique_ptr func, + SignatureDefFunctionMetadata metadata) + : func_(std::move(func)), metadata_(std::move(metadata)) {} + +Status TFSignatureDefFunction::Create( + const FunctionDef* function_def, + std::vector captures, + SignatureDefFunctionMetadata metadata, ImmediateExecutionContext* ctx, + std::unique_ptr* out) { + std::unique_ptr func; + TF_RETURN_IF_ERROR(FlatTensorFunction::Create( + function_def, std::move(captures), ctx, &func)); + + out->reset(new TFSignatureDefFunction(std::move(func), std::move(metadata))); + return Status(); +} + +const SignatureDefFunctionMetadata& +TFSignatureDefFunction::GetFunctionMetadata() const { + return metadata_; +} + +Status TFSignatureDefFunction::MakeCallOp( + absl::Span inputs, ImmediateOpPtr* out) const { + return func_->MakeCallOp(inputs, out); +} + +} // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.h b/tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.h new file mode 100644 index 00000000000..7b564185b8b --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.h @@ -0,0 +1,85 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_SIGNATURE_DEF_FUNCTION_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_SIGNATURE_DEF_FUNCTION_H_ + +#include +#include +#include +#include +#include + +#include "tensorflow/c/eager/immediate_execution_context.h" +#include "tensorflow/c/eager/immediate_execution_operation.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h" +#include "tensorflow/c/experimental/saved_model/core/signature_def_function.h" +#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/protobuf/saved_object_graph.pb.h" + +namespace tensorflow { + +// This is the TF eager runtime implementation of SignatureDefFunction (separate +// from the TFRT implementation). The user-facing API of SignatureDefFunctions +// and their semantic differences from ConcreteFunction are described here: +// https://github.com/tensorflow/tensorflow/blob/e2db60c9d9598ebae0b7741587ce6f5d473584d9/tensorflow/cc/saved_model/experimental/public/signature_def_function.h#L30-L59 +// Additional implementation notes are available here: +// https://github.com/tensorflow/tensorflow/blob/e2db60c9d9598ebae0b7741587ce6f5d473584d9/tensorflow/c/experimental/saved_model/core/signature_def_function.h#L31-L48 +class TFSignatureDefFunction : public SignatureDefFunction { + public: + // Factory function for creating a TFSignatureDefFunction. + // + // Params: + // function_def - The function_def associated with the created + // TFSignatureDefFunction. TFSignatureDefFunction will + // register this function_def with `ctx` on creation, and + // de-register it on destruction. function_def must be + // non-null, but otherwise has no lifetime requirements. + // captures - The captured TensorHandles associated with this + // TFConcreteFunction. + // metadata - FunctionMetadata associated with this TFSignatureDefFunction. + // ctx - A handle to the Tensorflow runtime. This MUST be non-null and + // outlive TFSignatureDefFunction. + // out - The output TFSignatureDefFunction. + static Status Create(const FunctionDef* function_def, + std::vector captures, + SignatureDefFunctionMetadata metadata, + ImmediateExecutionContext* ctx, + std::unique_ptr* out); + + // This method creates a "Call" Op used to execute the function. + Status MakeCallOp(absl::Span inputs, + ImmediateOpPtr* out) const override; + + const SignatureDefFunctionMetadata& GetFunctionMetadata() const override; + + ~TFSignatureDefFunction() override = default; + + private: + TFSignatureDefFunction(std::unique_ptr func, + SignatureDefFunctionMetadata metadata); + + TFSignatureDefFunction(const TFSignatureDefFunction&) = delete; + TFSignatureDefFunction& operator=(const TFSignatureDefFunction&) = delete; + + std::unique_ptr func_; + SignatureDefFunctionMetadata metadata_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_SIGNATURE_DEF_FUNCTION_H_ diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function_revival_state.h b/tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function_revival_state.h new file mode 100644 index 00000000000..ac1b20e474b --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function_revival_state.h @@ -0,0 +1,55 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_TF_SIGNATURE_DEF_FUNCTION_REVIVAL_STATE_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_TF_SIGNATURE_DEF_FUNCTION_REVIVAL_STATE_H_ + +#include +#include + +#include "absl/types/optional.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/protobuf/saved_object_graph.pb.h" + +namespace tensorflow { + +// FunctionBuilder wraps the state needed for building a SignatureDefFunction. +// This is mainly used in PartiallyRevivedObjects, which wraps partially +// constructed Function and Resource objects. +struct TFSignatureDefFunctionRevivalState { + // Index of the node in the SavedObjectGraph it was loaded from. + int node_id = 0; + + // Pointer to the original functiondef. fdef_ is guaranteed to be + // non-null. + const FunctionDef* fdef = nullptr; + + // SavedConcreteFunction contains much of the metadata of the expected "types" + // of the inputs and outputs of a function. + // Note(bmzhao): saved_concrete_func_ is guaranteed to be non-null. + const SavedConcreteFunction* saved_concrete_func = nullptr; + + // The name of the SignatureDef key. + std::string signature_key; + + // TensorHandle captures for this funtion + std::vector captures; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_TF_SIGNATURE_DEF_FUNCTION_REVIVAL_STATE_H_ diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/variable.cc b/tensorflow/c/experimental/saved_model/core/revived_types/variable.cc index a212c25bd28..2ede228e4ed 100644 --- a/tensorflow/c/experimental/saved_model/core/revived_types/variable.cc +++ b/tensorflow/c/experimental/saved_model/core/revived_types/variable.cc @@ -20,8 +20,10 @@ limitations under the License. #include "tensorflow/c/eager/immediate_execution_tensor_handle.h" #include "tensorflow/c/experimental/saved_model/core/ops/variable_ops.h" #include "tensorflow/core/common_runtime/eager/context.h" +#include "tensorflow/core/common_runtime/eager/tensor_handle.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/logging.h" @@ -62,15 +64,53 @@ Status Variable::ReadValue(ImmediateTensorHandlePtr* out) { return internal::ReadVariable(ctx_, handle_.get(), dtype_, out); } -Status Variable::CreateUninitialized(ImmediateExecutionContext* ctx, - DataType dtype, TensorShape shape, - absl::optional name, - const char* raw_device_name, - std::unique_ptr* output) { +Status Variable::CreateUninitialized( + ImmediateExecutionContext* ctx, DataType dtype, TensorShape shape, + absl::optional name, const char* raw_device_name, + const std::vector& component_devices, + std::unique_ptr* output) { ImmediateTensorHandlePtr handle; - TF_RETURN_IF_ERROR(internal::CreateUninitializedResourceVariable( - ctx, dtype, shape, raw_device_name, &handle)); + if (component_devices.empty()) { + TF_RETURN_IF_ERROR(internal::CreateUninitializedResourceVariable( + ctx, dtype, shape, raw_device_name, &handle)); + output->reset( + new Variable(ctx, dtype, shape, std::move(name), std::move(handle))); + return Status(); + } + + if (!tensorflow::isa(ctx)) { + return errors::InvalidArgument( + "Can only load distributed variables with EagerContext."); + } + + EagerContext* eager_ctx = reinterpret_cast(ctx); + + std::vector handles; + for (const auto& device : component_devices) { + ImmediateTensorHandlePtr handlePtr; + TF_RETURN_IF_ERROR(internal::CreateUninitializedResourceVariable( + ctx, dtype, shape, device.empty() ? nullptr : device.c_str(), + &handlePtr)); + if (!tensorflow::isa(handlePtr.get())) { + return errors::Internal("Returned replica handle has unsupported type."); + } + handles.push_back(reinterpret_cast(handlePtr.release())); + } + TensorHandle* packed_handle; + TF_RETURN_IF_ERROR(TensorHandle::CreatePackedHandle( + std::move(handles), eager_ctx, &packed_handle)); + // The call to `CreatePackedHandle` incremented the handles' reference count, + // which we must now decrement to make the packed handle the owner of those + // handles. We can't loop through the `handles` vector because it was + // `std::move`d in the call above. + for (int i = 0; i != packed_handle->NumPackedHandles(); ++i) { + TensorHandle* component; + TF_RETURN_IF_ERROR(packed_handle->ExtractPackedHandle(i, &component)); + component->Unref(); + } + + handle.reset(packed_handle); output->reset( new Variable(ctx, dtype, shape, std::move(name), std::move(handle))); return Status(); diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/variable.h b/tensorflow/c/experimental/saved_model/core/revived_types/variable.h index 13f56fda5f3..6d630b54562 100644 --- a/tensorflow/c/experimental/saved_model/core/revived_types/variable.h +++ b/tensorflow/c/experimental/saved_model/core/revived_types/variable.h @@ -34,11 +34,11 @@ class Variable : public TensorHandleConvertible { public: // Creates an uninitialized resource variable. Note that a caller must // call "assign" to associate a value with the variable. - static Status CreateUninitialized(ImmediateExecutionContext* ctx, - DataType dtype, TensorShape shape, - absl::optional name, - const char* raw_device_name, - std::unique_ptr* output); + static Status CreateUninitialized( + ImmediateExecutionContext* ctx, DataType dtype, TensorShape shape, + absl::optional name, const char* raw_device_name, + const std::vector& component_devices, + std::unique_ptr* output); // The dtype of the underlying variable. DataType dtype(); diff --git a/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc b/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc index e79fd8d7001..2a4297e2b67 100644 --- a/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc +++ b/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc @@ -17,14 +17,22 @@ limitations under the License. #include #include +#include +#include #include "absl/strings/str_split.h" +#include "absl/types/optional.h" #include "tensorflow/c/experimental/saved_model/core/function_metadata.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/restored_resource_revival_state.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function_revival_state.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function_revival_state.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h" #include "tensorflow/c/tf_tensor_internal.h" +#include "tensorflow/cc/saved_model/loader_util.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/protobuf.h" @@ -40,6 +48,83 @@ namespace { using StructuredValueDictEntry = protobuf::MapPair; +// Maps from a Nodedef's name to its corresponding AttrValues, for a given +// Graphdef +using NodeAttrMap = + gtl::FlatMap; + +// Maps from a FunctionDef's name to FunctionDef, for a given FunctionDefLibrary +using FunctionDefMap = gtl::FlatMap; + +// Looks up a SavedConstant's associated tensorproto from the NodeAttrMap and +// returns a tensorflow::Constant. +Status ConstantFromSavedConstant( + ImmediateExecutionContext* ctx, + const tensorflow::SavedConstant& saved_constant, + const NodeAttrMap& node_attr_map, std::unique_ptr* output) { + const std::string& const_op_name = saved_constant.operation(); + const auto& node_name_and_attrs = node_attr_map.find(const_op_name); + if (node_name_and_attrs == node_attr_map.end()) { + return errors::FailedPrecondition( + "Unable to find Const operation with name'", const_op_name, + "' in SavedModel graphdef"); + } + const AttrValueMap* attrs = node_name_and_attrs->second; + const auto& attr_name_and_value = attrs->find("value"); + if (attr_name_and_value == attrs->end()) { + return errors::FailedPrecondition("Unable to find Const operation '", + const_op_name, "'s value attribute"); + } + const TensorProto& tensor_proto = attr_name_and_value->second.tensor(); + return internal::TensorProtoToConstant(ctx, tensor_proto, output); +} + +// Finds the "signatures" object in the object graph, and fills a mapping of +// each signature's name to the corresponding function's node in the object +// graph. +Status GetSignaturesMap(const SavedObjectGraph& saved_objects, + gtl::FlatMap* signatures_map) { + if (saved_objects.nodes().empty()) { + return errors::FailedPrecondition("Saved Object Graph was empty."); + } + const SavedObject& root = saved_objects.nodes(0); + const SavedObject* signatures = nullptr; + for (const auto& child : root.children()) { + if (child.local_name() == "signatures") { + if (child.node_id() >= saved_objects.nodes().size()) { + return errors::FailedPrecondition( + "Signature object had child node id ", child.node_id(), + " which exceeds the size of the set of nodes"); + } + signatures = &saved_objects.nodes(child.node_id()); + } + } + + // Some basic sanity checks that this object is actually our "signatures" map + if (signatures == nullptr) { + // This is where the "signatures" attribute is always set: + // https://github.com/tensorflow/tensorflow/blob/a2c542a0d83227568f9214a2af9a38ae3625976f/tensorflow/python/saved_model/save.py#L1106-L1109 + return errors::FailedPrecondition( + "SavedObjectGraph's root object must have a child 'signatures' object"); + } + if (signatures->kind_case() != SavedObject::kUserObject) { + return errors::FailedPrecondition( + "Signatures must be a SavedObject of type UserObject."); + } + if (signatures->user_object().identifier() != "signature_map") { + // This is where the string comes from: + // https://github.com/tensorflow/tensorflow/blob/c59af2913aaec235d883f50428efef1086f4c0e6/tensorflow/python/saved_model/signature_serialization.py#L220 + return errors::FailedPrecondition( + "Signatures SavedObject must have identifier 'signature_map'."); + } + + for (const auto& child : signatures->children()) { + (*signatures_map)[child.local_name()] = child.node_id(); + } + return Status(); +} + // Perform some basic sanity checks on SavedConcreteFunction's input and // output signatures with respect to the corresponding FunctionDef's input // and output args. @@ -98,8 +183,37 @@ Status ValidateSavedFunctionCompatibleWithFunctionDef( return Status(); } +Status ValidateSingleConcreteFunction(const SavedFunction& saved_function) { + // We only allow loading functions that have an annotated input signature, + // which means there is 1:1 correspondence between tf.function + // <=> SavedFunction <=> SavedConcreteFunction <=> FunctionDef. This is + // the same restriction that MLIR has: + // https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc#L2677-L2707 + if (saved_function.concrete_functions_size() != 1) { + return errors::FailedPrecondition( + "Only tf.functions annotated with an input signature are supported " + "by SavedModelAPI. This means that there should only be a single " + "ConcreteFunction per tf.function"); + } + return Status(); +} + } // namespace +Status LoadSavedAsset(ImmediateExecutionContext* ctx, const SavedAsset& asset, + const std::string& saved_model_dir, + absl::Span assets, + std::unique_ptr* output) { + int asset_index = asset.asset_file_def_index(); + if (asset_index >= assets.size()) { + return errors::FailedPrecondition( + "SavedAsset contained asset index ", asset_index, + " but AssetFileDef only contains ", assets.size(), " # of assets"); + } + const std::string& asset_filename = assets[asset_index].filename(); + return Asset::Create(ctx, saved_model_dir, asset_filename, output); +} + Status TensorProtoToConstant(ImmediateExecutionContext* ctx, const TensorProto& proto, std::unique_ptr* output) { @@ -121,10 +235,17 @@ Status LoadSavedVariable(ImmediateExecutionContext* ctx, const std::string& name = variable.name(); tensorflow::TensorShape shape(variable.shape()); tensorflow::DataType dtype = variable.dtype(); + std::vector component_devices; + + for (const auto& component : + variable.experimental_distributed_variable_components()) { + component_devices.push_back(component.device()); + } TF_RETURN_IF_ERROR(Variable::CreateUninitialized( ctx, dtype, shape, name, - variable.device().empty() ? nullptr : variable.device().c_str(), output)); + variable.device().empty() ? nullptr : variable.device().c_str(), + component_devices, output)); return Status(); } @@ -210,16 +331,17 @@ Status FlattenSignature(const StructuredValue& signature, } } -const SavedObject* FindNodeAtPath(StringPiece path, - const SavedObjectGraph& object_graph) { +absl::optional FindNodeAtPath(StringPiece path, + const SavedObjectGraph& object_graph) { const auto& nodes = object_graph.nodes(); if (nodes.empty()) { - return nullptr; + return absl::nullopt; } // Starting from the root, iterate through the saved object graph, matching // object names as we go. - const SavedObject* current_node = &nodes.Get(0); + int node_id = 0; + const SavedObject* current_node = &nodes.Get(node_id); for (absl::string_view object_name : absl::StrSplit(path, '.')) { auto child_node_iter = std::find_if( @@ -229,29 +351,28 @@ const SavedObject* FindNodeAtPath(StringPiece path, return object_name == obj.local_name(); }); if (child_node_iter == current_node->children().end()) { - return nullptr; + return absl::nullopt; } - current_node = &nodes.Get(child_node_iter->node_id()); + + node_id = child_node_iter->node_id(); + current_node = &nodes.Get(node_id); } - return current_node; + return node_id; } -std::unordered_map -NodeToAttrMap(const tensorflow::GraphDef& graphdef) { - std::unordered_map - result; +gtl::FlatMap NodeToAttrMap( + const tensorflow::GraphDef& graphdef) { + gtl::FlatMap result; for (const tensorflow::NodeDef& node : graphdef.node()) { result[node.name()] = &node.attr(); } return result; } -std::unordered_map +gtl::FlatMap FunctionNameToFunctionDefMap(const FunctionDefLibrary& library) { - std::unordered_map + gtl::FlatMap result; for (const FunctionDef& function_def : library.function()) { result[function_def.signature().name()] = &function_def; @@ -259,5 +380,156 @@ FunctionNameToFunctionDefMap(const FunctionDefLibrary& library) { return result; } +Status PartiallyReviveSavedModelObjects(const MetaGraphDef& metagraph, + ImmediateExecutionContext* context, + const std::string& directory, + PartiallyRevivedObjects* objects) { + // This is needed to restore "Constant" nodes by looking up their + // "Value" attribute. + NodeAttrMap node_attr_map = NodeToAttrMap(metagraph.graph_def()); + + // These are needed for creating "Assets", by looking up their filenames. + std::vector assets; + TF_RETURN_IF_ERROR(GetAssetFileDefs(metagraph, &assets)); + + // Signatures are needed for determining whether a function is a + // SignatureDefFunction or not. + gtl::FlatMap signatures_map; + TF_RETURN_IF_ERROR( + GetSignaturesMap(metagraph.object_graph_def(), &signatures_map)); + + gtl::FlatMap reversed_signatures_map; + reversed_signatures_map.reserve(signatures_map.size()); + for (const auto& signature_key_and_node : signatures_map) { + reversed_signatures_map.emplace(signature_key_and_node.second, + signature_key_and_node.first); + } + + // FunctionDefs are needed to help construct + // TFConcreteFunction/SignatureDefFunctions + const FunctionDefMap function_def_map = + internal::FunctionNameToFunctionDefMap(metagraph.graph_def().library()); + + // Iterate through all the saved objects, restoring objects (if we can) as we + // go. For objects that dependencies on other objects (resources/functions), + // we partially initialize "builders" that correspond to their currently known + // state, and gradually fill them out in subsequent passes. + for (int i = 0; i < metagraph.object_graph_def().nodes_size(); ++i) { + const SavedObject& node = metagraph.object_graph_def().nodes(i); + if (node.kind_case() == SavedObject::kVariable) { + std::unique_ptr variable; + TF_RETURN_IF_ERROR( + LoadSavedVariable(context, node.variable(), &variable)); + objects->variables[i] = std::move(variable); + } else if (node.kind_case() == SavedObject::kConstant) { + std::unique_ptr constant; + TF_RETURN_IF_ERROR(ConstantFromSavedConstant(context, node.constant(), + node_attr_map, &constant)); + objects->constants[i] = std::move(constant); + } else if (node.kind_case() == SavedObject::kAsset) { + std::unique_ptr asset; + TF_RETURN_IF_ERROR( + LoadSavedAsset(context, node.asset(), directory, assets, &asset)); + objects->assets[i] = std::move(asset); + } else if (node.kind_case() == SavedObject::kResource) { + RestoredResourceRevivalState resource_revival_state; + // We'll set the resource's functions in a subsequent pass, once we get + // all functions in a partially revived state. + resource_revival_state.device = node.resource().device(); + objects->restored_resources[i] = std::move(resource_revival_state); + } else if (node.kind_case() == SavedObject::kFunction) { + // Get the SavedFunction node and validate it has a single concrete func. + const SavedFunction& saved_function = node.function(); + TF_RETURN_IF_ERROR(ValidateSingleConcreteFunction(saved_function)); + + // Retrieve related function information. + const std::string& function_name = saved_function.concrete_functions(0); + const FunctionDef* function_def = function_def_map.at(function_name); + const SavedConcreteFunction& saved_concrete_func = + metagraph.object_graph_def().concrete_functions().at(function_name); + const FunctionSpec& function_spec = saved_function.function_spec(); + + // Construct either a SignatureDefFunctionBuilder or a + // ConcreteFunctionBuilder, depending on whether this node was a child + // of the "signatures" attribute from root object. + auto reverse_signature_iter = reversed_signatures_map.find(i); + if (reverse_signature_iter != reversed_signatures_map.end()) { + TFSignatureDefFunctionRevivalState func_revival_state; + func_revival_state.node_id = i; + func_revival_state.fdef = function_def; + func_revival_state.saved_concrete_func = &saved_concrete_func; + func_revival_state.signature_key = reverse_signature_iter->second; + objects->signature_def_functions[i] = std::move(func_revival_state); + } else { + TFConcreteFunctionRevivalState func_revival_state; + func_revival_state.node_id = i; + func_revival_state.fdef = function_def; + func_revival_state.saved_concrete_func = &saved_concrete_func; + func_revival_state.function_spec = &function_spec; + objects->concrete_functions[i] = std::move(func_revival_state); + } + } else if (node.kind_case() == SavedObject::kBareConcreteFunction) { + const SavedBareConcreteFunction& bare_cf = node.bare_concrete_function(); + + // Retrieve related function information. + const std::string& function_name = bare_cf.concrete_function_name(); + const FunctionDef* function_def = function_def_map.at(function_name); + const SavedConcreteFunction& saved_concrete_func = + metagraph.object_graph_def().concrete_functions().at(function_name); + + // Check whether this is a SignatureDefFunction, or not. + auto reverse_signature_iter = reversed_signatures_map.find(i); + if (reverse_signature_iter != reversed_signatures_map.end()) { + TFSignatureDefFunctionRevivalState func_revival_state; + func_revival_state.node_id = i; + func_revival_state.fdef = function_def; + func_revival_state.saved_concrete_func = &saved_concrete_func; + func_revival_state.signature_key = reverse_signature_iter->second; + objects->signature_def_functions[i] = std::move(func_revival_state); + } else { + TFConcreteFunctionRevivalState func_revival_state; + func_revival_state.node_id = i; + func_revival_state.fdef = function_def; + func_revival_state.saved_concrete_func = &saved_concrete_func; + objects->concrete_functions[i] = std::move(func_revival_state); + } + } + } + + // Now that we've partially restored all functions, we can have resources + // point to them + for (auto& node_and_resource_revival_state : objects->restored_resources) { + int node_id = node_and_resource_revival_state.first; + const SavedObjectGraph& obj_graph = metagraph.object_graph_def(); + const SavedObject& node = obj_graph.nodes(node_id); + RestoredResourceRevivalState& resource = + node_and_resource_revival_state.second; + for (const TrackableObjectGraph::TrackableObject::ObjectReference& child : + node.children()) { + int child_node_id = child.node_id(); + // Note(bmzhao): The expected functions saved by a resource object are: + // "_create_resource", "_initialize", and "_destroy_resource". + // https://github.com/tensorflow/tensorflow/blob/ad66f588c1666ade8051feb42811fa27b285271c/tensorflow/python/training/tracking/tracking.py#L277-L281 + if (child.local_name() == "_create_resource" && + obj_graph.nodes(child.node_id()).kind_case() == + SavedObject::kFunction) { + resource.create_resource = &objects->concrete_functions[child_node_id]; + } else if (child.local_name() == "_initialize" && + obj_graph.nodes(child.node_id()).kind_case() == + SavedObject::kFunction) { + resource.initialize = &objects->concrete_functions[child_node_id]; + } else if (child.local_name() == "_destroy_resource" && + obj_graph.nodes(child.node_id()).kind_case() == + SavedObject::kFunction) { + resource.destroy_resource = &objects->concrete_functions[child_node_id]; + } + } + } + + objects->signatures_map = std::move(signatures_map); + + return Status(); +} + } // namespace internal } // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/saved_model_utils.h b/tensorflow/c/experimental/saved_model/core/saved_model_utils.h index 68bfbe32222..db45e28087f 100644 --- a/tensorflow/c/experimental/saved_model/core/saved_model_utils.h +++ b/tensorflow/c/experimental/saved_model/core/saved_model_utils.h @@ -22,15 +22,21 @@ limitations under the License. #include #include +#include "absl/types/optional.h" +#include "absl/types/span.h" #include "tensorflow/c/eager/immediate_execution_context.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/asset.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/stringpiece.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" #include "tensorflow/core/protobuf/saved_object_graph.pb.h" #include "tensorflow/core/protobuf/struct.pb.h" @@ -52,6 +58,11 @@ Status LoadSavedVariable(ImmediateExecutionContext* ctx, const SavedVariable& variable, std::unique_ptr* output); +Status LoadSavedAsset(ImmediateExecutionContext* ctx, const SavedAsset& asset, + const std::string& saved_model_dir, + absl::Span assets, + std::unique_ptr* output); + // Creates a TFConcreteFunction from a SavedConcreteFunction. Status LoadTFConcreteFunction( const SavedConcreteFunction& saved_concrete_function, @@ -67,24 +78,30 @@ Status LoadTFConcreteFunction( Status FlattenSignature(const StructuredValue& signature, std::vector* flattened_specs); -// Find the SavedObject in `object_graph` at location `path`. `path` must be +// Find the node id in `object_graph` at location `path`. `path` must be // a dot-delimited string of object names relative to the root object. If no -// object is found, returns nullptr. Callers must ensure `object_graph` -// outlives the returned pointer. -const SavedObject* FindNodeAtPath(StringPiece path, - const SavedObjectGraph& object_graph); +// object is found, returns absl::nullopt. +absl::optional FindNodeAtPath(StringPiece path, + const SavedObjectGraph& object_graph); // Maps each node in `graphdef` to its corresponding Attribute Map. // Callers must ensure that `graphdef` outlives the returned map. -std::unordered_map -NodeToAttrMap(const tensorflow::GraphDef& graphdef); +gtl::FlatMap NodeToAttrMap( + const tensorflow::GraphDef& graphdef); // Maps the name of each FunctionDef in `library` to its corresponding // FunctionDef. Callers must ensure `library` outlives the returned map. -std::unordered_map +gtl::FlatMap FunctionNameToFunctionDefMap(const FunctionDefLibrary& library); +// Walks through the SavedObjectGraph in metagraph, and restores all nodes +// (except "UserDefinedObjects") with their corresponding type in +// "PartiallyRevivedObjects". +Status PartiallyReviveSavedModelObjects(const MetaGraphDef& metagraph, + ImmediateExecutionContext* context, + const std::string& directory, + PartiallyRevivedObjects* objects); + } // namespace internal } // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/saved_variable_loading_test.cc b/tensorflow/c/experimental/saved_model/core/saved_variable_loading_test.cc index 45b0ac00c9b..a5a4e900843 100644 --- a/tensorflow/c/experimental/saved_model/core/saved_variable_loading_test.cc +++ b/tensorflow/c/experimental/saved_model/core/saved_variable_loading_test.cc @@ -119,7 +119,7 @@ TEST_P(SavedVariableLoadingTest, AssignAndReadVariableSuccesful) { Status status; std::unique_ptr var; TF_EXPECT_OK(Variable::CreateUninitialized(context(), dtype, shape, - absl::nullopt, nullptr, &var)); + absl::nullopt, nullptr, {}, &var)); // Create a TensorHandle ImmediateTensorHandlePtr expected_handle = diff --git a/tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.cc b/tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.cc new file mode 100644 index 00000000000..4e455f08f49 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.cc @@ -0,0 +1,42 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h" + +namespace tensorflow { + +SignatureDefParam::SignatureDefParam(std::string name, TensorSpec spec) + : name_(std::move(name)), spec_(std::move(spec)) {} + +const std::string& SignatureDefParam::name() const { return name_; } + +const TensorSpec& SignatureDefParam::spec() const { return spec_; } + +SignatureDefFunctionMetadata::SignatureDefFunctionMetadata( + std::vector arguments, + std::vector returns) + : arguments_(std::move(arguments)), returns_(std::move(returns)) {} + +const std::vector& SignatureDefFunctionMetadata::arguments() + const { + return arguments_; +} + +const std::vector& SignatureDefFunctionMetadata::returns() + const { + return returns_; +} + +} // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h b/tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h index 5a579676d4e..e9cc0b11b00 100644 --- a/tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h +++ b/tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h @@ -16,10 +16,42 @@ limitations under the License. #ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SIGNATURE_DEF_FUNCTION_METADATA_H_ #define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SIGNATURE_DEF_FUNCTION_METADATA_H_ +#include +#include + +#include "tensorflow/c/experimental/saved_model/core/tensor_spec.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/protobuf/struct.pb.h" + namespace tensorflow { +// SignatureDefParam represents a named Tensor input or output to a +// SignatureDefFunction. +class SignatureDefParam { + public: + SignatureDefParam(std::string name, TensorSpec spec); + + const std::string& name() const; + + const TensorSpec& spec() const; + + private: + std::string name_; + TensorSpec spec_; +}; + class SignatureDefFunctionMetadata { - // TODO(bmzhao): Fill in with fields as necessary + public: + SignatureDefFunctionMetadata() = default; + SignatureDefFunctionMetadata(std::vector arguments, + std::vector returns); + + const std::vector& arguments() const; + const std::vector& returns() const; + + private: + std::vector arguments_; + std::vector returns_; }; } // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/tensor_spec.cc b/tensorflow/c/experimental/saved_model/core/tensor_spec.cc new file mode 100644 index 00000000000..4d68ec73b1b --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/tensor_spec.cc @@ -0,0 +1,38 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/experimental/saved_model/core/tensor_spec.h" + +#include + +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" + +namespace tensorflow { + +TensorSpec::TensorSpec() + : shape_(std::initializer_list()), dtype_(DT_FLOAT) {} + +TensorSpec::TensorSpec(PartialTensorShape shape, DataType dtype) + : shape_(std::move(shape)), dtype_(dtype) {} + +TensorSpec::TensorSpec(const TensorSpecProto& proto) + : shape_(proto.shape()), dtype_(proto.dtype()) {} + +const PartialTensorShape& TensorSpec::shape() const { return shape_; } + +DataType TensorSpec::dtype() const { return dtype_; } + +} // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/tensor_spec.h b/tensorflow/c/experimental/saved_model/core/tensor_spec.h new file mode 100644 index 00000000000..dcdff8900bd --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/tensor_spec.h @@ -0,0 +1,51 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TENSOR_SPEC_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TENSOR_SPEC_H_ + +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/protobuf/struct.pb.h" + +namespace tensorflow { + +// Note(bmzhao): TensorSpec deliberately does not store the "name" from a +// TensorSpecProto. From edloper@, "Names should really be associated with +// parameters, not the tensors inside those parameters. This would be +// inconsistent with the corresponding Python class, but I don't think that's +// necessarily a problem. If it turns out later that we really need a name +// attribute here, we can always add it back in; but let's see how far we can +// get without it." +class TensorSpec { + public: + // Constructs a scalar, DT_FLOAT TensorSpec + TensorSpec(); + + TensorSpec(PartialTensorShape shape, DataType dtype); + + explicit TensorSpec(const TensorSpecProto& proto); + + const PartialTensorShape& shape() const; + DataType dtype() const; + + private: + PartialTensorShape shape_; + DataType dtype_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TENSOR_SPEC_H_ diff --git a/tensorflow/c/experimental/saved_model/core/test_utils.cc b/tensorflow/c/experimental/saved_model/core/test_utils.cc index d551919ea94..988f7e382a8 100644 --- a/tensorflow/c/experimental/saved_model/core/test_utils.cc +++ b/tensorflow/c/experimental/saved_model/core/test_utils.cc @@ -45,11 +45,9 @@ EagerContextPtr CreateTestingEagerContext(DeviceMgr* device_mgr) { return EagerContextPtr(new EagerContext( SessionOptions(), tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, - tensorflow::ContextMirroringPolicy::MIRRORING_NONE, /* async= */ false, /* lazy_copy_function_remote_inputs= */ false, device_mgr, /* device_mgr_owned= */ false, /* rendezvous= */ nullptr, - /* custom_kernel_creator= */ nullptr, /* cluster_flr= */ nullptr)); } diff --git a/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.cc b/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.cc index ab7052b52ed..f0990235963 100644 --- a/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.cc +++ b/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.cc @@ -17,7 +17,6 @@ limitations under the License. #include #include -#include #include #include @@ -30,6 +29,9 @@ limitations under the License. #include "tensorflow/c/experimental/saved_model/core/concrete_function.h" #include "tensorflow/c/experimental/saved_model/core/ops/restore_ops.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/revived_objects.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h" @@ -45,6 +47,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/casts.h" #include "tensorflow/core/platform/errors.h" @@ -61,139 +64,15 @@ limitations under the License. namespace tensorflow { // Maps from a FunctionDef's name to FunctionDef, for a given FunctionDefLibrary -using FunctionDefMap = - std::unordered_map; - -// Maps from a Nodedef's name to its corresponding AttrValues, for a given -// Graphdef -using NodeAttrMap = - std::unordered_map; - -// Maps from Node ID to an "Revived Object" implementing -// "TensorHandleConvertible" -using RevivedObjectMap = - std::unordered_map>; +using FunctionDefMap = gtl::FlatMap; // Maps from a functiondef's name to the corresponding "TFConcreteFunction" -using ConcreteFunctionMap = - std::unordered_map>; +using FlatTensorFunctionMap = + gtl::FlatMap>; namespace { -Status ConstantFromSavedConstant( - ImmediateExecutionContext* ctx, - const tensorflow::SavedConstant& saved_constant, - const NodeAttrMap& node_attr_map, std::unique_ptr* output) { - const std::string& const_op_name = saved_constant.operation(); - const auto& node_name_and_attrs = node_attr_map.find(const_op_name); - if (node_name_and_attrs == node_attr_map.end()) { - return errors::FailedPrecondition( - "Unable to find Const operation with name'", const_op_name, - "' in SavedModel graphdef"); - } - const AttrValueMap* attrs = node_name_and_attrs->second; - const auto& attr_name_and_value = attrs->find("value"); - if (attr_name_and_value == attrs->end()) { - return errors::FailedPrecondition("Unable to find Const operation '", - const_op_name, "'s value attribute"); - } - const TensorProto& tensor_proto = attr_name_and_value->second.tensor(); - return internal::TensorProtoToConstant(ctx, tensor_proto, output); -} - -// Restores all non-function objects in the SavedModel's object graph. -// This function walks through the metagraph's saved object graph, and -// constructs revived versions of SavedVariable, SavedConstant, SavedAsset, and -// SavedResources. These are returned via the `out` parameter. -Status ReviveObjects( - const MetaGraphDef& metagraph, ImmediateExecutionContext* context, - std::unordered_map>* - revived_objects) { - // This is needed to restore "Constant" nodes by looking up their - // "Value" attribute. - NodeAttrMap node_attr_map = internal::NodeToAttrMap(metagraph.graph_def()); - - // Iterate through all the saved objects, restoring objects as we go. - // We don't recreate functions until all other objects have been created. - for (int i = 0; i < metagraph.object_graph_def().nodes_size(); ++i) { - const SavedObject& node = metagraph.object_graph_def().nodes(i); - if (node.kind_case() == SavedObject::kVariable) { - std::unique_ptr variable; - TF_RETURN_IF_ERROR( - internal::LoadSavedVariable(context, node.variable(), &variable)); - (*revived_objects)[i] = std::move(variable); - } else if (node.kind_case() == SavedObject::kConstant) { - std::unique_ptr constant; - TF_RETURN_IF_ERROR(ConstantFromSavedConstant(context, node.constant(), - node_attr_map, &constant)); - (*revived_objects)[i] = std::move(constant); - } else if (node.kind_case() == SavedObject::kAsset) { - // TODO(bmzhao): Implement Asset C++ class. This should be just recreating - // the full path to the asset file: - // https://github.com/tensorflow/tensorflow/blob/6a0bdbdb7c48a3491ae1277083ae3dafb4ab4d7a/tensorflow/python/saved_model/load.py#L395-L396 - // and storing it as a string tensor: - // https://github.com/tensorflow/tensorflow/blob/6a0bdbdb7c48a3491ae1277083ae3dafb4ab4d7a/tensorflow/python/training/tracking/tracking.py#L324-L325 - return errors::Unimplemented("SavedAsset loading is not implemented yet"); - } else if (node.kind_case() == SavedObject::kResource) { - // TODO(bmzhao): Figure out how resource loading works and implement it - return errors::Unimplemented( - "SavedResource loading is not implemented yet"); - } - } - return Status(); -} - -Status ReviveFunctions(const MetaGraphDef& metagraph, - const RevivedObjectMap& revived_objects, - ImmediateExecutionContext* context, - ConcreteFunctionMap* restored_functions) { - const FunctionDefMap function_def_map = - internal::FunctionNameToFunctionDefMap(metagraph.graph_def().library()); - - // Iterate through all objects, only examining functions. - for (const SavedObject& node : metagraph.object_graph_def().nodes()) { - if (node.kind_case() == SavedObject::kBareConcreteFunction) { - const std::string& function_name = - node.bare_concrete_function().concrete_function_name(); - - const SavedConcreteFunction& saved_concrete_function = - metagraph.object_graph_def().concrete_functions().at(function_name); - - const FunctionDef* function_def = function_def_map.at(function_name); - std::unique_ptr concrete_function; - TF_RETURN_IF_ERROR(internal::LoadTFConcreteFunction( - saved_concrete_function, function_def, revived_objects, context, - &concrete_function)); - (*restored_functions)[function_name] = std::move(concrete_function); - } else if (node.kind_case() == SavedObject::kFunction) { - // We only allow loading functions that have an annotated input signature, - // which means there is 1:1 correspondence between tf.function - // <=> SavedFunction <=> SavedConcreteFunction <=> FunctionDef. This is - // the same restriction that MLIR has: - // https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc#L2677-L2707 - const SavedFunction& saved_function = node.function(); - if (saved_function.concrete_functions_size() != 1) { - return errors::FailedPrecondition( - "Only tf.functions annotated with an input signature are supported " - "by SavedModelAPI. This means that there should only be a single " - "ConcreteFunction per tf.function"); - } - const std::string& function_name = saved_function.concrete_functions(0); - const SavedConcreteFunction& saved_concrete_function = - metagraph.object_graph_def().concrete_functions().at(function_name); - - const FunctionDef* function_def = function_def_map.at(function_name); - - std::unique_ptr concrete_function; - TF_RETURN_IF_ERROR(internal::LoadTFConcreteFunction( - saved_concrete_function, function_def, revived_objects, context, - &concrete_function)); - (*restored_functions)[function_name] = std::move(concrete_function); - } - } - return Status(); -} const TrackableObjectGraph::TrackableObject::SerializedTensor* FindSerializedTensorInTrackable( @@ -230,7 +109,7 @@ FindSerializedTensorInTrackable( // overridden "restore" method: // https://github.com/tensorflow/tensorflow/blob/ddc1bbad3dfd4a089eb96014f26cc16664b1b2f8/tensorflow/python/training/saving/saveable_object.py#L85 Status RestoreCheckpoint(SavedModelV2Bundle* bundle, - const RevivedObjectMap& revived_objects, + const RevivedObjects& revived_objects, const std::string& directory, ImmediateExecutionContext* context) { // TODO(bmzhao): Batch up all the restores into a single restore op per @@ -250,8 +129,7 @@ Status RestoreCheckpoint(SavedModelV2Bundle* bundle, return Status::OK(); } - Variable* variable = - down_cast(revived_objects.at(node).get()); + Variable* variable = revived_objects.variables.at(node).get(); // Restore the tensor's value from the checkpoint const TrackableObjectGraph::TrackableObject::SerializedTensor* @@ -264,6 +142,12 @@ Status RestoreCheckpoint(SavedModelV2Bundle* bundle, } const std::string& checkpoint_key = attribute->checkpoint_key(); + if (!bundle->variable_reader()->Contains(checkpoint_key)) { + LOG(WARNING) << "No checkpoint entry found for " << checkpoint_key + << ". Variable will be uninitialized."; + return Status(); + } + std::string variables_path_prefix = io::JoinPath(directory, kSavedModelVariablesDirectory, kSavedModelVariablesFilename); @@ -279,58 +163,86 @@ Status RestoreCheckpoint(SavedModelV2Bundle* bundle, return Status(); } +Status InitializeAllResources(const RevivedObjects& revived) { + for (const auto& node_and_resource : revived.restored_resources) { + const RestoredResource& resource = node_and_resource.second; + TF_RETURN_IF_ERROR(resource.Initialize()); + } + return Status(); +} + } // namespace Status TFSavedModelAPI::GetFunction(const std::string& function_path, ConcreteFunction** function) { - const SavedObject* object = + absl::optional node = internal::FindNodeAtPath(function_path, bundle_.saved_object_graph()); - if (object == nullptr) { + if (!node.has_value()) { return errors::NotFound("No saved object found at path ", function_path); } - if (object->kind_case() == SavedObject::kBareConcreteFunction) { - *function = - concrete_functions_ - .at(object->bare_concrete_function().concrete_function_name()) - .get(); - } else if (object->kind_case() == SavedObject::kFunction) { - *function = - concrete_functions_.at(object->function().concrete_functions(0)).get(); - } else { - return errors::InvalidArgument(function_path, - " is not a path to a Function."); + auto function_iter = revived_objects_.concrete_functions.find(*node); + if (function_iter == revived_objects_.concrete_functions.end()) { + return errors::NotFound("No function found at path ", function_path); } + *function = function_iter->second.get(); return Status(); } Status TFSavedModelAPI::GetSignatureDefFunction( const std::string& signature_def_key, SignatureDefFunction** function) { - // TODO(bmzhao): Add support for retrieving a signaturedef function. - return errors::Unimplemented( - "Retrieving SignatureDef functions is unimplemented currently"); + auto signatures_iter = + revived_objects_.signatures_map.find(signature_def_key); + if (signatures_iter == revived_objects_.signatures_map.end()) { + return errors::NotFound("No signature with key ", signature_def_key, + " was found"); + } + int node = signatures_iter->second; + + auto function_iter = revived_objects_.signature_def_functions.find(node); + if (function_iter == revived_objects_.signature_def_functions.end()) { + return errors::Internal( + "Unable to find SignatureDefFunction associated with key ", + signature_def_key, " despite key being valid."); + } + + *function = function_iter->second.get(); + return Status(); } std::vector TFSavedModelAPI::ListFunctions() { std::vector result; - result.reserve(concrete_functions_.size()); - for (auto& index_and_function : concrete_functions_) { + result.reserve(revived_objects_.concrete_functions.size()); + for (auto& index_and_function : revived_objects_.concrete_functions) { result.push_back(index_and_function.second.get()); } return result; } -TFSavedModelAPI::TFSavedModelAPI( - const std::string& directory, SavedModelV2Bundle bundle, - std::unordered_map> - revived_objects, - std::unordered_map> - concrete_functions) +Status TFSavedModelAPI::GetVariable(const std::string& variable_path, + Variable** variable) { + absl::optional node = + internal::FindNodeAtPath(variable_path, bundle_.saved_object_graph()); + if (!node.has_value()) { + return errors::NotFound("No saved object found at path ", variable_path); + } + + auto variables_iter = revived_objects_.variables.find(*node); + if (variables_iter == revived_objects_.variables.end()) { + return errors::NotFound("No variable found at path ", variable_path); + } + + *variable = variables_iter->second.get(); + return Status(); +} + +TFSavedModelAPI::TFSavedModelAPI(const std::string& directory, + SavedModelV2Bundle bundle, + RevivedObjects revived_objects) : directory_(directory), bundle_(std::move(bundle)), - revived_objects_(std::move(revived_objects)), - concrete_functions_(std::move(concrete_functions)) {} + revived_objects_(std::move(revived_objects)) {} Status TFSavedModelAPI::Load( const std::string& directory, @@ -351,28 +263,25 @@ Status TFSavedModelAPI::Load( // This occurs in python here: // https://github.com/tensorflow/tensorflow/blob/285b5fa15405c5e2c084080f52a1818be8648079/tensorflow/python/saved_model/function_deserialization.py#L438-L454 - RevivedObjectMap revived_objects; - TF_RETURN_IF_ERROR( - ReviveObjects(bundle.meta_graph_def(), context, &revived_objects)); + // Step 1: For each node in the graph, we should initialize an object of the + // corresponding type. For objects that depend on the initialization of other + // objects (like functions which capture resources), we will initialize them + // in step 2. + PartiallyRevivedObjects partially_revived_objects; + TF_RETURN_IF_ERROR(internal::PartiallyReviveSavedModelObjects( + bundle.meta_graph_def(), context, directory, &partially_revived_objects)); - // TODO(bmzhao): When we later add support for loading resources, we need to - // handle the case where materializing a function's captures requires invoking - // other functions. This occurs when retrieving the resource handle for a - // TrackableResource: - // https://github.com/tensorflow/tensorflow/blob/f19c6efb4a8ba60e2492eedc98ef5375abb39dc7/tensorflow/python/saved_model/load.py#L240 - // https://github.com/tensorflow/tensorflow/blob/f19c6efb4a8ba60e2492eedc98ef5375abb39dc7/tensorflow/python/training/tracking/tracking.py#L233 - // This requires restoring functions in a topological sort order by capture - // dependencies. - ConcreteFunctionMap function_map; - TF_RETURN_IF_ERROR(ReviveFunctions(bundle.meta_graph_def(), revived_objects, - context, &function_map)); + RevivedObjects revived_objects; + TF_RETURN_IF_ERROR(partially_revived_objects.Build( + context, bundle.saved_object_graph(), &revived_objects)); TF_RETURN_IF_ERROR( RestoreCheckpoint(&bundle, revived_objects, directory, context)); + TF_RETURN_IF_ERROR(InitializeAllResources(revived_objects)); + out->reset(new TFSavedModelAPI(directory, std::move(bundle), - std::move(revived_objects), - std::move(function_map))); + std::move(revived_objects))); return Status(); } diff --git a/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h b/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h index fd07c09474b..bc39a974ad2 100644 --- a/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h +++ b/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h @@ -25,8 +25,10 @@ limitations under the License. #include "absl/types/optional.h" #include "tensorflow/c/eager/immediate_execution_context.h" #include "tensorflow/c/experimental/saved_model/core/concrete_function.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/revived_objects.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h" #include "tensorflow/c/experimental/saved_model/core/saved_model_api.h" #include "tensorflow/c/experimental/saved_model/core/signature_def_function.h" #include "tensorflow/cc/saved_model/bundle_v2.h" @@ -68,20 +70,15 @@ class TFSavedModelAPI : public SavedModelAPI { ~TFSavedModelAPI() override = default; + Status GetVariable(const std::string& variable_path, Variable** variable); + private: - TFSavedModelAPI( - const std::string& directory, SavedModelV2Bundle bundle, - std::unordered_map> - revived_objects, - std::unordered_map> - concrete_functions); + TFSavedModelAPI(const std::string& directory, SavedModelV2Bundle bundle, + RevivedObjects revived_objects); std::string directory_; SavedModelV2Bundle bundle_; - std::unordered_map> - revived_objects_; - std::unordered_map> - concrete_functions_; + RevivedObjects revived_objects_; }; } // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/internal/BUILD b/tensorflow/c/experimental/saved_model/internal/BUILD index c0d121a4aee..06fbc7aef0a 100644 --- a/tensorflow/c/experimental/saved_model/internal/BUILD +++ b/tensorflow/c/experimental/saved_model/internal/BUILD @@ -9,6 +9,8 @@ # Note(bmzhao): The *.cc files in this directory form the direct implementation of the # C API functions exposed in tf/c/experimental/saved_model/public/. +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") + # Note(bmzhao): All *type.h files in this directory are the internal definitions of # the opaque C types. These headers should only be visible to internal tensorflow # implementors. @@ -222,6 +224,8 @@ cc_library( ], deps = [ ":signature_def_function_metadata_type", + ":signature_def_param_list", + ":signature_def_param_list_type", "//tensorflow/c:c_api_macros", "//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata", ], @@ -238,6 +242,104 @@ cc_library( ], ) +cc_library( + name = "signature_def_param", + srcs = [ + "signature_def_param.cc", + ], + hdrs = [ + "//tensorflow/c/experimental/saved_model/public:signature_def_param.h", + ], + copts = tf_copts(), + visibility = [ + "//tensorflow/c/experimental/saved_model/public:__pkg__", + ], + deps = [ + ":signature_def_param_type", + ":tensor_spec", + ":tensor_spec_type", + "//tensorflow/c:c_api_macros", + "//tensorflow/c:tf_shape_internal", + "//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata", + ], +) + +cc_library( + name = "signature_def_param_type", + hdrs = [ + "signature_def_param_type.h", + ], + deps = [ + "//tensorflow/c:conversion_macros", + "//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata", + ], +) + +cc_library( + name = "signature_def_param_list", + srcs = [ + "signature_def_param_list.cc", + ], + hdrs = [ + "//tensorflow/c/experimental/saved_model/public:signature_def_param_list.h", + ], + copts = tf_copts(), + visibility = [ + "//tensorflow/c/experimental/saved_model/public:__pkg__", + ], + deps = [ + ":signature_def_param", + ":signature_def_param_list_type", + ":signature_def_param_type", + "//tensorflow/c:c_api_macros", + ], +) + +cc_library( + name = "signature_def_param_list_type", + hdrs = [ + "signature_def_param_list_type.h", + ], + deps = [ + "//tensorflow/c:conversion_macros", + "//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata", + ], +) + +cc_library( + name = "tensor_spec", + srcs = [ + "tensor_spec.cc", + ], + hdrs = [ + "//tensorflow/c/experimental/saved_model/public:tensor_spec.h", + ], + copts = tf_copts(), + visibility = [ + "//tensorflow/c/experimental/saved_model/public:__pkg__", + ], + deps = [ + ":tensor_spec_type", + "//tensorflow/c:c_api_macros", + "//tensorflow/c:tf_datatype", + "//tensorflow/c:tf_shape", + "//tensorflow/c:tf_shape_internal", + "//tensorflow/c/experimental/saved_model/core:tensor_spec", + ], +) + +cc_library( + name = "tensor_spec_type", + hdrs = [ + "tensor_spec_type.h", + ], + deps = [ + "//tensorflow/c:conversion_macros", + "//tensorflow/c:tf_shape_internal", + "//tensorflow/c/experimental/saved_model/core:tensor_spec", + ], +) + tf_cc_test( name = "saved_model_api_test", size = "small", @@ -245,16 +347,26 @@ tf_cc_test( "saved_model_api_test.cc", ], data = [ + "//tensorflow/c/experimental/saved_model/internal/testdata:saved_models", "//tensorflow/cc/saved_model:saved_model_half_plus_two", ], deps = [ + ":saved_model_api_type", + "//tensorflow/c:tf_datatype", + "//tensorflow/c:tf_shape", "//tensorflow/c:tf_status", "//tensorflow/c:tf_tensor", "//tensorflow/c/eager:c_api", "//tensorflow/c/eager:c_api_experimental", "//tensorflow/c/eager:c_api_test_util", + "//tensorflow/c/experimental/saved_model/core:tf_saved_model_api", "//tensorflow/c/experimental/saved_model/public:concrete_function", "//tensorflow/c/experimental/saved_model/public:saved_model_api", + "//tensorflow/c/experimental/saved_model/public:signature_def_function", + "//tensorflow/c/experimental/saved_model/public:signature_def_function_metadata", + "//tensorflow/c/experimental/saved_model/public:signature_def_param", + "//tensorflow/c/experimental/saved_model/public:signature_def_param_list", + "//tensorflow/c/experimental/saved_model/public:tensor_spec", "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", diff --git a/tensorflow/c/experimental/saved_model/internal/concrete_function.cc b/tensorflow/c/experimental/saved_model/internal/concrete_function.cc index 65c6eca5623..2beed8f4119 100644 --- a/tensorflow/c/experimental/saved_model/internal/concrete_function.cc +++ b/tensorflow/c/experimental/saved_model/internal/concrete_function.cc @@ -34,15 +34,15 @@ TF_FunctionMetadata* TF_ConcreteFunctionGetMetadata(TF_ConcreteFunction* func) { &tensorflow::unwrap(func)->GetFunctionMetadata())); } -TFE_Op* TF_ConcreteFunctionGetCallOp(TF_ConcreteFunction* func, - TFE_TensorHandle** inputs, int num_inputs, - TF_Status* status) { +TFE_Op* TF_ConcreteFunctionMakeCallOp(TF_ConcreteFunction* func, + TFE_TensorHandle** inputs, int num_inputs, + TF_Status* status) { tensorflow::ImmediateOpPtr call_op; absl::Span input_span( reinterpret_cast( tensorflow::unwrap(inputs)), static_cast(num_inputs)); - status->status = tensorflow::unwrap(func)->GetCallOp(input_span, &call_op); + status->status = tensorflow::unwrap(func)->MakeCallOp(input_span, &call_op); if (!status->status.ok()) { return nullptr; } diff --git a/tensorflow/c/experimental/saved_model/internal/saved_model_api_test.cc b/tensorflow/c/experimental/saved_model/internal/saved_model_api_test.cc index e58b232f9c9..5a4f676ec06 100644 --- a/tensorflow/c/experimental/saved_model/internal/saved_model_api_test.cc +++ b/tensorflow/c/experimental/saved_model/internal/saved_model_api_test.cc @@ -21,15 +21,28 @@ limitations under the License. #include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/c/eager/c_api_test_util.h" +#include "tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h" +#include "tensorflow/c/experimental/saved_model/internal/saved_model_api_type.h" #include "tensorflow/c/experimental/saved_model/public/concrete_function.h" +#include "tensorflow/c/experimental/saved_model/public/signature_def_function.h" +#include "tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h" +#include "tensorflow/c/experimental/saved_model/public/signature_def_param.h" +#include "tensorflow/c/experimental/saved_model/public/signature_def_param_list.h" +#include "tensorflow/c/experimental/saved_model/public/tensor_spec.h" +#include "tensorflow/c/tf_datatype.h" +#include "tensorflow/c/tf_shape.h" #include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_tensor.h" #include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/stringpiece.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/tstring.h" namespace { +using tensorflow::tstring; + constexpr char kTestData[] = "cc/saved_model/testdata"; const char* kServeTag[] = {"serve"}; @@ -107,7 +120,7 @@ TEST_P(CSavedModelAPITest, LoadsSavedModel) { compute_fn_inputs.push_back(input_a); compute_fn_inputs.push_back(input_b); - TFE_Op* compute_fn_op = TF_ConcreteFunctionGetCallOp( + TFE_Op* compute_fn_op = TF_ConcreteFunctionMakeCallOp( compute_fn, compute_fn_inputs.data(), compute_fn_inputs.size(), status); EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); @@ -137,6 +150,380 @@ TEST_P(CSavedModelAPITest, LoadsSavedModel) { TFE_DeleteContext(ctx); } +// This tests running the "serving_default" SignatureDefFunction from the +// VarsAndArithmeticObjectGraph savedmodel. Here's what the signature_defs +// protobuf in the metagraph looks like: +// signature_def: { +// key : "serving_default" +// value: { +// inputs: { +// key : "a" +// value: { +// name : "serving_default_a:0" +// dtype: DT_FLOAT +// tensor_shape: { +// } +// } +// } +// inputs: { +// key : "b" +// value: { +// name : "serving_default_b:0" +// dtype: DT_FLOAT +// tensor_shape: { +// } +// } +// } +// outputs: { +// key : "output_0" +// value: { +// name : "StatefulPartitionedCall:0" +// dtype: DT_FLOAT +// tensor_shape: { +// } +// } +// } +// method_name: "tensorflow/serving/predict" +// } +// } +TEST_P(CSavedModelAPITest, RunsSignatureDefFunction) { + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + bool use_tfrt = GetParam(); + if (use_tfrt) { + TFE_DeleteContextOptions(opts); + TF_DeleteStatus(status); + GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced. + } + + TFE_ContextOptionsSetTfrt(opts, use_tfrt); + + TFE_Context* ctx = TFE_NewContext(opts, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph"); + + TF_SavedModel* saved_model = + TF_LoadSavedModel(model_dir.c_str(), ctx, status); + + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + TF_SignatureDefFunction* serving_default = + TF_GetSavedModelSignatureDefFunction(saved_model, "serving_default", + status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + TF_SignatureDefFunctionMetadata* metadata = + TF_SignatureDefFunctionGetMetadata(serving_default); + + const TF_SignatureDefParamList* args = + TF_SignatureDefFunctionMetadataArgs(metadata); + const TF_SignatureDefParamList* returns = + TF_SignatureDefFunctionMetadataReturns(metadata); + + EXPECT_EQ(TF_SignatureDefParamListSize(args), 2); + const TF_SignatureDefParam* param_a = TF_SignatureDefParamListGet(args, 0); + const TF_TensorSpec* tensor_spec_a = TF_SignatureDefParamTensorSpec(param_a); + const TF_Shape* shape_a = TF_TensorSpecShape(tensor_spec_a); + + // Input "a" is a scalar, float32 tensor + EXPECT_EQ("a", std::string(TF_SignatureDefParamName(param_a))); + EXPECT_EQ(TF_FLOAT, TF_TensorSpecDataType(tensor_spec_a)); + EXPECT_EQ(0, TF_ShapeDims(shape_a)); + + const TF_SignatureDefParam* param_b = TF_SignatureDefParamListGet(args, 1); + const TF_TensorSpec* tensor_spec_b = TF_SignatureDefParamTensorSpec(param_b); + const TF_Shape* shape_b = TF_TensorSpecShape(tensor_spec_b); + + // Input "b" is a scalar, float32 tensor + EXPECT_EQ("b", std::string(TF_SignatureDefParamName(param_b))); + EXPECT_EQ(TF_FLOAT, TF_TensorSpecDataType(tensor_spec_b)); + EXPECT_EQ(0, TF_ShapeDims(shape_b)); + + EXPECT_EQ(TF_SignatureDefParamListSize(returns), 1); + + const TF_SignatureDefParam* param_out = + TF_SignatureDefParamListGet(returns, 0); + const TF_TensorSpec* tensor_spec_out = + TF_SignatureDefParamTensorSpec(param_out); + const TF_Shape* shape_out = TF_TensorSpecShape(tensor_spec_out); + + // Output "output_0" is a scalar, float32 tensor + EXPECT_EQ("output_0", std::string(TF_SignatureDefParamName(param_out))); + EXPECT_EQ(TF_FLOAT, TF_TensorSpecDataType(tensor_spec_out)); + EXPECT_EQ(0, TF_ShapeDims(shape_out)); + + std::vector compute_fn_inputs; + TFE_TensorHandle* input_a = TestScalarTensorHandle(ctx, 2.0f); + TFE_TensorHandle* input_b = TestScalarTensorHandle(ctx, 1.0f); + compute_fn_inputs.push_back(input_a); + compute_fn_inputs.push_back(input_b); + + TFE_Op* serving_default_op = TF_SignatureDefFunctionMakeCallOp( + serving_default, compute_fn_inputs.data(), compute_fn_inputs.size(), + status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + std::vector compute_fn_outputs( + TF_SignatureDefParamListSize(returns)); + int num_retvals = TF_SignatureDefParamListSize(returns); + + TFE_Execute(serving_default_op, compute_fn_outputs.data(), &num_retvals, + status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + TF_Tensor* result = TFE_TensorHandleResolve(compute_fn_outputs[0], status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + EXPECT_EQ(TF_NumDims(result), 0); + float output_value = *static_cast(TF_TensorData(result)); + // (1 + 2) * (2 + 1) / 3 + 5 should be 8 + EXPECT_FLOAT_EQ(output_value, 8.0); + + TF_DeleteTensor(result); + TFE_DeleteTensorHandle(compute_fn_outputs[0]); + TFE_DeleteTensorHandle(input_a); + TFE_DeleteTensorHandle(input_b); + TFE_DeleteOp(serving_default_op); + TF_DeleteSavedModel(saved_model); + TF_DeleteStatus(status); + TFE_DeleteContext(ctx); +} + +TEST_P(CSavedModelAPITest, LoadsAssetSavedModel) { + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + bool use_tfrt = GetParam(); + if (use_tfrt) { + TFE_DeleteContextOptions(opts); + TF_DeleteStatus(status); + GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced. + } + + TFE_ContextOptionsSetTfrt(opts, use_tfrt); + + TFE_Context* ctx = TFE_NewContext(opts, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + std::string model_dir = SavedModelPath("AssetModule"); + + TF_SavedModel* saved_model = + TF_LoadSavedModel(model_dir.c_str(), ctx, status); + + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + TF_ConcreteFunction* read_file_fn = + TF_GetSavedModelConcreteFunction(saved_model, "read_file", status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + TFE_Op* read_file_op = + TF_ConcreteFunctionMakeCallOp(read_file_fn, nullptr, 0, status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + // TODO(bmzhao): Finish API on FunctionMetadata args, so we know how many + // inputs + outputs a function has. + TFE_TensorHandle* read_file_fn_outputs[1] = {nullptr}; + int num_retvals = 1; + + TFE_Execute(read_file_op, &read_file_fn_outputs[0], &num_retvals, status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + TF_Tensor* result = TFE_TensorHandleResolve(read_file_fn_outputs[0], status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + EXPECT_EQ(TF_NumDims(result), 0); + tensorflow::tstring* output_value = + static_cast(TF_TensorData(result)); + std::string file_contents(*output_value); + EXPECT_NE(file_contents.find("TEST ASSET FILE CONTENTS"), std::string::npos); + + TF_DeleteTensor(result); + TFE_DeleteTensorHandle(read_file_fn_outputs[0]); + TFE_DeleteOp(read_file_op); + TF_DeleteSavedModel(saved_model); + TF_DeleteStatus(status); + TFE_DeleteContext(ctx); +} + +TEST_P(CSavedModelAPITest, LoadsStaticHashtableSavedModel) { + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + bool use_tfrt = GetParam(); + if (use_tfrt) { + TFE_DeleteContextOptions(opts); + TF_DeleteStatus(status); + GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced. + } + + TFE_ContextOptionsSetTfrt(opts, use_tfrt); + + TFE_Context* ctx = TFE_NewContext(opts, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + std::string model_dir = SavedModelPath("StaticHashTableModule"); + + TF_SavedModel* saved_model = + TF_LoadSavedModel(model_dir.c_str(), ctx, status); + + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + TF_ConcreteFunction* lookup_fn = + TF_GetSavedModelConcreteFunction(saved_model, "lookup", status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + // Note(bmzhao): Based on static_hashtable_asset.txt, we expect the following + // mapping: + // "foo" -> 0 + // "bar" -> 1 + // "baz" -> 2 + // "wombat" -> 3 + // all other strings -> -1 + + // Call lookup function with input "foo", expecting an output of 0 + { + std::vector lookup_fn_inputs; + TFE_TensorHandle* input_foo = TestScalarTensorHandle(ctx, tstring("foo")); + lookup_fn_inputs.push_back(input_foo); + + TFE_Op* lookup_op = TF_ConcreteFunctionMakeCallOp( + lookup_fn, lookup_fn_inputs.data(), lookup_fn_inputs.size(), status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + // TODO(bmzhao): Finish API on FunctionMetadata args, so we know how many + // inputs + outputs a function has. + TFE_TensorHandle* lookup_fn_outputs[1] = {nullptr}; + int num_retvals = 1; + + TFE_Execute(lookup_op, &lookup_fn_outputs[0], &num_retvals, status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + TF_Tensor* result = TFE_TensorHandleResolve(lookup_fn_outputs[0], status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + EXPECT_EQ(TF_NumDims(result), 0); + tensorflow::int64* output_value = + static_cast(TF_TensorData(result)); + EXPECT_EQ(*output_value, 0); + + TF_DeleteTensor(result); + TFE_DeleteTensorHandle(input_foo); + TFE_DeleteTensorHandle(lookup_fn_outputs[0]); + TFE_DeleteOp(lookup_op); + } + + // Call lookup function with input "baz", expecting an output of 2 + { + std::vector lookup_fn_inputs; + TFE_TensorHandle* input_foo = TestScalarTensorHandle(ctx, tstring("baz")); + lookup_fn_inputs.push_back(input_foo); + + TFE_Op* lookup_op = TF_ConcreteFunctionMakeCallOp( + lookup_fn, lookup_fn_inputs.data(), lookup_fn_inputs.size(), status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + // TODO(bmzhao): Finish API on FunctionMetadata args, so we know how many + // inputs + outputs a function has. + TFE_TensorHandle* lookup_fn_outputs[1] = {nullptr}; + int num_retvals = 1; + + TFE_Execute(lookup_op, &lookup_fn_outputs[0], &num_retvals, status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + TF_Tensor* result = TFE_TensorHandleResolve(lookup_fn_outputs[0], status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + EXPECT_EQ(TF_NumDims(result), 0); + tensorflow::int64* output_value = + static_cast(TF_TensorData(result)); + EXPECT_EQ(*output_value, 2); + + TF_DeleteTensor(result); + TFE_DeleteTensorHandle(input_foo); + TFE_DeleteTensorHandle(lookup_fn_outputs[0]); + TFE_DeleteOp(lookup_op); + } + + // Call lookup function w/input "NON-EXISTENT-KEY", expecting an output of -1 + { + std::vector lookup_fn_inputs; + TFE_TensorHandle* input_foo = + TestScalarTensorHandle(ctx, tstring("NON-EXISTENT-KEY")); + lookup_fn_inputs.push_back(input_foo); + + TFE_Op* lookup_op = TF_ConcreteFunctionMakeCallOp( + lookup_fn, lookup_fn_inputs.data(), lookup_fn_inputs.size(), status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + // TODO(bmzhao): Finish API on FunctionMetadata args, so we know how many + // inputs + outputs a function has. + TFE_TensorHandle* lookup_fn_outputs[1] = {nullptr}; + int num_retvals = 1; + + TFE_Execute(lookup_op, &lookup_fn_outputs[0], &num_retvals, status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + TF_Tensor* result = TFE_TensorHandleResolve(lookup_fn_outputs[0], status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + EXPECT_EQ(TF_NumDims(result), 0); + tensorflow::int64* output_value = + static_cast(TF_TensorData(result)); + EXPECT_EQ(*output_value, -1); + + TF_DeleteTensor(result); + TFE_DeleteTensorHandle(input_foo); + TFE_DeleteTensorHandle(lookup_fn_outputs[0]); + TFE_DeleteOp(lookup_op); + } + + TF_DeleteSavedModel(saved_model); + TF_DeleteStatus(status); + TFE_DeleteContext(ctx); +} + +TEST_P(CSavedModelAPITest, LoadSavedModelWithUninitializedVariable) { + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + bool use_tfrt = GetParam(); + if (use_tfrt) { + TFE_DeleteContextOptions(opts); + TF_DeleteStatus(status); + GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced. + } + + TFE_ContextOptionsSetTfrt(opts, use_tfrt); + + TFE_Context* ctx = TFE_NewContext(opts, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + std::string model_dir = tensorflow::io::JoinPath( + tensorflow::testing::TensorFlowSrcRoot(), + "c/experimental/saved_model/internal/testdata/UninitializedVariable"); + + TF_SavedModel* saved_model = + TF_LoadSavedModel(model_dir.c_str(), ctx, status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + tensorflow::TFSavedModelAPI* model_api = + tensorflow::down_cast( + tensorflow::unwrap(saved_model)); + tensorflow::Variable* uninitialized_variable; + ASSERT_EQ(tensorflow::Status::OK(), + model_api->GetVariable("uninitialized_variable", + &uninitialized_variable)); + ASSERT_EQ(tensorflow::DT_FLOAT, uninitialized_variable->dtype()); + + ASSERT_EQ(tensorflow::Status::OK(), + model_api->GetVariable("sub_module.uninitialized_variable", + &uninitialized_variable)); + ASSERT_EQ(tensorflow::DT_INT64, uninitialized_variable->dtype()); + + TF_DeleteSavedModel(saved_model); + TF_DeleteStatus(status); + TFE_DeleteContext(ctx); +} + INSTANTIATE_TEST_SUITE_P(RuntimeAgnosticSavedModelTests, CSavedModelAPITest, ::testing::Bool()); diff --git a/tensorflow/c/experimental/saved_model/internal/signature_def_function_metadata.cc b/tensorflow/c/experimental/saved_model/internal/signature_def_function_metadata.cc index c5c3616211c..1c547a94155 100644 --- a/tensorflow/c/experimental/saved_model/internal/signature_def_function_metadata.cc +++ b/tensorflow/c/experimental/saved_model/internal/signature_def_function_metadata.cc @@ -16,5 +16,18 @@ limitations under the License. #include "tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h" #include "tensorflow/c/experimental/saved_model/internal/signature_def_function_metadata_type.h" +#include "tensorflow/c/experimental/saved_model/internal/signature_def_param_list_type.h" -// TODO(bmzhao): Add getter functions here as necessary. +extern "C" { + +extern const TF_SignatureDefParamList* TF_SignatureDefFunctionMetadataArgs( + const TF_SignatureDefFunctionMetadata* list) { + return tensorflow::wrap(&tensorflow::unwrap(list)->arguments()); +} + +extern const TF_SignatureDefParamList* TF_SignatureDefFunctionMetadataReturns( + const TF_SignatureDefFunctionMetadata* list) { + return tensorflow::wrap(&tensorflow::unwrap(list)->returns()); +} + +} // end extern "C" diff --git a/tensorflow/c/experimental/saved_model/internal/signature_def_param.cc b/tensorflow/c/experimental/saved_model/internal/signature_def_param.cc new file mode 100644 index 00000000000..ac54f8f5700 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/internal/signature_def_param.cc @@ -0,0 +1,33 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/experimental/saved_model/public/signature_def_param.h" + +#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h" +#include "tensorflow/c/experimental/saved_model/internal/signature_def_param_type.h" +#include "tensorflow/c/experimental/saved_model/internal/tensor_spec_type.h" + +extern "C" { + +extern const char* TF_SignatureDefParamName(const TF_SignatureDefParam* param) { + return tensorflow::unwrap(param)->name().c_str(); +} + +extern const TF_TensorSpec* TF_SignatureDefParamTensorSpec( + const TF_SignatureDefParam* param) { + return tensorflow::wrap(&tensorflow::unwrap(param)->spec()); +} + +} // end extern "C" diff --git a/tensorflow/c/experimental/saved_model/internal/signature_def_param_list.cc b/tensorflow/c/experimental/saved_model/internal/signature_def_param_list.cc new file mode 100644 index 00000000000..328f21635c3 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/internal/signature_def_param_list.cc @@ -0,0 +1,33 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/experimental/saved_model/public/signature_def_param_list.h" + +#include "tensorflow/c/experimental/saved_model/internal/signature_def_param_list_type.h" +#include "tensorflow/c/experimental/saved_model/internal/signature_def_param_type.h" + +extern "C" { + +extern size_t TF_SignatureDefParamListSize( + const TF_SignatureDefParamList* list) { + return tensorflow::unwrap(list)->size(); +} + +extern const TF_SignatureDefParam* TF_SignatureDefParamListGet( + const TF_SignatureDefParamList* list, int i) { + return tensorflow::wrap(&tensorflow::unwrap(list)->at(i)); +} + +} // end extern "C" diff --git a/tensorflow/c/experimental/saved_model/internal/signature_def_param_list_type.h b/tensorflow/c/experimental/saved_model/internal/signature_def_param_list_type.h new file mode 100644 index 00000000000..6f535110cee --- /dev/null +++ b/tensorflow/c/experimental/saved_model/internal/signature_def_param_list_type.h @@ -0,0 +1,33 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_PARAM_LIST_TYPE_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_PARAM_LIST_TYPE_H_ + +#include + +#include "tensorflow/c/conversion_macros.h" +#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h" + +typedef struct TF_SignatureDefParamList TF_SignatureDefParamList; + +namespace tensorflow { + +DEFINE_CONVERSION_FUNCTIONS(std::vector, + TF_SignatureDefParamList) + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_PARAM_LIST_TYPE_H_ diff --git a/tensorflow/c/experimental/saved_model/internal/signature_def_param_type.h b/tensorflow/c/experimental/saved_model/internal/signature_def_param_type.h new file mode 100644 index 00000000000..fd634bcddb0 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/internal/signature_def_param_type.h @@ -0,0 +1,30 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_PARAM_TYPE_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_PARAM_TYPE_H_ + +#include "tensorflow/c/conversion_macros.h" +#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h" + +typedef struct TF_SignatureDefParam TF_SignatureDefParam; + +namespace tensorflow { + +DEFINE_CONVERSION_FUNCTIONS(tensorflow::SignatureDefParam, TF_SignatureDefParam) + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_PARAM_TYPE_H_ diff --git a/tensorflow/c/experimental/saved_model/internal/tensor_spec.cc b/tensorflow/c/experimental/saved_model/internal/tensor_spec.cc new file mode 100644 index 00000000000..f310adef449 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/internal/tensor_spec.cc @@ -0,0 +1,32 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/experimental/saved_model/public/tensor_spec.h" + +#include "tensorflow/c/experimental/saved_model/core/tensor_spec.h" +#include "tensorflow/c/experimental/saved_model/internal/tensor_spec_type.h" +#include "tensorflow/c/tf_shape_internal.h" + +extern "C" { + +TF_DataType TF_TensorSpecDataType(const TF_TensorSpec* spec) { + return static_cast(tensorflow::unwrap(spec)->dtype()); +} + +const TF_Shape* TF_TensorSpecShape(const TF_TensorSpec* spec) { + return tensorflow::wrap(&tensorflow::unwrap(spec)->shape()); +} + +} // end extern "C" diff --git a/tensorflow/c/experimental/saved_model/internal/tensor_spec_type.h b/tensorflow/c/experimental/saved_model/internal/tensor_spec_type.h new file mode 100644 index 00000000000..7284c8a8fb2 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/internal/tensor_spec_type.h @@ -0,0 +1,30 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_TENSOR_SPEC_TYPE_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_TENSOR_SPEC_TYPE_H_ + +#include "tensorflow/c/conversion_macros.h" +#include "tensorflow/c/experimental/saved_model/core/tensor_spec.h" + +typedef struct TF_TensorSpec TF_TensorSpec; + +namespace tensorflow { + +DEFINE_CONVERSION_FUNCTIONS(tensorflow::TensorSpec, TF_TensorSpec) + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_TENSOR_SPEC_TYPE_H_ diff --git a/tensorflow/c/experimental/saved_model/internal/testdata/BUILD b/tensorflow/c/experimental/saved_model/internal/testdata/BUILD new file mode 100644 index 00000000000..f446401ae77 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/internal/testdata/BUILD @@ -0,0 +1,37 @@ +load("//tensorflow:tensorflow.bzl", "filegroup") +load("//tensorflow:tensorflow.bzl", "py_strict_binary") + +package( + licenses = ["notice"], # Apache 2.0 +) + +# Run this binary manually, with an argument pointing to the testdata/ +# directory, to generate the test files used by the filegroup rule below. +py_strict_binary( + name = "gen_saved_models", + srcs = ["gen_saved_models.py"], + python_version = "PY3", + deps = [ + "//tensorflow/python:dtypes", + "//tensorflow/python:platform", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:tensor_spec", + "//tensorflow/python:variables", + "//tensorflow/python/compat:v2_compat", + "//tensorflow/python/eager:def_function", + "//tensorflow/python/module", + "//tensorflow/python/saved_model", + "//tensorflow/python/saved_model:save_options", + ], +) + +# Files generated by the binary above. +filegroup( + name = "saved_models", + srcs = glob([ + "UninitializedVariable/**", + ]), + visibility = [ + "//tensorflow/c/experimental/saved_model/internal:__pkg__", + ], +) diff --git a/tensorflow/c/experimental/saved_model/internal/testdata/UninitializedVariable/saved_model.pb b/tensorflow/c/experimental/saved_model/internal/testdata/UninitializedVariable/saved_model.pb new file mode 100644 index 00000000000..81ce8fe662b Binary files /dev/null and b/tensorflow/c/experimental/saved_model/internal/testdata/UninitializedVariable/saved_model.pb differ diff --git a/tensorflow/c/experimental/saved_model/internal/testdata/UninitializedVariable/variables/variables.data-00000-of-00001 b/tensorflow/c/experimental/saved_model/internal/testdata/UninitializedVariable/variables/variables.data-00000-of-00001 new file mode 100644 index 00000000000..b68ed0f5a6e Binary files /dev/null and b/tensorflow/c/experimental/saved_model/internal/testdata/UninitializedVariable/variables/variables.data-00000-of-00001 differ diff --git a/tensorflow/c/experimental/saved_model/internal/testdata/UninitializedVariable/variables/variables.index b/tensorflow/c/experimental/saved_model/internal/testdata/UninitializedVariable/variables/variables.index new file mode 100644 index 00000000000..ed07d0514c7 Binary files /dev/null and b/tensorflow/c/experimental/saved_model/internal/testdata/UninitializedVariable/variables/variables.index differ diff --git a/tensorflow/c/experimental/saved_model/internal/testdata/gen_saved_models.py b/tensorflow/c/experimental/saved_model/internal/testdata/gen_saved_models.py new file mode 100644 index 00000000000..f2a8bd5a9a4 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/internal/testdata/gen_saved_models.py @@ -0,0 +1,84 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Lint as: python3 +"""Creates saved models used for testing. + +This executable should be run with an argument pointing to the testdata/ folder +in this directory. It will re-generate the saved models that are used for +testing. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import google_type_annotations +from __future__ import print_function + +import os + +from tensorflow.python.compat import v2_compat + +from tensorflow.python.eager import def_function +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_spec +from tensorflow.python.module import module +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import app +from tensorflow.python.saved_model import saved_model + + +def _gen_uninitialized_variable(base_dir): + """Generates a saved model with an uninitialized variable.""" + + class SubModule(module.Module): + """A module with an UninitializedVariable.""" + + def __init__(self): + self.uninitialized_variable = resource_variable_ops.UninitializedVariable( + name="uninitialized_variable", dtype=dtypes.int64) + + class Module(module.Module): + """A module with an UninitializedVariable.""" + + def __init__(self): + super(Module, self).__init__() + self.sub_module = SubModule() + self.initialized_variable = variables.Variable( + 1.0, name="initialized_variable") + # An UninitializedVariable with the same name as the variable in the + # SubModule, but with a different type. + self.uninitialized_variable = resource_variable_ops.UninitializedVariable( + name="uninitialized_variable", dtype=dtypes.float32) + + @def_function.function( + input_signature=[tensor_spec.TensorSpec((), dtypes.float32)]) + def compute(self, value): + return self.initialized_variable + value + + to_save = Module() + saved_model.save( + to_save, export_dir=os.path.join(base_dir, "UninitializedVariable")) + + +def main(args): + if len(args) != 2: + raise app.UsageError("Expected one argument (base_dir).") + _, base_dir = args + _gen_uninitialized_variable(base_dir) + + +if __name__ == "__main__": + v2_compat.enable_v2_behavior() + app.run(main) diff --git a/tensorflow/c/experimental/saved_model/public/BUILD b/tensorflow/c/experimental/saved_model/public/BUILD index d29585ae1ba..4198b0e7ee7 100644 --- a/tensorflow/c/experimental/saved_model/public/BUILD +++ b/tensorflow/c/experimental/saved_model/public/BUILD @@ -8,6 +8,8 @@ # programmatic checks that all "public" headers only include other "public" # headers. +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") + package( # This is intentionally public default_visibility = [ @@ -26,6 +28,9 @@ exports_files( "saved_model_api.h", "signature_def_function.h", "signature_def_function_metadata.h", + "signature_def_param.h", + "signature_def_param_list.h", + "tensor_spec.h", ], visibility = ["//tensorflow/c/experimental/saved_model/internal:__pkg__"], ) @@ -43,6 +48,9 @@ cc_library( ":saved_model_api", ":signature_def_function", ":signature_def_function_metadata", + ":signature_def_param", + ":signature_def_param_list", + ":tensor_spec", ], ) @@ -75,3 +83,18 @@ alias( name = "signature_def_function_metadata", actual = "//tensorflow/c/experimental/saved_model/internal:signature_def_function_metadata", ) + +alias( + name = "signature_def_param", + actual = "//tensorflow/c/experimental/saved_model/internal:signature_def_param", +) + +alias( + name = "signature_def_param_list", + actual = "//tensorflow/c/experimental/saved_model/internal:signature_def_param_list", +) + +alias( + name = "tensor_spec", + actual = "//tensorflow/c/experimental/saved_model/internal:tensor_spec", +) diff --git a/tensorflow/c/experimental/saved_model/public/README.md b/tensorflow/c/experimental/saved_model/public/README.md new file mode 100644 index 00000000000..9b3f392d7a8 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/public/README.md @@ -0,0 +1,28 @@ +# TensorFlow Saved Model C API + +## Small ConcreteFunction Example + +The following example loads a saved model from `"/path/to/model"` and +executes a function `f` taking no arguments and returning one single +value (error checking is omitted for simplicity): + +```c +TF_Status* status = TF_NewStatus(); +TFE_ContextOptions* ctx_options = TFE_NewContextOptions(); +TFE_Context* ctx = TFE_NewContext(ctx_options, status); + +TF_SavedModel* saved_model = TF_LoadSavedModel("/path/to/model", ctx, status); +TF_ConcreteFunction* f = TF_GetSavedModelConcreteFunction(saved_model, "f", status); +TFE_Op* op = TF_ConcreteFunctionMakeCallOp(f, NULL, 0, status); + +TFE_TensorHandle* output; +int nouts = 1; +TFE_Execute(op, &output, &nouts, status); + +TFE_DeleteTensorHandle(output); +TFE_DeleteOp(op); +TFE_DeleteSavedModel(saved_model); +TFE_DeleteContext(ctx); +TFE_DeleteContextOptions(ctx_options); +TF_DeleteStatus(status); +``` diff --git a/tensorflow/c/experimental/saved_model/public/c_saved_model_api.h b/tensorflow/c/experimental/saved_model/public/c_saved_model_api.h index cedb9de66b8..68f1ece2991 100644 --- a/tensorflow/c/experimental/saved_model/public/c_saved_model_api.h +++ b/tensorflow/c/experimental/saved_model/public/c_saved_model_api.h @@ -23,6 +23,9 @@ limitations under the License. #include "tensorflow/c/experimental/saved_model/public/saved_model_api.h" #include "tensorflow/c/experimental/saved_model/public/signature_def_function.h" #include "tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h" +#include "tensorflow/c/experimental/saved_model/public/signature_def_param.h" +#include "tensorflow/c/experimental/saved_model/public/signature_def_param_list.h" +#include "tensorflow/c/experimental/saved_model/public/tensor_spec.h" // IWYU pragma: end_exports #endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_C_SAVED_MODEL_API_H_ diff --git a/tensorflow/c/experimental/saved_model/public/concrete_function.h b/tensorflow/c/experimental/saved_model/public/concrete_function.h index 0fd0f70cf16..ff8a245961a 100644 --- a/tensorflow/c/experimental/saved_model/public/concrete_function.h +++ b/tensorflow/c/experimental/saved_model/public/concrete_function.h @@ -47,7 +47,7 @@ TF_CAPI_EXPORT extern TF_FunctionMetadata* TF_ConcreteFunctionGetMetadata( // high-level API here. A strawman for what this interface could look like: // TF_Value* TF_ExecuteFunction(TFE_Context*, TF_ConcreteFunction*, TF_Value* // inputs, int num_inputs, TF_Status* status); -TF_CAPI_EXPORT extern TFE_Op* TF_ConcreteFunctionGetCallOp( +TF_CAPI_EXPORT extern TFE_Op* TF_ConcreteFunctionMakeCallOp( TF_ConcreteFunction* func, TFE_TensorHandle** inputs, int num_inputs, TF_Status* status); diff --git a/tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h b/tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h index 6f4459732c4..b7a7f67eb19 100644 --- a/tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h +++ b/tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h @@ -16,6 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_FUNCTION_METADATA_H_ #define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_FUNCTION_METADATA_H_ +#include "tensorflow/c/c_api_macros.h" +#include "tensorflow/c/experimental/saved_model/public/signature_def_param_list.h" + #ifdef __cplusplus extern "C" { #endif // __cplusplus @@ -24,6 +27,18 @@ extern "C" { // SavedModel. typedef struct TF_SignatureDefFunctionMetadata TF_SignatureDefFunctionMetadata; +// Retrieves the arguments of the SignatureDefFunction. The caller is not +// responsible for freeing the returned pointer. +TF_CAPI_EXPORT extern const TF_SignatureDefParamList* +TF_SignatureDefFunctionMetadataArgs( + const TF_SignatureDefFunctionMetadata* list); + +// Retrieves the returns of the SignatureDefFunction. The caller is not +// responsible for freeing the returned pointer. +TF_CAPI_EXPORT extern const TF_SignatureDefParamList* +TF_SignatureDefFunctionMetadataReturns( + const TF_SignatureDefFunctionMetadata* list); + #ifdef __cplusplus } // end extern "C" #endif // __cplusplus diff --git a/tensorflow/c/experimental/saved_model/public/signature_def_param.h b/tensorflow/c/experimental/saved_model/public/signature_def_param.h new file mode 100644 index 00000000000..82993d7fedf --- /dev/null +++ b/tensorflow/c/experimental/saved_model/public/signature_def_param.h @@ -0,0 +1,44 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_PARAM_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_PARAM_H_ + +#include "tensorflow/c/c_api_macros.h" +#include "tensorflow/c/experimental/saved_model/public/tensor_spec.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// An opaque type that containing metadata of an input/output of a +// TF_SignatureDefFunction loaded from a SavedModel. +typedef struct TF_SignatureDefParam TF_SignatureDefParam; + +// Returns the name of the given parameter. The caller is not responsible for +// freeing the returned char*. +TF_CAPI_EXPORT extern const char* TF_SignatureDefParamName( + const TF_SignatureDefParam* param); + +// Returns the TensorSpec associated with the given parameter. The caller is +// not reponsible for freeing the returned TF_TensorSpec*. +TF_CAPI_EXPORT extern const TF_TensorSpec* TF_SignatureDefParamTensorSpec( + const TF_SignatureDefParam* param); + +#ifdef __cplusplus +} // end extern "C" +#endif // __cplusplus + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_PARAM_H_ diff --git a/tensorflow/c/experimental/saved_model/public/signature_def_param_list.h b/tensorflow/c/experimental/saved_model/public/signature_def_param_list.h new file mode 100644 index 00000000000..0cb3a0d6d33 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/public/signature_def_param_list.h @@ -0,0 +1,44 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_PARAM_LIST_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_PARAM_LIST_H_ + +#include + +#include "tensorflow/c/c_api_macros.h" +#include "tensorflow/c/experimental/saved_model/public/signature_def_param.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// An opaque type that containing metadata of an input/output of a +// ConcreteFunction loaded from a SavedModel. +typedef struct TF_SignatureDefParamList TF_SignatureDefParamList; + +// Returns the size of `list`. +TF_CAPI_EXPORT extern size_t TF_SignatureDefParamListSize( + const TF_SignatureDefParamList* list); + +// Returns the `i`th TF_SignatureDefParam in the list. +TF_CAPI_EXPORT extern const TF_SignatureDefParam* TF_SignatureDefParamListGet( + const TF_SignatureDefParamList* list, int i); + +#ifdef __cplusplus +} // end extern "C" +#endif // __cplusplus + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_PARAM_LIST_H_ diff --git a/tensorflow/c/experimental/saved_model/public/tensor_spec.h b/tensorflow/c/experimental/saved_model/public/tensor_spec.h new file mode 100644 index 00000000000..82972ef74ef --- /dev/null +++ b/tensorflow/c/experimental/saved_model/public/tensor_spec.h @@ -0,0 +1,46 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSOR_SPEC_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSOR_SPEC_H_ + +#include + +#include "tensorflow/c/c_api_macros.h" +#include "tensorflow/c/tf_datatype.h" +#include "tensorflow/c/tf_shape.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// An opaque type corresponding to TensorSpec +typedef struct TF_TensorSpec TF_TensorSpec; + +// Returns the dtype associated with the TensorSpec. +TF_CAPI_EXPORT extern TF_DataType TF_TensorSpecDataType( + const TF_TensorSpec* spec); + +// Returns the shape associated with the TensorSpec. The returned Shape is not +// owned by the caller. Caller must not call TF_DeleteShape on the returned +// shape. +TF_CAPI_EXPORT extern const TF_Shape* TF_TensorSpecShape( + const TF_TensorSpec* spec); + +#ifdef __cplusplus +} // end extern "C" +#endif // __cplusplus + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSOR_SPEC_H_ diff --git a/tensorflow/c/experimental/stream_executor/BUILD b/tensorflow/c/experimental/stream_executor/BUILD index 7daa311d461..214313c960a 100644 --- a/tensorflow/c/experimental/stream_executor/BUILD +++ b/tensorflow/c/experimental/stream_executor/BUILD @@ -1,6 +1,7 @@ # Description: # StreamExecutor C API. +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load( "//tensorflow:tensorflow.bzl", "tf_cc_test", @@ -10,17 +11,29 @@ package( licenses = ["notice"], # Apache 2.0 ) +cc_library( + name = "stream_executor_hdrs", + hdrs = ["stream_executor.h"], + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/c:c_api_macros", + "//tensorflow/c:tf_status", + ], +) + cc_library( name = "stream_executor", srcs = ["stream_executor.cc"], hdrs = ["stream_executor.h"], - visibility = ["//visibility:public"], + visibility = ["//tensorflow:internal"], deps = [ ":stream_executor_internal", "//tensorflow/c:c_api_macros", "//tensorflow/c:tf_status", "//tensorflow/c:tf_status_helper", "//tensorflow/core:lib", + "//tensorflow/core/platform:regexp", + "//tensorflow/core/platform:strcat", "//tensorflow/stream_executor:executor_cache", "//tensorflow/stream_executor:multi_platform_manager", "//tensorflow/stream_executor:platform", diff --git a/tensorflow/c/experimental/stream_executor/stream_executor.cc b/tensorflow/c/experimental/stream_executor/stream_executor.cc index 0e55ba3d72a..ec2bada791e 100644 --- a/tensorflow/c/experimental/stream_executor/stream_executor.cc +++ b/tensorflow/c/experimental/stream_executor/stream_executor.cc @@ -28,10 +28,14 @@ limitations under the License. #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/regexp.h" #include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/strcat.h" +#include "tensorflow/core/platform/stringpiece.h" #include "tensorflow/stream_executor/executor_cache.h" #include "tensorflow/stream_executor/multi_platform_manager.h" #include "tensorflow/stream_executor/platform.h" +#include "tensorflow/stream_executor/stream.h" #include "tensorflow/stream_executor/stream_executor_internal.h" #include "tensorflow/stream_executor/stream_executor_pimpl.h" #include "tensorflow/stream_executor/timer.h" @@ -39,6 +43,8 @@ limitations under the License. using tensorflow::StatusFromTF_Status; namespace stream_executor { +using tensorflow::StringPiece; + namespace { #define VALIDATE_STRUCT_SIZE(STRUCT_NAME, STRUCT_OBJ, SIZE_VALUE_NAME) \ @@ -58,17 +64,50 @@ namespace { } \ } while (0) +port::Status ValidateDeviceType(StringPiece type) { + // Validate device type. Device type must start with a capital letter and + // consist of capital letters and underscores. Reasoning behind this decision: + // * At the minimum we want to disallow '/' and ':' since + // these characters are used in device spec, for e.g. + // /job:foo/replica:12/device:GPU:1. + // * Underscores seem useful, for e.g. XLA_GPU uses underscores. + // * Allowing lowercase might get confusing. For example, say someone + // registers a new type called "Gpu". It might be confusing for users that + // "Gpu" is not the same device type as "GPU". + // Note that lowercase "cpu" and "gpu" are currently supported only for + // legacy reasons: + // https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/python/framework/device_spec.py;l=46;drc=d3a378f9665d8eee827c74cb9ecbee81e4c288dd + static const LazyRE2 kTfDeviceTypeRegEx = {"[A-Z][A-Z_]*"}; + bool matches = RE2::FullMatch(type, *kTfDeviceTypeRegEx); + if (!matches) { + return port::FailedPreconditionError( + tensorflow::strings::StrCat("Device name/type '", type, "' must match ", + kTfDeviceTypeRegEx->pattern(), ".")); + } + return port::Status::OK(); +} + port::Status ValidateSPPlatform(const SP_Platform& platform) { VALIDATE_STRUCT_SIZE(SP_Platform, platform, SP_PLATFORM_STRUCT_SIZE); VALIDATE_MEMBER(SP_Platform, platform, name); VALIDATE_MEMBER(SP_Platform, platform, type); - VALIDATE_MEMBER(SP_Platform, platform, visible_device_count); - VALIDATE_MEMBER(SP_Platform, platform, create_device); - VALIDATE_MEMBER(SP_Platform, platform, destroy_device); - VALIDATE_MEMBER(SP_Platform, platform, create_stream_executor); - VALIDATE_MEMBER(SP_Platform, platform, destroy_stream_executor); - VALIDATE_MEMBER(SP_Platform, platform, create_timer_fns); - VALIDATE_MEMBER(SP_Platform, platform, destroy_timer_fns); + TF_RETURN_IF_ERROR(ValidateDeviceType(platform.name)); + TF_RETURN_IF_ERROR(ValidateDeviceType(platform.type)); + // `visible_device_count` could be 0 at initialization time. + return port::Status::OK(); +} + +port::Status ValidateSPPlatformFns(const SP_PlatformFns& platform_fns) { + VALIDATE_STRUCT_SIZE(SP_PlatformFns, platform_fns, + SP_PLATFORM_FNS_STRUCT_SIZE); + VALIDATE_MEMBER(SP_PlatformFns, platform_fns, create_device); + VALIDATE_MEMBER(SP_PlatformFns, platform_fns, destroy_device); + VALIDATE_MEMBER(SP_PlatformFns, platform_fns, create_stream_executor); + VALIDATE_MEMBER(SP_PlatformFns, platform_fns, destroy_stream_executor); + VALIDATE_MEMBER(SP_PlatformFns, platform_fns, create_timer_fns); + VALIDATE_MEMBER(SP_PlatformFns, platform_fns, destroy_timer_fns); + VALIDATE_MEMBER(SP_PlatformFns, platform_fns, create_device_fns); + VALIDATE_MEMBER(SP_PlatformFns, platform_fns, destroy_device_fns); return port::Status::OK(); } @@ -97,11 +136,24 @@ port::Status ValidateSPDevice(const SP_Device& device) { return port::Status::OK(); } -port::Status ValidateSPStreamExecutor(const SP_StreamExecutor& se) { +port::Status ValidateSPDeviceFns(const SP_DeviceFns& device_fns) { + VALIDATE_STRUCT_SIZE(SP_DeviceFns, device_fns, SP_DEVICE_FNS_STRUCT_SIZE); + // All other fields could theoretically be zero/null. + return port::Status::OK(); +} + +port::Status ValidateSPStreamExecutor(const SP_StreamExecutor& se, + const SP_Platform& platform) { VALIDATE_STRUCT_SIZE(SP_StreamExecutor, se, SP_STREAM_EXECUTOR_STRUCT_SIZE); VALIDATE_MEMBER(SP_StreamExecutor, se, allocate); VALIDATE_MEMBER(SP_StreamExecutor, se, deallocate); VALIDATE_MEMBER(SP_StreamExecutor, se, get_allocator_stats); + VALIDATE_MEMBER(SP_StreamExecutor, se, host_memory_allocate); + VALIDATE_MEMBER(SP_StreamExecutor, se, host_memory_deallocate); + if (platform.supports_unified_memory) { + VALIDATE_MEMBER(SP_StreamExecutor, se, unified_memory_allocate); + VALIDATE_MEMBER(SP_StreamExecutor, se, unified_memory_deallocate); + } VALIDATE_MEMBER(SP_StreamExecutor, se, device_memory_usage); VALIDATE_MEMBER(SP_StreamExecutor, se, create_stream); VALIDATE_MEMBER(SP_StreamExecutor, se, destroy_stream); @@ -131,9 +183,9 @@ port::Status ValidateSEPlatformRegistrationParams( VALIDATE_STRUCT_SIZE(SE_PlatformRegistrationParams, params, SE_PLATFORM_REGISTRATION_PARAMS_STRUCT_SIZE); VALIDATE_MEMBER(SE_PlatformRegistrationParams, params, destroy_platform); + VALIDATE_MEMBER(SE_PlatformRegistrationParams, params, destroy_platform_fns); return port::Status::OK(); } - #undef VALIDATE_MEMBER struct TFStatusDeleter { @@ -297,19 +349,23 @@ void HostCallbackTrampoline(void* ctx, TF_Status* status) { class CStreamExecutor : public internal::StreamExecutorInterface { public: - explicit CStreamExecutor(SP_Device device, - void (*destroy_device)(SP_Device* const device), + explicit CStreamExecutor(SP_Device device, SP_DeviceFns* device_fns, SP_StreamExecutor* stream_executor, + SP_Platform* platform, SP_PlatformFns* platform_fns, SP_TimerFns* timer_fns, const std::string& name, int visible_device_count) : device_(std::move(device)), - destroy_device_(destroy_device), + device_fns_(device_fns), stream_executor_(stream_executor), + platform_(platform), + platform_fns_(platform_fns), timer_fns_(timer_fns), platform_name_(name), visible_device_count_(visible_device_count) {} - ~CStreamExecutor() override { destroy_device_(&device_); } + ~CStreamExecutor() override { + platform_fns_->destroy_device(platform_, &device_); + } port::Status Init(int device_ordinal, DeviceOptions device_options) override { return port::Status::OK(); @@ -348,6 +404,16 @@ class CStreamExecutor : public internal::StreamExecutorInterface { bool HostMemoryRegister(void* mem, uint64 size) override { return false; } bool HostMemoryUnregister(void* mem) override { return false; } + void* UnifiedMemoryAllocate(uint64 size) override { + CHECK(stream_executor_->unified_memory_allocate); + return stream_executor_->unified_memory_allocate(&device_, size); + } + + void UnifiedMemoryDeallocate(void* mem) override { + CHECK(stream_executor_->unified_memory_deallocate); + stream_executor_->unified_memory_deallocate(&device_, mem); + } + absl::optional GetAllocatorStats() override { SP_AllocatorStats c_stats{SP_ALLOCATORSTATS_STRUCT_SIZE}; TF_Bool has_stats = @@ -597,11 +663,19 @@ class CStreamExecutor : public internal::StreamExecutorInterface { port::Status BlockHostUntilDone(Stream* stream) override { OwnedTFStatus c_status(TF_NewStatus()); + SP_Stream stream_handle = + static_cast(stream->implementation())->Handle(); + + // If `block_host_until_done` is set, use it. + if (stream_executor_->block_host_until_done != nullptr) { + stream_executor_->block_host_until_done(&device_, stream_handle, + c_status.get()); + return StatusFromTF_Status(c_status.get()); + } + // Create and record an event and then wait for it. SP_Event event_handle; stream_executor_->create_event(&device_, &event_handle, c_status.get()); TF_RETURN_IF_ERROR(StatusFromTF_Status(c_status.get())); - SP_Stream stream_handle = - static_cast(stream->implementation())->Handle(); stream_executor_->record_event(&device_, stream_handle, event_handle, c_status.get()); port::Status s = StatusFromTF_Status(c_status.get()); @@ -644,9 +718,35 @@ class CStreamExecutor : public internal::StreamExecutorInterface { // Ownership is transferred to the caller. port::StatusOr> CreateDeviceDescription() const override { - // TODO(annarev): Figure out if we need to support more description fields. + OwnedTFStatus c_status(TF_NewStatus()); + internal::DeviceDescriptionBuilder builder; - builder.set_name(platform_name_); + if (device_.hardware_name != nullptr) { + builder.set_name(device_.hardware_name); + } + if (device_.device_vendor != nullptr) { + builder.set_device_vendor(device_.device_vendor); + } + if (device_.pci_bus_id != nullptr) { + builder.set_pci_bus_id(device_.pci_bus_id); + } + + if (device_fns_->get_numa_node != nullptr) { + int32_t numa_node = device_fns_->get_numa_node(&device_); + if (numa_node >= 0) { + builder.set_numa_node(numa_node); + } + } + + if (device_fns_->get_memory_bandwidth != nullptr) { + int64_t memory_bandwidth = device_fns_->get_memory_bandwidth(&device_); + if (memory_bandwidth >= 0) { + builder.set_memory_bandwidth(memory_bandwidth); + } + } + // TODO(annarev): Add gflops field in DeviceDescription and set it here. + // TODO(annarev): Perhaps add `supports_unified_memory` in + // DeviceDescription. return builder.Build(); } @@ -674,8 +774,10 @@ class CStreamExecutor : public internal::StreamExecutorInterface { private: SP_Device device_; - void (*destroy_device_)(SP_Device* const device); + SP_DeviceFns* device_fns_; SP_StreamExecutor* stream_executor_; + SP_Platform* platform_; + SP_PlatformFns* platform_fns_; SP_TimerFns* timer_fns_; std::string platform_name_; int visible_device_count_; @@ -684,18 +786,26 @@ class CStreamExecutor : public internal::StreamExecutorInterface { CPlatform::CPlatform(SP_Platform platform, void (*destroy_platform)(SP_Platform*), - SP_StreamExecutor stream_executor, SP_TimerFns timer_fns) + SP_PlatformFns platform_fns, + void (*destroy_platform_fns)(SP_PlatformFns*), + SP_DeviceFns device_fns, SP_StreamExecutor stream_executor, + SP_TimerFns timer_fns) : platform_(std::move(platform)), destroy_platform_(destroy_platform), + platform_fns_(std::move(platform_fns)), + destroy_platform_fns_(destroy_platform_fns), + device_fns_(std::move(device_fns)), stream_executor_(std::move(stream_executor)), timer_fns_(std::move(timer_fns)), name_(platform.name) {} CPlatform::~CPlatform() { executor_cache_.DestroyAllExecutors(); - platform_.destroy_stream_executor(&stream_executor_); - platform_.destroy_timer_fns(&timer_fns_); + platform_fns_.destroy_device_fns(&platform_, &device_fns_); + platform_fns_.destroy_stream_executor(&platform_, &stream_executor_); + platform_fns_.destroy_timer_fns(&platform_, &timer_fns_); destroy_platform_(&platform_); + destroy_platform_fns_(&platform_fns_); } port::StatusOr> @@ -735,48 +845,59 @@ port::StatusOr> CPlatform::GetUncachedExecutor( OwnedTFStatus c_status(TF_NewStatus()); // Create Device - platform_.create_device(&device_params, c_status.get()); + platform_fns_.create_device(&platform_, &device_params, c_status.get()); TF_RETURN_IF_ERROR(StatusFromTF_Status(c_status.get())); TF_RETURN_IF_ERROR(ValidateSPDevice(device)); auto executor = absl::make_unique( - std::move(device), platform_.destroy_device, &stream_executor_, - &timer_fns_, name_, platform_.visible_device_count); + std::move(device), &device_fns_, &stream_executor_, &platform_, + &platform_fns_, &timer_fns_, name_, platform_.visible_device_count); auto result = absl::make_unique(this, std::move(executor), config.ordinal); return result; } -port::Status RegisterDevicePlugin(const std::string& dso_path) { - // Step 1: Load plugin +port::Status InitStreamExecutorPlugin(void* dso_handle) { tensorflow::Env* env = tensorflow::Env::Default(); - void* dso_handle; - TF_RETURN_IF_ERROR(env->LoadDynamicLibrary(dso_path.c_str(), &dso_handle)); - // Step 2: Load symbol for `TF_InitPlugin` + // Step 1: Load symbol for `TF_InitPlugin` void* dso_symbol; TF_RETURN_IF_ERROR( env->GetSymbolFromLibrary(dso_handle, "SE_InitPlugin", &dso_symbol)); - // Step 3: Call `TF_InitPlugin` - auto init_fn = reinterpret_cast(dso_symbol); - return RegisterDevicePlugin(init_fn); + // Step 2: Call `TF_InitPlugin` + auto init_fn = reinterpret_cast(dso_symbol); + return InitStreamExecutorPlugin(init_fn); } -port::Status RegisterDevicePlugin(SEPluginInitFn init_fn) { +port::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn) { SE_PlatformRegistrationParams params{ SE_PLATFORM_REGISTRATION_PARAMS_STRUCT_SIZE}; SP_Platform platform{SP_PLATFORM_STRUCT_SIZE}; + SP_PlatformFns platform_fns{SP_PLATFORM_FNS_STRUCT_SIZE}; params.major_version = SE_MAJOR; params.minor_version = SE_MINOR; - params.revision_version = SE_REVISION; + params.patch_version = SE_PATCH; params.platform = &platform; + params.platform_fns = &platform_fns; OwnedTFStatus c_status(TF_NewStatus()); init_fn(¶ms, c_status.get()); TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(c_status.get())); TF_RETURN_IF_ERROR(ValidateSEPlatformRegistrationParams(params)); TF_RETURN_IF_ERROR(ValidateSPPlatform(platform)); + TF_RETURN_IF_ERROR(ValidateSPPlatformFns(platform_fns)); + + // Fill SP_DeviceFns creation params + SE_CreateDeviceFnsParams device_fns_params{ + SE_CREATE_DEVICE_FNS_PARAMS_STRUCT_SIZE}; + SP_DeviceFns device_fns{SP_DEVICE_FNS_STRUCT_SIZE}; + device_fns_params.device_fns = &device_fns; + + // Create StreamExecutor + platform_fns.create_device_fns(&platform, &device_fns_params, c_status.get()); + TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(c_status.get())); + TF_RETURN_IF_ERROR(ValidateSPDeviceFns(device_fns)); // Fill stream executor creation params SE_CreateStreamExecutorParams se_params{ @@ -785,21 +906,26 @@ port::Status RegisterDevicePlugin(SEPluginInitFn init_fn) { se_params.stream_executor = &se; // Create StreamExecutor - platform.create_stream_executor(&se_params, c_status.get()); + platform_fns.create_stream_executor(&platform, &se_params, c_status.get()); TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(c_status.get())); - TF_RETURN_IF_ERROR(ValidateSPStreamExecutor(se)); + TF_RETURN_IF_ERROR(ValidateSPStreamExecutor(se, platform)); SP_TimerFns timer_fns{SP_TIMER_FNS_STRUCT_SIZE}; - platform.create_timer_fns(&timer_fns, c_status.get()); + platform_fns.create_timer_fns(&platform, &timer_fns, c_status.get()); + TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(c_status.get())); + TF_RETURN_IF_ERROR(ValidateSPTimerFns(timer_fns)); + + platform_fns.create_timer_fns(&platform, &timer_fns, c_status.get()); TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(c_status.get())); TF_RETURN_IF_ERROR(ValidateSPTimerFns(timer_fns)); // Register new platform std::string platform_name = std::string(platform.name); std::unique_ptr cplatform( - new stream_executor::CPlatform(std::move(platform), - params.destroy_platform, std::move(se), - std::move(timer_fns))); + new stream_executor::CPlatform( + std::move(platform), params.destroy_platform, std::move(platform_fns), + params.destroy_platform_fns, std::move(device_fns), std::move(se), + std::move(timer_fns))); SE_CHECK_OK(stream_executor::MultiPlatformManager::RegisterPlatform( std::move(cplatform))); diff --git a/tensorflow/c/experimental/stream_executor/stream_executor.h b/tensorflow/c/experimental/stream_executor/stream_executor.h index b3459a29ccc..bec77ef520b 100644 --- a/tensorflow/c/experimental/stream_executor/stream_executor.h +++ b/tensorflow/c/experimental/stream_executor/stream_executor.h @@ -52,10 +52,11 @@ limitations under the License. // params.device = &device; // // /* Plugin code below */ -// constexpr char DEVICE_NAME[] = "MyDevice"; +// constexpr char DEVICE_NAME[] = "MY_DEVICE"; // constexpr char DEVICE_TYPE[] = "GPU"; // -// void create_device(SE_CreateDeviceParams* params, TF_Status* status) { +// void create_device(const SP_Platform* platform, +// SE_CreateDeviceParams* params, TF_Status* status) { // // Custom actions based on TensorFlow's view of SP_Device. // OnTFDeviceView(params->device->struct_size); // params->device = { SP_DEVICE_STRUCT_SIZE }; @@ -64,7 +65,7 @@ limitations under the License. // ... // } // -// void destroy_device(SP_Device* device) { +// void destroy_device(const SP_Platform* platform, SP_Device* device) { // delete_my_device_handle(device->device_handle); // } // @@ -76,14 +77,14 @@ limitations under the License. // params->platform->name = DEVICE_NAME; // params->platform->type = DEVICE_TYPE; // params->platform->visible_device_count = 2; -// params->platform->create_device = create_device; -// params->platform->destroy_device = destroy_device; +// params->platform_fns->create_device = create_device; +// params->platform_fns->destroy_device = destroy_device; // ... // } #define SE_MAJOR 0 #define SE_MINOR 0 -#define SE_REVISION 1 +#define SE_PATCH 1 #ifdef __cplusplus extern "C" { @@ -147,7 +148,7 @@ typedef struct SP_DeviceMemoryBase { } SP_DeviceMemoryBase; #define SP_DEVICE_MEMORY_BASE_STRUCT_SIZE \ - TF_OFFSET_OF_END(SP_DeviceMemoryBase, size) + TF_OFFSET_OF_END(SP_DeviceMemoryBase, payload) typedef struct SP_Device { size_t struct_size; @@ -157,9 +158,30 @@ typedef struct SP_Device { // Device vendor can store handle to their device representation // here. void* device_handle; + + // [Optional] + // Device hardware name. Used for printing. + // Must be null-terminated. + const char* hardware_name; + + // [Optional] + // Device vendor name. Used for printing. + // Must be null-terminated. + const char* device_vendor; + + // [Optional] + // Returns the PCI bus identifier for this device, of the form + // [domain]:[bus]:[device].[function] + // where domain number is usually 0000. + // Example: 0000:00:02.1 + // For more information see: + // https://en.wikipedia.org/wiki/PCI_configuration_space + // https://www.oreilly.com/library/view/linux-device-drivers/0596005903/ch12.html + // Used for printing. Must be null-terminated. + const char* pci_bus_id; } SP_Device; -#define SP_DEVICE_STRUCT_SIZE TF_OFFSET_OF_END(SP_Device, device_handle) +#define SP_DEVICE_STRUCT_SIZE TF_OFFSET_OF_END(SP_Device, pci_bus_id) typedef struct SE_CreateDeviceParams { size_t struct_size; @@ -173,6 +195,42 @@ typedef struct SE_CreateDeviceParams { #define SE_CREATE_DEVICE_PARAMS_STRUCT_SIZE \ TF_OFFSET_OF_END(SE_CreateDeviceParams, device) +typedef struct SP_DeviceFns { + size_t struct_size; + void* ext; // reserved for future use + + // [Optional] + // Returns the NUMA node associated with this device, for use in + // determining socket locality. If the NUMA node could not be determined, -1 + // is returned. + // Negative values are treated as "unset". + int32_t (*get_numa_node)(const SP_Device* device); + + // [Optional] + // Device's memory bandwidth in bytes/sec. (This is for reads/writes to/from + // the device's own memory, not for transfers between the host and device.) + // Negative values are treated as "unset". + int64_t (*get_memory_bandwidth)(const SP_Device* device); + + // [Optional] + // Estimate of average number of floating point operations per second for + // this device * 10e-9. + // Negative values are treated as "unset". + double (*get_gflops)(const SP_Device* device); +} SP_DeviceFns; + +#define SP_DEVICE_FNS_STRUCT_SIZE TF_OFFSET_OF_END(SP_DeviceFns, get_gflops) + +typedef struct SE_CreateDeviceFnsParams { + size_t struct_size; + void* ext; // reserved for future use + + SP_DeviceFns* device_fns; // output, to be filled by plugin +} SE_CreateDeviceFnsParams; + +#define SE_CREATE_DEVICE_FNS_PARAMS_STRUCT_SIZE \ + TF_OFFSET_OF_END(SE_CreateDeviceFnsParams, device_fns) + typedef struct SP_StreamExecutor { size_t struct_size; void* ext; // reserved for future use @@ -198,6 +256,17 @@ typedef struct SP_StreamExecutor { // Deallocates a region of host memory allocated by `host_memory_allocate`. void (*host_memory_deallocate)(const SP_Device* device, void* mem); + // Allocates unified memory space of the given size, if supported. Unified + // memory support should be added by setting `supports_unified_memory` field + // in `SP_Platform`. + void* (*unified_memory_allocate)(const SP_Device* device, uint64_t bytes); + + // Deallocates unified memory space previously allocated with + // `unified_memory_allocate`. Unified + // memory support should be added by setting `supports_unified_memory` field + // in `SP_Platform`. + void (*unified_memory_deallocate)(const SP_Device* device, void* location); + // Fills SP_AllocatorStats with allocator statistics, if it is available. // If it is not available, return false. TF_Bool (*get_allocator_stats)(const SP_Device* device, @@ -309,13 +378,23 @@ typedef struct SP_StreamExecutor { void (*block_host_for_event)(const SP_Device* device, SP_Event event, TF_Status* status); + // [Optional] + // Causes the host code to synchronously wait for operations entrained onto + // stream to complete. Effectively a join on the asynchronous device + // operations enqueued on the stream before this program point. + // If not set, then corresponding functionality will be implemented + // by registering an event on the `stream` and waiting for it using + // `block_host_for_event`. + void (*block_host_until_done)(const SP_Device* device, SP_Stream stream, + TF_Status* status); + // Synchronizes all activity occurring in the StreamExecutor's context (most // likely a whole device). void (*synchronize_all_activity)(const SP_Device* device, TF_Status* status); // Enqueues on a stream a user-specified function to be run on the host. // `callback_arg` should be passed as the first argument to `callback_fn`. - TF_Bool (*host_callback)(SP_Device* device, SP_Stream stream, + TF_Bool (*host_callback)(const SP_Device* device, SP_Stream stream, SE_StatusCallbackFn callback_fn, void* callback_arg); } SP_StreamExecutor; @@ -337,36 +416,70 @@ typedef struct SP_Platform { void* ext; // free-form data set by plugin - // Platform name. Must be null-terminated. + // Platform name (also referred to as subtype), for example MY_DEVICE. + // The name must start with a capital letter and consist of + // capital letters and underscores. + // Must be null-terminated. const char* name; // Device type name, for example GPU. Must be null-terminated. + // The name must start with a capital letter and consist of + // capital letters and underscores. const char* type; // Number of visible devices size_t visible_device_count; + // Whether this platform supports unified memory. + // Unified memory is a single memory address space accessible from any device. + TF_Bool supports_unified_memory; +} SP_Platform; + +#define SP_PLATFORM_STRUCT_SIZE \ + TF_OFFSET_OF_END(SP_Platform, supports_unified_memory) + +typedef struct SP_PlatformFns { + size_t struct_size; + + void* ext; // reserved for future use + // Callbacks for creating/destroying SP_Device. - void (*create_device)(SE_CreateDeviceParams* params, TF_Status* status); + void (*create_device)(const SP_Platform* platform, + SE_CreateDeviceParams* params, TF_Status* status); // Clean up fields inside SP_Device that were allocated // by the plugin. `device` itself should not be deleted here. - void (*destroy_device)(SP_Device* device); + void (*destroy_device)(const SP_Platform* platform, SP_Device* device); + + // Callbacks for creating/destroying SP_DeviceFns. + void (*create_device_fns)(const SP_Platform* platform, + SE_CreateDeviceFnsParams* params, + TF_Status* status); + + // Clean up fields inside SP_DeviceFns that were allocated + // by the plugin. `device_fns` itself should not be deleted here. + void (*destroy_device_fns)(const SP_Platform* platform, + SP_DeviceFns* device_fns); // Callbacks for creating/destroying SP_StreamExecutor. - void (*create_stream_executor)(SE_CreateStreamExecutorParams* params, + void (*create_stream_executor)(const SP_Platform* platform, + SE_CreateStreamExecutorParams* params, TF_Status* status); // Clean up fields inside SP_StreamExecutor that were allocated // by the plugin. `stream_executor` itself should not be deleted here. - void (*destroy_stream_executor)(SP_StreamExecutor* stream_executor); + void (*destroy_stream_executor)(const SP_Platform* platform, + SP_StreamExecutor* stream_executor); // Callbacks for creating/destroying SP_TimerFns. - void (*create_timer_fns)(SP_TimerFns* timer, TF_Status* status); + void (*create_timer_fns)(const SP_Platform* platform, SP_TimerFns* timer, + TF_Status* status); - void (*destroy_timer_fns)(SP_TimerFns* timer_fns); -} SP_Platform; + void (*destroy_timer_fns)(const SP_Platform* platform, + SP_TimerFns* timer_fns); +} SP_PlatformFns; -#define SP_PLATFORM_STRUCT_SIZE TF_OFFSET_OF_END(SP_Platform, destroy_timer_fns) +#define SP_PLATFORM_FNS_STRUCT_SIZE \ + TF_OFFSET_OF_END(SP_PlatformFns, destroy_timer_fns) typedef struct SE_PlatformRegistrationParams { size_t struct_size; @@ -375,16 +488,19 @@ typedef struct SE_PlatformRegistrationParams { // StreamExecutor C API version. int32_t major_version; int32_t minor_version; - int32_t revision_version; + int32_t patch_version; - SP_Platform* platform; // output, set by plugin + SP_Platform* platform; // output, set by plugin + SP_PlatformFns* platform_fns; // output, set by plugin // Clean up fields inside SP_Platform that were allocated // by the plugin. `platform` itself should not be deleted here. void (*destroy_platform)(SP_Platform* platform); // out, set by plugin + void (*destroy_platform_fns)( + SP_PlatformFns* platform_fns); // out, set by plugin } SE_PlatformRegistrationParams; #define SE_PLATFORM_REGISTRATION_PARAMS_STRUCT_SIZE \ - TF_OFFSET_OF_END(SE_PlatformRegistrationParams, destroy_platform) + TF_OFFSET_OF_END(SE_PlatformRegistrationParams, destroy_platform_fns) void SE_InitPlugin(SE_PlatformRegistrationParams* params, TF_Status* status); diff --git a/tensorflow/c/experimental/stream_executor/stream_executor_internal.h b/tensorflow/c/experimental/stream_executor/stream_executor_internal.h index 2285fe85867..52ae4ba77e0 100644 --- a/tensorflow/c/experimental/stream_executor/stream_executor_internal.h +++ b/tensorflow/c/experimental/stream_executor/stream_executor_internal.h @@ -27,20 +27,24 @@ namespace stream_executor { // Plugin initialization function that a device plugin // must define. -typedef void (*SEPluginInitFn)(SE_PlatformRegistrationParams* const, +typedef void (*SEInitPluginFn)(SE_PlatformRegistrationParams* const, TF_Status* const); -// Loads dso and registers StreamExecutor-based pluggable device. -port::Status RegisterDevicePlugin(const std::string& dso_path); +// Registers StreamExecutor platform. +port::Status InitStreamExecutorPlugin(void* dso_handle); -// Allow registering a plugin using a function (used for testing). -port::Status RegisterDevicePlugin(SEPluginInitFn init_fn); +// Allow registering a StreamExecutor plugin using a function (used for +// testing). +port::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn); class CPlatform : public Platform { public: explicit CPlatform(SP_Platform platform, void (*destroy_platform)(SP_Platform*), - SP_StreamExecutor stream_executor, SP_TimerFns timer_fns); + SP_PlatformFns platform_fns, + void (*destroy_platform_fns)(SP_PlatformFns*), + SP_DeviceFns device_fns, SP_StreamExecutor stream_executor, + SP_TimerFns timer_fns); ~CPlatform() override; Id id() const override { return const_cast(&plugin_id_value_); } @@ -69,6 +73,9 @@ class CPlatform : public Platform { private: SP_Platform platform_; void (*destroy_platform_)(SP_Platform*); + SP_PlatformFns platform_fns_; + void (*destroy_platform_fns_)(SP_PlatformFns*); + SP_DeviceFns device_fns_; SP_StreamExecutor stream_executor_; SP_TimerFns timer_fns_; const std::string name_; diff --git a/tensorflow/c/experimental/stream_executor/stream_executor_test.cc b/tensorflow/c/experimental/stream_executor/stream_executor_test.cc index 86fe00fe5ad..56c4ea09052 100644 --- a/tensorflow/c/experimental/stream_executor/stream_executor_test.cc +++ b/tensorflow/c/experimental/stream_executor/stream_executor_test.cc @@ -41,15 +41,19 @@ struct SP_Timer_st { namespace stream_executor { namespace { -constexpr int DEVICE_COUNT = 2; -constexpr char DEVICE_NAME[] = "MyDevice"; -constexpr char DEVICE_TYPE[] = "GPU"; +constexpr int kDeviceCount = 2; +constexpr char kDeviceName[] = "MY_DEVICE"; +constexpr char kDeviceType[] = "GPU"; /*** Create SP_StreamExecutor (with empty functions) ***/ void allocate(const SP_Device* const device, uint64_t size, int64_t memory_space, SP_DeviceMemoryBase* const mem) {} void deallocate(const SP_Device* const device, SP_DeviceMemoryBase* const mem) { } +void* host_memory_allocate(const SP_Device* const device, uint64_t size) { + return nullptr; +} +void host_memory_deallocate(const SP_Device* const device, void* mem) {} TF_Bool get_allocator_stats(const SP_Device* const device, SP_AllocatorStats* const stats) { return true; @@ -104,16 +108,18 @@ void block_host_for_event(const SP_Device* const device, SP_Event event, TF_Status* const status) {} void synchronize_all_activity(const SP_Device* const device, TF_Status* const status) {} -TF_Bool host_callback(SP_Device* const device, SP_Stream stream, +TF_Bool host_callback(const SP_Device* const device, SP_Stream stream, SE_StatusCallbackFn const callback_fn, void* const callback_arg) { return true; } void PopulateDefaultStreamExecutor(SP_StreamExecutor* se) { - se->struct_size = SP_STREAMEXECUTOR_STRUCT_SIZE; + *se = {SP_STREAMEXECUTOR_STRUCT_SIZE}; se->allocate = allocate; se->deallocate = deallocate; + se->host_memory_allocate = host_memory_allocate; + se->host_memory_deallocate = host_memory_deallocate; se->get_allocator_stats = get_allocator_stats; se->device_memory_usage = device_memory_usage; se->create_stream = create_stream; @@ -138,6 +144,10 @@ void PopulateDefaultStreamExecutor(SP_StreamExecutor* se) { se->host_callback = host_callback; } +void PopulateDefaultDeviceFns(SP_DeviceFns* device_fns) { + *device_fns = {SP_DEVICE_FNS_STRUCT_SIZE}; +} + /*** Create SP_TimerFns ***/ uint64_t nanoseconds(SP_Timer timer) { return timer->timer_id; } @@ -146,91 +156,158 @@ void PopulateDefaultTimerFns(SP_TimerFns* timer_fns) { } /*** Create SP_Platform ***/ -void create_timer_fns(SP_TimerFns* timer_fns, TF_Status* status) { +void create_timer_fns(const SP_Platform* platform, SP_TimerFns* timer_fns, + TF_Status* status) { TF_SetStatus(status, TF_OK, ""); PopulateDefaultTimerFns(timer_fns); } -void destroy_timer_fns(SP_TimerFns* timer_fns) {} +void destroy_timer_fns(const SP_Platform* platform, SP_TimerFns* timer_fns) {} -void create_stream_executor(SE_CreateStreamExecutorParams* params, +void create_stream_executor(const SP_Platform* platform, + SE_CreateStreamExecutorParams* params, TF_Status* status) { TF_SetStatus(status, TF_OK, ""); PopulateDefaultStreamExecutor(params->stream_executor); } -void destroy_stream_executor(SP_StreamExecutor* se) {} +void destroy_stream_executor(const SP_Platform* platform, + SP_StreamExecutor* se) {} -void create_device(SE_CreateDeviceParams* params, TF_Status* status) { +void create_device(const SP_Platform* platform, SE_CreateDeviceParams* params, + TF_Status* status) { TF_SetStatus(status, TF_OK, ""); - params->device->struct_size = SP_DEVICE_STRUCT_SIZE; + params->device->struct_size = {SP_DEVICE_STRUCT_SIZE}; } -void destroy_device(SP_Device* device) {} +void destroy_device(const SP_Platform* platform, SP_Device* device) {} -void PopulateDefaultPlatform(SP_Platform* platform) { - platform->struct_size = SP_PLATFORM_STRUCT_SIZE; - platform->name = DEVICE_NAME; - platform->type = DEVICE_TYPE; - platform->visible_device_count = DEVICE_COUNT; - platform->create_device = create_device; - platform->destroy_device = destroy_device; - platform->create_stream_executor = create_stream_executor; - platform->destroy_stream_executor = destroy_stream_executor; - platform->create_timer_fns = create_timer_fns; - platform->destroy_timer_fns = destroy_timer_fns; +void create_device_fns(const SP_Platform* platform, + SE_CreateDeviceFnsParams* params, TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + params->device_fns->struct_size = {SP_DEVICE_FNS_STRUCT_SIZE}; +} +void destroy_device_fns(const SP_Platform* platform, SP_DeviceFns* device_fns) { +} + +void PopulateDefaultPlatform(SP_Platform* platform, + SP_PlatformFns* platform_fns) { + *platform = {SP_PLATFORM_STRUCT_SIZE}; + platform->name = kDeviceName; + platform->type = kDeviceType; + platform->visible_device_count = kDeviceCount; + platform_fns->create_device = create_device; + platform_fns->destroy_device = destroy_device; + platform_fns->create_device_fns = create_device_fns; + platform_fns->destroy_device_fns = destroy_device_fns; + platform_fns->create_stream_executor = create_stream_executor; + platform_fns->destroy_stream_executor = destroy_stream_executor; + platform_fns->create_timer_fns = create_timer_fns; + platform_fns->destroy_timer_fns = destroy_timer_fns; } void destroy_platform(SP_Platform* const platform) {} +void destroy_platform_fns(SP_PlatformFns* const platform_fns) {} /*** Registration tests ***/ TEST(StreamExecutor, SuccessfulRegistration) { auto plugin_init = [](SE_PlatformRegistrationParams* const params, TF_Status* const status) -> void { TF_SetStatus(status, TF_OK, ""); - PopulateDefaultPlatform(params->platform); + PopulateDefaultPlatform(params->platform, params->platform_fns); params->destroy_platform = destroy_platform; + params->destroy_platform_fns = destroy_platform_fns; }; - port::Status status = RegisterDevicePlugin(plugin_init); + port::Status status = InitStreamExecutorPlugin(plugin_init); TF_ASSERT_OK(status); port::StatusOr maybe_platform = - MultiPlatformManager::PlatformWithName("MyDevice"); + MultiPlatformManager::PlatformWithName("MY_DEVICE"); TF_ASSERT_OK(maybe_platform.status()); Platform* platform = maybe_platform.ConsumeValueOrDie(); - ASSERT_EQ(platform->Name(), DEVICE_NAME); - ASSERT_EQ(platform->VisibleDeviceCount(), DEVICE_COUNT); + ASSERT_EQ(platform->Name(), kDeviceName); + ASSERT_EQ(platform->VisibleDeviceCount(), kDeviceCount); port::StatusOr maybe_executor = platform->ExecutorForDevice(0); TF_ASSERT_OK(maybe_executor.status()); - StreamExecutor* executor = maybe_executor.ConsumeValueOrDie(); - ASSERT_EQ(executor->GetDeviceDescription().name(), "MyDevice"); } TEST(StreamExecutor, NameNotSet) { auto plugin_init = [](SE_PlatformRegistrationParams* const params, TF_Status* const status) -> void { TF_SetStatus(status, TF_OK, ""); - PopulateDefaultPlatform(params->platform); + PopulateDefaultPlatform(params->platform, params->platform_fns); params->platform->name = nullptr; params->destroy_platform = destroy_platform; + params->destroy_platform_fns = destroy_platform_fns; }; - port::Status status = RegisterDevicePlugin(plugin_init); + port::Status status = InitStreamExecutorPlugin(plugin_init); ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION); ASSERT_EQ(status.error_message(), "'name' field in SP_Platform must be set."); } +TEST(StreamExecutor, InvalidNameWithSemicolon) { + auto plugin_init = [](SE_PlatformRegistrationParams* const params, + TF_Status* const status) -> void { + TF_SetStatus(status, TF_OK, ""); + PopulateDefaultPlatform(params->platform, params->platform_fns); + params->platform->name = "INVALID:NAME"; + params->destroy_platform = destroy_platform; + params->destroy_platform_fns = destroy_platform_fns; + }; + + port::Status status = InitStreamExecutorPlugin(plugin_init); + ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION); + EXPECT_THAT( + status.error_message(), + testing::ContainsRegex("Device name/type 'INVALID:NAME' must match")); +} + +TEST(StreamExecutor, InvalidNameWithSlash) { + auto plugin_init = [](SE_PlatformRegistrationParams* const params, + TF_Status* const status) -> void { + TF_SetStatus(status, TF_OK, ""); + PopulateDefaultPlatform(params->platform, params->platform_fns); + params->platform->name = "INVALID/"; + params->destroy_platform = destroy_platform; + params->destroy_platform_fns = destroy_platform_fns; + }; + + port::Status status = InitStreamExecutorPlugin(plugin_init); + ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION); + EXPECT_THAT(status.error_message(), + testing::ContainsRegex("Device name/type 'INVALID/' must match")); +} + TEST(StreamExecutor, CreateDeviceNotSet) { auto plugin_init = [](SE_PlatformRegistrationParams* const params, TF_Status* const status) -> void { TF_SetStatus(status, TF_OK, ""); - PopulateDefaultPlatform(params->platform); - params->platform->create_device = nullptr; + PopulateDefaultPlatform(params->platform, params->platform_fns); + params->platform_fns->create_device = nullptr; params->destroy_platform = destroy_platform; + params->destroy_platform_fns = destroy_platform_fns; }; - port::Status status = RegisterDevicePlugin(plugin_init); + port::Status status = InitStreamExecutorPlugin(plugin_init); ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION); ASSERT_EQ(status.error_message(), - "'create_device' field in SP_Platform must be set."); + "'create_device' field in SP_PlatformFns must be set."); +} + +TEST(StreamExecutor, UnifiedMemoryAllocateNotSet) { + auto plugin_init = [](SE_PlatformRegistrationParams* const params, + TF_Status* const status) -> void { + TF_SetStatus(status, TF_OK, ""); + PopulateDefaultPlatform(params->platform, params->platform_fns); + params->platform->supports_unified_memory = true; + params->destroy_platform = destroy_platform; + params->destroy_platform_fns = destroy_platform_fns; + }; + + port::Status status = InitStreamExecutorPlugin(plugin_init); + ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION); + ASSERT_EQ( + status.error_message(), + "'unified_memory_allocate' field in SP_StreamExecutor must be set."); } /*** StreamExecutor behavior tests ***/ @@ -238,7 +315,8 @@ class StreamExecutorTest : public ::testing::Test { protected: StreamExecutorTest() {} void SetUp() override { - PopulateDefaultPlatform(&platform_); + PopulateDefaultPlatform(&platform_, &platform_fns_); + PopulateDefaultDeviceFns(&device_fns_); PopulateDefaultStreamExecutor(&se_); PopulateDefaultTimerFns(&timer_fns_); } @@ -246,8 +324,9 @@ class StreamExecutorTest : public ::testing::Test { StreamExecutor* GetExecutor(int ordinal) { if (!cplatform_) { - cplatform_ = absl::make_unique(platform_, destroy_platform, - se_, timer_fns_); + cplatform_ = absl::make_unique( + platform_, destroy_platform, platform_fns_, destroy_platform_fns, + device_fns_, se_, timer_fns_); } port::StatusOr maybe_executor = cplatform_->ExecutorForDevice(ordinal); @@ -255,6 +334,8 @@ class StreamExecutorTest : public ::testing::Test { return maybe_executor.ConsumeValueOrDie(); } SP_Platform platform_; + SP_PlatformFns platform_fns_; + SP_DeviceFns device_fns_; SP_StreamExecutor se_; SP_TimerFns timer_fns_; std::unique_ptr cplatform_; @@ -264,13 +345,13 @@ TEST_F(StreamExecutorTest, Allocate) { se_.allocate = [](const SP_Device* const device, uint64_t size, int64_t memory_space, SP_DeviceMemoryBase* const mem) { mem->struct_size = SP_DEVICE_MEMORY_BASE_STRUCT_SIZE; - mem->opaque = std::malloc(size); + mem->opaque = malloc(size); mem->size = size; }; se_.deallocate = [](const SP_Device* const device, SP_DeviceMemoryBase* const mem) { EXPECT_EQ(mem->size, 2 * sizeof(int)); - std::free(mem->opaque); + free(mem->opaque); mem->opaque = nullptr; mem->size = 0; }; @@ -287,10 +368,10 @@ TEST_F(StreamExecutorTest, HostMemoryAllocate) { static bool deallocate_called = false; se_.host_memory_allocate = [](const SP_Device* const device, uint64_t size) { allocate_called = true; - return std::malloc(size); + return malloc(size); }; se_.host_memory_deallocate = [](const SP_Device* const device, void* mem) { - std::free(mem); + free(mem); deallocate_called = true; }; StreamExecutor* executor = GetExecutor(0); @@ -303,6 +384,28 @@ TEST_F(StreamExecutorTest, HostMemoryAllocate) { ASSERT_TRUE(deallocate_called); } +TEST_F(StreamExecutorTest, UnifiedMemoryAllocate) { + static bool allocate_called = false; + static bool deallocate_called = false; + se_.unified_memory_allocate = [](const SP_Device* const device, + uint64_t size) { + allocate_called = true; + return malloc(size); + }; + se_.unified_memory_deallocate = [](const SP_Device* const device, void* mem) { + free(mem); + deallocate_called = true; + }; + StreamExecutor* executor = GetExecutor(0); + ASSERT_FALSE(allocate_called); + void* mem = executor->UnifiedMemoryAllocate(8); + ASSERT_NE(mem, nullptr); + ASSERT_TRUE(allocate_called); + ASSERT_FALSE(deallocate_called); + executor->UnifiedMemoryDeallocate(mem); + ASSERT_TRUE(deallocate_called); +} + TEST_F(StreamExecutorTest, GetAllocatorStats) { se_.get_allocator_stats = [](const SP_Device* const device, SP_AllocatorStats* const stat) -> TF_Bool { @@ -745,6 +848,31 @@ TEST_F(StreamExecutorTest, BlockHostForEvent) { ASSERT_TRUE(block_host_for_event_called); } +TEST_F(StreamExecutorTest, BlockHostUntilDone) { + static bool block_host_until_done_called = false; + se_.create_stream = [](const SP_Device* const device, SP_Stream* stream, + TF_Status* const status) { + *stream = new SP_Stream_st(58); + }; + se_.destroy_stream = [](const SP_Device* const device, SP_Stream stream) { + delete stream; + }; + se_.block_host_until_done = [](const SP_Device* const device, + SP_Stream stream, + TF_Status* const status) -> void { + ASSERT_EQ(stream->stream_id, 58); + TF_SetStatus(status, TF_OK, ""); + block_host_until_done_called = true; + }; + + StreamExecutor* executor = GetExecutor(0); + Stream stream(executor); + stream.Init(); + ASSERT_FALSE(block_host_until_done_called); + TF_ASSERT_OK(stream.BlockHostUntilDone()); + ASSERT_TRUE(block_host_until_done_called); +} + TEST_F(StreamExecutorTest, SynchronizeAllActivity) { static bool synchronize_all_called = false; se_.synchronize_all_activity = [](const SP_Device* const device, @@ -760,7 +888,7 @@ TEST_F(StreamExecutorTest, SynchronizeAllActivity) { } TEST_F(StreamExecutorTest, HostCallbackOk) { - se_.host_callback = [](SP_Device* const device, SP_Stream stream, + se_.host_callback = [](const SP_Device* const device, SP_Stream stream, SE_StatusCallbackFn const callback_fn, void* const callback_arg) -> TF_Bool { TF_Status* status = TF_NewStatus(); @@ -780,7 +908,7 @@ TEST_F(StreamExecutorTest, HostCallbackOk) { } TEST_F(StreamExecutorTest, HostCallbackError) { - se_.host_callback = [](SP_Device* const device, SP_Stream stream, + se_.host_callback = [](const SP_Device* const device, SP_Stream stream, SE_StatusCallbackFn const callback_fn, void* const callback_arg) -> TF_Bool { TF_Status* status = TF_NewStatus(); @@ -798,5 +926,59 @@ TEST_F(StreamExecutorTest, HostCallbackError) { stream.ThenDoHostCallbackWithStatus(callback); ASSERT_FALSE(stream.ok()); } + +TEST_F(StreamExecutorTest, DeviceDescription) { + static const char* hardware_name = "TestName"; + static const char* vendor = "TestVendor"; + static const char* pci_bus_id = "TestPCIBusId"; + platform_fns_.create_device = [](const SP_Platform* platform, + SE_CreateDeviceParams* params, + TF_Status* status) { + params->device->hardware_name = hardware_name; + params->device->device_vendor = vendor; + params->device->pci_bus_id = pci_bus_id; + }; + + device_fns_.get_numa_node = [](const SP_Device* device) { return 123; }; + device_fns_.get_memory_bandwidth = [](const SP_Device* device) -> int64_t { + return 54; + }; + device_fns_.get_gflops = [](const SP_Device* device) -> double { return 32; }; + + StreamExecutor* executor = GetExecutor(0); + const DeviceDescription& description = executor->GetDeviceDescription(); + ASSERT_EQ(description.name(), "TestName"); + ASSERT_EQ(description.device_vendor(), "TestVendor"); + ASSERT_EQ(description.pci_bus_id(), "TestPCIBusId"); + ASSERT_EQ(description.numa_node(), 123); + ASSERT_EQ(description.memory_bandwidth(), 54); +} + +TEST_F(StreamExecutorTest, DeviceDescriptionNumaNodeNotSet) { + static const char* hardware_name = "TestName"; + static const char* vendor = "TestVendor"; + static const char* pci_bus_id = "TestPCIBusId"; + platform_fns_.create_device = [](const SP_Platform* platform, + SE_CreateDeviceParams* params, + TF_Status* status) { + params->device->hardware_name = hardware_name; + params->device->device_vendor = vendor; + params->device->pci_bus_id = pci_bus_id; + }; + + device_fns_.get_memory_bandwidth = [](const SP_Device* device) -> int64_t { + return 54; + }; + device_fns_.get_gflops = [](const SP_Device* device) -> double { return 32; }; + + StreamExecutor* executor = GetExecutor(0); + const DeviceDescription& description = executor->GetDeviceDescription(); + ASSERT_EQ(description.name(), "TestName"); + ASSERT_EQ(description.device_vendor(), "TestVendor"); + ASSERT_EQ(description.pci_bus_id(), "TestPCIBusId"); + ASSERT_EQ(description.numa_node(), -1); + ASSERT_EQ(description.memory_bandwidth(), 54); +} + } // namespace } // namespace stream_executor diff --git a/tensorflow/c/kernels/BUILD b/tensorflow/c/kernels/BUILD index 93b82b2396f..d89eda3eb4e 100644 --- a/tensorflow/c/kernels/BUILD +++ b/tensorflow/c/kernels/BUILD @@ -1,9 +1,13 @@ -load( - "//tensorflow:tensorflow.bzl", - "tf_cc_test", - "tf_gen_op_libs", - "tf_kernel_library", -) +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "filegroup") + +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs") + +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "tf_kernel_library") +load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( default_visibility = ["//visibility:public"], @@ -132,6 +136,23 @@ tf_cc_test( ], ) +tf_cc_test( + name = "summary_op_benchmark_test", + size = "small", + srcs = ["summary_op_benchmark_test.cc"], + deps = [ + ":summary_op", + "//tensorflow/c:kernels", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + cc_library( name = "tensor_shape_utils", srcs = ["tensor_shape_utils.cc"], diff --git a/tensorflow/c/kernels/histogram_summary_op.cc b/tensorflow/c/kernels/histogram_summary_op.cc index 5de52703f5d..143a2675a05 100644 --- a/tensorflow/c/kernels/histogram_summary_op.cc +++ b/tensorflow/c/kernels/histogram_summary_op.cc @@ -93,11 +93,13 @@ void HistogramSummaryOp_Compute(void* kernel, TF_OpKernelContext* ctx) { std::ostringstream err; err << "Nan in summary histogram for: " << k->op_node_name; TF_SetStatus(status.get(), TF_INVALID_ARGUMENT, err.str().c_str()); + TF_OpKernelContext_Failure(ctx, status.get()); return; } else if (Eigen::numext::isinf(double_val)) { std::ostringstream err; err << "Infinity in Histogram for: " << k->op_node_name; TF_SetStatus(status.get(), TF_INVALID_ARGUMENT, err.str().c_str()); + TF_OpKernelContext_Failure(ctx, status.get()); return; } histo.Add(double_val); diff --git a/tensorflow/c/kernels/summary_op_benchmark_test.cc b/tensorflow/c/kernels/summary_op_benchmark_test.cc new file mode 100644 index 00000000000..887a86066d3 --- /dev/null +++ b/tensorflow/c/kernels/summary_op_benchmark_test.cc @@ -0,0 +1,71 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" + +namespace tensorflow { +namespace { + +Graph* BM_ScalarSummaryOp(TensorShape shape, std::string tag, float value) { + Graph* g = new Graph(OpRegistry::Global()); + Tensor tags(DT_STRING, shape); + Tensor values(DT_FLOAT, shape); + for (int i = 0; i < tags.NumElements(); ++i) { + tags.flat()(i) = tag; + values.flat()(i) = value; + } + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("dummy"), "ScalarSummary") + .Input(test::graph::Constant(g, tags)) + .Input(test::graph::Constant(g, values)) + .Attr("T", DT_FLOAT) + .Finalize(g, &ret)); + return g; +} + +// Macro used to parse initializer list for tensorshape +#define DIMARGS(...) \ + { __VA_ARGS__ } +// // Random parameters for testing +constexpr char longTagParam[] = "LONGTAG____________________________"; +constexpr float largeValueParam = 2352352.2623433; + +#define BM_ScalarSummaryDev(device, dims, name, tag, value) \ + void BM_ScalarSummary##name##device(int iters) { \ + testing::StopTiming(); \ + TensorShape tensorshape(DIMARGS dims); \ + auto g = BM_ScalarSummaryOp(tensorshape, #tag, value); \ + testing::StartTiming(); \ + test::Benchmark("cpu", g).Run(iters); \ + } \ + BENCHMARK(BM_ScalarSummary##name##device); + +BM_ScalarSummaryDev(Cpu, (5, 10, 100), Base, Tag, 5.2); +// Benchmark for large shapes +BM_ScalarSummaryDev(Cpu, (500, 100, 100), LargeShape, Tag, 5.2); +// Benchmark for large tag tstring +BM_ScalarSummaryDev(Cpu, (5, 10, 100), LongTag, longTagParam, 5.2); +// Benchmark for large values +BM_ScalarSummaryDev(Cpu, (500, 100, 100), LargeValue, Tag, largeValueParam); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc index 3bdaa866ee6..7ec1f4cc951 100644 --- a/tensorflow/c/python_api.cc +++ b/tensorflow/c/python_api.cc @@ -57,43 +57,7 @@ void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device) { void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst, TF_Status* status) { - mutex_lock l(graph->mu); - tensorflow::shape_inference::InferenceContext* ic = - graph->refiner.GetContext(&new_src.oper->node); - - if (ic->num_outputs() <= new_src.index) { - status->status = tensorflow::errors::OutOfRange( - "Cannot update edge. Output index [", new_src.index, - "] is greater than the number of total outputs [", ic->num_outputs(), - "]."); - return; - } - tensorflow::shape_inference::ShapeHandle shape = ic->output(new_src.index); - - tensorflow::shape_inference::InferenceContext* ic_dst = - graph->refiner.GetContext(&dst.oper->node); - if (ic_dst->num_inputs() <= dst.index) { - status->status = tensorflow::errors::OutOfRange( - "Cannot update edge. Input index [", dst.index, - "] is greater than the number of total inputs [", ic_dst->num_inputs(), - "]."); - return; - } - if (!ic_dst->MergeInput(dst.index, shape)) { - status->status = tensorflow::errors::InvalidArgument( - "Cannot update edge, incompatible shapes: ", ic_dst->DebugString(shape), - " and ", ic_dst->DebugString(ic_dst->input(dst.index)), "."); - return; - } - status->status = graph->graph.UpdateEdge(&new_src.oper->node, new_src.index, - &dst.oper->node, dst.index); - - if (TF_GetCode(status) == TF_OK) { - // This modification only updates the destination node for - // the purposes of running this graph in a session. Thus, we don't - // record the source node as being modified. - RecordMutation(graph, *dst.oper, "updating input tensor"); - } + TF_UpdateEdge(graph, new_src, dst, status); } void RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op) { @@ -136,6 +100,7 @@ std::string GetHandleShapeAndType(TF_Graph* graph, TF_Output output) { auto* out_shape_and_type = handle_data.add_shape_and_type(); ic->ShapeHandleToProto(p.shape, out_shape_and_type->mutable_shape()); out_shape_and_type->set_dtype(p.dtype); + out_shape_and_type->set_specialized_type(p.specialized_type); } } string result; @@ -163,7 +128,8 @@ void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto, status->status = ic->MakeShapeFromShapeProto(shape_and_type_proto.shape(), &shape); if (TF_GetCode(status) != TF_OK) return; - shapes_and_types.emplace_back(shape, shape_and_type_proto.dtype()); + shapes_and_types.emplace_back(shape, shape_and_type_proto.dtype(), + shape_and_type_proto.specialized_type()); } ic->set_output_handle_shapes_and_types(output.index, shapes_and_types); } diff --git a/tensorflow/c/tf_shape.cc b/tensorflow/c/tf_shape.cc new file mode 100644 index 00000000000..a715544a13f --- /dev/null +++ b/tensorflow/c/tf_shape.cc @@ -0,0 +1,39 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/tf_shape.h" + +#include + +#include "tensorflow/c/tf_shape_internal.h" +#include "tensorflow/core/framework/tensor_shape.h" + +extern "C" { + +TF_Shape* TF_NewShape() { + return tensorflow::wrap(new tensorflow::PartialTensorShape()); +} + +int TF_ShapeDims(const TF_Shape* shape) { + return tensorflow::unwrap(shape)->dims(); +} + +int64_t TF_ShapeDimSize(const TF_Shape* shape, int d) { + return tensorflow::unwrap(shape)->dim_size(d); +} + +void TF_DeleteShape(TF_Shape* shape) { delete tensorflow::unwrap(shape); } + +} // end extern "C" diff --git a/tensorflow/c/tf_shape.h b/tensorflow/c/tf_shape.h new file mode 100644 index 00000000000..f218d05e274 --- /dev/null +++ b/tensorflow/c/tf_shape.h @@ -0,0 +1,50 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/c/c_api_macros.h" + +#ifndef TENSORFLOW_C_TF_SHAPE_H_ +#define TENSORFLOW_C_TF_SHAPE_H_ + +#ifdef __cplusplus +extern "C" { +#endif + +// An opaque type corresponding to a shape in tensorflow. In the future, +// we may expose the ABI of TF_Shape for performance reasons. +typedef struct TF_Shape TF_Shape; + +// Return a new, unknown rank shape object. The caller is responsible for +// calling TF_DeleteShape to deallocate and destroy the returned shape. +TF_CAPI_EXPORT extern TF_Shape* TF_NewShape(); + +// Returns the rank of `shape`. If `shape` has unknown rank, returns -1. +TF_CAPI_EXPORT extern int TF_ShapeDims(const TF_Shape* shape); + +// Returns the `d`th dimension of `shape`. If `shape` has unknown rank, +// invoking this function is undefined behavior. Returns -1 if dimension is +// unknown. +TF_CAPI_EXPORT extern int64_t TF_ShapeDimSize(const TF_Shape* shape, int d); + +// Deletes `shape`. +TF_CAPI_EXPORT extern void TF_DeleteShape(TF_Shape* shape); + +#ifdef __cplusplus +} /* end extern "C" */ +#endif + +#endif // TENSORFLOW_C_TF_SHAPE_H_ diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/dialect_registration.cc b/tensorflow/c/tf_shape_internal.h similarity index 62% rename from tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/dialect_registration.cc rename to tensorflow/c/tf_shape_internal.h index 9d1c354690a..fe97726460f 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/dialect_registration.cc +++ b/tensorflow/c/tf_shape_internal.h @@ -13,11 +13,18 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" +#ifndef TENSORFLOW_C_TF_SHAPE_INTERNAL_H_ +#define TENSORFLOW_C_TF_SHAPE_INTERNAL_H_ -// Static initialization for *HLO dialects registration. -static mlir::DialectRegistration mhlo_ops; -static mlir::DialectRegistration chlo_ops; -static mlir::DialectRegistration lmhlo_ops; +#include "tensorflow/c/conversion_macros.h" +#include "tensorflow/core/framework/tensor_shape.h" + +typedef struct TF_Shape TF_Shape; + +namespace tensorflow { + +DEFINE_CONVERSION_FUNCTIONS(tensorflow::PartialTensorShape, TF_Shape); + +} + +#endif // TENSORFLOW_C_TF_SHAPE_INTERNAL_H_ diff --git a/tensorflow/c/tf_status_helper.h b/tensorflow/c/tf_status_helper.h index ff8085f1229..a895e608159 100644 --- a/tensorflow/c/tf_status_helper.h +++ b/tensorflow/c/tf_status_helper.h @@ -28,6 +28,14 @@ void Set_TF_Status_from_Status(TF_Status* tf_status, // Returns a "status" from "tf_status". tensorflow::Status StatusFromTF_Status(const TF_Status* tf_status); +namespace internal { +struct TF_StatusDeleter { + void operator()(TF_Status* tf_status) const { TF_DeleteStatus(tf_status); } +}; +} // namespace internal + +using TF_StatusPtr = std::unique_ptr; + } // namespace tensorflow #endif // TENSORFLOW_C_TF_STATUS_HELPER_H_ diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index 8602bfafff8..8f7e447d322 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -2,16 +2,22 @@ # TensorFlow is a computational framework, primarily for use in machine # learning applications. +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load( "//tensorflow:tensorflow.bzl", "cc_library_with_android_deps", "tf_cc_binary", "tf_cc_test", "tf_copts", - "tf_gen_op_wrappers_cc", "transitive_hdrs", ) +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "filegroup") + +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrappers_cc") + package( default_visibility = ["//visibility:public"], licenses = ["notice"], # Apache 2.0 @@ -245,7 +251,6 @@ cc_library_with_android_deps( deps = [ "//tensorflow/core:core_cpu", "//tensorflow/core:lib", - "//tensorflow/core:lib_experimental", "//tensorflow/core:protos_all_cc", ], ) @@ -260,7 +265,6 @@ tf_cc_test( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", - "//tensorflow/core:lib_experimental", "//tensorflow/core:tensorflow", "//tensorflow/core:test", "//tensorflow/core:test_main", diff --git a/tensorflow/cc/client/client_session.h b/tensorflow/cc/client/client_session.h index 3765eaec9bf..a661319b074 100644 --- a/tensorflow/cc/client/client_session.h +++ b/tensorflow/cc/client/client_session.h @@ -64,7 +64,7 @@ class ClientSession { ClientSession(const Scope& scope, const string& target); /// Same as above, but use the empty string ("") as the target specification. - ClientSession(const Scope& scope); + explicit ClientSession(const Scope& scope); /// Create a new session, configuring it with `session_options`. ClientSession(const Scope& scope, const SessionOptions& session_options); diff --git a/tensorflow/cc/experimental/base/public/BUILD b/tensorflow/cc/experimental/base/public/BUILD index 045d4e6cd97..0aaf2238e6a 100644 --- a/tensorflow/cc/experimental/base/public/BUILD +++ b/tensorflow/cc/experimental/base/public/BUILD @@ -8,6 +8,8 @@ # 2. Are std:: types # 3. Wrap an opaque C type +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") + package( # This is intentionally public default_visibility = [ diff --git a/tensorflow/cc/experimental/base/tests/BUILD b/tensorflow/cc/experimental/base/tests/BUILD index f449d618f72..f7f6e77c98f 100644 --- a/tensorflow/cc/experimental/base/tests/BUILD +++ b/tensorflow/cc/experimental/base/tests/BUILD @@ -1,3 +1,5 @@ +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") + # Tests for the C++ header-only base types. load("//tensorflow:tensorflow.bzl", "tf_cc_test") diff --git a/tensorflow/cc/gradients/array_grad.cc b/tensorflow/cc/gradients/array_grad.cc index e9173227aad..480243a29e6 100644 --- a/tensorflow/cc/gradients/array_grad.cc +++ b/tensorflow/cc/gradients/array_grad.cc @@ -15,13 +15,12 @@ limitations under the License. #include +#include "tensorflow/cc/framework/grad_op_registry.h" +#include "tensorflow/cc/framework/gradients.h" #include "tensorflow/cc/ops/array_ops_internal.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/cc/framework/grad_op_registry.h" -#include "tensorflow/cc/framework/gradients.h" - namespace tensorflow { namespace ops { namespace { @@ -90,15 +89,25 @@ Status QuantizeAndDequantizeGrad(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("QuantizeAndDequantize", QuantizeAndDequantizeGrad); -Status QuantizeAndDequantizeV2Grad(const Scope& scope, const Operation& op, - const std::vector& grad_inputs, - std::vector* grad_outputs) { - grad_outputs->push_back(Identity(scope, grad_inputs[0])); - grad_outputs->push_back(NoGradient()); - grad_outputs->push_back(NoGradient()); +Status QuantizeAndDequantizeV4GradHelper(const Scope& scope, + const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + Input input = Shape(scope, op.input(0)); + Input input_min = op.input(1); + Input input_max = op.input(2); + int64 axis; + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "axis", &axis)); + auto qdq_v4_grad = QuantizeAndDequantizeV4Grad( + scope, grad_inputs[0], input, input_min, input_max, + QuantizeAndDequantizeV4Grad::Axis(axis)); + grad_outputs->push_back(qdq_v4_grad.input_backprop); + grad_outputs->push_back(qdq_v4_grad.input_min_backprop); + grad_outputs->push_back(qdq_v4_grad.input_max_backprop); return scope.status(); } -REGISTER_GRADIENT_OP("QuantizeAndDequantizeV2", QuantizeAndDequantizeV2Grad); +REGISTER_GRADIENT_OP("QuantizeAndDequantizeV4", + QuantizeAndDequantizeV4GradHelper); Status QuantizeAndDequantizeV3Grad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, diff --git a/tensorflow/cc/gradients/grad_testutil.h b/tensorflow/cc/gradients/grad_testutil.h index 70c81f1a73a..43d533ad760 100644 --- a/tensorflow/cc/gradients/grad_testutil.h +++ b/tensorflow/cc/gradients/grad_testutil.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_CC_GRADIENTS_GRAD_TESTUTIL_H_ #define TENSORFLOW_CC_GRADIENTS_GRAD_TESTUTIL_H_ +#include + #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/scope.h" diff --git a/tensorflow/cc/ops/const_op.h b/tensorflow/cc/ops/const_op.h index 424a683665f..9c888701b45 100644 --- a/tensorflow/cc/ops/const_op.h +++ b/tensorflow/cc/ops/const_op.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_CC_OPS_CONST_OP_H_ #define TENSORFLOW_CC_OPS_CONST_OP_H_ +#include + #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/scope.h" #include "tensorflow/core/graph/node_builder.h" diff --git a/tensorflow/cc/ops/while_loop.h b/tensorflow/cc/ops/while_loop.h index 727237b5c7a..6dbf1d23dba 100644 --- a/tensorflow/cc/ops/while_loop.h +++ b/tensorflow/cc/ops/while_loop.h @@ -16,6 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_CC_OPS_WHILE_LOOP_H_ #define TENSORFLOW_CC_OPS_WHILE_LOOP_H_ +#include +#include + #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/scope.h" diff --git a/tensorflow/cc/profiler/BUILD b/tensorflow/cc/profiler/BUILD index 057ce7cb993..43240506f8c 100644 --- a/tensorflow/cc/profiler/BUILD +++ b/tensorflow/cc/profiler/BUILD @@ -1,3 +1,4 @@ +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") package( diff --git a/tensorflow/cc/profiler/profiler.h b/tensorflow/cc/profiler/profiler.h index 64edbb5766c..dc60fd5fb37 100644 --- a/tensorflow/cc/profiler/profiler.h +++ b/tensorflow/cc/profiler/profiler.h @@ -16,6 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_CC_PROFILER_PROFILER_H_ #define TENSORFLOW_CC_PROFILER_PROFILER_H_ +#include +#include + #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/profiler/internal/tfprof_stats.h" @@ -56,7 +59,7 @@ namespace tfprof { class Profiler { public: /// `graph` is the model's GraphDef. - Profiler(const GraphDef& graph); + explicit Profiler(const GraphDef& graph); /// Adds tracing information `run_meta` to profiler. A `run_meta` is /// generated by a TensorFlow session run call. `step` is the key diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD index a3ea0c75bc7..056c99eed8e 100644 --- a/tensorflow/cc/saved_model/BUILD +++ b/tensorflow/cc/saved_model/BUILD @@ -1,6 +1,8 @@ # Description: # TensorFlow SavedModel. +load("//tensorflow:tensorflow.bzl", "filegroup") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load( "//tensorflow:tensorflow.bzl", "if_android", @@ -19,10 +21,7 @@ package( licenses = ["notice"], # Apache 2.0 ) -exports_files([ - "LICENSE", - "loader.h", -]) +exports_files(["loader.h"]) cc_library( name = "constants", @@ -43,13 +42,15 @@ cc_library( name = "reader", srcs = ["reader.cc"], hdrs = ["reader.h"], - deps = [":constants"] + if_not_mobile([ + deps = [ + ":constants", + "//tensorflow/core:protos_all_cc", + ] + if_not_mobile([ # TODO(b/111634734): :lib and :protos_all contain dependencies that # cannot be built on mobile platforms. Instead, include the appropriate # tf_lib depending on the build platform. "@com_google_absl//absl/memory:memory", "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", ]), ) @@ -57,7 +58,7 @@ tf_cc_test( name = "reader_test", srcs = ["reader_test.cc"], data = [ - ":saved_model_half_plus_two", + ":saved_model_test_files", ], linkstatic = 1, deps = [ @@ -149,7 +150,7 @@ tf_cc_test( name = "bundle_v2_test", srcs = ["bundle_v2_test.cc"], data = [ - ":saved_model_half_plus_two", + ":saved_model_test_files", ], linkstatic = 1, deps = [ @@ -166,7 +167,7 @@ tf_cc_test( name = "saved_model_bundle_test", srcs = ["saved_model_bundle_test.cc"], data = [ - ":saved_model_half_plus_two", + ":saved_model_test_files", ], linkstatic = 1, deps = [ @@ -188,7 +189,7 @@ tf_cc_test( name = "saved_model_bundle_lite_test", srcs = ["saved_model_bundle_lite_test.cc"], data = [ - ":saved_model_half_plus_two", + ":saved_model_test_files", ], linkstatic = 1, deps = [ @@ -209,11 +210,17 @@ tf_cc_test( py_binary( name = "testdata/generate_saved_models", srcs = ["testdata/generate_saved_models.py"], + data = [ + ":saved_model_asset_data", + ":saved_model_static_hashtable_asset_data", + ], python_version = "PY3", srcs_version = "PY3", deps = [ + "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", + "//tensorflow/python:lookup_ops", "//tensorflow/python:tensor_spec", "//tensorflow/python:variables", "//tensorflow/python/compat:v2_compat", @@ -221,24 +228,47 @@ py_binary( "//tensorflow/python/module", "//tensorflow/python/saved_model", "//tensorflow/python/saved_model:save_options", + "//tensorflow/python/training/tracking", "@absl_py//absl:app", ], ) # TODO(b/32673259): add a test to continuously validate these files. filegroup( - name = "saved_model_half_plus_two", + name = "saved_model_test_files", srcs = glob([ + "testdata/AssetModule/**", "testdata/half_plus_two_pbtxt/**", "testdata/half_plus_two_main_op/**", "testdata/half_plus_two/**", "testdata/half_plus_two_v2/**", "testdata/x_plus_y_v2_debuginfo/**", "testdata/CyclicModule/**", + "testdata/StaticHashTableModule/**", "testdata/VarsAndArithmeticObjectGraph/**", + "testdata/fuzz_generated/**", ]), ) +alias( + name = "saved_model_half_plus_two", + actual = ":saved_model_test_files", +) + +filegroup( + name = "saved_model_asset_data", + srcs = [ + "testdata/test_asset.txt", + ], +) + +filegroup( + name = "saved_model_static_hashtable_asset_data", + srcs = [ + "testdata/static_hashtable_asset.txt", + ], +) + exports_files( glob([ "testdata/half_plus_two_pbtxt/**", @@ -248,5 +278,6 @@ exports_files( "testdata/x_plus_y_v2_debuginfo/**", "testdata/CyclicModule/**", "testdata/VarsAndArithmeticObjectGraph/**", + "testdata/fuzz_generated/**", ]), ) diff --git a/tensorflow/cc/saved_model/experimental/public/BUILD b/tensorflow/cc/saved_model/experimental/public/BUILD index 9640848ebf5..a0f8204c937 100644 --- a/tensorflow/cc/saved_model/experimental/public/BUILD +++ b/tensorflow/cc/saved_model/experimental/public/BUILD @@ -1,6 +1,8 @@ # Experimental C++ SavedModel Header Only APIs. See RFC # https://github.com/tensorflow/community/pull/207 +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") + package( # This is intentionally public default_visibility = [ diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc index ecefe7d0406..70d080a682f 100644 --- a/tensorflow/cc/saved_model/loader.cc +++ b/tensorflow/cc/saved_model/loader.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/cc/saved_model/loader_util.h" #include "tensorflow/cc/saved_model/reader.h" #include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/lib/io/path.h" @@ -73,26 +74,41 @@ uint64 GetLatencyMicroseconds(const uint64 start_microseconds) { // Ensure that constant tensors loaded from the saved model have valid shape. // Also ensure that constant nodes have a value assigned to them. // TODO(b/154763635): this is temporary and will be replaced with a better audit +static Status ValidateNode(const NodeDef& node) { + const auto node_iterator = node.attr().find("value"); + if (node_iterator != node.attr().end()) { + AttrValue node_value = node_iterator->second; + if (node_value.has_tensor()) { + const PartialTensorShape node_shape(node_value.tensor().tensor_shape()); + if (node_shape.num_elements() < 0) { + return errors::FailedPrecondition( + "Saved model contains node \"", node.name(), "\" (op \"", node.op(), + "\") which initializes from a tensor with ", + node_shape.num_elements(), " elements"); + } + } + } else if (node.op() == "Const") { + return errors::FailedPrecondition( + "Saved model contains node \"", node.name(), + "\" which is a constant tensor but no value has been provided"); + } + return Status::OK(); +} + static Status ValidateSavedTensors(const GraphDef& graph_def) { for (const auto& node : graph_def.node()) { - const auto node_iterator = node.attr().find("value"); - if (node_iterator != node.attr().end()) { - AttrValue node_value = node_iterator->second; - if (node_value.has_tensor()) { - const PartialTensorShape node_shape(node_value.tensor().tensor_shape()); - if (node_shape.num_elements() < 0) { - return errors::FailedPrecondition( - "Saved model contains node \"", node.name(), "\" (op \"", - node.op(), "\") which initializes from a tensor with ", - node_shape.num_elements(), " elements"); - } + TF_RETURN_IF_ERROR(ValidateNode(node)); + } + + if (graph_def.has_library()) { + const FunctionDefLibrary& library = graph_def.library(); + for (const auto& function : library.function()) { + for (const auto& node : function.node_def()) { + TF_RETURN_IF_ERROR(ValidateNode(node)); } - } else if (node.op() == "Const") { - return errors::FailedPrecondition( - "Saved model contains node \"", node.name(), - "\" which is a constant tensor but no value has been provided"); } } + return Status::OK(); } diff --git a/tensorflow/cc/saved_model/saved_model_bundle_test.cc b/tensorflow/cc/saved_model/saved_model_bundle_test.cc index 31f676920aa..127176002b9 100644 --- a/tensorflow/cc/saved_model/saved_model_bundle_test.cc +++ b/tensorflow/cc/saved_model/saved_model_bundle_test.cc @@ -45,6 +45,8 @@ constexpr char kTestFuzzGeneratedNegativeShape[] = "cc/saved_model/testdata/fuzz_generated/negative_shape"; constexpr char kTestFuzzGeneratedConstWithNoValue[] = "cc/saved_model/testdata/fuzz_generated/const_with_no_value"; +constexpr char kTestFuzzGeneratedBadNodeAttr[] = + "cc/saved_model/testdata/fuzz_generated/bad_node_attr"; class LoaderTest : public ::testing::Test { protected: @@ -308,6 +310,9 @@ TEST_F(LoaderTest, NegativeShapeDimension) { Status st = LoadSavedModel(session_options, run_options, export_dir, {kSavedModelTagServe}, &bundle); EXPECT_FALSE(st.ok()); + EXPECT_NE( + st.error_message().find("initializes from a tensor with -1 elements"), + std::string::npos); } TEST_F(LoaderTest, ConstNoValue) { @@ -320,6 +325,24 @@ TEST_F(LoaderTest, ConstNoValue) { Status st = LoadSavedModel(session_options, run_options, export_dir, {kSavedModelTagServe}, &bundle); EXPECT_FALSE(st.ok()); + EXPECT_NE( + st.error_message().find("constant tensor but no value has been provided"), + std::string::npos); +} + +TEST_F(LoaderTest, BadNodeAttr) { + SavedModelBundle bundle; + RunOptions run_options; + SessionOptions session_options; + + const string export_dir = + io::JoinPath(testing::TensorFlowSrcRoot(), kTestFuzzGeneratedBadNodeAttr); + Status st = LoadSavedModel(session_options, run_options, export_dir, + {kSavedModelTagServe}, &bundle); + EXPECT_FALSE(st.ok()); + EXPECT_NE( + st.error_message().find("constant tensor but no value has been provided"), + std::string::npos); } } // namespace diff --git a/tensorflow/cc/saved_model/testdata/AssetModule/assets/test_asset.txt b/tensorflow/cc/saved_model/testdata/AssetModule/assets/test_asset.txt new file mode 100644 index 00000000000..40d69b1aac4 --- /dev/null +++ b/tensorflow/cc/saved_model/testdata/AssetModule/assets/test_asset.txt @@ -0,0 +1 @@ +TEST ASSET FILE CONTENTS diff --git a/tensorflow/cc/saved_model/testdata/AssetModule/saved_model.pb b/tensorflow/cc/saved_model/testdata/AssetModule/saved_model.pb new file mode 100644 index 00000000000..4bf99e03c22 Binary files /dev/null and b/tensorflow/cc/saved_model/testdata/AssetModule/saved_model.pb differ diff --git a/tensorflow/cc/saved_model/testdata/AssetModule/variables/variables.data-00000-of-00001 b/tensorflow/cc/saved_model/testdata/AssetModule/variables/variables.data-00000-of-00001 new file mode 100644 index 00000000000..4105bb4c15e Binary files /dev/null and b/tensorflow/cc/saved_model/testdata/AssetModule/variables/variables.data-00000-of-00001 differ diff --git a/tensorflow/cc/saved_model/testdata/AssetModule/variables/variables.index b/tensorflow/cc/saved_model/testdata/AssetModule/variables/variables.index new file mode 100644 index 00000000000..3d903ca79a2 Binary files /dev/null and b/tensorflow/cc/saved_model/testdata/AssetModule/variables/variables.index differ diff --git a/tensorflow/cc/saved_model/testdata/StaticHashTableModule/assets/static_hashtable_asset.txt b/tensorflow/cc/saved_model/testdata/StaticHashTableModule/assets/static_hashtable_asset.txt new file mode 100644 index 00000000000..e79f591665f --- /dev/null +++ b/tensorflow/cc/saved_model/testdata/StaticHashTableModule/assets/static_hashtable_asset.txt @@ -0,0 +1,4 @@ +foo +bar +baz +wombat diff --git a/tensorflow/cc/saved_model/testdata/StaticHashTableModule/saved_model.pb b/tensorflow/cc/saved_model/testdata/StaticHashTableModule/saved_model.pb new file mode 100644 index 00000000000..04e8ba62bdb Binary files /dev/null and b/tensorflow/cc/saved_model/testdata/StaticHashTableModule/saved_model.pb differ diff --git a/tensorflow/cc/saved_model/testdata/StaticHashTableModule/variables/variables.data-00000-of-00001 b/tensorflow/cc/saved_model/testdata/StaticHashTableModule/variables/variables.data-00000-of-00001 new file mode 100644 index 00000000000..f6d62d9a51c Binary files /dev/null and b/tensorflow/cc/saved_model/testdata/StaticHashTableModule/variables/variables.data-00000-of-00001 differ diff --git a/tensorflow/cc/saved_model/testdata/StaticHashTableModule/variables/variables.index b/tensorflow/cc/saved_model/testdata/StaticHashTableModule/variables/variables.index new file mode 100644 index 00000000000..df6c85e5783 Binary files /dev/null and b/tensorflow/cc/saved_model/testdata/StaticHashTableModule/variables/variables.index differ diff --git a/tensorflow/examples/android/__init__.py b/tensorflow/cc/saved_model/testdata/fuzz_generated/bad_node_attr/assets/empty similarity index 100% rename from tensorflow/examples/android/__init__.py rename to tensorflow/cc/saved_model/testdata/fuzz_generated/bad_node_attr/assets/empty diff --git a/tensorflow/cc/saved_model/testdata/fuzz_generated/bad_node_attr/saved_model.pb b/tensorflow/cc/saved_model/testdata/fuzz_generated/bad_node_attr/saved_model.pb new file mode 100644 index 00000000000..0b33dbe7352 Binary files /dev/null and b/tensorflow/cc/saved_model/testdata/fuzz_generated/bad_node_attr/saved_model.pb differ diff --git a/tensorflow/cc/saved_model/testdata/fuzz_generated/bad_node_attr/variables/variables.data-00000-of-00001 b/tensorflow/cc/saved_model/testdata/fuzz_generated/bad_node_attr/variables/variables.data-00000-of-00001 new file mode 100644 index 00000000000..3fd3ba2223d Binary files /dev/null and b/tensorflow/cc/saved_model/testdata/fuzz_generated/bad_node_attr/variables/variables.data-00000-of-00001 differ diff --git a/tensorflow/cc/saved_model/testdata/fuzz_generated/bad_node_attr/variables/variables.index b/tensorflow/cc/saved_model/testdata/fuzz_generated/bad_node_attr/variables/variables.index new file mode 100644 index 00000000000..7357e8d57ed Binary files /dev/null and b/tensorflow/cc/saved_model/testdata/fuzz_generated/bad_node_attr/variables/variables.index differ diff --git a/tensorflow/examples/android/jni/__init__.py b/tensorflow/cc/saved_model/testdata/fuzz_generated/const_with_no_value/assets/empty similarity index 100% rename from tensorflow/examples/android/jni/__init__.py rename to tensorflow/cc/saved_model/testdata/fuzz_generated/const_with_no_value/assets/empty diff --git a/tensorflow/cc/saved_model/testdata/fuzz_generated/const_with_no_value b/tensorflow/cc/saved_model/testdata/fuzz_generated/const_with_no_value/saved_model.pb similarity index 100% rename from tensorflow/cc/saved_model/testdata/fuzz_generated/const_with_no_value rename to tensorflow/cc/saved_model/testdata/fuzz_generated/const_with_no_value/saved_model.pb diff --git a/tensorflow/cc/saved_model/testdata/fuzz_generated/const_with_no_value/variables/variables.data-00000-of-00001 b/tensorflow/cc/saved_model/testdata/fuzz_generated/const_with_no_value/variables/variables.data-00000-of-00001 new file mode 100644 index 00000000000..3fd3ba2223d Binary files /dev/null and b/tensorflow/cc/saved_model/testdata/fuzz_generated/const_with_no_value/variables/variables.data-00000-of-00001 differ diff --git a/tensorflow/cc/saved_model/testdata/fuzz_generated/const_with_no_value/variables/variables.index b/tensorflow/cc/saved_model/testdata/fuzz_generated/const_with_no_value/variables/variables.index new file mode 100644 index 00000000000..7357e8d57ed Binary files /dev/null and b/tensorflow/cc/saved_model/testdata/fuzz_generated/const_with_no_value/variables/variables.index differ diff --git a/tensorflow/examples/tutorials/__init__.py b/tensorflow/cc/saved_model/testdata/fuzz_generated/negative_shape/assets/empty similarity index 100% rename from tensorflow/examples/tutorials/__init__.py rename to tensorflow/cc/saved_model/testdata/fuzz_generated/negative_shape/assets/empty diff --git a/tensorflow/cc/saved_model/testdata/fuzz_generated/negative_shape b/tensorflow/cc/saved_model/testdata/fuzz_generated/negative_shape/saved_model.pb similarity index 100% rename from tensorflow/cc/saved_model/testdata/fuzz_generated/negative_shape rename to tensorflow/cc/saved_model/testdata/fuzz_generated/negative_shape/saved_model.pb diff --git a/tensorflow/cc/saved_model/testdata/fuzz_generated/negative_shape/variables/variables.data-00000-of-00001 b/tensorflow/cc/saved_model/testdata/fuzz_generated/negative_shape/variables/variables.data-00000-of-00001 new file mode 100644 index 00000000000..3fd3ba2223d Binary files /dev/null and b/tensorflow/cc/saved_model/testdata/fuzz_generated/negative_shape/variables/variables.data-00000-of-00001 differ diff --git a/tensorflow/cc/saved_model/testdata/fuzz_generated/negative_shape/variables/variables.index b/tensorflow/cc/saved_model/testdata/fuzz_generated/negative_shape/variables/variables.index new file mode 100644 index 00000000000..7357e8d57ed Binary files /dev/null and b/tensorflow/cc/saved_model/testdata/fuzz_generated/negative_shape/variables/variables.index differ diff --git a/tensorflow/cc/saved_model/testdata/generate_saved_models.py b/tensorflow/cc/saved_model/testdata/generate_saved_models.py index 5f39ae0651d..2b64cf52096 100644 --- a/tensorflow/cc/saved_model/testdata/generate_saved_models.py +++ b/tensorflow/cc/saved_model/testdata/generate_saved_models.py @@ -29,9 +29,13 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_spec from tensorflow.python.module import module +from tensorflow.python.ops import io_ops +from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import variables +from tensorflow.python.platform import test from tensorflow.python.saved_model import save_options from tensorflow.python.saved_model import saved_model +from tensorflow.python.training.tracking import tracking class VarsAndArithmeticObjectGraph(module.Module): @@ -68,9 +72,42 @@ class CyclicModule(module.Module): self.child = ReferencesParent(self) +class AssetModule(module.Module): + + def __init__(self): + self.asset = tracking.Asset( + test.test_src_dir_path("cc/saved_model/testdata/test_asset.txt")) + + @def_function.function(input_signature=[]) + def read_file(self): + return io_ops.read_file(self.asset) + + +class StaticHashTableModule(module.Module): + """A module with an Asset, StaticHashTable, and a lookup function.""" + + def __init__(self): + self.asset = tracking.Asset( + test.test_src_dir_path( + "cc/saved_model/testdata/static_hashtable_asset.txt")) + self.table = lookup_ops.StaticHashTable( + lookup_ops.TextFileInitializer(self.asset, dtypes.string, + lookup_ops.TextFileIndex.WHOLE_LINE, + dtypes.int64, + lookup_ops.TextFileIndex.LINE_NUMBER), + -1) + + @def_function.function( + input_signature=[tensor_spec.TensorSpec(shape=None, dtype=dtypes.string)]) + def lookup(self, word): + return self.table.lookup(word) + + MODULE_CTORS = { "VarsAndArithmeticObjectGraph": VarsAndArithmeticObjectGraph, "CyclicModule": CyclicModule, + "AssetModule": AssetModule, + "StaticHashTableModule": StaticHashTableModule, } diff --git a/tensorflow/cc/saved_model/testdata/static_hashtable_asset.txt b/tensorflow/cc/saved_model/testdata/static_hashtable_asset.txt new file mode 100644 index 00000000000..e79f591665f --- /dev/null +++ b/tensorflow/cc/saved_model/testdata/static_hashtable_asset.txt @@ -0,0 +1,4 @@ +foo +bar +baz +wombat diff --git a/tensorflow/cc/saved_model/testdata/test_asset.txt b/tensorflow/cc/saved_model/testdata/test_asset.txt new file mode 100644 index 00000000000..40d69b1aac4 --- /dev/null +++ b/tensorflow/cc/saved_model/testdata/test_asset.txt @@ -0,0 +1 @@ +TEST ASSET FILE CONTENTS diff --git a/tensorflow/cc/tools/BUILD b/tensorflow/cc/tools/BUILD index a192c4bdb18..e8e128f9a16 100644 --- a/tensorflow/cc/tools/BUILD +++ b/tensorflow/cc/tools/BUILD @@ -1,6 +1,7 @@ # Description: # TensorFlow cc tools. +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load( "//tensorflow:tensorflow.bzl", "tf_cc_test", @@ -11,8 +12,6 @@ package( licenses = ["notice"], # Apache 2.0 ) -exports_files(["LICENSE"]) - cc_library( name = "freeze_saved_model", srcs = ["freeze_saved_model.cc"], diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index ff255dd9cc1..4a41caf1d40 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -1,3 +1,4 @@ +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library") load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test") load("//tensorflow/core/platform:build_config.bzl", "if_llvm_aarch64_available", "if_llvm_system_z_available") @@ -74,7 +75,7 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//llvm:Target", "@llvm-project//llvm:X86CodeGen", # fixdeps: keep - "//tensorflow/core:regexp_internal", + "//tensorflow/core/platform:regexp", ] + if_llvm_system_z_available([ "@llvm-project//llvm:SystemZCodeGen", # fixdeps: keep ]) + if_llvm_aarch64_available([ diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD index 5f6b3dc7101..06745de647b 100644 --- a/tensorflow/compiler/aot/tests/BUILD +++ b/tensorflow/compiler/aot/tests/BUILD @@ -1,3 +1,8 @@ +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "filegroup") + +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "genrule") load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") @@ -331,9 +336,9 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/service:hlo_profile_printer", "//tensorflow/core:lib", - "//tensorflow/core:regexp_internal", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core/platform:regexp", "//third_party/eigen3", "@com_google_absl//absl/strings", ], @@ -554,9 +559,9 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/service:hlo_profile_printer", "//tensorflow/core:lib", - "//tensorflow/core:regexp_internal", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core/platform:regexp", "//third_party/eigen3", "@com_google_absl//absl/strings", ], diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index 29f37bf7498..742cb308b3c 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -127,7 +127,7 @@ def tf_library( "$(location " + tfcompile_tool + ")" + " --config=$(location " + config + ")" + " --dump_fetch_nodes > $@"), - tools = [tfcompile_tool], + exec_tools = [tfcompile_tool], # Run tfcompile on the build host, rather than forge, since it's # typically way faster on the local machine. local = 1, @@ -162,7 +162,7 @@ def tf_library( "//tensorflow/python/tools:freeze_graph)" + freeze_args ), - tools = ["//tensorflow/python/tools:freeze_graph"], + exec_tools = ["//tensorflow/python/tools:freeze_graph"], tags = tags, ) tfcompile_graph = freeze_file @@ -242,7 +242,7 @@ def tf_library( " --out_function_object=$(@D)/" + function_object_file + " " + flags + " " + profiling_flag + " " + mlir_flag + " " + traceme_flag ), - tools = [tfcompile_tool], + exec_tools = [tfcompile_tool], visibility = visibility, testonly = testonly, # Run tfcompile on the build host since it's typically faster on the @@ -281,7 +281,7 @@ def tf_library( " --out_session_module=$(@D)/" + session_module_pb + " " + flags ), - tools = [tfcompile_tool], + exec_tools = [tfcompile_tool], visibility = visibility, testonly = testonly, local = 1, diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 35c6a8b0357..deb3396d89c 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -1,5 +1,19 @@ +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") + +# buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "cc_header_only_library", "if_mlir", "tf_cc_test") + +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "if_libtpu", "tf_copts") load("//tensorflow/stream_executor:build_defs.bzl", "if_cuda_or_rocm") + +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "cc_header_only_library") + +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "filegroup") + +# buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library", "tf_jit_compilation_passes_extra_deps") load("//tensorflow/core/platform:build_config.bzl", "tf_additional_all_protos", "tf_proto_library") load("//tensorflow/core/platform:build_config_root.bzl", "tf_cuda_tests_tags") @@ -15,6 +29,7 @@ package_group( "//tensorflow/compiler/tf2xla:internal", ], packages = [ + "//tensorflow/c/...", "//tensorflow/compiler/tests/...", "//tensorflow/python/...", ], @@ -65,8 +80,10 @@ cc_library( "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", - "//tensorflow/compiler/xla/service:cpu_plugin", - ], + ] + if_libtpu( + if_false = ["//tensorflow/compiler/xla/service:cpu_plugin"], + if_true = [], + ), alwayslink = 1, ) @@ -93,15 +110,19 @@ cc_library( ":jit_compilation_passes", ":xla_device", ":xla_kernel_creator", # buildcleaner: keep + "@com_google_absl//absl/memory", "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/compiler/tf2xla/kernels:xla_ops", - "//tensorflow/compiler/xla/service:cpu_plugin", # buildcleaner: keep "//tensorflow/core:core_cpu_internal", "//tensorflow/core:lib", - "@com_google_absl//absl/memory", - ], + ] + if_libtpu( + if_false = [ + "//tensorflow/compiler/xla/service:cpu_plugin", # buildcleaner: keep + ], + if_true = [], + ), alwayslink = 1, ) @@ -114,17 +135,21 @@ cc_library( ":jit_compilation_passes", ":xla_device", ":xla_kernel_creator", # buildcleaner: keep + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/compiler/tf2xla/kernels:xla_ops", - "//tensorflow/compiler/xla/service:gpu_plugin", # buildcleaner: keep "//tensorflow/core:core_cpu_internal", "//tensorflow/core:lib", "//tensorflow/core/common_runtime/gpu:gpu_init", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", - ], + ] + if_libtpu( + if_false = [ + "//tensorflow/compiler/xla/service:gpu_plugin", # buildcleaner: keep + ], + if_true = [], + ), alwayslink = 1, ) @@ -182,7 +207,7 @@ XLA_DEVICE_DEPS = [ "//tensorflow/core:resource_variable_ops_op_lib", "//tensorflow/core:sendrecv_ops_op_lib", "//tensorflow/core:state_ops_op_lib", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "//tensorflow/core/kernels:constant_op", "//tensorflow/core/kernels:fifo_queue", "//tensorflow/core/kernels:function_ops", @@ -261,6 +286,7 @@ cc_library( "//tensorflow/compiler/xla:parse_flags_from_env", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/base", "@com_google_absl//absl/strings", ], @@ -269,7 +295,7 @@ cc_library( # Header-only version of "flags" library, for linking from the shared object # without ODR violations. cc_library( - name = "flags_headers_only", + name = "flags_headers", hdrs = ["flags.h"], visibility = [":friends"], deps = [ @@ -280,6 +306,11 @@ cc_library( ], ) +cc_header_only_library( + name = "flags_headers_only", + deps = [":flags_headers"], +) + cc_library( name = "common", srcs = [ @@ -328,10 +359,17 @@ cc_library( name = "xla_compilation_cache", srcs = ["xla_compilation_cache.cc"], hdrs = ["xla_compilation_cache.h"], + copts = tf_copts(), deps = [ + ":flags", ":xla_activity_listener", ":xla_activity_proto_cc", - "//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes", + "@com_google_absl//absl/base", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_context", @@ -346,13 +384,13 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform:logging", - "@com_google_absl//absl/base", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:optional", - "@com_google_absl//absl/types:span", - ], + ] + if_libtpu( + if_false = [ + "//tensorflow/compiler/mlir:array_container_utils", + "//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes", + ], + if_true = [], + ), ) tf_cc_test( @@ -361,8 +399,11 @@ tf_cc_test( "xla_compilation_cache_test.cc", ], deps = [ + ":flags", ":xla_compilation_cache", + ":xla_cpu_jit", "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/xla/client:client_library", "//tensorflow/core:test", "//tensorflow/core:test_main", ], @@ -382,6 +423,72 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "get_compiler_ir", + srcs = ["get_compiler_ir.cc"], + hdrs = ["get_compiler_ir.h"], + visibility = [ + ":internal", + "//learning/brain/contrib/tpu_modeling/exp/tpu_inference_converter:__pkg__", + "//tensorflow/core/common_runtime/eager:__pkg__", + ], + deps = [ + ":common", + ":compilability_check_util", + ":flags", + ":xla_device_no_jit_rewrite_registration", + ":xla_launch_util", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/service:hlo_graph_dumper", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core/common_runtime:core_cpu_internal", + "//tensorflow/core/common_runtime/eager:tensor_handle", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + ], + alwayslink = 1, +) + +# Header-only version of "flags" library, for linking from the shared object +# without ODR violations. +cc_library( + name = "get_compiler_ir_hdrs", + textual_hdrs = ["get_compiler_ir.h"], + visibility = [ + ":internal", + "//learning/brain/contrib/tpu_modeling/exp/tpu_inference_converter:__pkg__", + "//tensorflow/core/common_runtime/eager:__pkg__", + ], + deps = [ + "//tensorflow/compiler/xla:statusor", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + ], +) + +cc_header_only_library( + name = "get_compiler_ir_hdrs_only", + deps = [":get_compiler_ir_hdrs"], +) + +# This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library. +cc_header_only_library( + name = "xla_jit_headers_lib", + visibility = ["//visibility:public"], + deps = [ + ":xla_cpu_device", + ":xla_cpu_jit", + ":xla_gpu_device", + ":xla_gpu_jit", + ], +) + cc_library( name = "xla_kernel_creator", srcs = [ @@ -604,7 +711,6 @@ cc_library( ":flags", ":resource_operation_safety_analysis", ":shape_inference_helpers", - ":union_find", ":xla_activity_listener", ":xla_cluster_util", "//tensorflow/cc:cc_ops", @@ -623,8 +729,8 @@ cc_library( "//tensorflow/compiler/tf2xla/cc:xla_ops", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:union_find", "//tensorflow/compiler/xla:util", - "//tensorflow/core:all_kernels", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", @@ -701,11 +807,6 @@ tf_cc_test( ], ) -cc_library( - name = "union_find", - hdrs = ["union_find.h"], -) - tf_cc_test( name = "deadness_analysis_test", size = "small", @@ -800,6 +901,7 @@ tf_cc_test( "//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla:test", + "//tensorflow/core:all_kernels", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", @@ -886,7 +988,6 @@ cc_library( ":device_util", ":flags", ":resource_operation_safety_analysis", - ":union_find", ":xla_activity_listener", ":xla_activity_proto_cc", ":xla_cluster_util", @@ -895,6 +996,7 @@ cc_library( "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:union_find", "//tensorflow/compiler/xla:util", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", @@ -918,6 +1020,7 @@ tf_cc_test( ":xla_cpu_jit", "//tensorflow/cc:cc_ops", "//tensorflow/cc:function_ops", + "//tensorflow/cc:functional_ops", "//tensorflow/cc:ops", "//tensorflow/cc:scope", "//tensorflow/compiler/tf2xla:test_util", @@ -944,11 +1047,12 @@ tf_cc_test( ":xla_cpu_jit", "//tensorflow/cc:cc_ops", "//tensorflow/cc:ops", + "//tensorflow/core:all_kernels", "//tensorflow/core:core_cpu", - "//tensorflow/core:direct_session_internal", "//tensorflow/core:framework", "//tensorflow/core:ops", "//tensorflow/core:test", + "//tensorflow/core/common_runtime:direct_session_internal", "//tensorflow/core/kernels:cwise_op", "//tensorflow/core/kernels:matmul_op", "//tensorflow/core/kernels:partitioned_function_ops", @@ -997,15 +1101,3 @@ cc_library( ], alwayslink = 1, ) - -# This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library. -cc_header_only_library( - name = "xla_jit_headers_lib", - visibility = ["//visibility:public"], - deps = [ - ":xla_cpu_device", - ":xla_cpu_jit", - ":xla_gpu_device", - ":xla_gpu_jit", - ], -) diff --git a/tensorflow/compiler/jit/build_xla_ops_pass_test.cc b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc index 8463c788496..160ea83585d 100644 --- a/tensorflow/compiler/jit/build_xla_ops_pass_test.cc +++ b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc @@ -130,17 +130,6 @@ FunctionDefLibrary CreateFunctionDefLibWithConstFunction(const string& name) { return fdef_lib; } -FunctionDefLibrary CreateFunctionDefLibWithInt32Input(const string& name) { - FunctionDefLibrary fdef_lib; - FunctionDef func = FunctionDefHelper::Create( - /*function_name=*/name, /*in_def=*/{"in: int32"}, - /*out_def=*/{"out: int32"}, - /*attr_def=*/{}, /*node_def=*/{{{"out"}, "Identity", {"in"}}}, - /*ret_def=*/{{"out", "out:output:0"}}); - *fdef_lib.add_function() = std::move(func); - return fdef_lib; -} - TEST_F(BuildXlaOpsTest, ControlDepsPreserved) { const char* kXlaDeviceName = "/job:worker/replica:0/task:0/device:XLA_CPU:0"; Scope root = Scope::NewRootScope().WithDevice(kXlaDeviceName).ExitOnError(); @@ -269,6 +258,17 @@ TEST_F(BuildXlaOpsTest, NoExtraMergeForEdgeToSink) { } #ifdef GOOGLE_CUDA +FunctionDefLibrary CreateFunctionDefLibWithInt32Input(const string& name) { + FunctionDefLibrary fdef_lib; + FunctionDef func = FunctionDefHelper::Create( + /*function_name=*/name, /*in_def=*/{"in: int32"}, + /*out_def=*/{"out: int32"}, + /*attr_def=*/{}, /*node_def=*/{{{"out"}, "Identity", {"in"}}}, + /*ret_def=*/{{"out", "out:output:0"}}); + *fdef_lib.add_function() = std::move(func); + return fdef_lib; +} + // This tests a rewrite that only makes sense and is active in a CUDA-enabled // build. Specifically we check that we insert an IdentityN op to avoid extra // device-to-host copies. diff --git a/tensorflow/compiler/jit/compilability_check_util.cc b/tensorflow/compiler/jit/compilability_check_util.cc index 6d4bc51f1b2..62e121420c3 100644 --- a/tensorflow/compiler/jit/compilability_check_util.cc +++ b/tensorflow/compiler/jit/compilability_check_util.cc @@ -36,7 +36,6 @@ limitations under the License. #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" -#include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/jit/xla_activity.pb.h" #include "tensorflow/compiler/jit/xla_activity_listener.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" @@ -44,6 +43,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/resource_operation_table.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/union_find.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/graph_constructor.h" @@ -84,6 +84,60 @@ Status MakeCallNodeFromAttribute(const Node& node, const std::string& attr_name, return Status::OK(); } +xla::StatusOr> MakeCallNodesFromAttribute( + const Node& node, absl::string_view attr_name, + absl::string_view call_name) { + std::vector attr_lists; + TF_RETURN_IF_ERROR(GetNodeAttr(node.attrs(), attr_name, &attr_lists)); + + std::vector out; + for (int i = 0; i < attr_lists.size(); i++) { + out.emplace_back(); + NodeDef& inserted = out.back(); + inserted.set_name(absl::StrCat(call_name, "_", i)); + inserted.set_op(attr_lists[i].name()); + *inserted.mutable_attr() = attr_lists[i].attr(); + } + return out; +} + +// Utility which searches for values in a sorted list by scanning over it once. +// No matter how many times ScanForValue is called, the list is scanned at most +// once. However, if a call to ScanForValue skips over a value, that value is +// not revisited in future calls to ScanForValue, so callers must take +// care to order their calls. +// +// Useful for merging multiple sorted lists in O(n) time. +class SinglePassSearch { + public: + // Creates a SinglePassSearch object that can be used to search in `values`. + // Does not take ownership of `values`. `values` must outlive this. + // `values` must be sorted. + explicit SinglePassSearch(absl::Span values) + : current_index_(0), values_(values) {} + + // Scans forward in the vector looking for "value", updating the internal + // position in to the vector. + // Returns true iff the vector contains the given value at or after current + // position. + // Not thread-safe. + bool ScanForValue(int value) { + while (current_index_ < values_.size() && + values_[current_index_] <= value) { + if (values_[current_index_] == value) { + current_index_++; + return true; + } + current_index_++; + } + return false; + } + + private: + int current_index_; + const absl::Span values_; +}; + } // anonymous namespace RecursiveCompilabilityChecker::UncompilableNodesMap @@ -190,6 +244,30 @@ bool RecursiveCompilabilityChecker::IsCompilableIf( return is_compilable; } +bool RecursiveCompilabilityChecker::IsCompilableCase( + const Node& case_node, FunctionLibraryRuntime* lib_runtime, + std::vector* stack_trace, + NameAttrList* encapsulating_function, + RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes) + const { + xla::StatusOr> calls = + MakeCallNodesFromAttribute(case_node, "branches", "branch"); + if (!calls.ok()) { + VLOG(2) << "Rejecting node " << case_node.name() << ": " + << "missing attribute 'branches'"; + return false; + } + + bool is_compilable = true; + + for (const NodeDef& call : *calls) { + is_compilable &= + IsCompilableCall(call, lib_runtime, stack_trace, encapsulating_function, + uncompilable_nodes); + } + return is_compilable; +} + // Tests whether 'while_node' is a completely compilable loop. // Every operator in the condition and body functions must be compilable for a // while loop to be compilable. @@ -380,6 +458,13 @@ bool RecursiveCompilabilityChecker::IsCompilableNode( return false; } + if (op_filter_.require_always_compilable && node.IsCaseNode() && + !IsCompilableCase(node, lib_runtime, stack_trace, encapsulating_function, + uncompilable_nodes)) { + LogNotCompilable(node, "unsupported case"); + return false; + } + if (!op_filter_.allow_stateful_rng_ops && IsStatefulRandomOp(node.type_string())) { absl::string_view uncompilable_reason = "stateful random op"; @@ -518,23 +603,23 @@ RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter( } } +// Returns `true` iff node has a given `attr` set to `true`. Returns `false` +// both for the missing attr, and the attr set to `false`. +static bool HasBoolAttr(const NodeDef& node, const char* attr) { + const auto& it = node.attr().find(attr); + return it != node.attr().end() && it->second.b(); +} + bool CanCreateXlaKernel(const NodeDef& node_def) { - // If kXlaMustCompileAttr is set on the node_def, use its value. - const auto& it = node_def.attr().find(kXlaMustCompileAttr); - return it != node_def.attr().end() && it->second.b(); + return HasBoolAttr(node_def, kXlaMustCompileAttr); } Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr, - const NodeDef& node_def, + const NameAttrList& function, const FunctionBody** fbody, std::vector* constant_arg_indices, std::vector* resource_arg_indices) { FunctionLibraryRuntime::Handle handle; - // If node_def is not instantiable, e.g., the function does not exist, - // simply bail out. - NameAttrList function; - TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function)); - TF_RETURN_IF_ERROR( flr->Instantiate(function.name(), AttrSlice(&function.attr()), &handle)); *fbody = flr->GetFunctionBody(handle); @@ -564,4 +649,96 @@ Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr, return Status::OK(); } +tensorflow::MemoryTypeVector GetInputMemoryTypes( + const tensorflow::FunctionBody* fbody, + absl::Span constant_arg_indices, + absl::Span resource_arg_indices) { + // Set input and output memory types. + tensorflow::MemoryTypeVector input_memory_types(fbody->arg_types.size(), + tensorflow::DEVICE_MEMORY); + // These indices are used only for optimization purposes. They allow us + // to loop over constant_arg_indices and resource_arg_indices only once + // while iterating over all the function arguments checking if it is a + // resource or a constant. + // The reason we optimized this code is because functions can have a lot of + // captured arguments. For example, the backward pass of ResNet50 takes in all + // 214 variables and a similar number of activations. + SinglePassSearch constants_search(constant_arg_indices); + SinglePassSearch resources_search(resource_arg_indices); + for (size_t i = 0; i < fbody->arg_types.size(); ++i) { + if (resources_search.ScanForValue(i) || constants_search.ScanForValue(i)) { + // Compile-time constants and resource handles are expected to be in + // host memory. + input_memory_types[i] = tensorflow::HOST_MEMORY; + } + } + return input_memory_types; +} + +tensorflow::MemoryTypeVector GetOutputMemoryTypes( + const tensorflow::FunctionBody* fbody) { + tensorflow::MemoryTypeVector output_memory_types(fbody->ret_types.size(), + tensorflow::DEVICE_MEMORY); + for (size_t i = 0; i < fbody->ret_types.size(); ++i) { + if (fbody->ret_types[i] == tensorflow::DT_RESOURCE) { + output_memory_types[i] = tensorflow::HOST_MEMORY; + } + } + return output_memory_types; +} + +static auto const ops_triggering_xla_compilation = + new absl::flat_hash_set{"XlaBroadcastHelper", + "XlaConv", + "XlaDequantize", + "XlaDot", + "XlaDynamicSlice", + "XlaDynamicUpdateSlice", + "XlaEinsum", + "XlaGather", + "XlaIf", + "XlaKeyValueSort", + "XlaPad", + "XlaRecv", + "XlaReduce", + "XlaReduceWindow", + "XlaReplicaId", + "XlaScatter", + "XlaSelectAndScatter", + "XlaSelfAdjointEig", + "XlaSend", + "XlaSharding", + "XlaSort", + "XlaSpmdFullToShardShape", + "XlaSpmdShardToFullShape", + "XlaSvd", + "XlaWhile"}; + +static bool NodeCanTriggerXlaCompilation(const NodeDef& node) { + return node.attr().find(kXlaClusterIdAttr) != node.attr().end() || + HasBoolAttr(node, kXlaMustCompileAttr) || + HasBoolAttr(node, kXlaCompileAttr) || + HasBoolAttr(node, kXlaScopeAttr) || + HasBoolAttr(node, kXlaInternalScopeAttr) || + ops_triggering_xla_compilation->count(node.op()); +} + +bool CanTriggerXlaCompilation(const GraphDef& graph) { + for (const FunctionDef& function : graph.library().function()) { + for (const NodeDef& node : function.node_def()) { + if (NodeCanTriggerXlaCompilation(node)) { + return true; + } + } + } + + for (const NodeDef& node : graph.node()) { + if (NodeCanTriggerXlaCompilation(node)) { + return true; + } + } + + return false; +} + } // namespace tensorflow diff --git a/tensorflow/compiler/jit/compilability_check_util.h b/tensorflow/compiler/jit/compilability_check_util.h index 3b20784cc29..65da072483b 100644 --- a/tensorflow/compiler/jit/compilability_check_util.h +++ b/tensorflow/compiler/jit/compilability_check_util.h @@ -26,11 +26,11 @@ limitations under the License. #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" -#include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/resource_operation_table.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/union_find.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/graph_constructor.h" @@ -124,11 +124,16 @@ class RecursiveCompilabilityChecker { // Whether ops known to have numerical accuracy issues should be considered // compilable.. bool allow_inaccurate_ops = false; + + // Require the function to be always compilable, regardless whether some + // control flow branches might be dead for a given input. + bool require_always_compilable = false; }; - RecursiveCompilabilityChecker(const OperationFilter* op_filter, - const DeviceType* jit_device_type) - : op_filter_(*op_filter), jit_device_type_(*jit_device_type) {} + RecursiveCompilabilityChecker(OperationFilter op_filter, + DeviceType jit_device_type) + : op_filter_(std::move(op_filter)), + jit_device_type_(std::move(jit_device_type)) {} using UncompilableNodesMap = std::map* stack_trace, + NameAttrList* encapsulating_function, + UncompilableNodesMap* uncompilable_nodes) const; + // Returns compilability of node def retrieved from `node`'s attribute with // name `attr_name`. bool ExtractNodeDefAndCheckCompilability( @@ -259,21 +272,20 @@ class RecursiveCompilabilityChecker { // Make sure we don't recurse infinitely on recursive functions. const size_t kMaxRecursionDepth = 10; - const OperationFilter& op_filter_; - const DeviceType& jit_device_type_; + const OperationFilter op_filter_; + const DeviceType jit_device_type_; }; RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter( const XlaOpRegistry::DeviceRegistration& registration); -// Given a FunctionLibraryRuntime and a NodeDef calling a function in the -// runtime, returns this function's body in `fbody` as well as the indices -// of its constant and resource arguments. +// Given a FunctionLibraryRuntime and a `function`, returns this function's body +// in `fbody` as well as the indices of its constant and resource arguments. // `fbody` is owned by `flr`. // `constant_arg_indices` and `resource_arg_indices` should be empty vector. // They are sorted in ascending order on this function's return. Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr, - const NodeDef& node_def, + const NameAttrList& function, const FunctionBody** fbody, std::vector* constant_arg_indices, std::vector* resource_arg_indices); @@ -282,6 +294,44 @@ Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr, // set. bool CanCreateXlaKernel(const NodeDef& node_def); +// Returns memory types for the input. +// `constant_arg_indices` and `resource_arg_indices` are sorted arrays of +// indices corresponding to constant and resource arguments respectively. +// +// One might wonder, about the case where a compile-time constant argument +// (which must be in host memory) is also used as an input into an op, +// e.g. `Add`, that expects its inputs in device memory. Here is how it +// works now. +// First, what do we mean by "op expects an input in XYZ memory"? +// There are two types of "ops" here: the tf2xla kernel and the HLO +// computation it builds. The tf2xla kernel needs to retrieve the actual +// numeric value of the compile-time constant tensors, so it really expects +// them to be on in host memory. However, for other inputs, it refers to them +// using xla::ComputationDataHandle, which is just a symbolic handle that +// xla::ComputationBuilder assigns. How does this handle gets assigned for +// constant arguments? Even constant arguments get an _Arg node in the graph +// instantiated for Function compilation. The tf2xla kernel for constant _Arg +// nodes takes the constant value, converts it to XlaLiteral, and feeds it +// to xla::ComputationBuilder.ConstantLiteral, which returns the handle. This +// constant XlaLiteral is included in the HLO graph, and subsequently, in +// the actual executable, which is copied to the device before being +// executed. Thus, when this executable runs, the constant is available in +// device memory. +tensorflow::MemoryTypeVector GetInputMemoryTypes( + const tensorflow::FunctionBody* fbody, + absl::Span constant_arg_indices, + absl::Span resource_arg_indices); + +// Returns output memory types. +// +// XlaLaunch kernel keeps all outputs (including constants, which it copies), +// in device memory except for resources. +tensorflow::MemoryTypeVector GetOutputMemoryTypes( + const tensorflow::FunctionBody* fbody); + +// Check whether graph can trigger XLA compilation. +bool CanTriggerXlaCompilation(const GraphDef& graph); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_JIT_COMPILABILITY_CHECK_UTIL_H_ diff --git a/tensorflow/compiler/jit/compilability_check_util_test.cc b/tensorflow/compiler/jit/compilability_check_util_test.cc index 3ea38e69ad9..9058b129589 100644 --- a/tensorflow/compiler/jit/compilability_check_util_test.cc +++ b/tensorflow/compiler/jit/compilability_check_util_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/functional_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -33,7 +34,16 @@ limitations under the License. namespace tensorflow { namespace { +AttrValue FuncListAttr(const absl::Span names) { + AttrValue attr; + for (const char* name : names) { + attr.mutable_list()->add_func()->set_name(name); + } + return attr; +} + constexpr char kFunctionalIfNodeName[] = "If"; +constexpr char kFunctionalCaseNodeName[] = "Case"; constexpr char kFunctionalWhileNodeName[] = "While"; constexpr char kCompilableFunctionName[] = "CompilableFn"; constexpr char kCompilableFunctionNodeName[] = "n_c"; @@ -75,8 +85,12 @@ class CompilabilityCheckUtilTest : public ::testing::Test { op_filter_.allow_inaccurate_ops = false; op_filter_.allow_slow_ops = false; - checker_ = absl::make_unique(&op_filter_, - &device_type_); + checker_ = CreateCompilabilityChecker(); + } + + std::unique_ptr CreateCompilabilityChecker() { + return absl::make_unique(op_filter_, + device_type_); } FunctionLibraryRuntime* GetFunctionLibraryRuntime() { @@ -354,5 +368,161 @@ TEST_F(CompilabilityCheckUtilTest, CheckFunctionalIfNode) { "unsupported op")); } +TEST_F(CompilabilityCheckUtilTest, CheckFunctionalCaseNode) { + FunctionDefLibrary flib; + *flib.add_function() = FunctionDefHelper::Define( + /*Function*/ kUncompilableFunctionName, + /*Inputs*/ {"n_a:float"}, + /*Outputs*/ {"n_c_uncompilable:float"}, + /*Attributes*/ {}, + // Node info + {{{kUncompilableFunctionNodeName}, "MissingKernel", {"n_a"}}}); + *flib.add_function() = FunctionDefHelper::Define( + /*Function*/ kUncompilableFunctionTwoName, + /*Inputs*/ {"n_a:float"}, + /*Outputs*/ {"n_d_uncompilable:float"}, + /*Attribute*/ {}, + // Node info + {{{kUncompilableFunctionNodeTwoName}, "MissingKernel", {"n_a"}}}); + + Scope root = Scope::NewRootScope().ExitOnError(); + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib)); + auto branch_index = ops::Placeholder(root.WithOpName("pred"), DT_INT32); + auto placeholder = ops::Placeholder(root.WithOpName("A"), DT_INT32); + std::vector inputes( + {NodeBuilder::NodeOut(placeholder.node())}); + Node* case_node; + TF_ASSERT_OK( + NodeBuilder(kFunctionalCaseNodeName, "Case", &root.graph()->flib_def()) + .Input(branch_index.node()) + .Input(inputes) + .Attr("branches", FuncListAttr({kUncompilableFunctionName, + kUncompilableFunctionTwoName})) + .Attr("Tout", {DT_INT32}) + .Finalize(root.graph(), &case_node)); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(root.ToGraph(graph.get())); + + flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), flib)); + + auto case_node_it = std::find_if( + graph->nodes().begin(), graph->nodes().end(), + [&](const Node* n) { return n->name() == kFunctionalCaseNodeName; }); + EXPECT_NE(case_node_it, graph->nodes().end()); + auto* flib_runtime = GetFunctionLibraryRuntime(); + + op_filter_.require_always_compilable = false; + checker_ = CreateCompilabilityChecker(); + EXPECT_TRUE(checker_->IsCompilableNode(**case_node_it, flib_runtime)); + op_filter_.require_always_compilable = true; + checker_ = CreateCompilabilityChecker(); + EXPECT_FALSE(checker_->IsCompilableNode(**case_node_it, flib_runtime)); +} + +TEST_F(CompilabilityCheckUtilTest, TestCanNotTriggerXlaCompilation) { + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Scope root = Scope::NewRootScope().ExitOnError(); + FunctionDefLibrary library; + + FunctionDef identity_func = FunctionDefHelper::Create( + "IdentityFunc", + /*in_def=*/{"x:float"}, + /*out_def=*/{"res:float"}, + /*attr_def=*/{}, + /*node_def=*/{{{"t0"}, "Identity", {"x"}, {{"T", DT_FLOAT}}}}, + /*ret_def*/ {{"res", "t0:output"}}); + + *library.add_function() = identity_func; + + Output in = ops::Placeholder(root, DT_FLOAT); + NameAttrList b_name_attr; + b_name_attr.set_name("IdentityFunc"); + ops::PartitionedCall call(root.WithOpName("call"), {in}, {DT_FLOAT}, + b_name_attr); + + GraphDef graph_def; + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(library)); + TF_ASSERT_OK(root.ToGraphDef(&graph_def)); + + EXPECT_FALSE(CanTriggerXlaCompilation(graph_def)); +} + +TEST_F(CompilabilityCheckUtilTest, TestXlaOpsCanTriggerXlaCompilation) { + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Scope root = Scope::NewRootScope().ExitOnError(); + FunctionDefLibrary library; + + FunctionDef sort_func = FunctionDefHelper::Create( + "SortFunc", + /*in_def=*/{"x:float"}, + /*out_def=*/{"res:float"}, + /*attr_def=*/{}, + /*node_def=*/{{{"t0"}, "XlaSort", {"x"}, {{"T", DT_FLOAT}}}}, + /*ret_def*/ {{"res", "t0:output"}}); + + *library.add_function() = sort_func; + + Output in = ops::Placeholder(root, DT_FLOAT); + NameAttrList b_name_attr; + b_name_attr.set_name("SortFunc"); + ops::PartitionedCall call(root.WithOpName("call"), {in}, {DT_FLOAT}, + b_name_attr); + + GraphDef graph_def; + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(library)); + TF_ASSERT_OK(root.ToGraphDef(&graph_def)); + + EXPECT_TRUE(CanTriggerXlaCompilation(graph_def)); +} + +TEST_F(CompilabilityCheckUtilTest, TestCanTriggerXlaCompilation) { + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Scope root = Scope::NewRootScope().ExitOnError(); + FunctionDefLibrary library; + + AttrValue true_attribute; + true_attribute.set_b(true); + + FunctionDef identity_func = FunctionDefHelper::Create( + "IdentityFunc", + /*in_def=*/{"x:float"}, + /*out_def=*/{"res:float"}, + /*attr_def=*/{}, + /*node_def=*/{{{"t0"}, "Identity", {"x"}, {{"T", DT_FLOAT}}}}, + /*ret_def*/ {{"res", "t0:output"}}); + + (*identity_func.mutable_attr())[kXlaMustCompileAttr] = true_attribute; + + FunctionDef call_identity = FunctionDefHelper::Create( + "CallIdentity", + /*in_def=*/{"x:float"}, + /*out_def=*/{"z:float"}, /*attr_def=*/{}, + /*node_def=*/ + {{{"func_call"}, + "PartitionedCall", + {"x"}, + {{"Tin", DataTypeSlice({DT_FLOAT})}, + {"Tout", DataTypeSlice({DT_FLOAT})}, + {"f", + FunctionDefHelper::FunctionRef("IdentityRef", {{"T", DT_FLOAT}})}, + {kXlaMustCompileAttr, true}}}}, + /*ret_def=*/{{"z", "func_call:output:0"}}); + + *library.add_function() = identity_func; + *library.add_function() = call_identity; + + Output in = ops::Placeholder(root, DT_FLOAT); + NameAttrList b_name_attr; + b_name_attr.set_name("CallIdentity"); + ops::PartitionedCall call(root.WithOpName("call"), {in}, {DT_FLOAT}, + b_name_attr); + + GraphDef graph_def; + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(library)); + TF_ASSERT_OK(root.ToGraphDef(&graph_def)); + + EXPECT_TRUE(CanTriggerXlaCompilation(graph_def)); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/defs.cc b/tensorflow/compiler/jit/defs.cc index 4bea71e8fc1..84e1e36bcf6 100644 --- a/tensorflow/compiler/jit/defs.cc +++ b/tensorflow/compiler/jit/defs.cc @@ -28,4 +28,6 @@ const char* const kXlaScopeAttr = "_XlaScope"; // only when auto_jit is ON. const char* const kXlaInternalScopeAttr = "_XlaInternalScope"; +const char* const kXlaClusterIdAttr = "_xla_compile_id"; + } // namespace tensorflow diff --git a/tensorflow/compiler/jit/defs.h b/tensorflow/compiler/jit/defs.h index 9eb4c2ca2e8..fa983db8df8 100644 --- a/tensorflow/compiler/jit/defs.h +++ b/tensorflow/compiler/jit/defs.h @@ -35,6 +35,9 @@ extern const char* const kXlaCompileAttr; // "_XlaCompile" extern const char* const kXlaScopeAttr; // "_XlaScope" extern const char* const kXlaInternalScopeAttr; // "_XlaInternalScope" +// The id of the compiled cluster. +extern const char* const kXlaClusterIdAttr; // "_xla_compile_id" + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_JIT_DEFS_H_ diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc index ed25baa62ff..4a5c79c02d9 100644 --- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/ascii.h" #include "absl/strings/str_cat.h" +#include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/framework/node_def.pb.h" @@ -34,9 +35,6 @@ limitations under the License. namespace tensorflow { -const char* const EncapsulateXlaComputationsPass::kXlaClusterAttr = - "_xla_compile_id"; - namespace { const char* const kXlaClusterOutput = "XlaClusterOutput"; @@ -45,10 +43,7 @@ bool IsCpuGpuCompile(const Graph* graph) { for (Node* n : graph->nodes()) { string name; // Only consider nodes being compiled. - if (!GetNodeAttr(n->attrs(), - EncapsulateXlaComputationsPass::kXlaClusterAttr, &name) - .ok()) - continue; + if (!GetNodeAttr(n->attrs(), kXlaClusterIdAttr, &name).ok()) continue; // Early return for any node with a device that is not a CPU or GPU. DeviceNameUtils::ParsedName parsed; if (DeviceNameUtils::ParseFullName(n->requested_device(), &parsed)) { @@ -180,8 +175,7 @@ Status RewriteSubgraph(const std::vector& arg_source_tensors, retvals[i]->AddAttr("index", i); } - AddNodeAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, call_def->name(), - call_def); + AddNodeAttr(kXlaClusterIdAttr, call_def->name(), call_def); AddNodeAttr("_variable_start_index", variable_start_index, call_def); // Uniquify the function name. @@ -216,8 +210,8 @@ Status RewriteSubgraph(const std::vector& arg_source_tensors, // O(n) pass over the edges. for (const Edge* e : (*graph)->edges()) { if (!e->IsControlEdge() && - e->src()->attrs().Find(kXlaClusterAttr) != nullptr && - e->dst()->attrs().Find(kXlaClusterAttr) == nullptr && + e->src()->attrs().Find(kXlaClusterIdAttr) != nullptr && + e->dst()->attrs().Find(kXlaClusterIdAttr) == nullptr && e->dst()->type_string() != kXlaClusterOutput) { return errors::InvalidArgument( "Undeclared output of XLA computation. Some common causes of this " @@ -232,9 +226,9 @@ Status RewriteSubgraph(const std::vector& arg_source_tensors, auto output = absl::make_unique((*graph)->op_registry()); TF_RETURN_WITH_CONTEXT_IF_ERROR( - EncapsulateSubgraphsInFunctions(kXlaClusterAttr, **graph, RewriteSubgraph, - /*reuse_existing_functions=*/true, - &output, flib_def), + EncapsulateSubgraphsInFunctions( + kXlaClusterIdAttr, **graph, RewriteSubgraph, + /*reuse_existing_functions=*/true, &output, flib_def), "EncapsulateXlaComputationsPass failed"); graph->swap(output); return Status::OK(); @@ -246,7 +240,7 @@ Status RewriteSubgraph(const std::vector& arg_source_tensors, // while iterating. std::vector launch_nodes; for (Node* n : graph->nodes()) { - const string& name = GetNodeAttrString(n->attrs(), kXlaClusterAttr); + const string& name = GetNodeAttrString(n->attrs(), kXlaClusterIdAttr); if (!name.empty()) { launch_nodes.push_back(n); } diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h index 3057e4c7469..9931b23fa41 100644 --- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h @@ -34,8 +34,6 @@ namespace tensorflow { // XlaLaunch operators. class EncapsulateXlaComputationsPass : public GraphOptimizationPass { public: - static const char* const kXlaClusterAttr; // _xla_compile_id - Status Run(const GraphOptimizationPassOptions& options) override; // The following methods are public only for unit tests. diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc index cc177036591..61c9a3ff9c0 100644 --- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/cc/ops/function_ops.h" #include "tensorflow/cc/ops/resource_variable_ops.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" #include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_ops.h" #include "tensorflow/compiler/tf2xla/test_util.h" @@ -46,19 +47,18 @@ static std::unique_ptr MakeOuterGraph( auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE); NodeDef def; - TF_CHECK_OK( - NodeDefBuilder("launch0", function, &flib_def) - .Input(a.node()->name(), 0, DT_INT32) - .Input(b.node()->name(), 0, DT_FLOAT) - .Input(c.node()->name(), 0, DT_INT32) - .Input(d.node()->name(), 0, DT_FLOAT) - .Input(u.node()->name(), 0, DT_RESOURCE) - .Input(v.node()->name(), 0, DT_RESOURCE) - .Input(w.node()->name(), 0, DT_RESOURCE) - .Device("/gpu:0") - .Attr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0") - .Attr("_variable_start_index", 4) - .Finalize(&def)); + TF_CHECK_OK(NodeDefBuilder("launch0", function, &flib_def) + .Input(a.node()->name(), 0, DT_INT32) + .Input(b.node()->name(), 0, DT_FLOAT) + .Input(c.node()->name(), 0, DT_INT32) + .Input(d.node()->name(), 0, DT_FLOAT) + .Input(u.node()->name(), 0, DT_RESOURCE) + .Input(v.node()->name(), 0, DT_RESOURCE) + .Input(w.node()->name(), 0, DT_RESOURCE) + .Device("/gpu:0") + .Attr(kXlaClusterIdAttr, "launch0") + .Attr("_variable_start_index", 4) + .Finalize(&def)); Status status; Node* launch = scope.graph()->AddNode(def, &status); @@ -107,7 +107,7 @@ static std::unique_ptr MakeBodyGraph() { auto arg6 = ops::_Arg(scope.WithOpName("w_0_arg"), DT_RESOURCE, 6); auto add_attrs = [](Node* node) { - node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0"); + node->AddAttr(kXlaClusterIdAttr, "launch0"); node->set_requested_device("/gpu:0"); }; @@ -155,8 +155,7 @@ TEST(EncapsulateXlaComputations, DeterministicEncapsulate) { : ops::Add(scope.WithOpName("E"), a1, a0); auto add_attrs = [](Node* node) { - node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, - "launch0"); + node->AddAttr(kXlaClusterIdAttr, "launch0"); }; add_attrs(e.node()); @@ -216,7 +215,7 @@ TEST(EncapsulateXlaComputations, Encapsulate) { auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE); auto add_attrs = [](Node* node) { - node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0"); + node->AddAttr(kXlaClusterIdAttr, "launch0"); node->set_requested_device("/gpu:0"); }; diff --git a/tensorflow/compiler/jit/flags.cc b/tensorflow/compiler/jit/flags.cc index a4a750bae0d..683acd0bae9 100644 --- a/tensorflow/compiler/jit/flags.cc +++ b/tensorflow/compiler/jit/flags.cc @@ -167,8 +167,8 @@ void AllocateAndParseFlags() { jitter_flags = new IntroduceFloatingPointJitterPassFlags; jitter_flags->jitter_amount = 1e-5; - mlir_flags = new MlirCommonFlags; - mlir_flags->tf_mlir_enable_mlir_bridge = false; + bool enable_mlir_bridge = false; + bool enable_mlir_bridge_flag_updated = false; auto setter_for_jitter_tensor_names = [](string sequence) { jitter_flags->tensor_names = absl::StrSplit(sequence, ','); @@ -217,12 +217,24 @@ void AllocateAndParseFlags() { "The amount of jitter to introduce. This amount is added to each " "element in the tensors named in `tensor_names."), - Flag("tf_mlir_enable_mlir_bridge", - &mlir_flags->tf_mlir_enable_mlir_bridge, - "Enables experimental MLIR-Based TensorFlow Compiler Bridge.")}); + Flag("tf_mlir_enable_mlir_bridge", &enable_mlir_bridge, + "Enables experimental MLIR-Based TensorFlow Compiler Bridge.", + &enable_mlir_bridge_flag_updated)}); AppendMarkForCompilationPassFlagsInternal(flag_list); xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", *flag_list); + + mlir_flags = new MlirCommonFlags; + if (!enable_mlir_bridge_flag_updated) { + mlir_flags->tf_mlir_enable_mlir_bridge = + ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_UNSPECIFIED; + } else if (enable_mlir_bridge) { + mlir_flags->tf_mlir_enable_mlir_bridge = + ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED; + } else { + mlir_flags->tf_mlir_enable_mlir_bridge = + ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_DISABLED; + } } } // namespace @@ -268,4 +280,10 @@ void AppendMarkForCompilationPassFlags(std::vector* flag_list) { AppendMarkForCompilationPassFlagsInternal(flag_list); } +static std::atomic xla_compilation_disabled(false); + +void DisableXlaCompilation() { xla_compilation_disabled = true; } + +bool FailOnXlaCompilation() { return xla_compilation_disabled; } + } // namespace tensorflow diff --git a/tensorflow/compiler/jit/flags.h b/tensorflow/compiler/jit/flags.h index 6c54fc8825e..a0860da7b04 100644 --- a/tensorflow/compiler/jit/flags.h +++ b/tensorflow/compiler/jit/flags.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/util/command_line_flags.h" namespace tensorflow { @@ -135,7 +136,7 @@ struct IntroduceFloatingPointJitterPassFlags { // Flags for common MLIR configurations. struct MlirCommonFlags { - bool tf_mlir_enable_mlir_bridge; + ConfigProto::Experimental::MlirBridgeRollout tf_mlir_enable_mlir_bridge; }; // Return a pointer to the DumpGraphFlags struct; @@ -162,6 +163,13 @@ MlirCommonFlags* GetMlirCommonFlags(); void AppendMarkForCompilationPassFlags( std::vector* flag_list); +// Disables XLA compilation, forces it to return an error message instead. Can +// be used by a server to ensure that JIT compilation is opt-in. +void DisableXlaCompilation(); + +// Returns `false` unless `DisableXlaCompilation` was called. +bool FailOnXlaCompilation(); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_JIT_FLAGS_H_ diff --git a/tensorflow/compiler/jit/force_xla_constants_on_host_pass.cc b/tensorflow/compiler/jit/force_xla_constants_on_host_pass.cc index 3ba32f07506..3692d1f3aba 100644 --- a/tensorflow/compiler/jit/force_xla_constants_on_host_pass.cc +++ b/tensorflow/compiler/jit/force_xla_constants_on_host_pass.cc @@ -38,10 +38,12 @@ Status ForceXlaConstantsOnHostPass::Run( std::vector constant_arg_indices; std::vector resource_arg_indices; + NameAttrList function; + TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node->def(), &function)); + // Force all constants to be on the host memory. TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources( - flr, node->def(), &fbody, &constant_arg_indices, - &resource_arg_indices)); + flr, function, &fbody, &constant_arg_indices, &resource_arg_indices)); VLOG(3) << "Found constant arg indices: " << absl::StrJoin(constant_arg_indices, ", "); diff --git a/tensorflow/compiler/jit/get_compiler_ir.cc b/tensorflow/compiler/jit/get_compiler_ir.cc new file mode 100644 index 00000000000..08b3bea1084 --- /dev/null +++ b/tensorflow/compiler/jit/get_compiler_ir.cc @@ -0,0 +1,158 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/get_compiler_ir.h" + +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "tensorflow/compiler/jit/compilability_check_util.h" +#include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/jit/flags.h" +#include "tensorflow/compiler/jit/xla_launch_util.h" +#include "tensorflow/compiler/jit/xla_platform_info.h" +#include "tensorflow/compiler/tf2xla/const_analysis.h" +#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" +#include "tensorflow/core/common_runtime/eager/tensor_handle.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/util/ptr_util.h" + +namespace tensorflow { + +static xla::StatusOr GetLocalExecutable( + const XlaCompiler::Options& options, + const XlaCompiler::CompileOptions& compile_options, + const NameAttrList& function, XlaCompilationCache* cache, + absl::Span args, const XlaCompiler& compiler) { + const XlaCompiler::CompilationResult* compilation_result = nullptr; + xla::LocalExecutable* executable = nullptr; + TF_RETURN_IF_ERROR(cache->Compile(options, function, args, compile_options, + XlaCompilationCache::CompileMode::kStrict, + &compilation_result, &executable)); + return executable; +} + +xla::StatusOr GetCompilerIr( + IrExportStage stage, ProcessFunctionLibraryRuntime* pflr, + absl::string_view func_name, Device* dev, EagerContext* context, + absl::Span inputs_handles) { + NameAttrList function; + function.set_name(std::string{func_name}); + + FunctionLibraryRuntime* flr = pflr->GetFLR(dev->name()); + ResourceMgr* rmgr = dev->resource_manager(); + + const FunctionBody* fbody = nullptr; + std::vector constant_arg_indices; + std::vector resource_arg_indices; + TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources( + flr, function, &fbody, &constant_arg_indices, &resource_arg_indices)); + + MemoryTypeVector input_memory_types = + GetInputMemoryTypes(fbody, constant_arg_indices, resource_arg_indices); + MemoryTypeVector output_memory_types = GetOutputMemoryTypes(fbody); + + std::deque inputs_storage; + std::vector inputs; + inputs.reserve(inputs_handles.size()); + for (int i = 0; i < inputs_handles.size(); i++) { + const TensorHandle* th = inputs_handles[i]; + const Tensor* t; + // Handle owns the tensor. + TF_RETURN_IF_ERROR(th->Tensor(&t)); + if (absl::c_binary_search(constant_arg_indices, i)) { + // Need to make sure it's on the host. + inputs_storage.emplace_back(t->dtype(), t->shape()); + TF_RETURN_IF_ERROR( + th->CopyToDevice(*context, /*d=*/nullptr, &inputs_storage.back())); + inputs.push_back(&inputs_storage.back()); + } else { + inputs.push_back(t); + } + } + + std::vector variable_infos; + TF_RETURN_IF_ERROR(GetVariableInfosFromInputs( + rmgr, dev, inputs, resource_arg_indices, &variable_infos)); + TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variable_infos))); + + XlaPlatformInfo platform_info = XlaPlatformInfoFromDevice(dev); + + XlaCompilationCache* cache; + TF_RETURN_IF_ERROR(rmgr->LookupOrCreate( + rmgr->default_container(), "xla_cache", &cache, + [&](XlaCompilationCache** cache_write_into) { + return BuildXlaCompilationCache(dev, platform_info, cache_write_into); + })); + core::ScopedUnref cache_ref(cache); + + absl::optional tf_allocator_adapter; + + XlaCompiler::Options options = + GenerateCompilerOptions(*cache, *flr, dev, + /*stream=*/nullptr, platform_info, + /*has_ref_vars=*/false, &tf_allocator_adapter); + + XlaCompiler::CompileOptions compile_options; + compile_options.always_return_tuple = false; + compile_options.alias_resource_update = true; + + XlaCompiler compiler(options); + + xla::StatusOr> args = + XlaComputationLaunchContext::BuildXlaCompilerArguments( + constant_arg_indices, inputs, variable_infos); + TF_RETURN_IF_ERROR(args.status()); + + switch (stage) { + case IrExportStage::HLO: { + XlaCompiler::CompilationResult result; + TF_RETURN_IF_ERROR( + compiler.CompileFunction(compile_options, function, *args, &result)); + + TF_ASSIGN_OR_RETURN(xla::ProgramShape program_shape, + result.computation->GetProgramShape()); + xla::HloModuleConfig config(program_shape); + TF_ASSIGN_OR_RETURN( + std::unique_ptr new_module, + xla::HloModule::CreateFromProto(result.computation->proto(), config)); + + return new_module->ToString(); + } + case IrExportStage::OPTIMIZED_HLO: { + xla::StatusOr executable = GetLocalExecutable( + options, compile_options, function, cache, *args, compiler); + TF_RETURN_IF_ERROR(executable.status()); + return (*executable)->executable()->module().ToString(); + } + case IrExportStage::OPTIMIZED_HLO_DOT: { + xla::StatusOr executable = GetLocalExecutable( + options, compile_options, function, cache, *args, compiler); + TF_RETURN_IF_ERROR(executable.status()); + xla::StatusOr graph = xla::RenderGraph( + *(*executable)->executable()->module().entry_computation(), + "Visualization", + /*debug_options=*/{}, xla::RenderedGraphFormat::kDot, + /*hlo_execution_profile=*/nullptr, + /*hlo_render_options=*/{}); + TF_RETURN_IF_ERROR(graph.status()); + return *graph; + } + } +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/get_compiler_ir.h b/tensorflow/compiler/jit/get_compiler_ir.h new file mode 100644 index 00000000000..0a0a1a44271 --- /dev/null +++ b/tensorflow/compiler/jit/get_compiler_ir.h @@ -0,0 +1,41 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_JIT_GET_COMPILER_IR_H_ +#define TENSORFLOW_COMPILER_JIT_GET_COMPILER_IR_H_ + +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace tensorflow { + +class ProcessFunctionLibraryRuntime; +class Device; +class Tensor; +class TensorHandle; +class EagerContext; + +enum class IrExportStage { HLO, OPTIMIZED_HLO, OPTIMIZED_HLO_DOT }; + +// Returns HLO text for a given function `func_name` using library runtime +// `runtime` on a device `dev` with given `inputs`. +xla::StatusOr GetCompilerIr( + IrExportStage stage, ProcessFunctionLibraryRuntime* pflr, + absl::string_view func_name, Device* dev, EagerContext* context, + absl::Span inputs); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_GET_COMPILER_IR_H_ diff --git a/tensorflow/compiler/jit/graphcycles/BUILD b/tensorflow/compiler/jit/graphcycles/BUILD index 61d0c0de35f..23d994c27c5 100644 --- a/tensorflow/compiler/jit/graphcycles/BUILD +++ b/tensorflow/compiler/jit/graphcycles/BUILD @@ -1,3 +1,4 @@ +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test") package( diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD index eb9ad8a2e85..1f400137f5b 100644 --- a/tensorflow/compiler/jit/kernels/BUILD +++ b/tensorflow/compiler/jit/kernels/BUILD @@ -1,3 +1,5 @@ +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") + package( default_visibility = [ "//tensorflow/compiler/tf2xla:internal", @@ -32,7 +34,7 @@ XLA_OPS_DEPS = [ "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:state_ops_op_lib", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "//tensorflow/core/profiler/lib:traceme", "//tensorflow/stream_executor:tf_allocator_adapter", ] diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc index de462928c46..0f0f43cbad6 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.cc +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -158,12 +158,13 @@ XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx, constants_(constants), resources_(resources), function_(function), - platform_info_(XlaPlatformInfoFromContext(ctx)), + platform_info_(XlaPlatformInfoFromDevice(ctx->device())), has_ref_vars_(has_ref_vars) {} static Status CompileToLocalExecutable( OpKernelContext* ctx, const NameAttrList& function, bool has_ref_vars, const XlaPlatformInfo& platform_info, + absl::Span inputs, absl::Span variable_infos, absl::Span constants, bool lazy, bool may_alias_resource_update, xla::LocalClient** client, @@ -180,7 +181,7 @@ static Status CompileToLocalExecutable( TF_RETURN_IF_ERROR(rm->LookupOrCreate( rm->default_container(), "xla_cache", &cache, [&](XlaCompilationCache** cache) { - return BuildXlaCompilationCache(ctx, platform_info, cache); + return BuildXlaCompilationCache(ctx->device(), platform_info, cache); })); // Hold the reference to the JIT during evaluation. (We could probably // free it sooner because the ResourceMgr will retain a reference, but @@ -191,12 +192,9 @@ static Status CompileToLocalExecutable( absl::optional tf_allocator_adapter; XlaCompiler::Options options = GenerateCompilerOptions( - *cache, ctx, platform_info, has_ref_vars, &tf_allocator_adapter); - - std::map constant_args; - for (int i : constants) { - constant_args.insert({i, ctx->input(i)}); - } + *cache, *ctx->function_library(), ctx->device(), + ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr, + platform_info, has_ref_vars, &tf_allocator_adapter); XlaCompiler::CompileOptions compile_options; compile_options.is_entry_computation = true; @@ -207,10 +205,11 @@ static Status CompileToLocalExecutable( !platform_info.is_on_xla_device() && may_alias_resource_update; - std::vector args; - TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments( - constant_args, variable_infos, ctx, &args)); - return cache->Compile(options, function, args, compile_options, + xla::StatusOr> args = + XlaComputationLaunchContext::BuildXlaCompilerArguments(constants, inputs, + variable_infos); + TF_RETURN_IF_ERROR(args.status()); + return cache->Compile(options, function, *args, compile_options, lazy ? XlaCompilationCache::CompileMode::kLazy : XlaCompilationCache::CompileMode::kStrict, compilation_result, executable); @@ -220,6 +219,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { VLOG(1) << "XlaLocalLaunchOpBase::Compute " << Canonicalize(function_.name(), AttrSlice(&function_.attr())); + std::vector inputs = InputsFromContext(ctx); xla::LocalClient* client; const XlaCompiler::CompilationResult* compilation_result; xla::LocalExecutable* executable; @@ -227,10 +227,11 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { std::vector variable_infos; { OP_REQUIRES_OK( - ctx, GetVariableInfosFromCtxInputs(ctx, resources_, &variable_infos)); + ctx, GetVariableInfosFromInputs(ctx->resource_manager(), ctx->device(), + inputs, resources_, &variable_infos)); OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(variable_infos))); Status s = CompileToLocalExecutable( - ctx, function_, /*has_ref_vars=*/has_ref_vars_, platform_info_, + ctx, function_, /*has_ref_vars=*/has_ref_vars_, platform_info_, inputs, variable_infos, constants_, /*lazy=*/false, /*may_alias_resource_update=*/true, &client, &compilation_result, &executable); @@ -248,8 +249,10 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { VLOG(1) << "Executing XLA Computation..."; absl::optional tf_allocator_adapter; - se::DeviceMemoryAllocator* allocator = - GetAllocator(&tf_allocator_adapter, ctx, platform_info_); + se::DeviceMemoryAllocator* allocator = GetAllocator( + &tf_allocator_adapter, ctx->device(), + ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr, + platform_info_); int device_ordinal = stream ? stream->parent()->device_ordinal() : client->default_device_ordinal(); XlaComputationLaunchContext launch_context( @@ -271,18 +274,6 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { run_options.set_allocator(allocator); run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device()); run_options.set_rng_seed(GetXLARandomSeed()); - xla::ThenExecuteFunction then_execute; - if (ctx->op_device_context()) { - then_execute = [&](se::Stream* stream, std::function fn) { - Status status = ctx->op_device_context()->ThenExecute( - down_cast(ctx->device()), stream, std::move(fn)); - if (!status.ok()) { - // This should never happen. - LOG(ERROR) << "ThenExecute failed " << status; - } - }; - run_options.set_then_execute_function(&then_execute); - } Env* env = Env::Default(); auto start_time = env->NowMicros(); @@ -373,7 +364,7 @@ XlaCompileOp::XlaCompileOp(OpKernelConstruction* ctx) constants_(ConstantsVector(ctx)), resources_(ResourcesVector(ctx)), function_(FunctionAttr(ctx)), - platform_info_(XlaPlatformInfoFromContext(ctx)), + platform_info_(XlaPlatformInfoFromDevice(ctx->device())), must_compile_(MustCompileAttr(ctx)), has_ref_vars_(HasRefVars(ctx)) {} @@ -385,6 +376,7 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) { xla::LocalExecutable* executable; ResourceVarsSnapshot variables; + std::vector inputs = InputsFromContext(ctx); bool cannot_compile_cluster; { mutex_lock guard(cannot_compile_cluster_mu_); @@ -397,13 +389,14 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) { } else { std::vector variable_infos; OP_REQUIRES_OK( - ctx, GetVariableInfosFromCtxInputs(ctx, resources_, &variable_infos)); + ctx, GetVariableInfosFromInputs(ctx->resource_manager(), ctx->device(), + inputs, resources_, &variable_infos)); OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(variable_infos))); // Do not alias resource updates as locking variables in XlaCompile and // unlocking them in XlaRun may lead to deadlocks. Status status = CompileToLocalExecutable( - ctx, function_, has_ref_vars_, platform_info_, variable_infos, + ctx, function_, has_ref_vars_, platform_info_, inputs, variable_infos, constants_, /*lazy=*/!must_compile_, /*may_alias_resource_update=*/false, &client, &kernel, &executable); @@ -461,7 +454,7 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) { } XlaRunOp::XlaRunOp(OpKernelConstruction* ctx) - : OpKernel(ctx), platform_info_(XlaPlatformInfoFromContext(ctx)) {} + : OpKernel(ctx), platform_info_(XlaPlatformInfoFromDevice(ctx->device())) {} void XlaRunOp::Compute(OpKernelContext* ctx) { VLOG(3) << "XlaRunOp " << def().name(); @@ -472,8 +465,10 @@ void XlaRunOp::Compute(OpKernelContext* ctx) { XlaExecutableClosureStore::Global()->Consume(key); absl::optional tf_allocator_adapter; - se::DeviceMemoryAllocator* allocator = - GetAllocator(&tf_allocator_adapter, ctx, platform_info_); + se::DeviceMemoryAllocator* allocator = GetAllocator( + &tf_allocator_adapter, ctx->device(), + ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr, + platform_info_); se::Stream* stream = ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; int device_ordinal = stream ? stream->parent()->device_ordinal() @@ -515,18 +510,6 @@ void XlaRunOp::Compute(OpKernelContext* ctx) { run_options.set_allocator(allocator); run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device()); run_options.set_rng_seed(GetXLARandomSeed()); - xla::ThenExecuteFunction then_execute; - if (ctx->op_device_context()) { - then_execute = [&](se::Stream* stream, std::function fn) { - Status status = ctx->op_device_context()->ThenExecute( - down_cast(ctx->device()), stream, std::move(fn)); - if (!status.ok()) { - // This should never happen. - LOG(ERROR) << "ThenExecute failed " << status; - } - }; - run_options.set_then_execute_function(&then_execute); - } Env* env = Env::Default(); auto start_time = env->NowMicros(); diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 19eb61b6f72..ada7766fcbb 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -32,12 +32,12 @@ limitations under the License. #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" -#include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/resource_operation_table.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/union_find.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/graph_constructor.h" @@ -1196,13 +1196,14 @@ Status MarkForCompilationPassImpl::FindCompilationCandidates() { continue; } - DeviceType jit_device_type(registration->compilation_device_name); - - RecursiveCompilabilityChecker::OperationFilter op_filter = + RecursiveCompilabilityChecker::OperationFilter filter = CreateOperationFilter(*registration); + filter.require_always_compilable = true; - if (!RecursiveCompilabilityChecker{&op_filter, &jit_device_type} - .IsCompilableNode(*node, lib_runtime)) { + RecursiveCompilabilityChecker checker( + filter, DeviceType{registration->compilation_device_name}); + + if (!checker.IsCompilableNode(*node, lib_runtime)) { continue; } @@ -1711,40 +1712,6 @@ std::atomic* GetPointerToFuel(int64 initial_value) { } } // anonymous namespace -bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef, - RecursiveCompilabilityChecker::UncompilableNodesMap* - uncompilable_node_info) { - Device* device = flr->device(); - const XlaOpRegistry::DeviceRegistration* registration; - CHECK(XlaOpRegistry::GetCompilationDevice(device->device_type(), - ®istration)); - DeviceType jit_device_type(registration->compilation_device_name); - - // We can always *compile* resource operations, stateful RNGs and dummy ops, - // even if we are sometimes unable to auto-cluster them. - RecursiveCompilabilityChecker::OperationFilter op_filter; - op_filter.allow_resource_ops_in_called_functions = true; - op_filter.allow_stack_ops = true; - op_filter.allow_tensor_array_ops = true; - op_filter.allow_stateful_rng_ops = true; - op_filter.allow_control_trigger = true; - op_filter.allow_eliding_assert_and_checknumerics_ops = true; - op_filter.allow_ops_producing_or_consuming_variant = true; - op_filter.allow_slow_ops = true; - op_filter.allow_inaccurate_ops = true; - - RecursiveCompilabilityChecker checker{&op_filter, &jit_device_type}; - if (!uncompilable_node_info) { - // We do not need uncompilable node info. Just return the result. - return checker.IsCompilableCall(ndef, flr); - } - - RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_node_result = - checker.FindUncompilableNodes(ndef, flr); - uncompilable_node_info->swap(uncompilable_node_result); - return uncompilable_node_info->empty(); -} - Status MarkForCompilationPass::Run( const GraphOptimizationPassOptions& options) { MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); @@ -1837,7 +1804,9 @@ absl::flat_hash_map>* GetAllowlistTable() { "ConcatOffset", "Const", "MirrorPad", "Pack", "Pad", "PadV2", "Reverse", "ReverseV2", "ReverseSequence", "Slice", "Split", "SplitV", "StridedSlice", "StridedSliceGrad", "ResourceStridedSliceAssign", - "Tile", "Transpose", "InvertPermutation", "Unpack", "DeviceIndex"}}}; + "Tile", "Transpose", "InvertPermutation", "Unpack", "DeviceIndex", + "TensorStridedSliceUpdate", + }}}; // clang-format on return result; } @@ -1952,6 +1921,7 @@ absl::flat_hash_set GetKnownXLAAllowlistOp() { "ParallelDynamicStitch", "ParameterizedTruncatedNormal", "PartitionedCall", + "Polygamma", "PopulationCount", "Qr", "QuantizeAndDequantizeV2", @@ -1996,6 +1966,8 @@ absl::flat_hash_set GetKnownXLAAllowlistOp() { "ResourceScatterNdUpdate", "ResourceScatterSub", "ResourceScatterUpdate", + "RngReadAndSkip", + "RngSkip", "Roll", "ScatterNd", "SelfAdjointEigV2", @@ -2018,11 +1990,17 @@ absl::flat_hash_set GetKnownXLAAllowlistOp() { "StatelessCase", "StatelessIf", "StatelessMultinomial", + "StatelessRandomGetKeyCounterAlg", "StatelessRandomNormal", + "StatelessRandomNormalV2", "StatelessRandomUniform", + "StatelessRandomUniformV2", "StatelessRandomUniformInt", + "StatelessRandomUniformIntV2", "StatelessRandomUniformFullInt", + "StatelessRandomUniformFullIntV2", "StatelessTruncatedNormal", + "StatelessTruncatedNormalV2", "StatelessWhile", "Svd", "SymbolicGradient", @@ -2049,6 +2027,8 @@ absl::flat_hash_set GetKnownXLAAllowlistOp() { "TensorListSplit", "TensorListStack", "TensorScatterAdd", + "TensorScatterMax", + "TensorScatterMin", "TensorScatterSub", "TensorScatterUpdate", "TridiagonalSolve", @@ -2080,12 +2060,15 @@ absl::flat_hash_set GetKnownXLAAllowlistOp() { "XlaSelectAndScatter", "XlaSelfAdjointEig", "XlaSend", + "XlaSetBound", "XlaSharding", "XlaSort", "XlaSpmdFullToShardShape", "XlaSpmdShardToFullShape", "XlaSvd", + "XlaVariadicReduce", "XlaWhile", + "Zeta", "_Arg", "_ArrayToList", "_ListToArray", diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.h b/tensorflow/compiler/jit/mark_for_compilation_pass.h index 0e9a64e7f28..810ebf38b5c 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.h +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.h @@ -50,14 +50,6 @@ class MarkForCompilationPass : public GraphOptimizationPass { friend class MarkForCompilationPassTestHelper; }; -// Returns true iff 'ndef' is a call to a function that is compilable. A -// function is compilable iff every operator in the function body is -// compilable. If 'ndef' is not compilable and 'uncompilable_node_info' is not -// null, we will populate 'uncompilable_node_info' with uncompilable node info. -bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef, - RecursiveCompilabilityChecker::UncompilableNodesMap* - uncompilable_node_info = nullptr); - absl::flat_hash_map>* GetAllowlistTable(); namespace testing { diff --git a/tensorflow/compiler/jit/ops/BUILD b/tensorflow/compiler/jit/ops/BUILD index afc96a8e68c..6ca8fd0e34a 100644 --- a/tensorflow/compiler/jit/ops/BUILD +++ b/tensorflow/compiler/jit/ops/BUILD @@ -1,3 +1,4 @@ +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") package( diff --git a/tensorflow/compiler/jit/tests/BUILD b/tensorflow/compiler/jit/tests/BUILD index 412dfefb9b7..88ce43902fd 100644 --- a/tensorflow/compiler/jit/tests/BUILD +++ b/tensorflow/compiler/jit/tests/BUILD @@ -1,3 +1,4 @@ +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test") licenses(["notice"]) # Apache 2.0 diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index 971a5383f6b..461a6692c84 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -20,9 +20,9 @@ limitations under the License. #include "absl/base/call_once.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" +#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/xla_activity.pb.h" #include "tensorflow/compiler/jit/xla_activity_listener.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" @@ -47,6 +47,11 @@ limitations under the License. #include "tensorflow/core/public/version.h" #include "tensorflow/core/util/dump_graph.h" +#if !defined(LIBTPU_ON_GCE) +#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h" +#include "tensorflow/compiler/mlir/utils/array_container_utils.h" +#endif + namespace tensorflow { constexpr int64 XlaCompilationCache::kDefaultCompilationThreshold; @@ -278,23 +283,39 @@ Status XlaCompilationCache::CompileSingleOp( const NodeDef& node_def = ctx->op_kernel().def(); TF_ASSIGN_OR_RETURN(auto graph, CreateGraph(node_def, args, result_dtypes)); + // TODO(b/155596779): Support TensorList args. bool has_tensor_list_arg = absl::c_any_of(args, [](const XlaCompiler::Argument arg) { return arg.kind == XlaCompiler::Argument::kTensorList; }); const ConfigProto* config = ctx->function_library()->config_proto(); - bool use_mlir = config && config->experimental().enable_mlir_bridge(); - // TODO(b/155596779): Support TensorList args. - if (!use_mlir || !has_tensor_list_arg) { + // TODO(b/171039585): Support tf.VarIsInitializedOp using MLIR. + bool use_mlir = config && config->experimental().enable_mlir_bridge() && + !has_tensor_list_arg && + node_def.op() != "VarIsInitializedOp"; +#ifdef LIBTPU_ON_GCE + if (use_mlir) { + LOG(WARNING) << "MLIR is not supported in this environment."; + } + return compiler->CompileGraph(compile_options, node_def.name(), + std::move(graph), args, result); +#else + if (!use_mlir) { return compiler->CompileGraph(compile_options, node_def.name(), std::move(graph), args, result); } + VLOG(1) << "Using MLIR bridge"; GraphDebugInfo debug_info; + std::vector control_rets; + if (result_dtypes.empty()) { + control_rets.push_back(node_def.name()); + } return CompileGraphToXlaHlo( - *graph, {args.data(), args.size()}, options.device_type.type_string(), - compile_options.use_tuple_arg, *options.flib_def, debug_info, - options.shape_representation_fn, result); + *graph, mlir::SpanToArrayRef(args), control_rets, + options.device_type.type_string(), compile_options.use_tuple_arg, + *options.flib_def, debug_info, options.shape_representation_fn, result); +#endif }; return CompileImpl(options, name, args, compile_op, /*compile_threshold=*/absl::nullopt, @@ -323,6 +344,10 @@ Status XlaCompilationCache::CompileImpl( absl::optional compile_threshold, const XlaCompiler::CompilationResult** out_compilation_result, xla::LocalExecutable** out_executable) { + if (FailOnXlaCompilation()) { + return errors::Internal("XLA compilation disabled"); + } + DCHECK_NE(out_executable, nullptr); VLOG(2) << "XlaCompilationCache::Compile " << DebugString(); diff --git a/tensorflow/compiler/jit/xla_compilation_cache_test.cc b/tensorflow/compiler/jit/xla_compilation_cache_test.cc index 7227615d2bb..5578925b790 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache_test.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache_test.cc @@ -15,7 +15,9 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_compilation_cache.h" +#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" @@ -52,6 +54,30 @@ TEST(XlaCompilationCacheTest, SignatureEquality) { } } +TEST(XlaCompilationCacheTest, TestDisabledXlaCompilation) { + NameAttrList fn; + fn.set_name("afunction"); + + DisableXlaCompilation(); + + xla::LocalClient* client = xla::ClientLibrary::LocalClientOrDie(); + DeviceType device_type = DeviceType(DEVICE_CPU_XLA_JIT); + + const XlaCompiler::CompilationResult* compilation_result; + xla::LocalExecutable* executable; + + auto cache = new XlaCompilationCache(client, device_type); + core::ScopedUnref cache_ref(cache); + + Status status = cache->Compile(XlaCompiler::Options{}, fn, {}, + XlaCompiler::CompileOptions{}, + XlaCompilationCache::CompileMode::kStrict, + &compilation_result, &executable); + EXPECT_FALSE(status.ok()); + EXPECT_TRUE( + absl::StrContains(status.error_message(), "XLA compilation disabled")); +} + static void BM_BuildSignature(int iters, int n_args) { NameAttrList fn; fn.set_name("afunction"); diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc index da251c2c8f3..d092508eccf 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -49,8 +49,10 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx, xla::LocalClient* client = static_cast(cache->client()); absl::optional tf_allocator_adapter; - se::DeviceMemoryAllocator* allocator = - GetAllocator(&tf_allocator_adapter, ctx, platform_info_); + se::DeviceMemoryAllocator* allocator = GetAllocator( + &tf_allocator_adapter, ctx->device(), + ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr, + platform_info_); XlaComputationLaunchContext launch_context( client, allocator, client->default_device_ordinal(), /*allocate_xla_tensors=*/platform_info_.xla_device_metadata() != nullptr, @@ -101,53 +103,16 @@ Status XlaCompileOnDemandOp::Compile( OpKernelContext* ctx, const XlaCompiler::CompilationResult** result, XlaCompilationCache** cache, ResourceVarsSnapshot* variable_args, xla::LocalExecutable** executable) { - std::map constant_arguments; std::vector constant_input_indices; TF_RETURN_IF_ERROR(GetCompileTimeConstInputs( &ctx->op_kernel(), &constant_input_indices, ctx->function_library())); - CHECK(absl::c_is_sorted(constant_input_indices)); - - for (int64 i = 0; i < ctx->num_inputs(); ++i) { - const Tensor& device_tensor = ctx->input(i); - - if (const XlaTensor* xla_tensor = XlaTensor::FromTensor(&device_tensor)) { - if (xla_tensor->has_host_tensor()) { - if (absl::c_binary_search(constant_input_indices, i)) { - constant_arguments[i] = xla_tensor->host_tensor(); - } - } - } - - if (!constant_arguments.count(i)) { - if (absl::c_binary_search(constant_input_indices, i)) { - if (ctx->input_memory_type(i) != HOST_MEMORY && - ctx->op_device_context()) { - // Slow path; the argument is not available as a host constant so we - // must fetch it synchronously. - Tensor host_tensor; - AllocatorAttributes attrs; - attrs.set_on_host(true); - TF_RETURN_IF_ERROR(ctx->allocate_temp(device_tensor.dtype(), - device_tensor.shape(), - &host_tensor, attrs)); - Status status = ctx->op_device_context()->CopyDeviceTensorToCPUSync( - &device_tensor, "ConstantArgument", - reinterpret_cast(ctx->device()), &host_tensor); - if (!status.ok()) { - LOG(ERROR) << "Copying tensor of shape " - << device_tensor.shape().DebugString() << " from " - << ctx->device()->name() << "to CPU failed with " - << status.ToString(); - return status; - } - constant_arguments[i] = host_tensor; - } else { - constant_arguments[i] = device_tensor; - } - } - } + if (!absl::c_all_of(constant_input_indices, [&](int idx) { + return ctx->input_memory_type(idx) == HOST_MEMORY; + })) { + return errors::Internal("Unexpected device placement for a constant input"); } + std::vector inputs = InputsFromContext(ctx); // We store information about the JIT-compiled XLA computation // in the ResourceMgr. @@ -157,13 +122,16 @@ Status XlaCompileOnDemandOp::Compile( TF_RETURN_IF_ERROR(rm->LookupOrCreate( rm->default_container(), "xla_cache", cache, [&](XlaCompilationCache** write_into_cache) { - return BuildXlaCompilationCache(ctx, platform_info_, write_into_cache); + return BuildXlaCompilationCache(ctx->device(), platform_info_, + write_into_cache); })); absl::optional tf_allocator_adapter; - XlaCompiler::Options options = - GenerateCompilerOptions(**cache, ctx, platform_info_, - /*has_ref_vars=*/true, &tf_allocator_adapter); + XlaCompiler::Options options = GenerateCompilerOptions( + **cache, *ctx->function_library(), ctx->device(), + ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr, + platform_info_, + /*has_ref_vars=*/true, &tf_allocator_adapter); XlaCompiler::CompileOptions compile_options; compile_options.is_entry_computation = true; @@ -172,19 +140,23 @@ Status XlaCompileOnDemandOp::Compile( compile_options.always_return_tuple = false; std::vector variables_indices = GetResourceVariableIndices(ctx); - std::vector args; + xla::StatusOr> args; { std::vector variable_infos; TF_RETURN_IF_ERROR( - GetVariableInfosFromCtxInputs(ctx, variables_indices, &variable_infos)); + GetVariableInfosFromInputs(ctx->resource_manager(), ctx->device(), + inputs, variables_indices, &variable_infos)); + TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variable_infos))); TF_RETURN_IF_ERROR(SnapshotResourceVariables( ctx, variables_indices, variable_infos, variable_args)); - TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments( - constant_arguments, variable_infos, ctx, &args)); + + args = XlaComputationLaunchContext::BuildXlaCompilerArguments( + constant_input_indices, inputs, variable_infos); + TF_RETURN_IF_ERROR(args.status()); } - return (*cache)->CompileSingleOp(options, args, ctx, compile_options, result, + return (*cache)->CompileSingleOp(options, *args, ctx, compile_options, result, executable); } diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.h b/tensorflow/compiler/jit/xla_compile_on_demand_op.h index 095d3427d41..bb8ab889ce9 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.h +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.h @@ -37,7 +37,8 @@ namespace tensorflow { class XlaCompileOnDemandOp : public OpKernel { public: explicit XlaCompileOnDemandOp(OpKernelConstruction* ctx) - : OpKernel(ctx), platform_info_(XlaPlatformInfoFromContext(ctx)) {} + : OpKernel(ctx), + platform_info_(XlaPlatformInfoFromDevice(ctx->device())) {} void Compute(OpKernelContext* ctx) override; private: diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc index 446cd8944de..dd1ddb616f5 100644 --- a/tensorflow/compiler/jit/xla_cpu_device.cc +++ b/tensorflow/compiler/jit/xla_cpu_device.cc @@ -51,7 +51,7 @@ Status XlaCpuDeviceFactory::CreateDevices( std::vector>* devices) { XlaDeviceFlags* flags = GetXlaDeviceFlags(); if (!flags->tf_xla_enable_xla_devices) { - LOG(INFO) << "Not creating XLA devices, tf_xla_enable_xla_devices not set"; + VLOG(1) << "Not creating XLA devices, tf_xla_enable_xla_devices not set"; return Status::OK(); } bool compile_on_demand = flags->tf_xla_compile_on_demand; diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index c47c9a29c1a..089d22dca03 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -573,8 +573,7 @@ XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device, // Any op assigned to the device that isn't rewritten by the graph rewriter // gets executed by an XlaCompileOnDemandOp, which compiles it and executes // it just-in-time. - OpKernel* (*factory)(OpKernelConstruction*) = - [](OpKernelConstruction* context) -> OpKernel* { + auto factory = [](OpKernelConstruction* context) -> OpKernel* { return new XlaCompileOnDemandOp(context); }; XlaOpRegistry::RegisterCompilationKernels(); @@ -583,6 +582,13 @@ XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device, jit_device, /*include_compilation_only_kernels=*/false)) { KernelDef* def = new KernelDef(*jit_def); + const std::unordered_set* constant_inputs = + XlaOpRegistry::CompileTimeConstantInputArgNames(def->op()); + + for (const std::string& arg_name : *constant_inputs) { + def->add_host_memory_arg(arg_name); + } + def->set_device_type(device); registrations->op_kernel_registrars.emplace_back( new kernel_factory::OpKernelRegistrar(def, "XlaCompileOnDemandOp", diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h index f7e7ee9cf95..6d6086ce0fa 100644 --- a/tensorflow/compiler/jit/xla_device.h +++ b/tensorflow/compiler/jit/xla_device.h @@ -94,6 +94,11 @@ class XlaDevice : public LocalDevice { static Status GetMetadata(OpKernelConstruction* ctx, const Metadata** metadata); + // Sets `*metadata` to the XlaDevice Metadata in the XLA device used by + // `device`. + static Status GetMetadataFromDevice(DeviceBase* device, + const XlaDevice::Metadata** metadata); + struct Options { // The StreamExecutor platform. Not owned. Must be non-null. se::Platform* platform = nullptr; @@ -196,8 +201,6 @@ class XlaDevice : public LocalDevice { xla::StatusOr> GetDeviceContextLocked() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); - static Status GetMetadataFromDevice(DeviceBase* device, - const XlaDevice::Metadata** metadata); Status MakeTensorFromProto(XlaDeviceContext* device_context, const TensorProto& tensor_proto, diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index e1cef25e33e..7bdd0aecb34 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -294,4 +294,12 @@ se::Stream* XlaDeviceContext::GetDeviceToDeviceStream() { return device_to_device_stream(stream); } +Status XlaDeviceContext::ThenExecute(Device* device, + stream_executor::Stream* stream, + std::function func) { + VLOG(2) << "XlaDeviceContext::ThenExecute"; + stream->ThenDoHostCallback(std::move(func)); + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h index 05d8dfa7556..5689e815a99 100644 --- a/tensorflow/compiler/jit/xla_device_context.h +++ b/tensorflow/compiler/jit/xla_device_context.h @@ -86,6 +86,9 @@ class XlaDeviceContext : public DeviceContext { // Returns a device-to-device stream, in round-robin fashion. se::Stream* GetDeviceToDeviceStream(); + Status ThenExecute(Device* device, stream_executor::Stream* stream, + std::function func) override; + private: bool UseMultipleStreams() const { return stream_ != host_to_device_stream_; } diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc index 16f496d51a3..99ba5658819 100644 --- a/tensorflow/compiler/jit/xla_gpu_device.cc +++ b/tensorflow/compiler/jit/xla_gpu_device.cc @@ -66,7 +66,7 @@ class XlaGpuDeviceFactory : public DeviceFactory { Status XlaGpuDeviceFactory::ListPhysicalDevices(std::vector* devices) { XlaDeviceFlags* flags = GetXlaDeviceFlags(); if (!flags->tf_xla_enable_xla_devices) { - LOG(INFO) << "Not creating XLA devices, tf_xla_enable_xla_devices not set"; + VLOG(1) << "Not creating XLA devices, tf_xla_enable_xla_devices not set"; return Status::OK(); } diff --git a/tensorflow/compiler/jit/xla_kernel_creator.cc b/tensorflow/compiler/jit/xla_kernel_creator.cc index 3a6345afe9f..b90f8b7b990 100644 --- a/tensorflow/compiler/jit/xla_kernel_creator.cc +++ b/tensorflow/compiler/jit/xla_kernel_creator.cc @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/kernels/xla_ops.h" -#include "tensorflow/compiler/jit/mark_for_compilation_pass.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/common_runtime/function.h" @@ -30,53 +29,51 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/util/ptr_util.h" -namespace { +namespace tensorflow { -// Utility which searches for values in a sorted list by scanning over it once. -// No matter how many times ScanForValue is called, the list is scanned at most -// once. However, if a call to ScanForValue skips over a value, that value is -// not revisited in future calls to ScanForValue, so callers must take -// care to order their calls. -// -// Useful for merging multiple sorted lists in O(n) time. -class SinglePassSearch { - public: - // Creates a SinglePassSearch object that can be used to search in `values`. - // Does not take ownership of `values`. `values` must outlive this. - // `values` must be sorted. - explicit SinglePassSearch(const std::vector* values) - : current_index_(0), values_(values) {} +// Returns true iff 'ndef' is a call to a function that is compilable. A +// function is compilable iff every operator in the function body is +// compilable. If 'ndef' is not compilable and 'uncompilable_node_info' is not +// null, we will populate 'uncompilable_node_info' with uncompilable node info. +static bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef, + RecursiveCompilabilityChecker::UncompilableNodesMap* + uncompilable_node_info) { + Device* device = flr->device(); + const XlaOpRegistry::DeviceRegistration* registration; + CHECK(XlaOpRegistry::GetCompilationDevice(device->device_type(), + ®istration)); - // Scans forward in the vector looking for "value", updating the internal - // position in to the vector. - // Returns true iff the vector contains the given value at or after current - // position. - // Not thread-safe. - bool ScanForValue(int value) { - while (current_index_ < values_->size() && - (*values_)[current_index_] <= value) { - if ((*values_)[current_index_] == value) { - current_index_++; - return true; - } - current_index_++; - } - return false; + // We can always *compile* resource operations, stateful RNGs and dummy ops, + // even if we are sometimes unable to auto-cluster them. + RecursiveCompilabilityChecker::OperationFilter op_filter; + op_filter.allow_resource_ops_in_called_functions = true; + op_filter.allow_stack_ops = true; + op_filter.allow_tensor_array_ops = true; + op_filter.allow_stateful_rng_ops = true; + op_filter.allow_control_trigger = true; + op_filter.allow_eliding_assert_and_checknumerics_ops = true; + op_filter.allow_ops_producing_or_consuming_variant = true; + op_filter.allow_slow_ops = true; + op_filter.allow_inaccurate_ops = true; + + RecursiveCompilabilityChecker checker{ + op_filter, DeviceType{registration->compilation_device_name}}; + if (!uncompilable_node_info) { + // We do not need uncompilable node info. Just return the result. + return checker.IsCompilableCall(ndef, flr); } - private: - int current_index_; - const std::vector* values_; -}; - -} // end namespace - -namespace tensorflow { + RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_node_result = + checker.FindUncompilableNodes(ndef, flr); + uncompilable_node_info->swap(uncompilable_node_result); + return uncompilable_node_info->empty(); +} bool XlaKernelCreator::CanCreateKernel( const FunctionLibraryRuntime& flr, const std::shared_ptr& props) const { - return CanCreateXlaKernel(props->node_def); + return CanCreateXlaKernel(props->node_def) && + !XlaOpRegistry::IsCompilationDevice(flr.device()->device_type()); } static Status CreateXlaKernel(FunctionLibraryRuntime* flr, @@ -92,7 +89,8 @@ static Status CreateXlaKernel(FunctionLibraryRuntime* flr, XlaOpRegistry::RegisterCompilationKernels(); // Only check for compilability if the MLIR bridge is not enabled. - if (!GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge) { + if (GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge != + ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED) { RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes_map; if (!IsCompilable(flr, node_def, &uncompilable_nodes_map)) { std::vector @@ -122,62 +120,19 @@ static Status CreateXlaKernel(FunctionLibraryRuntime* flr, } // Get function body, constant args, and resource args. + NameAttrList function; + TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function)); const FunctionBody* fbody = nullptr; std::vector constant_arg_indices; std::vector resource_arg_indices; TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources( - flr, node_def, &fbody, &constant_arg_indices, &resource_arg_indices)); + flr, function, &fbody, &constant_arg_indices, &resource_arg_indices)); - // Set input and output memory types. - MemoryTypeVector input_memory_types(fbody->arg_types.size(), DEVICE_MEMORY); - // These indices are used only for optimization purposes. They allow us - // to loop over constant_arg_indices and resource_arg_indices only once - // while iterating over all the function arguments checking if it is a - // resource or a constant. - // The reason we optimized this code is because functions can have a lot of - // captured arguments. For example, the backward pass of ResNet50 takes in all - // 214 variables and a similar number of activations. - SinglePassSearch constants_search(&constant_arg_indices); - SinglePassSearch resources_search(&resource_arg_indices); - for (size_t i = 0; i < fbody->arg_types.size(); ++i) { - if (resources_search.ScanForValue(i) || constants_search.ScanForValue(i)) { - // Compile-time constants and resource handles are expected to be in - // host memory. - input_memory_types[i] = HOST_MEMORY; - } - } - // One might wonder, about the case where a compile-time constant argument - // (which must be in host memory) is also used as an input into an op, - // e.g. Add, that expects its inputs in device memory. Here is how it - // works now. - // First, what do we mean by "op expects an input in XYZ memory"? - // There are two types of "ops" here: the tf2xla kernel and the HLO - // computation it builds. The tf2xla kernel needs to retrieve the actual - // numeric value of the compile-time constant tensors, so it really expects - // them to be on in host memory. However, for other inputs, it refers to them - // using xla::ComputationDataHandle, which is just a symbolic handle that - // xla::ComputationBuilder assigns. How does this handle gets assigned for - // constant arguments? Even constant arguments get an _Arg node in the graph - // instantiated for Function compilation. The tf2xla kernel for constant _Arg - // nodes takes the constant value, converts it to XlaLiteral, and feeds it - // to xla::ComputationBuilder.ConstantLiteral, which returns the handle. This - // constant XlaLiteral is included in the HLO graph, and subsequently, in - // the actual executable, which is copied to the device before being - // executed. Thus, when this executable runs, the constant is available in - // device memory. - - // XlaLaunch kernel keeps all outputs (including constants, which it copies), - // in device memory except for resources. - MemoryTypeVector output_memory_types(fbody->ret_types.size(), DEVICE_MEMORY); - for (size_t i = 0; i < fbody->ret_types.size(); ++i) { - if (fbody->ret_types[i] == DT_RESOURCE) { - output_memory_types[i] = HOST_MEMORY; - } - } + MemoryTypeVector input_memory_types = + GetInputMemoryTypes(fbody, constant_arg_indices, resource_arg_indices); + MemoryTypeVector output_memory_types = GetOutputMemoryTypes(fbody); // Create the kernel. - NameAttrList function; - TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function)); Device* dev = flr->device(); Status s; auto props = std::make_shared( diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index 19e2b5a2bb5..a0e60b1eafe 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -44,12 +44,6 @@ namespace { using xla::ScopedShapedBuffer; using xla::ShapedBuffer; -const char kPossibleNonVariableResourceHintMessage[] = - "If the error is similar to `Trying to access resource using the wrong " - "type`, this is likely because XLA only accepts Resource Variables as " - "inputs by snapshotting their values. Other TensorFlow resource types like " - "TensorList/TensorArray/Stack are not supported. Try removing non-variable " - "resource inputs to XLA."; } // anonymous namespace VariableInfo::VariableInfo(int index, absl::string_view name, Var* var) @@ -85,19 +79,22 @@ VariableInfo::~VariableInfo() { } } -// Returns a vector of VariableInfo instances for the resource variable inputs -// to the kernel with context `ctx`. The input indices for the resource -// variable inputs are in `variable_indices`. -Status GetVariableInfosFromCtxInputs(OpKernelContext* ctx, - absl::Span variable_indices, - std::vector* result) { +Status GetVariableInfosFromInputs(ResourceMgr* rm, DeviceBase* dev, + absl::Span inputs, + absl::Span variable_indices, + std::vector* result) { result->clear(); result->reserve(variable_indices.size()); for (int var_idx : variable_indices) { Var* variable = nullptr; - ResourceHandle handle = HandleFromInput(ctx, var_idx); - TF_RETURN_IF_ERROR( - LookupOrCreateResource(ctx, handle, &variable, [&](Var** ptr) { + ResourceHandle handle = inputs[var_idx]->flat()(0); + if (handle.device() != dev->attributes().name()) { + return errors::InvalidArgument( + "Trying to access resource ", handle.name(), " located in device ", + handle.device(), " from device ", dev->attributes().name()); + } + TF_RETURN_IF_ERROR(rm->LookupOrCreate( + handle.container(), handle.name(), &variable, [](Var** ptr) { // This var is uninitialized for now. *ptr = new Var(DT_INVALID); return Status::OK(); @@ -107,6 +104,15 @@ Status GetVariableInfosFromCtxInputs(OpKernelContext* ctx, return Status::OK(); } +std::vector InputsFromContext(OpKernelContext* ctx) { + std::vector inputs; + inputs.reserve(ctx->num_inputs()); + for (int input_idx = 0; input_idx < ctx->num_inputs(); input_idx++) { + inputs.push_back(&ctx->input(input_idx)); + } + return inputs; +} + Status LockVariables(absl::Span variables) { std::vector lock_order(variables.size()); std::iota(lock_order.begin(), lock_order.end(), 0); @@ -358,9 +364,6 @@ static Status SetOutputForConstant( ctx->set_output(output_num, const_tensor); output_tensor = ctx->mutable_output(output_num); } - if (XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor)) { - xla_tensor->set_host_tensor(const_tensor); - } return Status::OK(); } @@ -557,11 +560,14 @@ Status XlaComputationLaunchContext::PopulateOutputs( return Status::OK(); } -Status XlaComputationLaunchContext::BuildXlaCompilerArguments( - const std::map& must_be_constant_args, - absl::Span variable_args, OpKernelContext* ctx, - std::vector* args) { - args->resize(ctx->num_inputs()); +xla::StatusOr> +XlaComputationLaunchContext::BuildXlaCompilerArguments( + absl::Span must_be_constant_idxs, + absl::Span inputs, + absl::Span variable_args) { + CHECK(absl::c_is_sorted(must_be_constant_idxs)); + std::vector out; + out.resize(inputs.size()); absl::flat_hash_map variable_info_lookup; for (const VariableInfo& info : variable_args) { @@ -571,33 +577,20 @@ Status XlaComputationLaunchContext::BuildXlaCompilerArguments( variable_info_lookup.emplace(info.index(), &info); } - for (int64 input_num = 0; input_num < ctx->num_inputs(); ++input_num) { - XlaCompiler::Argument& arg = (*args)[input_num]; + for (int64 input_num = 0; input_num < inputs.size(); ++input_num) { + const Tensor* input = inputs[input_num]; - if (must_be_constant_args.count(input_num) > 0) { + XlaCompiler::Argument& arg = out[input_num]; + if (absl::c_binary_search(must_be_constant_idxs, input_num)) { // Handles compile-time constants. - const Tensor& input = must_be_constant_args.at(input_num); - TF_RET_CHECK(input.dtype() != DT_RESOURCE); + TF_RET_CHECK(input->dtype() != DT_RESOURCE); arg.kind = XlaCompiler::Argument::kConstant; - arg.type = input.dtype(); - arg.shape = input.shape(); - arg.constant_value = input; - } else if (variable_info_lookup.count(input_num) == 0) { - // Handles the non-constant arguments. - const Tensor& input = ctx->input(input_num); - TF_RET_CHECK(input.dtype() != DT_RESOURCE); - if (input.NumElements() > 0) { - arg.kind = XlaCompiler::Argument::kParameter; - } else { - arg.kind = XlaCompiler::Argument::kConstant; - arg.constant_value = input; - } - arg.type = input.dtype(); - arg.shape = input.shape(); - } else { + arg.type = input->dtype(); + arg.shape = input->shape(); + arg.constant_value = *input; + } else if (variable_info_lookup.count(input_num)) { // Handles resource variables. - const Tensor& input = ctx->input(input_num); - TF_RET_CHECK(input.dtype() == DT_RESOURCE); + TF_RET_CHECK(input->dtype() == DT_RESOURCE); const VariableInfo& variable = *variable_info_lookup[input_num]; arg.name = std::string(variable.name()); arg.kind = XlaCompiler::Argument::kResource; @@ -616,10 +609,21 @@ Status XlaComputationLaunchContext::BuildXlaCompilerArguments( arg.type = DT_INVALID; arg.shape = TensorShape(); } + } else { + // Normal inputs. + TF_RET_CHECK(input->dtype() != DT_RESOURCE); + if (input->NumElements() > 0) { + arg.kind = XlaCompiler::Argument::kParameter; + } else { + arg.kind = XlaCompiler::Argument::kConstant; + arg.constant_value = *input; + } + arg.type = input->dtype(); + arg.shape = input->shape(); } } - return Status::OK(); + return out; } } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h index b34b3059a4f..ac085a022c8 100644 --- a/tensorflow/compiler/jit/xla_launch_util.h +++ b/tensorflow/compiler/jit/xla_launch_util.h @@ -109,12 +109,16 @@ Status SnapshotResourceVariables(OpKernelContext* ctx, Status LockVariables(absl::Span variables) TF_EXCLUSIVE_LOCK_FUNCTION(); -// Returns a vector of VariableInfo instances for the resource variable inputs -// to the kernel with context `ctx`. The input indices for the resource +// Returns a vector of VariableInfo instances for the resource variable inputs, +// given that *all* inputs are in `inputs`. The input indices for the resource // variable inputs are in `variable_indices`. -Status GetVariableInfosFromCtxInputs(OpKernelContext* ctx, - absl::Span variable_indices, - std::vector* result); +Status GetVariableInfosFromInputs(ResourceMgr* rm, DeviceBase* dev, + absl::Span inputs, + absl::Span variable_indices, + std::vector* result); + +// Returns pointers to inputs stored in `ctx`. +std::vector InputsFromContext(OpKernelContext* ctx); // Helper class to perform the marshalling of TensorFlow inputs and outputs to // ShapedBuffers suitable for passing to an XLA computation. @@ -136,10 +140,10 @@ class XlaComputationLaunchContext { // Builds a XlaCompiler::Argument vector from the arguments to an XlaLaunch // op. // Precondition: variables in `variable_args` are locked. - static Status BuildXlaCompilerArguments( - const std::map& constant_args, - absl::Span variable_args, OpKernelContext* ctx, - std::vector* args); + static xla::StatusOr> + BuildXlaCompilerArguments(absl::Span must_be_constant_idxs, + absl::Span inputs, + absl::Span variable_args); // Add all inputs within `ctx` as XLA arguments (returned by arguments()). // `variables` is a map from TensorFlow argument number to resource variable. diff --git a/tensorflow/compiler/jit/xla_ops_on_regular_devices.cc b/tensorflow/compiler/jit/xla_ops_on_regular_devices.cc index 82510a4926b..6c6c490e032 100644 --- a/tensorflow/compiler/jit/xla_ops_on_regular_devices.cc +++ b/tensorflow/compiler/jit/xla_ops_on_regular_devices.cc @@ -38,18 +38,29 @@ namespace tensorflow { XlaCompileOnDemandOp); \ REGISTER_KERNEL_BUILDER(Name("XlaDot").Device(DEVICE), \ XlaCompileOnDemandOp); \ - REGISTER_KERNEL_BUILDER(Name("XlaDynamicSlice").Device(DEVICE), \ - XlaCompileOnDemandOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("XlaDynamicSlice").HostMemory("size_indices").Device(DEVICE), \ + XlaCompileOnDemandOp); \ REGISTER_KERNEL_BUILDER(Name("XlaDynamicUpdateSlice").Device(DEVICE), \ XlaCompileOnDemandOp); \ REGISTER_KERNEL_BUILDER(Name("XlaIf").Device(DEVICE), XlaCompileOnDemandOp); \ - REGISTER_KERNEL_BUILDER(Name("XlaPad").Device(DEVICE), \ + REGISTER_KERNEL_BUILDER(Name("XlaPad") \ + .HostMemory("padding_low") \ + .HostMemory("padding_high") \ + .HostMemory("padding_interior") \ + .Device(DEVICE), \ XlaCompileOnDemandOp); \ REGISTER_KERNEL_BUILDER(Name("XlaRecv").Device(DEVICE), \ XlaCompileOnDemandOp); \ REGISTER_KERNEL_BUILDER(Name("XlaReduce").Device(DEVICE), \ XlaCompileOnDemandOp); \ - REGISTER_KERNEL_BUILDER(Name("XlaReduceWindow").Device(DEVICE), \ + REGISTER_KERNEL_BUILDER(Name("XlaReduceWindow") \ + .HostMemory("window_dimensions") \ + .HostMemory("window_strides") \ + .HostMemory("base_dilations") \ + .HostMemory("window_dilations") \ + .HostMemory("padding") \ + .Device(DEVICE), \ XlaCompileOnDemandOp); \ REGISTER_KERNEL_BUILDER(Name("XlaSelectAndScatter") \ .HostMemory("window_dimensions") \ @@ -75,11 +86,9 @@ namespace tensorflow { XlaCompileOnDemandOp); \ REGISTER_KERNEL_BUILDER(Name("XlaReplicaId").Device(DEVICE), \ XlaCompileOnDemandOp); \ - REGISTER_KERNEL_BUILDER(Name("XlaGather") \ - .HostMemory("start_indices") \ - .HostMemory("slice_sizes") \ - .Device(DEVICE), \ - XlaCompileOnDemandOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("XlaGather").HostMemory("slice_sizes").Device(DEVICE), \ + XlaCompileOnDemandOp); \ REGISTER_KERNEL_BUILDER(Name("XlaScatter").Device(DEVICE), \ XlaCompileOnDemandOp); diff --git a/tensorflow/compiler/jit/xla_platform_info.cc b/tensorflow/compiler/jit/xla_platform_info.cc index a5e12b37563..b38bf9282b1 100644 --- a/tensorflow/compiler/jit/xla_platform_info.cc +++ b/tensorflow/compiler/jit/xla_platform_info.cc @@ -19,7 +19,7 @@ limitations under the License. namespace tensorflow { -Status BuildXlaCompilationCache(OpKernelContext* ctx, +Status BuildXlaCompilationCache(DeviceBase* device, const XlaPlatformInfo& platform_info, XlaCompilationCache** cache) { if (platform_info.xla_device_metadata()) { @@ -59,7 +59,7 @@ Status BuildXlaCompilationCache(OpKernelContext* ctx, xla::LocalClientOptions client_options; client_options.set_platform(platform.ValueOrDie()); client_options.set_intra_op_parallelism_threads( - ctx->device()->tensorflow_cpu_worker_threads()->num_threads); + device->tensorflow_cpu_worker_threads()->num_threads); auto client = xla::ClientLibrary::GetOrCreateLocalClient(client_options); if (!client.ok()) { return client.status(); @@ -75,21 +75,21 @@ Status BuildXlaCompilationCache(OpKernelContext* ctx, return Status::OK(); } -XlaPlatformInfo XlaPlatformInfoFromContext(OpKernelConstruction* ctx) { - DeviceType device_type = ctx->device_type(); +XlaPlatformInfo XlaPlatformInfoFromDevice(DeviceBase* device_base) { + auto device = static_cast(device_base); se::Platform::Id platform_id = nullptr; const XlaDevice::Metadata* xla_device_metadata = nullptr; se::DeviceMemoryAllocator* custom_allocator = nullptr; - if (ctx->device_type() == DeviceType(DEVICE_CPU)) { + if (device->device_type() == DEVICE_CPU) { platform_id = se::host::kHostPlatformId; - } else if (ctx->device_type() == DeviceType(DEVICE_GPU)) { - platform_id = ctx->device() - ->tensorflow_gpu_device_info() + } else if (device->device_type() == DEVICE_GPU) { + platform_id = device->tensorflow_gpu_device_info() ->stream->parent() ->platform() ->id(); - } else if (XlaDevice::GetMetadata(ctx, &xla_device_metadata).ok()) { + } else if (XlaDevice::GetMetadataFromDevice(device, &xla_device_metadata) + .ok()) { // If we are on an XlaDevice, use the underlying XLA platform's allocator // directly. We could use the StreamExecutor's allocator which may // theoretically be more correct, but XLA returns a nice OOM message in a @@ -104,47 +104,46 @@ XlaPlatformInfo XlaPlatformInfoFromContext(OpKernelConstruction* ctx) { xla_device_metadata->client()->backend().memory_allocator(); } - return XlaPlatformInfo(device_type, platform_id, xla_device_metadata, - custom_allocator); + return XlaPlatformInfo(DeviceType(device->device_type()), platform_id, + xla_device_metadata, custom_allocator); } se::DeviceMemoryAllocator* GetAllocator( absl::optional* tf_allocator_adapter, - OpKernelContext* ctx, const XlaPlatformInfo& platform_info) { + DeviceBase* device, se::Stream* stream, + const XlaPlatformInfo& platform_info) { if (platform_info.custom_allocator()) { return platform_info.custom_allocator(); } - if (!ctx->op_device_context()) { + if (!stream) { // Stream is not set for the host platform. se::Platform* platform = se::MultiPlatformManager::PlatformWithId(platform_info.platform_id()) .ValueOrDie(); - tf_allocator_adapter->emplace(ctx->device()->GetAllocator({}), platform); + tf_allocator_adapter->emplace(device->GetAllocator({}), platform); return &tf_allocator_adapter->value(); } - tf_allocator_adapter->emplace(ctx->device()->GetAllocator({}), - ctx->op_device_context()->stream()); + tf_allocator_adapter->emplace(device->GetAllocator({}), stream); return &tf_allocator_adapter->value(); } XlaCompiler::Options GenerateCompilerOptions( - const XlaCompilationCache& cache, OpKernelContext* ctx, - const XlaPlatformInfo& platform_info, bool has_ref_vars, + const XlaCompilationCache& cache, + const FunctionLibraryRuntime& function_library, DeviceBase* device, + se::Stream* stream, const XlaPlatformInfo& platform_info, bool has_ref_vars, absl::optional* tf_allocator_adapter) { - CHECK(ctx->function_library()); XlaCompiler::Options options; options.client = static_cast(cache.client()); - if (ctx->op_device_context() != nullptr) { - options.device_ordinal = - ctx->op_device_context()->stream()->parent()->device_ordinal(); + if (stream != nullptr) { + options.device_ordinal = stream->parent()->device_ordinal(); } options.device_type = cache.device_type(); - options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition(); - options.graph_def_version = ctx->function_library()->graph_def_version(); + options.flib_def = function_library.GetFunctionLibraryDefinition(); + options.graph_def_version = function_library.graph_def_version(); options.allow_cpu_custom_calls = (platform_info.platform_id() == se::host::kHostPlatformId); options.device_allocator = - GetAllocator(tf_allocator_adapter, ctx, platform_info); + GetAllocator(tf_allocator_adapter, device, stream, platform_info); if (platform_info.xla_device_metadata()) { options.shape_representation_fn = platform_info.xla_device_metadata()->shape_representation_fn(); diff --git a/tensorflow/compiler/jit/xla_platform_info.h b/tensorflow/compiler/jit/xla_platform_info.h index d58b32a996f..bfb438cc398 100644 --- a/tensorflow/compiler/jit/xla_platform_info.h +++ b/tensorflow/compiler/jit/xla_platform_info.h @@ -80,27 +80,31 @@ class XlaPlatformInfo { }; // Returns created XLA compilation cache. -Status BuildXlaCompilationCache(OpKernelContext* ctx, +Status BuildXlaCompilationCache(DeviceBase* dev, const XlaPlatformInfo& platform_info, XlaCompilationCache** cache); // Returns information about the platform from kernel context. -XlaPlatformInfo XlaPlatformInfoFromContext(OpKernelConstruction* ctx); +XlaPlatformInfo XlaPlatformInfoFromDevice(DeviceBase* device); // Returns allocator from platform info if non-null, or populate and return a // pointer to the allocator adapter with allocator from context. // // This is necessary because for XLA devices the underlying TF allocator returns // dummy tensors. +// +// `stream` parameter is nullable when running on host. se::DeviceMemoryAllocator* GetAllocator( absl::optional* tf_allocator_adapter, - OpKernelContext* ctx, const XlaPlatformInfo& platform_info); + DeviceBase* device, se::Stream* stream, + const XlaPlatformInfo& platform_info); // Returns created options for the XLA compiler, and writes the used allocator // into `tf_allocator_adapter`. XlaCompiler::Options GenerateCompilerOptions( - const XlaCompilationCache& cache, OpKernelContext* ctx, - const XlaPlatformInfo& platform_info, bool has_ref_vars, + const XlaCompilationCache& cache, + const FunctionLibraryRuntime& function_library, DeviceBase* device, + se::Stream* stream, const XlaPlatformInfo& platform_info, bool has_ref_vars, absl::optional* tf_allocator_adapter); } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_tensor.h b/tensorflow/compiler/jit/xla_tensor.h index dc358760534..2da1501819c 100644 --- a/tensorflow/compiler/jit/xla_tensor.h +++ b/tensorflow/compiler/jit/xla_tensor.h @@ -71,18 +71,6 @@ class XlaTensor { shaped_buffer_ = std::move(shaped_buffer); } - // Some tensors on the device may have known values on the host. We use these - // in on-demand mode to avoid re-copying values from the device if we know the - // host value already. - - // Return true if this XlaTensor contains a host tensor. - bool has_host_tensor() const { return host_tensor_.has_value(); } - // Return the contained host tensor. - // REQUIRES: has_host_tensor() - const Tensor& host_tensor() const { return *host_tensor_; } - // Sets the contained host tensor. - void set_host_tensor(const Tensor& tensor) { host_tensor_.emplace(tensor); } - // Adds synchronization events to 'stream' that wait for this tensor to be // defined on 'stream'. Does nothing if the tensor is already defined on that // stream. diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD index fd9953de1e2..18d05bdaace 100644 --- a/tensorflow/compiler/mlir/BUILD +++ b/tensorflow/compiler/mlir/BUILD @@ -1,6 +1,8 @@ # Description: # TensorFlow/TensorFlow Lite/XLA MLIR dialects and tools. +load("//tensorflow:tensorflow.bzl", "filegroup") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow:tensorflow.bzl", "tf_cc_binary") package( @@ -24,11 +26,40 @@ filegroup( srcs = glob(["**/*.td"]), ) +cc_library( + name = "string_container_utils", + hdrs = ["utils/string_container_utils.h"], + deps = [ + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + ], +) + +cc_library( + name = "array_container_utils", + hdrs = ["utils/array_container_utils.h"], + deps = [ + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + ], +) + +cc_library( + name = "name_utils", + srcs = ["utils/name_utils.cc"], + hdrs = ["utils/name_utils.h"], + deps = [ + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + ], +) + cc_library( name = "op_or_arg_name_mapper", srcs = ["op_or_arg_name_mapper.cc"], hdrs = ["op_or_arg_name_mapper.h"], deps = [ + ":name_utils", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", @@ -40,17 +71,15 @@ cc_library( srcs = ["tf_mlir_opt_main.cc"], deps = [ ":init_mlir", + "//tensorflow/compiler/mlir/hlo:all_passes", + "//tensorflow/compiler/mlir/hlo:hlo_dialect_registration", "//tensorflow/compiler/mlir/lite:tensorflow_lite", "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tools/kernel_gen/ir:tf_framework_ops", "//tensorflow/core:lib", - "//tensorflow/core/platform:logging", - "@llvm-project//llvm:Support", "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", - "@llvm-project//mlir:IR", "@llvm-project//mlir:MlirOptLib", - "@llvm-project//mlir:Pass", "@llvm-project//mlir:Shape", - "@llvm-project//mlir:Support", ], ) @@ -67,14 +96,13 @@ cc_library( # xla-legalize-tf-with-tf2xla pass. "//tensorflow/compiler/jit", "//tensorflow/compiler/mlir/lite:tensorflow_lite", - "//tensorflow/compiler/mlir/lite:tensorflow_lite_dialect_registration", "//tensorflow/compiler/mlir/lite:tensorflow_lite_legalize_tf", "//tensorflow/compiler/mlir/lite:tensorflow_lite_optimize", "//tensorflow/compiler/mlir/lite:tensorflow_lite_quantize", "//tensorflow/compiler/mlir/lite/quantization:quantization_passes", "//tensorflow/compiler/mlir/lite/quantization/tensorflow:tf_to_quant", "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration", + "//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_pass", "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow:tensorflow_test_passes", "//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes", @@ -141,10 +169,9 @@ tf_cc_binary( srcs = ["tf_mlir_translate_main.cc"], deps = [ ":init_mlir", - "//tensorflow/compiler/mlir/lite:tensorflow_lite_dialect_registration", "//tensorflow/compiler/mlir/tensorflow:convert_graphdef", "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", - "//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration", + "//tensorflow/compiler/mlir/tensorflow:tf_xla_mlir_translate", "//tensorflow/compiler/mlir/tensorflow:translate_cl_options", "//tensorflow/compiler/mlir/tensorflow:translate_lib", "//tensorflow/compiler/mlir/tensorflow:translate_registration", @@ -157,7 +184,7 @@ tf_cc_binary( "//tensorflow/stream_executor/lib", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", - "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", "@llvm-project//mlir:Translation", diff --git a/tensorflow/compiler/mlir/README.md b/tensorflow/compiler/mlir/README.md index cbb0b08503a..c415edceb8c 100644 --- a/tensorflow/compiler/mlir/README.md +++ b/tensorflow/compiler/mlir/README.md @@ -9,3 +9,31 @@ dialects and utilities for 3. TF Lite See [MLIR's website](https://mlir.llvm.org) for complete documentation. + +## Getting started + +Building dialects and utilities here follow the standard approach using +`bazel` as the rest of TensorFlow. + +### Using local LLVM repo + +To develop across MLIR core and TensorFlow, it is useful to override the repo +to use a local version instead of fetching from head. This can be achieved as +below but note, the BUILD files are not automatically generated from or CMake +used, so if your change requires a BUILD file change (or you are using a +different version of LLVM than set in tensorflow/workspace.bzl's LLVM_COMMIT) +then manual BUILD file changes may be required. + +```sh +LLVM_SRC=... + +# Create basic workspace file +echo 'workspace(name = "llvm-project")' > $LLVM_SRC/WORKSPACE +# and copy over the bazel BUILD files. +cp third_party/llvm/llvm.autogenerated.BUILD $LLVM_SRC/llvm/BUILD +cp third_party/mlir/BUILD $LLVM_SRC/mlir +cp third_party/mlir/test.BUILD $LLVM_SRC/mlir/test/BUILD + +bazel build --override_repository=llvm-project=$LLVM_SRC \ + -c opt tensorflow/compiler/mlir:tf-opt +``` diff --git a/tensorflow/compiler/mlir/hlo/BUILD b/tensorflow/compiler/mlir/hlo/BUILD index 7be39aef9da..1636bbb89ee 100644 --- a/tensorflow/compiler/mlir/hlo/BUILD +++ b/tensorflow/compiler/mlir/hlo/BUILD @@ -1,3 +1,10 @@ +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") + +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "filegroup") + +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "get_compatible_with_cloud") load("//third_party/mlir:tblgen.bzl", "gentbl") # TODO(b/160617323): Decouple MLIR HLO from TensorFlow/XLA @@ -17,6 +24,7 @@ package_group( "//learning/brain/experimental/mlir/...", "//learning/brain/google/xla/kernels/...", "//learning/brain/google/xla/mlir/...", + "//learning/deepmind/partir/...", "//learning/pathways/data_parallel/tf2xla/...", "//platforms/xla/...", "//tensorflow/compiler/mlir/...", @@ -37,10 +45,13 @@ filegroup( "include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td", "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td", "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td", + "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.td", "include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td", "include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td", "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td", + "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td", "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:include/mlir/Interfaces/CopyOpInterface.td", "@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td", "@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td", "@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td", @@ -57,7 +68,8 @@ filegroup( gentbl( name = "MhloPassIncGen", - strip_include_prefix = "include/mlir-hlo/Dialect/mhlo/transforms/", + compatible_with = get_compatible_with_cloud(), + strip_include_prefix = "include", tbl_outs = [ ( "-gen-pass-decls -name MHLO", @@ -73,7 +85,8 @@ gentbl( gentbl( name = "LmhloPassIncGen", - strip_include_prefix = "include/mlir-hlo/Dialect/mhlo/transforms/", + compatible_with = get_compatible_with_cloud(), + strip_include_prefix = "include", tbl_outs = [ ( "-gen-pass-decls -name LMHLO", @@ -89,6 +102,7 @@ gentbl( gentbl( name = "chlo_ops_inc_gen", + compatible_with = get_compatible_with_cloud(), strip_include_prefix = "include", tbl_outs = [ ("-gen-op-decls", "include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h.inc"), @@ -104,12 +118,11 @@ gentbl( gentbl( name = "hlo_ops_inc_gen", + compatible_with = get_compatible_with_cloud(), strip_include_prefix = "include", tbl_outs = [ ("-gen-op-decls", "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h.inc"), ("-gen-op-defs", "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.cc.inc"), - ("-gen-struct-attr-decls", "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_structs.h.inc"), - ("-gen-struct-attr-defs", "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_structs.cc.inc"), ], tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td", @@ -128,6 +141,7 @@ gentbl( gentbl( name = "hlo_ops_base_inc_gen", + compatible_with = get_compatible_with_cloud(), strip_include_prefix = "include", tbl_outs = [ ("-gen-op-decls", "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.h.inc"), @@ -135,11 +149,30 @@ gentbl( ], tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td", + td_relative_includes = [ + "include", + ], + td_srcs = [":hlo_ops_td_files"], +) + +gentbl( + name = "hlo_ops_base_structs_inc_gen", + compatible_with = get_compatible_with_cloud(), + tbl_outs = [ + ("-gen-struct-attr-decls", "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h.inc"), + ("-gen-struct-attr-defs", "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.cc.inc"), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td", + td_relative_includes = [ + "include", + ], td_srcs = [":hlo_ops_td_files"], ) gentbl( name = "hlo_ops_pattern_gen", + compatible_with = get_compatible_with_cloud(), strip_include_prefix = "lib/Dialect/mhlo/IR/", tbl_outs = [ ( @@ -162,6 +195,7 @@ gentbl( gentbl( name = "lhlo_ops_inc_gen", + compatible_with = get_compatible_with_cloud(), strip_include_prefix = "include", tbl_outs = [ ("-gen-op-decls", "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc"), @@ -177,9 +211,67 @@ gentbl( td_srcs = [":hlo_ops_td_files"], ) +gentbl( + name = "lhlo_gpu_ops_structs_inc_gen", + compatible_with = get_compatible_with_cloud(), + strip_include_prefix = "include", + tbl_outs = [ + ("-gen-struct-attr-decls", "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h.inc"), + ("-gen-struct-attr-defs", "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.cc.inc"), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.td", + td_relative_includes = [ + "include", + ], + td_srcs = [ + ":hlo_ops_td_files", + "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_base.td", + ], +) + +cc_library( + name = "lhlo_gpu_ops_structs", + srcs = [ + "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.cc.inc", + "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h.inc", + "lib/Dialect/mhlo/IR/lhlo_gpu_ops_structs.cc", + ], + hdrs = [ + "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h", + ], + includes = ["include"], + deps = [ + ":lhlo_gpu_ops_structs_inc_gen", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + +gentbl( + name = "lhlo_gpu_ops_inc_gen", + compatible_with = get_compatible_with_cloud(), + strip_include_prefix = "include", + tbl_outs = [ + ("-gen-op-decls", "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h.inc"), + ("-gen-op-defs", "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.cc.inc"), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td", + td_relative_includes = [ + "include", + ], + td_srcs = [ + ":hlo_ops_td_files", + "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_base.td", + "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.td", + ], +) + #TODO(aminim): revisit the naming and grouping of these rules post-move. gentbl( name = "canonicalize_inc_gen", + compatible_with = get_compatible_with_cloud(), strip_include_prefix = "lib/Dialect/mhlo/IR/", tbl_outs = [ ("-gen-rewriters", "lib/Dialect/mhlo/IR/mhlo_canonicalize.inc"), @@ -194,6 +286,7 @@ gentbl( gentbl( name = "infer_fusibility_op_interface_gen", + compatible_with = get_compatible_with_cloud(), tbl_outs = [ ( "-gen-op-interface-decls", @@ -232,6 +325,23 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "hlo_ops_base_structs", + srcs = [ + "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h.inc", + "lib/Dialect/mhlo/IR/hlo_ops_base_structs.cc", + ], + hdrs = [ + "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h", + ], + includes = ["include"], + deps = [ + ":hlo_ops_base_structs_inc_gen", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + cc_library( name = "convert_op_folder", srcs = ["lib/utils/convert_op_folder.cc"], @@ -265,6 +375,7 @@ cc_library( ":chlo_ops_inc_gen", ":convert_op_folder", ":hlo_ops_base_inc_gen", + ":hlo_ops_base_structs", ":hlo_ops_inc_gen", ":infer_fusibility_op_interface", "@llvm-project//llvm:Support", @@ -295,9 +406,11 @@ cc_library( includes = ["include"], deps = [ ":hlo_ops_base_inc_gen", + ":hlo_ops_base_structs", ":lhlo_ops_inc_gen", "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:CopyOpInterface", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:SideEffects", @@ -311,12 +424,34 @@ cc_library( ) cc_library( - name = "hlo_dialect_force_registration", - srcs = ["lib/Dialect/mhlo/IR/dialect_registration.cc"], + name = "lhlo_gpu", + srcs = [ + "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.cc.inc", + "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h.inc", + "lib/Dialect/mhlo/IR/lhlo_gpu_ops.cc", + ], + hdrs = [ + "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h", + ], + includes = ["include"], deps = [ ":hlo", - ":lhlo", + ":hlo_ops_base_structs", + ":infer_fusibility_op_interface", + ":lhlo_gpu_ops_inc_gen", + ":lhlo_gpu_ops_structs", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:CopyOpInterface", "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SideEffects", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "@llvm-project//mlir:ViewLikeInterface", ], alwayslink = 1, ) @@ -328,24 +463,45 @@ cc_library( deps = [ ":hlo", ":lhlo", + ":lhlo_gpu", "@llvm-project//mlir:IR", ], ) cc_library( name = "sink_constants_to_control_flow", - srcs = ["lib/Dialect/mhlo/transforms/sink_constants_to_control_flow.cc"], + srcs = [ + "lib/Dialect/mhlo/transforms/sink_constants_to_control_flow.cc", + ], + hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/passes.h"], + deps = [ + ":hlo", + ":pass_details", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", + ], + alwayslink = 1, +) + +cc_library( + name = "mhlo_control_flow_to_scf", + srcs = ["lib/Dialect/mhlo/transforms/mhlo_control_flow_to_scf.cc"], hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/passes.h"], deps = [ ":hlo", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms", ], - alwayslink = 1, ) cc_library( @@ -356,6 +512,7 @@ cc_library( ":lhlo", ":map_hlo_to_lhlo_op", "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", "@llvm-project//mlir:StandardOps", ], ) @@ -420,7 +577,10 @@ cc_library( cc_library( name = "legalize_to_linalg", srcs = ["lib/Dialect/mhlo/transforms/legalize_to_linalg.cc"], - hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"], + hdrs = [ + "include/mlir-hlo/Dialect/mhlo/transforms/passes.h", + "include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h", + ], deps = [ ":hlo", ":lhlo", @@ -439,9 +599,13 @@ cc_library( cc_library( name = "transform_unranked_hlo", srcs = ["lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc"], - hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"], + hdrs = [ + "include/mlir-hlo/Dialect/mhlo/transforms/passes.h", + "include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h", + ], deps = [ ":hlo", + "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Shape", @@ -459,6 +623,7 @@ cc_library( ":lhlo", ":map_lmhlo_to_scalar_op", "@llvm-project//llvm:Support", + "@llvm-project//mlir:Affine", "@llvm-project//mlir:GPUDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:LinalgOps", @@ -477,27 +642,15 @@ cc_library( deps = [ ":lhlo", "@llvm-project//llvm:Support", + "@llvm-project//mlir:Affine", "@llvm-project//mlir:IR", "@llvm-project//mlir:LinalgTransforms", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", - ], - alwayslink = 1, -) - -cc_library( - name = "lhlo_copy_removal", - srcs = ["lib/Dialect/mhlo/transforms/lhlo_copy_removal.cc"], - hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/passes.h"], - deps = [ - ":lhlo", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:StandardOps", - "@llvm-project//mlir:Support", + "@llvm-project//mlir:ViewLikeInterface", ], alwayslink = 1, ) @@ -516,6 +669,8 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Shape", + "@llvm-project//mlir:ShapeTransforms", "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms", @@ -564,6 +719,7 @@ cc_library( gentbl( name = "legalize_to_standard_inc_gen", + compatible_with = get_compatible_with_cloud(), strip_include_prefix = "lib/Dialect/mhlo/transforms/", tbl_outs = [ ("-gen-rewriters", "lib/Dialect/mhlo/transforms/generated_legalize_to_standard.inc"), @@ -601,8 +757,8 @@ cc_library( hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/passes.h"], deps = [ ":hlo", - ":legalize_tanh_to_approximation", ":legalize_to_standard_inc_gen", + ":legalize_trigonometric_to_approximation", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", @@ -632,8 +788,8 @@ cc_library( ) cc_library( - name = "legalize_tanh_to_approximation", - srcs = ["lib/Dialect/mhlo/transforms/legalize_tanh_to_approximation.cc"], + name = "legalize_trigonometric_to_approximation", + srcs = ["lib/Dialect/mhlo/transforms/legalize_trigonometric_to_approximation.cc"], hdrs = [ "include/mlir-hlo/Dialect/mhlo/transforms/passes.h", "include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h", @@ -652,6 +808,7 @@ cc_library( gentbl( name = "lower_complex_inc_gen", + compatible_with = get_compatible_with_cloud(), strip_include_prefix = "lib/Dialect/mhlo/transforms/", tbl_outs = [ ("-gen-rewriters", "lib/Dialect/mhlo/transforms/generated_lower_complex.inc"), @@ -682,7 +839,6 @@ cc_library( ], deps = [ ":hlo", - ":hlo_dialect_force_registration", ":lower_complex_inc_gen", "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", @@ -733,6 +889,7 @@ cc_library( srcs = ["lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc"], hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"], deps = [ + ":chlo_legalize_to_hlo_inc_gen", ":hlo", "@llvm-project//mlir:IR", "@llvm-project//mlir:SCFDialect", @@ -742,6 +899,40 @@ cc_library( ], ) +gentbl( + name = "chlo_legalize_to_hlo_inc_gen", + compatible_with = get_compatible_with_cloud(), + strip_include_prefix = "lib/Dialect/mhlo/transforms/", + tbl_outs = [ + ( + "-gen-rewriters", + "lib/Dialect/mhlo/transforms/generated_chlo_legalize_to_hlo.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td", + td_relative_includes = [ + "include", + ], + td_srcs = [ + ":hlo_ops_td_files", + ], +) + +cc_library( + name = "pass_details", + hdrs = [ + "include/mlir-hlo/Dialect/mhlo/transforms/PassDetail.h", + ], + visibility = [ + "//visibility:private", # This target is a private detail of pass implementations + ], + deps = [ + ":MhloPassIncGen", + "@llvm-project//mlir:Pass", + ], +) + cc_library( name = "test_passes", srcs = [ @@ -759,6 +950,7 @@ cc_library( ":lhlo", ":lhlo_legalize_to_llvm", # build-cleaner: keep ":materialize_broadcasts", # build-cleaner: keep + ":pass_details", ":unfuse_batch_norm", # build-cleaner: keep "@llvm-project//mlir:IR", "@llvm-project//mlir:InferTypeOpInterface", @@ -788,15 +980,15 @@ cc_library( ":hlo_legalize_to_lhlo", ":legalize_control_flow", ":legalize_gather_to_torch_index_select", - ":legalize_tanh_to_approximation", ":legalize_to_linalg", ":legalize_to_standard", + ":legalize_trigonometric_to_approximation", ":lhlo", - ":lhlo_copy_removal", ":lhlo_fuse_linalg", ":lhlo_legalize_to_affine", ":lhlo_legalize_to_gpu", ":lhlo_legalize_to_parallel_loops", + ":mhlo_control_flow_to_scf", ":mhlo_fusion", ":mhlo_to_mhlo_lowering_patterns", ":sink_constants_to_control_flow", @@ -815,6 +1007,7 @@ cc_binary( ":all_passes", ":hlo", ":lhlo", + ":lhlo_gpu", "@llvm-project//llvm:Support", "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", "@llvm-project//mlir:IR", diff --git a/tensorflow/compiler/mlir/hlo/README.md b/tensorflow/compiler/mlir/hlo/README.md index 9eaa14031fd..61517cd9fca 100644 --- a/tensorflow/compiler/mlir/hlo/README.md +++ b/tensorflow/compiler/mlir/hlo/README.md @@ -106,7 +106,7 @@ pipeline using MLIR: * `mhlo`: "meta"-HLO dialect ; similar to `xla_hlo`, but with extensions for dynamic shape support. * `lmhlo`: "late"-"meta"-HLO, it is the IR after buffer allocation is - performed. In XLA the buffer allocation is a side-datastructure which keeps + performed. In XLA the buffer allocation is a side-data structure which keeps track of these informations, while this separate dialect materializes it in the IR. @@ -114,7 +114,7 @@ We describe these in more details below. ### HLO Client Dialect: `chlo`. -* It was originaly designed to map the +* It was originally designed to map the [XLA client APIs](https://www.tensorflow.org/xla/operation_semantics) (e.g., ops supports implicit broadcast and roughly modeled on XlaBuilder API) modulo support for dynamic shapes and additional ops required to support diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/CMakeLists.txt b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/CMakeLists.txt index 09bdca84cd3..3fa2b908d9c 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/CMakeLists.txt +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/CMakeLists.txt @@ -25,7 +25,22 @@ function(add_mlir_hlo_dialect dialect dialect_namespace) endfunction() add_mlir_hlo_dialect(chlo_ops chlo) -add_mlir_hlo_dialect(hlo_ops mhlo) add_mlir_hlo_dialect(lhlo_ops lmhlo) +set(LLVM_TARGET_DEFINITIONS hlo_ops.td) +mlir_tablegen(hlo_ops.h.inc -gen-op-decls) +mlir_tablegen(hlo_ops.cc.inc -gen-op-defs) +mlir_tablegen(hlo_ops_base_structs.h.inc -gen-struct-attr-decls) +mlir_tablegen(hlo_ops_base_structs.cc.inc -gen-struct-attr-defs) +add_public_tablegen_target(MLIRhlo_opsIncGen) + +set(LLVM_TARGET_DEFINITIONS lhlo_gpu_ops.td) +mlir_tablegen(lhlo_gpu_ops.h.inc -gen-op-decls) +mlir_tablegen(lhlo_gpu_ops.cc.inc -gen-op-defs) +set(LLVM_TARGET_DEFINITIONS lhlo_gpu_ops_structs.td) +mlir_tablegen(lhlo_gpu_ops_structs.h.inc -gen-struct-attr-decls) +mlir_tablegen(lhlo_gpu_ops_structs.cc.inc -gen-struct-attr-defs) +add_public_tablegen_target(MLIRlhlo_gpu_opsIncGen) +add_dependencies(mlir-headers MLIRlhlo_gpu_opsIncGen) + add_mlir_interface(infer_fusibility_op_interface) diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h index 9704f34a4d6..05b22770401 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h @@ -44,11 +44,18 @@ class HloClientDialect : public Dialect { static StringRef getDialectNamespace() { return "chlo"; } }; +} // namespace chlo +} // namespace mlir + #define GET_OP_CLASSES #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h.inc" +namespace mlir { +namespace chlo { + template -static Value getConstantLike(OpBuilder& b, T constant, Value val) { +static Value getConstantLike(OpBuilder& b, Location loc, T constant, + Value val) { Type ty = getElementTypeOrSelf(val.getType()); auto getAttr = [&]() -> Attribute { @@ -56,8 +63,7 @@ static Value getConstantLike(OpBuilder& b, T constant, Value val) { if (ty.isa()) return b.getFloatAttr(ty, constant); llvm_unreachable("unhandled element type"); }; - // TODO(jpienaar): Add ability to pass loc via native call and update. - return b.create(b.getUnknownLoc(), getAttr(), val); + return b.create(loc, getAttr(), val); } } // namespace chlo diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td index 2f3bbefb5ab..13d5f02368b 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td @@ -37,7 +37,7 @@ include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td" def HLOClient_Dialect : Dialect { let name = "chlo"; - let cppNamespace = "chlo"; + let cppNamespace = "::mlir::chlo"; let summary = [{ Client HLO Ops }]; @@ -79,7 +79,8 @@ class HLOClient_BroadcastBinaryElementwiseOp< string mnemonic, list traits> : HLOClient_Op])> { + DeclareOpInterfaceMethods])> { let arguments = (ins HLO_Tensor:$lhs, HLO_Tensor:$rhs, @@ -99,13 +100,6 @@ class HLOClient_BroadcastBinaryElementwiseOp< $lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type(results) }]; - - let extraClassDeclaration = [{ - // TODO(laurenzo): It isn't clear to me why reifyReturnShapes does not - // have its declaration generated by DeclareOpInterfaceMethods. - LogicalResult reifyReturnTypeShapes( - OpBuilder& builder, SmallVectorImpl& reifiedReturnShapes); - }]; } def HLOClient_BroadcastAddOp : HLOClient_BroadcastBinaryElementwiseOp<"broadcast_add", @@ -344,14 +338,16 @@ def HLOClient_BroadcastComplexOp : HLOClient_BroadcastBinaryElementwiseOp< //===----------------------------------------------------------------------===// class HLOClient_UnaryElementwiseOp traits, - Type TensorType>: HLOClient_Op { + Type TensorType> : HLOClient_Op { let arguments = (ins TensorType:$operand); - let results = (outs TensorType); + let results = (outs TensorType:$result); + + let assemblyFormat = "$operand attr-dict `:` type($operand)"; } -def HLOClient_AcosOp: HLOClient_UnaryElementwiseOp<"acos", - [NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor> { +def HLOClient_AcosOp : HLOClient_UnaryElementwiseOp<"acos", [], + HLO_FpOrComplexTensor> { let summary = "Acos operator"; let description = [{ @@ -364,7 +360,47 @@ def HLOClient_AcosOp: HLOClient_UnaryElementwiseOp<"acos", }]; } -def HLOClient_ConstantLikeOp: HLOClient_Op<"constant_like", +def HLOClient_AtanOp : HLOClient_UnaryElementwiseOp<"atan", [], + HLO_FpOrComplexTensor> { + let summary = "Atan operator"; + + let description = [{ + Returns `Atan(operand)` element-wise. + + $$ + \atan(x) = \atan2(x, 1) + $$ + }]; +} + +def HLOClient_SinhOp : HLOClient_UnaryElementwiseOp<"sinh", [], + HLO_FpOrComplexTensor> { + let summary = "Sinh operation"; + + let description = [{ + Returns `Sinh(operand)` element-wise. + + $$ + \sinh(x) = (e^x - e^-x) / 2 if |x| < 1 + = e^(x + log(1/2)) - e^(-x + log(1/2)) otherwise. + $$ + }]; +} + +def HLOClient_TanOp : HLOClient_UnaryElementwiseOp<"tan", [], + HLO_FpOrComplexTensor> { + let summary = "Tan operation"; + + let description = [{ + Returns `Tan(operand)` element-wise. + + $$ + \tan(x) = \sin(x) / \cos(x) + $$ + }]; +} + +def HLOClient_ConstantLikeOp : HLOClient_Op<"constant_like", [NoSideEffect, SameOperandsAndResultShape, InferTypeOpInterface, DeclareOpInterfaceMethods, diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h index 4286c837a24..b354189c12a 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h @@ -19,7 +19,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_H_ #include "llvm/ADT/StringRef.h" -#include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/DialectImplementation.h" @@ -32,11 +32,14 @@ limitations under the License. #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" +// clang-format off +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h" +#include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h" +// clang-format on + namespace mlir { class OpBuilder; -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_structs.h.inc" - namespace mhlo { class MhloDialect : public Dialect { @@ -77,10 +80,10 @@ LogicalResult deriveShapeFromFirstOperand( OpBuilder *builder, Operation *op, SmallVectorImpl *reifiedReturnShapes); -#define GET_OP_CLASSES -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h.inc" - } // end namespace mhlo } // end namespace mlir +#define GET_OP_CLASSES +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h.inc" + #endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_H_ diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index d0abbe043ea..3defb65adf8 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -25,11 +25,6 @@ include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td" include "mlir-hlo/Dialect/mhlo/IR/hlo_utils.td" include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td" -def HLO_Dialect : Dialect { - let name = "mhlo"; - let cppNamespace = "mhlo"; -} - class HLO_Op traits> : Op { // Whether this operation has a custom conversion to HLO or not. @@ -136,8 +131,8 @@ class HLO_UnaryElementwiseOp traits, } LogicalResult reifyReturnTypeShapes( OpBuilder& builder, SmallVectorImpl& reifiedReturnShapes) { - return deriveShapeFromFirstOperand(&builder, getOperation(), - &reifiedReturnShapes); + return ::mlir::mhlo::deriveShapeFromFirstOperand(&builder, getOperation(), + &reifiedReturnShapes); } bool inferInputOutputShapeEquality(int input, int output) { return true; @@ -153,10 +148,13 @@ def HLO_AbsOp: HLO_UnaryElementwiseOp<"abs", [NoSideEffect, SameOperandsAndResultShape], TensorOf<[HLO_SInt, AnyFloat, HLO_Complex]>>, BASE_HLO_AbsOp { let builders = [OpBuilder< - "OpBuilder &builder, OperationState &result, Value operand" + "Value operand" >]; } +def HLO_CbrtOp: HLO_UnaryElementwiseOp<"cbrt", + [NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_CbrtOp; + def HLO_CeilOp: HLO_UnaryElementwiseOp<"ceil", [NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_CeilOp; @@ -165,8 +163,7 @@ def HLO_ConvertOp : HLO_UnaryElementwiseOp< BASE_HLO_ConvertOp { let builders = [OpBuilder< - "OpBuilder &, OperationState &tblgen_state, Value operand, " - "Type result_element_ty" + "Value operand, Type result_element_ty" >]; let hasFolder = 1; @@ -193,12 +190,10 @@ def HLO_Expm1Op: HLO_UnaryElementwiseOp<"exponential_minus_one", def HLO_FloorOp: HLO_UnaryElementwiseOp<"floor", [NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_FloorOp; -def HLO_ImagOp: HLO_Op< - "imag", [NoSideEffect, SameOperandsAndResultShape]>, BASE_HLO_ImagOp { - let builders = [OpBuilder< - "OpBuilder &, OperationState &tblgen_state, Value val">]; - - let arguments = (ins HLO_ComplexTensor); +def HLO_ImagOp: HLO_UnaryElementwiseOp<"imag", + [NoSideEffect, SameOperandsAndResultShape, + DeclareOpInterfaceMethods], + HLO_ComplexTensor>, BASE_HLO_ImagOp { let results = (outs HLO_FpTensor); let hasFolder = 1; } @@ -224,22 +219,23 @@ def HLO_LogisticOp: HLO_UnaryElementwiseOp<"logistic", def HLO_NotOp: HLO_UnaryElementwiseOp<"not", [NoSideEffect, SameOperandsAndResultType], HLO_PredOrIntTensor>, - BASE_HLO_NotOp; + BASE_HLO_NotOp { +} def HLO_NegOp: HLO_UnaryElementwiseOp<"negate", [NoSideEffect, SameOperandsAndResultType], HLO_IntFpOrComplexTensor>, - BASE_HLO_NegOp; + BASE_HLO_NegOp { + let hasFolder = 1; +} def HLO_PopulationCountOp: HLO_UnaryElementwiseOp<"popcnt", [NoSideEffect, SameOperandsAndResultType], HLO_IntTensor>, BASE_HLO_PopulationCountOp; -def HLO_RealOp: HLO_Op< - "real", [NoSideEffect, SameOperandsAndResultShape]>, BASE_HLO_RealOp { - let builders = [OpBuilder< - "OpBuilder &, OperationState &tblgen_state, Value val">]; - - let arguments = (ins HLO_ComplexTensor); +def HLO_RealOp: HLO_UnaryElementwiseOp<"real", + [NoSideEffect, SameOperandsAndResultShape, + DeclareOpInterfaceMethods], + HLO_ComplexTensor>, BASE_HLO_RealOp { let results = (outs HLO_FpTensor); let hasFolder = 1; } @@ -262,7 +258,9 @@ def HLO_SinOp: HLO_UnaryElementwiseOp<"sine", def HLO_SqrtOp: HLO_UnaryElementwiseOp<"sqrt", [NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor>, - BASE_HLO_SqrtOp; + BASE_HLO_SqrtOp { + let hasFolder = 1; +} def HLO_TanhOp: HLO_UnaryElementwiseOp<"tanh", [NoSideEffect, SameOperandsAndResultType], @@ -289,8 +287,8 @@ class HLO_BinaryElementwiseOp traits> : } LogicalResult reifyReturnTypeShapes( OpBuilder& builder, SmallVectorImpl& reifiedReturnShapes) { - return deriveShapeFromFirstOperand(&builder, getOperation(), - &reifiedReturnShapes); + return ::mlir::mhlo::deriveShapeFromFirstOperand(&builder, getOperation(), + &reifiedReturnShapes); } bool inferInputsShapeEquality(int lhs, int rhs) { return true; @@ -316,12 +314,10 @@ def HLO_AddOp : HLO_BinaryElementwiseOp<"add", def HLO_Atan2Op : HLO_BinaryElementwiseOp<"atan2", [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_Atan2Op; -def HLO_ComplexOp: HLO_Op<"complex", - [NoSideEffect, SameOperandsAndResultShape]>, +def HLO_ComplexOp: HLO_BinaryElementwiseOp<"complex", + [NoSideEffect, SameOperandsAndResultShape, + DeclareOpInterfaceMethods]>, BASE_HLO_ComplexOp { - let builders = [OpBuilder< - "OpBuilder &, OperationState &tblgen_state, Value lhs, Value rhs">]; - let arguments = (ins HLO_FpTensor:$lhs, HLO_FpTensor:$rhs); let results = (outs HLO_ComplexTensor); let hasFolder = 1; @@ -351,7 +347,9 @@ def HLO_PowOp : HLO_BinaryElementwiseOp<"power", [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_PowOp; def HLO_RemOp : HLO_BinaryElementwiseOp<"remainder", - [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_RemOp; + [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_RemOp { + let hasFolder = 1; +} def HLO_ShiftLeftOp : HLO_BinaryElementwiseOp<"shift_left", [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ShiftLeftOp; @@ -379,6 +377,8 @@ class HLO_BinaryLogicalElementwiseOp : HLO_PredOrIntTensor:$lhs, HLO_PredOrIntTensor:$rhs ); + + let hasFolder = 1; } def HLO_AndOp: HLO_BinaryLogicalElementwiseOp<"and">, BASE_HLO_AndOp; @@ -452,7 +452,7 @@ def HLO_SendOp : HLO_Op<"send", []> { let arguments = (ins HLO_TensorOrTuple:$operand, HLO_Token:$token, - ChannelHandle:$channel_id, + ChannelHandle:$channel_id, DefaultValuedAttr:$is_host_transfer ); @@ -477,7 +477,7 @@ def HLO_RecvOp : HLO_Op<"recv", []> { let arguments = (ins HLO_Token:$token, - ChannelHandle:$channel_id, + ChannelHandle:$channel_id, DefaultValuedAttr:$is_host_transfer ); @@ -491,9 +491,7 @@ def HLO_RecvOp : HLO_Op<"recv", []> { def HLO_ReplicaIdOp : HLO_Op<"replica_id", [NoSideEffect]>, BASE_HLO_ReplicaIdOp { - // TODO(prakalps): The output should unsigned 32-bit integer but mlir does - // not differentiate between signed and unsigned int. - let results = (outs I32Tensor); + let results = (outs TensorOf<[UI32]>); } //===----------------------------------------------------------------------===// @@ -583,7 +581,7 @@ def HLO_AllReduceOp : HLO_Op<"all_reduce", let arguments = (ins HLO_Tensor:$operand, I64ElementsAttr:$replica_groups, - OptionalAttr>:$channel_id + OptionalAttr:$channel_id ); let regions = (region SizedRegion<1>:$computation); let results = (outs HLO_Tensor); @@ -673,9 +671,10 @@ def HLO_TupleOp : HLO_Op<"tuple", [NoSideEffect]>, BASE_HLO_TupleOp { let hasCanonicalizer = 1; } -def HLO_CompareOp: HLO_Op<"compare", - [NoSideEffect, SameTypeOperands, SameOperandsAndResultShape]>, - BASE_HLO_CompareOp { +def HLO_CompareOp: HLO_Op<"compare", [NoSideEffect, SameTypeOperands, + SameOperandsAndResultShape, + DeclareOpInterfaceMethods]>, BASE_HLO_CompareOp { let arguments = (ins HLO_Tensor:$lhs, HLO_Tensor:$rhs, @@ -683,6 +682,8 @@ def HLO_CompareOp: HLO_Op<"compare", ); let results = (outs HLO_PredTensor); + let hasFolder = 1; + let builders = [OpBuilder< "OpBuilder &builder, OperationState &result, Value lhs, Value rhs, " "StringAttr comparison_direction" @@ -905,39 +906,12 @@ def HLO_CollectivePermuteOp: HLO_Op<"collective_permute", let results = (outs HLO_Tensor); } -// TODO(hinsu): Make this struct dialect independent so that it can be shared -// between HLO and LHLO dialect. -def ConvDimensionNumbers : StructAttr<"ConvDimensionNumbers", HLO_Dialect, [ - StructFieldAttr<"input_batch_dimension",I64Attr>, - StructFieldAttr<"input_feature_dimension", I64Attr>, - StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>, - StructFieldAttr<"kernel_input_feature_dimension", I64Attr>, - StructFieldAttr<"kernel_output_feature_dimension", I64Attr>, - StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>, - StructFieldAttr<"output_batch_dimension", I64Attr>, - StructFieldAttr<"output_feature_dimension", I64Attr>, - StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > { - - let description = "Structure of dimension information for conv op"; -} - def HLO_ConvOp : HLO_Op<"convolution", [NoSideEffect]>, BASE_HLO_ConvOp { - let arguments = (ins - HLO_Tensor:$lhs, - HLO_Tensor:$rhs, - // Default value: one for each of the spatial dimension. - OptionalAttr:$window_strides, - // Default value: zero for each of the spatial dimension. - OptionalAttr:$padding, - // Default value: one for each of the spatial dimension. - OptionalAttr:$lhs_dilation, - // Default value: one for each of the spatial dimension. - OptionalAttr:$rhs_dilation, - ConvDimensionNumbers:$dimension_numbers, - I64Attr:$feature_group_count, - I64Attr:$batch_group_count, - HLO_PrecisionConfigAttr:$precision_config - ); + let arguments = !con( + (ins + HLO_Tensor:$lhs, + HLO_Tensor:$rhs), + ConvolutionAttributes.attributes); let results = (outs HLO_Tensor); } @@ -979,15 +953,6 @@ def HLO_DotOp: HLO_Op<"dot", [NoSideEffect]>, BASE_HLO_DotOp { let results = (outs HLO_Tensor); } -def DotDimensionNumbers : StructAttr<"DotDimensionNumbers", HLO_Dialect, [ - StructFieldAttr<"lhs_batching_dimensions", I64ElementsAttr>, - StructFieldAttr<"rhs_batching_dimensions", I64ElementsAttr>, - StructFieldAttr<"lhs_contracting_dimensions", I64ElementsAttr>, - StructFieldAttr<"rhs_contracting_dimensions", I64ElementsAttr> - ]> { - let description = "Structure of dimension information for dot product"; -} - def HLO_DotGeneralOp: HLO_Op<"dot_general", [NoSideEffect]>, BASE_HLO_DotGeneralOp { let arguments = (ins HLO_Tensor:$lhs, @@ -1049,14 +1014,6 @@ def HLO_FftOp: HLO_Op<"fft", [NoSideEffect]>, BASE_HLO_FftOp { let results = (outs HLO_Tensor); } -def GatherDimensionNumbers : StructAttr<"GatherDimensionNumbers", HLO_Dialect, - [StructFieldAttr<"offset_dims", I64ElementsAttr>, - StructFieldAttr<"collapsed_slice_dims", I64ElementsAttr>, - StructFieldAttr<"start_index_map", I64ElementsAttr>, - StructFieldAttr<"index_vector_dim", I64Attr>]> { - let description = "Structure of dimension information for gather"; -} - def HLO_GatherOp: HLO_Op<"gather", [NoSideEffect]>, BASE_HLO_GatherOp { let arguments = (ins HLO_Tensor:$operand, @@ -1067,6 +1024,8 @@ def HLO_GatherOp: HLO_Op<"gather", [NoSideEffect]>, BASE_HLO_GatherOp { ); let results = (outs HLO_Tensor); + + let hasCanonicalizer = 1; } def HLO_GetDimensionSizeOp: HLO_Op<"get_dimension_size", [NoSideEffect]>, @@ -1079,6 +1038,8 @@ def HLO_GetDimensionSizeOp: HLO_Op<"get_dimension_size", [NoSideEffect]>, // XLA semantics is available. This limitation is because of the current XLA // implementation. let results = (outs I32Tensor); + + let hasFolder = 1; } def HLO_MapOp: HLO_Op<"map", @@ -1130,7 +1091,7 @@ def HLO_ScatterOp: HLO_Op<"scatter", [RecursiveSideEffects]>, HLO_Tensor:$operand, HLO_Tensor:$scatter_indices, HLO_Tensor:$updates, - ScatterDimensionNumbers:$scatter_dimension_numbers, + ScatterDimensionNumbers:$scatter_dimension_numbers, DefaultValuedAttr:$indices_are_sorted, DefaultValuedAttr:$unique_indices ); @@ -1140,10 +1101,15 @@ def HLO_ScatterOp: HLO_Op<"scatter", [RecursiveSideEffects]>, let results = (outs HLO_Tensor); let hasCustomHLOConverter = 1; + + let hasFolder = 1; } // TODO(jpienaar): Add broadcastable trait. -def HLO_SelectOp: HLO_Op<"select", [NoSideEffect, DeclareOpInterfaceMethods]>, BASE_HLO_SelectOp { +def HLO_SelectOp: HLO_Op<"select", [NoSideEffect, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, + ]>, BASE_HLO_SelectOp { let arguments = (ins HLO_PredTensor:$pred, HLO_Tensor:$on_true, @@ -1151,6 +1117,8 @@ def HLO_SelectOp: HLO_Op<"select", [NoSideEffect, DeclareOpInterfaceMethods, let results = (outs HLO_Tensor); } -def HLO_SortOp : HLO_Op<"sort", [RecursiveSideEffects]>, BASE_HLO_SortOp { +def HLO_SortOp : HLO_Op<"sort", [RecursiveSideEffects, SameOperandsAndResultShape]>, BASE_HLO_SortOp { let arguments = (ins Variadic:$operands, DefaultValuedAttr:$dimension, DefaultValuedAttr:$is_stable ); - let results = (outs HLO_TensorOrTuple); + let results = (outs Variadic); let regions = (region SizedRegion<1>:$comparator); @@ -1412,4 +1380,21 @@ def HLO_FusionOp : HLO_Op<"fusion", []> { let hasCustomHLOConverter = 1; } +// This is an op for purposes internal to XLA/GPU. +def HLO_BitcastOp : HLO_Op<"bitcast", [NoSideEffect]>, BASE_HLO_BitcastOp { + let arguments = (ins HLO_Tensor:$operand); + let results = (outs HLO_Tensor); + let hasCustomHLOConverter = 1; +} + +def HLO_ReducePrecisionOp: HLO_Op<"reduce_precision", [SameOperandsAndResultShape]>, + BASE_HLO_ReducePrecisionOp { + let arguments = (ins + HLO_FpTensor:$operand, + I32Attr:$exponent_bits, + I32Attr:$mantissa_bits + ); + let results = (outs HLO_FpTensor:$output); +} + #endif // HLO_OPS diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td index 2f80545ad19..da8c921a47b 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td @@ -18,6 +18,13 @@ limitations under the License. include "mlir/IR/OpBase.td" +def HLO_Dialect : Dialect { + let name = "mhlo"; + let cppNamespace = "::mlir::mhlo"; +} + +include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.td" + def HLO_Pred : TypeAlias; // TODO(hinsu): Use signed integers instead of signless integer which is being @@ -45,9 +52,7 @@ def HLO_Token : Type()">, "token">; def HLO_IntTensor : TensorOf<[HLO_Int]>; // Any integer tensor type with rank 0 (i.e. representing a single integer). -def HLO_ScalarIntTensor : ShapedContainerType< - [HLO_Int], And<[IsTensorTypePred, HasAnyRankOfPred<[0]>]>, - "a 0-dim integer tensor">; +def HLO_ScalarIntTensor : 0DTensorOf<[HLO_Int]>; // Any floating-point tensor types def HLO_FpTensor : TensorOf<[AnyFloat]>; @@ -67,10 +72,7 @@ def HLO_TensorOrTokenOrTuple : AnyTypeOf<[HLO_Tensor, HLO_Token, HLO_Tuple]>; def HLO_DimensionValue : AnyTypeOf<[Index, HLO_Pred, HLO_Int]>; // Dynamic representation of a shape vector as a tensor. -def HLO_DimensionTensor : ShapedContainerType< - [HLO_DimensionValue], - And<[IsTensorTypePred, HasAnyRankOfPred<[1]>]>, - "a 1D tensor of dimensions">; +def HLO_DimensionTensor : 1DTensorOf<[HLO_DimensionValue]>; // In general, static shaped tensor constraints should be avoided unless // it is for a legacy op which is only correct with static shapes. @@ -132,6 +134,17 @@ class BASE_HLO_AbsOp { }]; } +class BASE_HLO_CbrtOp { + string summary = "Cubic root operator"; + + string description = [{ + Returns element-wise cubic root of the operand. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} + class BASE_HLO_CeilOp { string summary = "Ceil operator"; @@ -608,15 +621,6 @@ class BASE_HLO_CaseOp { // XLA parallelism related op definitions. //===----------------------------------------------------------------------===// -// Represents a unique identifier for each Send/Recv instruction pair or -// optionally for collective instructions (AllReduce, CollectivePermute, -// AllToAll). Non-positive channel_id handle is equivalent to no channel id. -class ChannelHandle : StructAttr<"ChannelHandle", dialect, [ - StructFieldAttr<"handle", I64Attr>, - StructFieldAttr<"type", I64Attr>]> { - let description = "two 64-bit integers 'handle' and 'type'"; -} - class BASE_HLO_ReplicaIdOp { string summary = "ReplicaId operator"; @@ -706,6 +710,7 @@ def HLO_PrecisionConfigAttr: OptionalAttr< TypedArrayAttrBase>; + //===----------------------------------------------------------------------===// // Fast Fourier Transform Type enum definitions. //===----------------------------------------------------------------------===// @@ -1001,6 +1006,27 @@ class BASE_HLO_ConcatenateOp { }]; } +//===----------------------------------------------------------------------===// +// Common convolution attributes +//===----------------------------------------------------------------------===// + +class ConvolutionAttributes { + dag attributes = (ins + // Default value: one for each of the spatial dimension. + OptionalAttr:$window_strides, + // Default value: zero for each of the spatial dimension. + OptionalAttr:$padding, + // Default value: one for each of the spatial dimension. + OptionalAttr:$lhs_dilation, + // Default value: one for each of the spatial dimension. + OptionalAttr:$rhs_dilation, + ConvDimensionNumbers:$dimension_numbers, + I64Attr:$feature_group_count, + I64Attr:$batch_group_count, + HLO_PrecisionConfigAttr:$precision_config + ); +} + class BASE_HLO_ConvOp { string summary = "Convolution operator"; @@ -1122,15 +1148,6 @@ class BASE_HLO_ReshapeOp { }]; } -class ScatterDimensionNumbers : StructAttr< - "ScatterDimensionNumbers", dialect, [ - StructFieldAttr<"update_window_dims", I64ElementsAttr>, - StructFieldAttr<"inserted_window_dims", I64ElementsAttr>, - StructFieldAttr<"scatter_dims_to_operand_dims", I64ElementsAttr>, - StructFieldAttr<"index_vector_dim", I64Attr>]> { - let description = "Structure of dimension information for scatter"; -} - class BASE_HLO_ScatterOp { string summary = "Scatter operator"; @@ -1341,4 +1358,17 @@ class BASE_HLO_WhileOp { }]; } +class BASE_HLO_BitcastOp { + string summary = "Bitcast operator"; + + string description = [{ + This op changes the shape of the input in the way that the physical + arranggment of elements are unchanged. + + However, the op needs layout information to make sense of "physical + arrangement of elements". Layout support in MHLO is currently under + exploration. + }]; +} + #endif // HLO_OPS_BASE diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h new file mode 100644 index 00000000000..3b78ff8a367 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h @@ -0,0 +1,30 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines structures used in MHLO and LMHLO. + +#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_BASE_STRUCTS_H_ +#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_BASE_STRUCTS_H_ + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Identifier.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/Types.h" + +// Order matters, this .inc header is not self-contained, and relies on the +// #includes above. +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h.inc" + +#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_BASE_STRUCTS_H_ diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.td new file mode 100644 index 00000000000..d25eb5104c6 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.td @@ -0,0 +1,73 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef HLO_OPS_BASE_STRUCTS +#define HLO_OPS_BASE_STRUCTS + +//===----------------------------------------------------------------------===// +// Dot dimensions enum definitions. +//===----------------------------------------------------------------------===// + +def DotDimensionNumbers : StructAttr<"DotDimensionNumbers", HLO_Dialect, [ + StructFieldAttr<"lhs_batching_dimensions", I64ElementsAttr>, + StructFieldAttr<"rhs_batching_dimensions", I64ElementsAttr>, + StructFieldAttr<"lhs_contracting_dimensions", I64ElementsAttr>, + StructFieldAttr<"rhs_contracting_dimensions", I64ElementsAttr> + ]> { + let description = "Structure of dimension information for dot product"; +} + +def ScatterDimensionNumbers : StructAttr< + "ScatterDimensionNumbers", HLO_Dialect, [ + StructFieldAttr<"update_window_dims", I64ElementsAttr>, + StructFieldAttr<"inserted_window_dims", I64ElementsAttr>, + StructFieldAttr<"scatter_dims_to_operand_dims", I64ElementsAttr>, + StructFieldAttr<"index_vector_dim", I64Attr>]> { + let description = "Structure of dimension information for scatter"; +} + +def ConvDimensionNumbers : StructAttr<"ConvDimensionNumbers", HLO_Dialect, [ + StructFieldAttr<"input_batch_dimension",I64Attr>, + StructFieldAttr<"input_feature_dimension", I64Attr>, + StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>, + StructFieldAttr<"kernel_input_feature_dimension", I64Attr>, + StructFieldAttr<"kernel_output_feature_dimension", I64Attr>, + StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>, + StructFieldAttr<"output_batch_dimension", I64Attr>, + StructFieldAttr<"output_feature_dimension", I64Attr>, + StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > { + + let description = "Structure of dimension information for conv op"; +} + +def GatherDimensionNumbers : StructAttr<"GatherDimensionNumbers", HLO_Dialect, + [StructFieldAttr<"offset_dims", I64ElementsAttr>, + StructFieldAttr<"collapsed_slice_dims", I64ElementsAttr>, + StructFieldAttr<"start_index_map", I64ElementsAttr>, + StructFieldAttr<"index_vector_dim", I64Attr>]> { + let description = "Structure of dimension information for gather"; +} + + +// Represents a unique identifier for each Send/Recv instruction pair or +// optionally for collective instructions (AllReduce, CollectivePermute, +// AllToAll). Non-positive channel_id handle is equivalent to no channel id. +def ChannelHandle : StructAttr<"ChannelHandle", HLO_Dialect, [ + StructFieldAttr<"handle", I64Attr>, + StructFieldAttr<"type", I64Attr>]> { + let description = "two 64-bit integers 'handle' and 'type'"; +} + +#endif // HLO_OPS_BASE_STRUCTS diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td index c201aeff8ec..32940cbc623 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td @@ -28,7 +28,7 @@ class ConstantSplat : NativeCodeCall< "hlo::getSplat(&$_builder, $0, " # value # ")">; class HLO_ConstantLike : NativeCodeCall< - "chlo::getConstantLike($_builder, " # value # ", $0)">; + "chlo::getConstantLike($_builder, $_loc, " # value # ", $0)">; def NullDenseIntElementsAttr : NativeCodeCall<"DenseIntElementsAttr()">; diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h new file mode 100644 index 00000000000..effa9ecc83b --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h @@ -0,0 +1,59 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines the operations used in the LHLO dialect. + +#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_H_ +#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_H_ + +#include "llvm/ADT/StringRef.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h" +#include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h" +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/Types.h" +#include "mlir/Interfaces/CopyOpInterface.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Interfaces/ViewLikeInterface.h" + +namespace mlir { +class OpBuilder; +} // namespace mlir + + +namespace mlir { +namespace lmhlo_gpu { + +class LmhloGpuDialect : public Dialect { + public: + explicit LmhloGpuDialect(MLIRContext *context); + static StringRef getDialectNamespace() { return "lmhlo_gpu"; } +}; + +} // namespace lmhlo_gpu +} // end namespace mlir + +#define GET_OP_CLASSES +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h.inc" + +#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_H_ diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td new file mode 100644 index 00000000000..b3708bf4ff1 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td @@ -0,0 +1,210 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This is the operation definition file for LHMLO level GPU operations. +// Because these are LMHLO level operations, they operate on memrefs. + +#ifndef LHLO_GPU_OPS +#define LHLO_GPU_OPS + +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td" +include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_base.td" +include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.td" + + +class LHLOGPU_Op traits = []> : + Op], traits)>; + +// Type for scratch buffers used by GPU library calls (memref) +def UntypedBuffer : MemRefRankOf<[I8], [1]>; + +// Cholesky info output buffer type. +def I32Buffer : MemRefOf<[I32]>; + +//===----------------------------------------------------------------------===// +// LMHLO ops representing batch norm library functions. +//===----------------------------------------------------------------------===// + +// Note: these are semantically different from similar LHLO as the GPU library +// calls generate or consume standard deviation, whereas LHLO ops generate or +// consume variance (= std-dev ^ 2). + +def LHLOGPU_BatchNormGradOp : LHLOGPU_Op<"batch_norm_grad">, + BASE_HLO_BatchNormGradOp { + let arguments = (ins + Arg:$operand, + Arg:$scale, + Arg:$mean, + Arg:$stddev, + Arg:$grad_output, + Arg:$grad_operand, // gradient of $operand. + Arg:$grad_scale, + Arg:$grad_offset, + F32Attr:$epsilon, + I64Attr:$feature_index + ); +} + +def LHLOGPU_BatchNormInferenceOp : LHLOGPU_Op<"batch_norm_inference">, + BASE_HLO_BatchNormInferenceOp { + let arguments = (ins + Arg:$operand, + Arg:$scale, + Arg:$offset, + Arg:$mean, + Arg:$stddev, + Arg:$output, + F32Attr:$epsilon, + I64Attr:$feature_index); +} + +def LHLOGPU_BatchNormTrainingOp : LHLOGPU_Op<"batch_norm_training">, + BASE_HLO_BatchNormTrainingOp { + + let arguments = (ins + Arg:$operand, + Arg:$scale, + Arg:$offset, + Arg:$output, + Arg:$batch_mean, + Arg:$batch_stddev, + F32Attr:$epsilon, + I64Attr:$feature_index + ); +} + +//===----------------------------------------------------------------------===// +// LMHLO ops representing convolution library functions. +//===----------------------------------------------------------------------===// + +def ActivationModeNone : StrEnumAttrCase<"None">; +def ActivationModeSigmoid : StrEnumAttrCase<"Sigmoid">; +def ActivationModeTanh : StrEnumAttrCase<"Relu">; +def ActivationModeRelu : StrEnumAttrCase<"Relu">; +def ActivationModeRelu6 : StrEnumAttrCase<"Relu6">; +def ActivationModeReluX : StrEnumAttrCase<"ReluX">; +def ActivationModeBandPass : StrEnumAttrCase<"BandPass">; + +def ActivationAttr : StrEnumAttr<"Activation", + "Activation applied with fused convolution", + [ActivationModeNone, ActivationModeSigmoid, ActivationModeTanh, + ActivationModeRelu, ActivationModeRelu6, ActivationModeReluX, + ActivationModeBandPass]>; + +def GpuConvolutionAttributes { + dag attributes = !con( + ConvolutionAttributes.attributes, + (ins F64Attr:$result_scale), + (ins ConvolutionBackendConfigAttr:$backend_config)); +} + +def GpuFusedConvolutionAttributes { + dag attributes = !con( + ConvolutionAttributes.attributes, + (ins F64Attr:$result_scale, + ActivationAttr:$activation_mode, + F64Attr:$side_input_scale), + (ins ConvolutionBackendConfigAttr:$backend_config)); +} + +def LHLOGPU_ConvForwardOp : LHLOGPU_Op<"conv_forward"> { + let arguments = !con( + (ins + Arg:$input, + Arg:$filter, + Arg:$output, + Arg:$scratch), + GpuConvolutionAttributes.attributes); +} + +def LHLOGPU_ConvBackwardInputOp : LHLOGPU_Op<"conv_backwardinput"> { + let arguments = !con( + (ins + Arg:$d_output, + Arg:$filter, + Arg:$d_input, + Arg:$scratch), + GpuConvolutionAttributes.attributes); +} + +def LHLOGPU_ConvBackwardFilterOp : LHLOGPU_Op<"conv_backwardfilter"> { + let arguments = !con( + (ins + Arg:$input, + Arg:$d_output, + Arg:$d_filter, + Arg:$scratch), + GpuConvolutionAttributes.attributes); +} + +// output = activation(result_scale * conv(input, filter) + +// side_input * side_input_scale + +// bias) +def LHLOGPU_ConvForwardFusedOp : LHLOGPU_Op<"conv_forward_fused"> { + let arguments = !con( + (ins + Arg:$input, + Arg:$filter, + Arg:$bias, + Arg:$side_input, + Arg:$output, + Arg:$scratch), + GpuFusedConvolutionAttributes.attributes); +} + +//===----------------------------------------------------------------------===// +// LMHLO ops representing other library functions. +//===----------------------------------------------------------------------===// + +// output = alpha * (lhs * rhs) +// Verify: beta = 0.0 +def LHLOGPU_GEMMOp : LHLOGPU_Op<"gemm"> { + let arguments = (ins + Arg:$lhs, + Arg:$rhs, + Arg:$output, + DotDimensionNumbers:$dot_dimension_numbers, + F64Attr:$alpha, + I64Attr:$batch_size, + I64Attr:$algorithm); +} + +// output = alpha(lhs * rhs) + beta * bias +def LHLOGPU_GEMM_BiasOp : LHLOGPU_Op<"gemm_bias"> { + let arguments = (ins + Arg:$lhs, + Arg:$rhs, + Arg:$bias, + Arg:$output, + DotDimensionNumbers:$dot_dimension_numbers, + F64Attr:$alpha, + F64Attr:$beta, + I64Attr:$batch_size, + I64Attr:$algorithm); +} + +def LHLOGPU_CholeskyOp : LHLOGPU_Op<"cholesky"> { + let arguments = (ins + Arg:$input, + Arg:$output, + Arg:$scratch, + Arg:$info, + BoolAttr:$is_upper); +} + +#endif // LHLO_GPU_OPS diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_base.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_base.td new file mode 100644 index 00000000000..820e4ce64b0 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_base.td @@ -0,0 +1,28 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// We define the dialect here so that both structs and ops can refer to it. + +#ifndef LHLO_GPU_OPS_BASE +#define LHLO_GPU_OPS_BASE + +include "mlir/IR/OpBase.td" + +def LHLO_GPU_Dialect : Dialect { + let name = "lmhlo_gpu"; + let cppNamespace = "::mlir::lmhlo_gpu"; +} + +#endif // LHLO_GPU_OPS_BASE diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h new file mode 100644 index 00000000000..ff642b82c22 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h @@ -0,0 +1,30 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ==============================================================================*/ + +// This file defines structures used in the LMHLO_GPU dialect. + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_STRUCTS_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_STRUCTS_H_ + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Identifier.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/Types.h" + +// Order matters, this .inc header is not self-contained, and relies on the +// #includes above. +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h.inc" + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_STRUCTS_H_ diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.td new file mode 100644 index 00000000000..2236fc38e29 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.td @@ -0,0 +1,29 @@ + +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef LHLO_GPU_OPS_STRUCTS +#define LHLO_GPU_OPS_STRUCTS + +include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_base.td" + +def ConvolutionBackendConfigAttr : StructAttr<"ConvolutionBackendConfig", + LHLO_GPU_Dialect, [ + StructFieldAttr<"algorithm", I64Attr>, + StructFieldAttr<"tensor_ops_enabled", BoolAttr>]> { + let description = "GPU Convolution backend configuration"; +} + +#endif // LHLO_GPU_OPS_STRUCTS diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h index bb9b29096f3..9dc6d7aa0c0 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h @@ -13,12 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// This file defines the operations used in the LXLA dialect. +// This file defines the operations used in the LHLO dialect. #ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_H_ #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_H_ #include "llvm/ADT/StringRef.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Location.h" @@ -27,14 +28,12 @@ limitations under the License. #include "mlir/IR/Operation.h" #include "mlir/IR/StandardTypes.h" #include "mlir/IR/Types.h" +#include "mlir/Interfaces/CopyOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/ViewLikeInterface.h" namespace mlir { class OpBuilder; - -#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.h.inc" - namespace lmhlo { class LmhloDialect : public Dialect { @@ -43,10 +42,10 @@ class LmhloDialect : public Dialect { static StringRef getDialectNamespace() { return "lmhlo"; } }; -#define GET_OP_CLASSES -#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc" - } // namespace lmhlo } // end namespace mlir +#define GET_OP_CLASSES +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc" + #endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_H_ diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td index 750cce65b62..25d5e50af7d 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td @@ -16,58 +16,34 @@ limitations under the License. // This is the operation definition file for LMHLO, the "late" MHLO variant of // the dialect, which operates on buffers instead of tensors. // -// This file largely overlaps with mhlo_ops.td at a logic level. It's tempting to -// merge these two files together, but we need to consider the following +// This file largely overlaps with hlo_ops.td at a logical level. It's tempting +// to merge these two files together, but we need to consider the following // obstacles: // * We need to have a common representation for arguments. That is to say, -// HLO_Array translates to HLO_Tensor in HLO dialect, and -// Arg, "", [Mem(Read|Write)]> in LHLO. Array types within tuples -// also need to be transformed. +// HLO_Array translates to HLO_Tensor in HLO dialect, and +// Arg, "", [Mem(Read|Write)]> in LHLO. Array types within +// tuples also need to be transformed. // * As of now, TableGen's dag functions are not sufficient to accomplish the -// one above. -// * Traits aren't identical, but need to be coped. For example, -// SameOperandAndResultType in HLO corresponds to SameTypeOperands in LHLO. +// one above. +// * Traits aren't identical, but need to be copied. For example, +// SameOperandAndResultType in HLO corresponds to SameTypeOperands in LHLO. // * Also, currently HLO describes the API in XLA's client side, not service -// side. LHLO aims for the service side. +// side. LHLO aims for the service side. #ifndef LHLO_OPS #define LHLO_OPS include "mlir/IR/OpBase.td" +include "mlir/Interfaces/CopyOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/ViewLikeInterface.td" -include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td" +include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td" def LHLO_Dialect : Dialect { let name = "lmhlo"; - let cppNamespace = "lmhlo"; + let cppNamespace = "::mlir::lmhlo"; } -//===----------------------------------------------------------------------===// -// LMHLO type definitions. -//===----------------------------------------------------------------------===// - -// Any integer tensor types -def LHLO_IntBuffer : MemRefOf<[HLO_Int]>; - -// Any floating-point tensor types -def LHLO_FpBuffer : MemRefOf<[AnyFloat]>; - -def LHLO_ComplexBuffer : MemRefOf<[AnyComplex]>; - -def LHLO_FpOrComplexBuffer : MemRefOf<[AnyFloat, AnyComplex]>; - -def LHLO_PredBuffer : MemRefOf<[HLO_Pred]>; - -// Any integer or floating-point tensor types -def LHLO_IntOrFpBuffer : MemRefOf<[HLO_Int, AnyFloat]>; - -def LHLO_PredOrIntBuffer : MemRefOf<[HLO_Int, HLO_Pred]>; - -def LHLO_Buffer : MemRefOf<[AnyFloat, AnySignlessInteger, AnyComplex]>; - -def LHLO_ExtentBuffer : MemRefRankOf<[AnySignlessInteger, Index], [1]>; - //===----------------------------------------------------------------------===// // LMHLO nullary op definitions. //===----------------------------------------------------------------------===// @@ -288,6 +264,16 @@ def LHLO_WhileOp: LHLO_Op<"while", [SameVariadicOperandSize]>, let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body); } +def LHLO_CustomCallOp : LHLO_Op<"custom_call", []>, BASE_HLO_CustomCallOp { + let arguments = (ins + Arg, "", [MemRead]>:$args, + Arg:$output, + StrAttr:$call_target_name, + DefaultValuedAttr:$has_side_effect, + DefaultValuedAttr:$backend_config + ); +} + //===----------------------------------------------------------------------===// // LMHLO tuple op definitions. //===----------------------------------------------------------------------===// @@ -334,10 +320,11 @@ def HLO_DynamicUpdateSliceOp: LHLO_Op<"dynamic-update-slice", []> { def HLO_StaticMemRefCastOp: Op]> { let summary = [{ - "modifies the offset, sizes and strides of a statically shaped memref. + modifies the offset, sizes and strides of a statically shaped memref }]; let description = [{ - Allows to modify the offset, sizes and strides of a statically shaped memref. + Casts the statically shaped memref operand to a memref with optionally + modified offsets, sizes and strides. Example: ```mlir @@ -353,12 +340,11 @@ def HLO_StaticMemRefCastOp: Op:$operand); let results = (outs Res:$result); - let builders = [OpBuilder< - "OpBuilder &builder, OperationState &result, MemRefType resultType, " # - "Value operand", [{ - result.addOperands(operand); - result.types.push_back(resultType); - }]>]; + let builders = [OpBuilder<"MemRefType resultType, Value operand", + [{ + $_state.addOperands(operand); + $_state.types.push_back(resultType); + }]>]; let extraClassDeclaration = [{ MemRefType getType() { return getResult().getType().cast(); } @@ -399,13 +385,13 @@ def HLO_DynamicMemRefCastOp: Op:$result); - let builders = [OpBuilder< - "OpBuilder &builder, OperationState &result, MemRefType resultType, " # - "Value operand, ValueRange sizes, ValueRange strides", [{ - result.addOperands(operand); - result.addOperands(sizes); - result.addOperands(strides); - result.types.push_back(resultType); + let builders = [ + OpBuilder<"MemRefType resultType, Value operand, ValueRange sizes, " + "ValueRange strides", [{ + $_state.addOperands(operand); + $_state.addOperands(sizes); + $_state.addOperands(strides); + $_state.types.push_back(resultType); }]>]; let extraClassDeclaration = [{ @@ -476,7 +462,8 @@ def ReshapeMemRefCastOp: Op(); } + BaseMemRefType getType() { + return getResult().getType().cast(); } }]; let verifier = [{ return Verify(*this); }]; @@ -580,53 +567,32 @@ def LHLO_ConcatenateOp : LHLO_Op<"concatenate", []>, BASE_HLO_ConcatenateOp { ); } -// TODO(bondhugula): Make this struct dialect independent so that it can be -// shared between the HLO and LHLO dialects. -def ConvDimensionNumbers : StructAttr<"ConvDimensionNumbers", LHLO_Dialect, [ - StructFieldAttr<"input_batch_dimension",I64Attr>, - StructFieldAttr<"input_feature_dimension", I64Attr>, - StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>, - StructFieldAttr<"kernel_input_feature_dimension", I64Attr>, - StructFieldAttr<"kernel_output_feature_dimension", I64Attr>, - StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>, - StructFieldAttr<"output_batch_dimension", I64Attr>, - StructFieldAttr<"output_feature_dimension", I64Attr>, - StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > { - - let description = "Structure of dimension information for conv op"; -} - def LHLO_ConvOp : LHLO_Op<"convolution", []>, BASE_HLO_ConvOp { - let arguments = (ins - Arg:$lhs, - Arg:$rhs, - Arg:$output, - // Default value: one for each of the spatial dimension. - OptionalAttr:$window_strides, - // Default value: zero for each of the spatial dimension. - OptionalAttr:$padding, - // Default value: one for each of the spatial dimension. - OptionalAttr:$lhs_dilation, - // Default value: one for each of the spatial dimension. - OptionalAttr:$rhs_dilation, - ConvDimensionNumbers:$dimension_numbers, - I64Attr:$feature_group_count, - I64Attr:$batch_group_count, - HLO_PrecisionConfigAttr:$precision_config - ); + let arguments = !con( + (ins + Arg:$lhs, + Arg:$rhs, + Arg:$output), + ConvolutionAttributes.attributes); } -def LHLO_CopyOp: LHLO_Op<"copy", []>, BASE_HLO_CopyOp { +def LHLO_CopyOp: LHLO_Op<"copy", [CopyOpInterface]>, BASE_HLO_CopyOp { let arguments = (ins Arg:$operand, Arg:$output ); + + let extraClassDeclaration = [{ + Value getSource() { return operand();} + Value getTarget() { return output(); } + }]; } def LHLO_DotOp: LHLO_Op<"dot", []>, BASE_HLO_DotOp { let arguments = (ins Arg:$lhs, Arg:$rhs, + DotDimensionNumbers:$dot_dimension_numbers, HLO_PrecisionConfigAttr:$precision_config, Arg:$output ); @@ -658,7 +624,7 @@ def LHLO_ScatterOp: LHLO_Op<"scatter", []>, BASE_HLO_ScatterOp { Arg:$scatter_indices, Arg:$updates, Arg:$output, - ScatterDimensionNumbers:$scatter_dimension_numbers, + ScatterDimensionNumbers:$scatter_dimension_numbers, DefaultValuedAttr:$indices_are_sorted, DefaultValuedAttr:$unique_indices ); @@ -734,7 +700,7 @@ def LHLO_AllReduceOp : LHLO_Op<"all_reduce", [SameTypeOperands]>, Arg:$output, I64ElementsAttr:$replica_groups, DefaultValuedAttr:$constrain_layout, - OptionalAttr>:$channel_id, + OptionalAttr:$channel_id, DefaultValuedAttr:$use_global_device_ids ); let regions = (region SizedRegion<1>:$computation); @@ -747,7 +713,7 @@ def LHLO_CollectivePermuteOp: LHLO_Op<"collective_permute", [SameTypeOperands]>, Arg:$operand, Arg:$output, I64ElementsAttr:$source_target_pairs, - OptionalAttr>:$channel_id + OptionalAttr:$channel_id ); } @@ -849,9 +815,8 @@ def FusionOp : LHLO_Op<"fusion", [SingleBlockImplicitTerminator<"TerminatorOp">] let skipDefaultBuilders = 1; let builders = [ - OpBuilder<"OpBuilder &builder, OperationState &result, " - "ArrayRef attributes"> - ]; + OpBuilder<"ArrayRef attributes"> + ]; } def TerminatorOp : @@ -860,9 +825,8 @@ def TerminatorOp : let description = [{ Terminator operation for the LHLO dialect. }]; - let builders = [OpBuilder< - "OpBuilder &b, OperationState &result, ValueRange operands", - [{ build(b, result, llvm::None, operands, llvm::None); }] + let builders = [OpBuilder<"ValueRange operands", + [{ build($_builder, $_state, llvm::None, operands, llvm::None); }] >]; } diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td new file mode 100644 index 00000000000..9cd77417ffd --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td @@ -0,0 +1,47 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef LHLO_OPS_BASE +#define LHLO_OPS_BASE + +include "mlir/IR/OpBase.td" +include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td" + +//===----------------------------------------------------------------------===// +// LMHLO type definitions. +//===----------------------------------------------------------------------===// + +// Any integer tensor types +def LHLO_IntBuffer : MemRefOf<[HLO_Int]>; + +// Any floating-point tensor types +def LHLO_FpBuffer : MemRefOf<[AnyFloat]>; + +def LHLO_ComplexBuffer : MemRefOf<[AnyComplex]>; + +def LHLO_FpOrComplexBuffer : MemRefOf<[AnyFloat, AnyComplex]>; + +def LHLO_PredBuffer : MemRefOf<[HLO_Pred]>; + +// Any integer or floating-point tensor types +def LHLO_IntOrFpBuffer : MemRefOf<[HLO_Int, AnyFloat]>; + +def LHLO_PredOrIntBuffer : MemRefOf<[HLO_Int, HLO_Pred]>; + +def LHLO_Buffer : MemRefOf<[AnyFloat, AnySignlessInteger, AnyComplex]>; + +def LHLO_ExtentBuffer : MemRefRankOf<[AnySignlessInteger, Index], [1]>; + +#endif // LHLO_OPS_BASE diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/register.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/register.h index 90ff6c99751..cb0af3a159d 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/register.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/register.h @@ -20,8 +20,6 @@ namespace mlir { class DialectRegistry; namespace mhlo { -void registerAllDialects(); - // Add chlo, mhlo, lmhlo dialects to the provided registry. void registerAllMhloDialects(DialectRegistry ®istry); } diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/PassDetail.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/PassDetail.h new file mode 100644 index 00000000000..5f18eeb6ecc --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/PassDetail.h @@ -0,0 +1,30 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_PASSDETAIL_H_ +#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_PASSDETAIL_H_ + +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace mhlo { + +#define GEN_PASS_CLASSES +#include "mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.h.inc" + +} // end namespace mhlo +} // end namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_PASSDETAIL_H_ diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/lmhlo_passes.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/lmhlo_passes.td index 963ff5dbacf..39b4ca65043 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/lmhlo_passes.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/lmhlo_passes.td @@ -15,12 +15,6 @@ limitations under the License. include "mlir/Pass/PassBase.td" -def LhloCopyRemovalPass : Pass<"lhlo-copy-removal", "FuncOp"> { - let summary = "Removes redundant LHLO copy operations."; - let constructor = "createLhloCopyRemovalPass()"; -} - - def LhloLegalizeToLinalgPass : Pass<"lhlo-legalize-to-linalg", "FuncOp"> { let summary = "Legalize from LHLO dialect to Linalg dialect."; let constructor = "createLegalizeLhloToLinalgPass()"; diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h index c51bcfcfe89..ac67619e6e3 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h @@ -40,6 +40,7 @@ using HloToLhloOp = typename HloToLhloOpImpl::Type; MAP_HLO_TO_LHLO(AbsOp); MAP_HLO_TO_LHLO(AddOp); MAP_HLO_TO_LHLO(AndOp); +MAP_HLO_TO_LHLO(Atan2Op); MAP_HLO_TO_LHLO(BroadcastInDimOp); MAP_HLO_TO_LHLO(CeilOp); MAP_HLO_TO_LHLO(ConstOp); @@ -49,17 +50,21 @@ MAP_HLO_TO_LHLO(ConvOp); MAP_HLO_TO_LHLO(ConvertOp); MAP_HLO_TO_LHLO(CopyOp); MAP_HLO_TO_LHLO(CosOp); +MAP_HLO_TO_LHLO(CustomCallOp); MAP_HLO_TO_LHLO(DivOp); MAP_HLO_TO_LHLO(DotOp); MAP_HLO_TO_LHLO(ExpOp); +MAP_HLO_TO_LHLO(FloorOp); MAP_HLO_TO_LHLO(GatherOp); MAP_HLO_TO_LHLO(ImagOp); MAP_HLO_TO_LHLO(IotaOp); +MAP_HLO_TO_LHLO(IsFiniteOp); MAP_HLO_TO_LHLO(LogOp); MAP_HLO_TO_LHLO(MaxOp); MAP_HLO_TO_LHLO(MinOp); MAP_HLO_TO_LHLO(MulOp); MAP_HLO_TO_LHLO(NegOp); +MAP_HLO_TO_LHLO(NotOp); MAP_HLO_TO_LHLO(RealOp); MAP_HLO_TO_LHLO(ReduceOp); MAP_HLO_TO_LHLO(ReshapeOp); @@ -68,9 +73,11 @@ MAP_HLO_TO_LHLO(RsqrtOp); MAP_HLO_TO_LHLO(SelectOp); MAP_HLO_TO_LHLO(SignOp); MAP_HLO_TO_LHLO(SinOp); +MAP_HLO_TO_LHLO(SliceOp); MAP_HLO_TO_LHLO(SqrtOp); MAP_HLO_TO_LHLO(SubOp); MAP_HLO_TO_LHLO(TanhOp); +MAP_HLO_TO_LHLO(TransposeOp); #undef MAP_HLO_TO_LHLO diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h index 2bb5ab2888d..d59dfd43d1b 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h @@ -22,6 +22,7 @@ limitations under the License. #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/TypeUtilities.h" namespace mlir { namespace lmhlo { @@ -96,7 +97,7 @@ template struct MapLhloOpToStdScalarOpImpl { Value operator()(Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - Type element_type = args.front().getType(); + Type element_type = getElementTypeOrSelf(args.front().getType()); if (element_type.isa()) { return b->template create(loc, result_types, args, mlir::None); @@ -120,7 +121,7 @@ inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - Type element_type = args.front().getType(); + Type element_type = getElementTypeOrSelf(args.front().getType()); if (element_type.isa()) { return MapLhloOpToStdScalarOpImpl{}( loc, result_types, args, b); @@ -130,8 +131,11 @@ inline Value MapLhloOpToStdScalarOp(Location loc, Value lhs = args[0]; auto integer_type = element_type.dyn_cast(); - auto zero_intval = + Value zero_intval = b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth()); + if (VectorType vec_type = args.front().getType().dyn_cast()) { + zero_intval = b->create<::mlir::SplatOp>(loc, vec_type, zero_intval); + } auto lhs_gt_zero = b->create>(loc, CmpIPredicate::sge, lhs, zero_intval); auto neg_val = b->create>(loc, zero_intval, lhs); @@ -149,6 +153,15 @@ inline Value MapLhloOpToStdScalarOp(Location loc, loc, result_types, args, b); } +template <> +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { + return MapLhloOpToStdScalarOpImpl{}( + loc, result_types, args, b); +} + template inline Optional getCmpPredicate(StringRef comparison_direction) { return llvm::None; @@ -187,7 +200,7 @@ inline Value MapCompareOpToStdScalarOp(Location loc, ArrayRef args, OpBuilder* b) { const auto& lhs = args[0]; const auto& rhs = args[1]; - Type element_type = lhs.getType(); + Type element_type = getElementTypeOrSelf(lhs.getType()); if (element_type.isSignlessInteger()) { Optional predicate = getCmpPredicate(comparison_direction); @@ -259,8 +272,8 @@ template <> inline Value MapLhloOpToStdScalarOp( Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - Type sourceType = args.front().getType(); - Type targetType = result_types.front(); + Type sourceType = getElementTypeOrSelf(args.front().getType()); + Type targetType = getElementTypeOrSelf(result_types.front()); if (mlir::SIToFPOp::areCastCompatible(sourceType, targetType)) { return b->create(loc, result_types, args, mlir::None); @@ -336,6 +349,31 @@ inline Value MapLhloOpToStdScalarOp(Location loc, loc, result_types, args, b); } +template <> +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { + return MapLhloOpToStdScalarOpImpl{}( + loc, result_types, args, b); +} + +template <> +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, + OpBuilder* b) { + if (args[0].getType().isa()) { + auto pos_inf = APFloat::getInf( + args[0].getType().cast().getFloatSemantics()); + auto const_pos_inf = + b->create(loc, b->getFloatAttr(args[0].getType(), pos_inf)); + Value abs_x = b->create<::mlir::AbsFOp>(loc, args[0]); + return b->create<::mlir::CmpFOp>(loc, CmpFPredicate::ONE, abs_x, + const_pos_inf); + } + return nullptr; +} + /// Implements the conversion of HLO op to scalar op (to use within region of a /// linalg.generic op) for compare-select style operations like min/max. template @@ -356,7 +394,7 @@ struct CompareSelectOpToStdScalarOp result_types, ArrayRef args, OpBuilder* b) { - Type element_type = args.front().getType(); + Type element_type = getElementTypeOrSelf(args.front().getType()); if (element_type.isa()) { auto predicate = getCmpPredicate(comparison_direction); assert(predicate.hasValue() && "expected valid comparison direction"); @@ -405,7 +443,7 @@ inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - Type element_type = args.front().getType(); + Type element_type = getElementTypeOrSelf(args.front().getType()); if (element_type.isa()) { return MapLhloOpToStdScalarOpImpl{}( loc, result_types, args, b); @@ -415,13 +453,34 @@ inline Value MapLhloOpToStdScalarOp(Location loc, Value lhs = args[0]; auto integer_type = element_type.dyn_cast(); - auto zero_intval = + Value zero_intval = b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth()); + if (VectorType vec_type = args.front().getType().dyn_cast()) { + zero_intval = b->create<::mlir::SplatOp>(loc, vec_type, zero_intval); + } return b->create>(loc, zero_intval, lhs); } return nullptr; } +template <> +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { + Type element_type = getElementTypeOrSelf(args.front().getType()); + if (auto integer_type = element_type.dyn_cast()) { + // lmhlo.not(x) -> x ^ -1 + Value all_ones = + b->create<::mlir::ConstantIntOp>(loc, -1, integer_type.getWidth()); + if (VectorType vec_type = args.front().getType().dyn_cast()) { + all_ones = b->create<::mlir::SplatOp>(loc, vec_type, all_ones); + } + return b->create<::mlir::XOrOp>(loc, all_ones, args[0]); + } + return nullptr; +} + template <> inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, @@ -444,12 +503,37 @@ inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - Type element_type = args.front().getType(); - if (element_type.isa()) { - FloatType float_type = element_type.cast(); - APFloat const_value = float_type.isF32() ? APFloat(1.0f) : APFloat(1.0); - Value one = b->create(loc, const_value, float_type); + Type element_type = getElementTypeOrSelf(args.front().getType()); + if (auto float_type = element_type.dyn_cast()) { + bool ignored; + APFloat one_apfloat(1.0f); + one_apfloat.convert(float_type.getFloatSemantics(), + APFloat::rmNearestTiesToEven, &ignored); + Value one = b->create(loc, one_apfloat, float_type); + if (VectorType vec_type = args.front().getType().dyn_cast()) { + one = b->create<::mlir::SplatOp>(loc, vec_type, one); + } return b->create<::mlir::CopySignOp>(loc, result_types, one, args[0]); + } else if (auto integer_type = element_type.dyn_cast()) { + // sign(x) = x == 0 ? 0 : ((x s>> 31) | 1) + Value zero = + b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth()); + Value bitwidth_minus_one = b->create<::mlir::ConstantIntOp>( + loc, integer_type.getWidth() - 1, integer_type.getWidth()); + Value one = + b->create<::mlir::ConstantIntOp>(loc, 1, integer_type.getWidth()); + if (VectorType vec_type = args.front().getType().dyn_cast()) { + zero = b->create<::mlir::SplatOp>(loc, vec_type, zero); + bitwidth_minus_one = + b->create<::mlir::SplatOp>(loc, vec_type, bitwidth_minus_one); + one = b->create<::mlir::SplatOp>(loc, vec_type, one); + } + Value cmp = + b->create<::mlir::CmpIOp>(loc, CmpIPredicate::eq, args[0], zero); + Value ashr = + b->create<::mlir::SignedShiftRightOp>(loc, args[0], bitwidth_minus_one); + Value or_op = b->create<::mlir::OrOp>(loc, ashr, one); + return b->create<::mlir::SelectOp>(loc, cmp, zero, or_op); } return nullptr; } @@ -518,6 +602,27 @@ struct HloOpToStdScalarOp { return impl::MapCompareOpToStdScalarOp( op.getLoc(), comparison_direction, result_types, args, b); } + + // Implementation for LHLO ops except lmhlo::CompareOp. + template ::value && + std::is_same, + std::false_type>::value>> + static Value map(Location loc, ArrayRef result_types, + ArrayRef args, OpBuilder* b, unsigned i = 0) { + return impl::MapLhloOpToStdScalarOp(loc, result_types, args, b); + } + + // Implementation for lmhlo::CompareOp. + template ::value>> + static Value map(Location loc, StringRef comparison_direction, + ArrayRef result_types, ArrayRef args, + OpBuilder* b) { + return impl::MapCompareOpToStdScalarOp( + loc, comparison_direction, result_types, args, b); + } }; } // namespace lmhlo diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td index fa3bde24df1..4348464fa74 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td @@ -15,9 +15,9 @@ limitations under the License. include "mlir/Pass/PassBase.td" -def TestChloLegalizeToHloPass : Pass<"mhlo-test-chlo-legalize-to-hlo", "FuncOp"> { - let summary = "Test pass for applying chlo -> hlo legalization patterns."; - let constructor = "createTestChloLegalizeToHloPass()"; +def ChloLegalizeToHloPass : Pass<"chlo-legalize-to-hlo", "FuncOp"> { + let summary = "Legalize CHLO to HLO."; + let constructor = "createChloLegalizeToHloPass()"; } def HloLegalizeToLhloPass : Pass<"hlo-legalize-to-lhlo", "ModuleOp"> { @@ -30,15 +30,20 @@ def LegalizeControlFlowPass : Pass<"mhlo-legalize-control-flow", "FuncOp"> { let constructor = "createLegalizeControlFlowPass()"; } +def LegalizeControlFlowToScfPass : Pass<"mhlo-control-flow-to-scf", "FuncOp"> { + let summary = "Legalize from MHLO control flow to SCF control flow."; + let constructor = "createControlFlowToScfPass()"; +} + def LegalizeGatherToTorchIndexSelectPass : Pass<"mhlo-legalize-gather-to-torch-index-select", "FuncOp"> { let summary = "Legalizes gathers to a torch index select."; let constructor = "createLegalizeGatherToTorchIndexSelectPass()"; } -def LegalizeTanhToApproximationPass : Pass<"mhlo-legalize-tanh-to-approximation", "FuncOp"> { - let summary = "Legalize tanh from standard dialect to an approximation."; - let constructor = "createLegalizeTanhToApproximationPass()"; +def LegalizeTanhToApproximationPass : Pass<"mhlo-legalize-trigonometric-to-approximation", "FuncOp"> { + let summary = "Legalize trigonometric operations from standard dialect to an approximation."; + let constructor = "createLegalizeTrigonometricToApproximationPass()"; } @@ -83,7 +88,7 @@ def OptimizeMhloPass : Pass<"mhlo-test-optimize", "FuncOp"> { } -def SinkConstantsToControlFlowPass : Pass<"mhlo-sink-constants-to-control-flow", "FuncOp"> { +def SinkConstantsToControlFlowPass : FunctionPass<"mhlo-sink-constants-to-control-flow"> { let summary = "Sink constants implicitly captured in control flow regions. This " "is necessary to export to XLA."; let constructor = "createSinkConstantsToControlFlowPass()"; diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h index efa116f3f0d..b1933f6686b 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h @@ -30,14 +30,23 @@ template class OperationPass; class Pass; +// Transforms unranked HLO operations to ranked ones where possible. +std::unique_ptr createTransformUnrankedHloPass(); + namespace mhlo { /// Lowers HLO control flow ops to the Standard dialect. std::unique_ptr> createLegalizeControlFlowPass(); +/// Lowers MHLO control flow ops to the SCF dialect. +std::unique_ptr> createControlFlowToScfPass(); + /// Lowers from HLO dialect to Standard dialect. std::unique_ptr> createLegalizeToStdPass(); +/// Lowers from the CHLO dialect to the HLO dialect. +std::unique_ptr createChloLegalizeToHloPass(); + /// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary /// buffers if necessary. If `results_escape_functions` is set to true, /// allocated buffers for function results will be returned and escape the @@ -49,9 +58,6 @@ std::unique_ptr> createLegalizeToLhloPass( // Lowers from HLO dialect to Linalg dialect. std::unique_ptr> createLegalizeHloToLinalgPass(); -// Transforms unranked HLO operations to ranked ones where possible. -std::unique_ptr> createTransformUnrankedHloPass(); - // Sinks constants implicitly captured in control flow regions. This is // necessary to export to XLA. std::unique_ptr> createSinkConstantsToControlFlowPass(); @@ -59,8 +65,10 @@ std::unique_ptr> createSinkConstantsToControlFlowPass(); // fuse mhlo ops to kLoop/kInput fusion patterns std::unique_ptr> createMhloFusionPass(); -/// Lowers the standard TanhOp to an approximation that does not use intrinsics. -std::unique_ptr> createLegalizeTanhToApproximationPass(); +/// Lowers trigonometric operations from the standard dialect to approximations +/// that do not use intrinsics. +std::unique_ptr> +createLegalizeTrigonometricToApproximationPass(); std::unique_ptr createOptimizeMhloPass(); std::unique_ptr createLowerComplexPass(); @@ -92,12 +100,6 @@ std::unique_ptr createLegalizeToGpuPass(); std::unique_ptr createLhloFuseLinalgPass( bool use_parallel_loops = false, llvm::ArrayRef tile_sizes = {}); -// Removes unnecessary LHLO copies which copy from the allocated buffers to the -// block arguments. The block arguments are used instead of all uses of these -// buffers. The buffers are freed. This pass only works in regions that contain -// a single block. -std::unique_ptr createLhloCopyRemovalPass(); - // Lowers from LHLO dialect to parallel loops. std::unique_ptr> createLegalizeLhloToParallelLoopsPass(); diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/register_passes.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/register_passes.h index 8f70f64359b..e9418f0e7a0 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/register_passes.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/register_passes.h @@ -22,7 +22,6 @@ limitations under the License. namespace mlir { namespace mhlo { -std::unique_ptr createTestChloLegalizeToHloPass(); std::unique_ptr createTestInferShapedTypeMethodsPass(); std::unique_ptr createTestMaterializeBroadcastsPass(); std::unique_ptr createTestUnfuseBatchNormPass(); diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h index 725155e9403..b6706187d50 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h @@ -20,6 +20,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/Bufferize.h" #include "mlir/Transforms/DialectConversion.h" namespace mlir { @@ -27,6 +28,12 @@ class LLVMTypeConverter; class LowerToLLVMOptions; class OwningRewritePatternList; class BufferAssignmentPlacer; + +// Populates a collection of rewrite patterns to realize element-wise operations +// on ranked tensors where possible. +void PopulateTransformUnrankedHloPatterns(MLIRContext *context, + OwningRewritePatternList *patterns); + namespace mhlo { // Collection of rewrite patterns for lowering a general dot product. @@ -49,9 +56,10 @@ void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns, MLIRContext *ctx); // Collection of rewrite patterns for lowering of HLO to LHLO dialect. -void populateHLOToLHLOConversionPattern( - MLIRContext *context, BufferAssignmentPlacer *bufferAssignment, - TypeConverter *converter, OwningRewritePatternList *patterns); +void populateHLOToLHLOConversionPattern(MLIRContext *context, + BufferizeTypeConverter *converter, + OwningRewritePatternList *patterns); + // Collection of rewrite patterns for lowering of HLO to Linalg dialect. void populateHLOToLinalgConversionPattern(MLIRContext *context, OwningRewritePatternList *patterns); @@ -80,10 +88,10 @@ void PopulateTransformUnrankedHloPatterns(MLIRContext *context, void PopulateUnfuseBatchNormPatterns(MLIRContext *context, OwningRewritePatternList *patterns); -// Populates a pattern that translates the standard TanhOp to an approximation -// that does not use intrinsics. -void PopulateTanhToApproximationPatterns(MLIRContext *context, - OwningRewritePatternList *patterns); +// Populates patterns that translate the trigonometric operations from the +// standard dialect to approximations that do not use intrinsics. +void PopulateTrigonometricToApproximationPatterns( + MLIRContext *context, OwningRewritePatternList *patterns); } // namespace mhlo diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/CMakeLists.txt b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/CMakeLists.txt index d7bb5057b00..7c0c11b1edd 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/CMakeLists.txt +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/CMakeLists.txt @@ -43,6 +43,7 @@ add_mlir_library(MhloInferFusibilityOpInterface add_mlir_dialect_library(MhloDialect hlo_ops.cc + hlo_ops_base_structs.cc DEPENDS MLIRhlo_opsIncGen @@ -66,6 +67,15 @@ add_mlir_dialect_library(LmhloDialect ) target_link_libraries(LmhloDialect PUBLIC MLIRIR) +add_mlir_dialect_library(LmhloGPUDialect + lhlo_gpu_ops.cc + lhlo_gpu_ops_structs.cc + + DEPENDS + MLIRlhlo_gpu_opsIncGen +) +target_link_libraries(LmhloGPUDialect PUBLIC MLIRIR) + add_mlir_dialect_library(MhloRegisterDialects init.cc @@ -73,10 +83,12 @@ DEPENDS MLIRchlo_opsIncGen MLIRhlo_opsIncGen MLIRlhlo_opsIncGen + MLIRlhlo_gpu_opsIncGen ) target_link_libraries(MhloRegisterDialects PUBLIC ChloDialect MhloDialect LmhloDialect + LmhloGPUDialect ) diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/chlo_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/chlo_ops.cc index b5eacd686bd..99b22a75a14 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/chlo_ops.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/chlo_ops.cc @@ -303,9 +303,15 @@ void ConstantLikeOp::getCanonicalizationPatterns( results.insert(context); } +} // namespace chlo +} // namespace mlir + #define GET_OP_CLASSES #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.cc.inc" +namespace mlir { +namespace chlo { + //===----------------------------------------------------------------------===// // chlo Dialect Constructor //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc index f5deb94e3a4..241b5938017 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -27,6 +27,7 @@ limitations under the License. #include "llvm/ADT/APFloat.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" @@ -60,7 +61,9 @@ limitations under the License. namespace mlir { #include "hlo_patterns.cc.inc" -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_structs.cc.inc" +} // namespace mlir + +namespace mlir { namespace mhlo { Operation* MhloDialect::materializeConstant(OpBuilder& builder, Attribute value, @@ -165,6 +168,94 @@ static LogicalResult Verify(DotGeneralOp op) { return success(); } +//===----------------------------------------------------------------------===// +// GatherOp +//===----------------------------------------------------------------------===// + +// Converts gather ops to slice ops in case we have a single set of constant +// indices. +struct GatherSlice : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GatherOp gather, + PatternRewriter& rewriter) const override { + DenseIntElementsAttr index; + if (!matchPattern(gather.start_indices(), m_Constant(&index))) + return failure(); + + const auto& dnums = gather.dimension_numbers(); + if (dnums.index_vector_dim().getInt() != 0 || index.getType().getRank() > 1) + return failure(); + + // TODO(tberghammer): Remove when the verifier catches this case what is + // invalid if all previous condition holds. + if (index.getNumElements() != dnums.start_index_map().getNumElements()) + return failure(); + + auto slice_end = + llvm::to_vector<8>(gather.slice_sizes().getValues()); + llvm::SmallVector slice_start(slice_end.size(), 0); + for (auto it : llvm::zip(dnums.start_index_map().getIntValues(), + index.getIntValues())) { + int64_t map_index = std::get<0>(it).getSExtValue(); + int64_t offset = std::get<1>(it).getSExtValue(); + slice_start[map_index] += offset; + slice_end[map_index] += offset; + } + + llvm::SmallVector slice_stride(slice_end.size(), 1); + llvm::SmallVector slice_shape(slice_end.size()); + for (int64_t i = 0; i < slice_end.size(); ++i) { + slice_shape[i] = slice_end[i] - slice_start[i]; + } + Type element_type = gather.getType().cast().getElementType(); + auto slice_type = RankedTensorType::get(slice_shape, element_type); + Value result = rewriter.create( + gather.getLoc(), slice_type, gather.getOperand(0), + GetI64ElementsAttr(slice_start, &rewriter), + GetI64ElementsAttr(slice_end, &rewriter), + GetI64ElementsAttr(slice_stride, &rewriter)); + + if (dnums.collapsed_slice_dims().getNumElements() > 0) { + auto collapsed_slice_dims = llvm::to_vector<8>(llvm::map_range( + dnums.collapsed_slice_dims().getIntValues(), + [](const llvm::APInt& i) { return i.getSExtValue(); })); + llvm::SmallVector reshape_shape; + for (int64_t i = 0; i < slice_shape.size(); ++i) { + if (llvm::count(collapsed_slice_dims, i) == 0) { + reshape_shape.push_back(slice_shape[i]); + } + } + auto reshape_type = RankedTensorType::get(reshape_shape, element_type); + result = + rewriter.create(gather.getLoc(), reshape_type, result); + } + + result.setType(gather.getType()); + rewriter.replaceOp(gather, result); + return success(); + } +}; + +void GatherOp::getCanonicalizationPatterns(OwningRewritePatternList& results, + MLIRContext* context) { + results.insert(context); +} + +//===----------------------------------------------------------------------===// +// GetDimensionSizeOp +//===----------------------------------------------------------------------===// + +/// Fold get_dimension_size when the said shape dimension is a constant. +OpFoldResult GetDimensionSizeOp::fold(ArrayRef attrs) { + RankedTensorType type = operand().getType().cast(); + int32_t dim = dimension(); + if (type.isDynamic(dim)) return {}; + // The result type is always is a 0-d i32 tensor. + return DenseIntElementsAttr::get( + getResult().getType().cast(), type.getDimSize(dim)); +} + //===----------------------------------------------------------------------===// // IotaOp //===----------------------------------------------------------------------===// @@ -176,7 +267,7 @@ static LogicalResult Verify(IotaOp op) { if (shape.getRank() == 0) return op.emitOpError() << "does not support scalars."; - auto iota_dimension = op.iota_dimension().getSExtValue(); + auto iota_dimension = op.iota_dimension(); if (iota_dimension >= shape.getRank() || iota_dimension < 0) return op.emitOpError() << "iota dimension cannot go beyond the output " "rank or be negative."; @@ -198,8 +289,7 @@ struct IotaBroadcast : public OpRewritePattern { auto iota_dimension = iota.iota_dimension(); auto iota_type = RankedTensorType::get( - {result_ty.getDimSize(iota_dimension.getLimitedValue())}, - result_ty.getElementType()); + {result_ty.getDimSize(iota_dimension)}, result_ty.getElementType()); auto new_iota = rewriter.create(iota.getLoc(), iota_type, rewriter.getI64IntegerAttr(0)); @@ -219,7 +309,7 @@ void IotaOp::getCanonicalizationPatterns(OwningRewritePatternList& results, } OpFoldResult IotaOp::fold(ArrayRef operands) { - auto dimension = iota_dimension().getLimitedValue(); + auto dimension = iota_dimension(); auto result_ty = getResult().getType().cast(); if (result_ty.hasRank() && result_ty.getDimSize(dimension) == 1) { Builder builder(getContext()); @@ -263,7 +353,7 @@ struct DynamicIotaBroadcast : public OpRewritePattern { } auto iota_dimension = iota.iota_dimension(); - auto iota_dimension_int = iota_dimension.getLimitedValue(); + auto iota_dimension_int = iota_dimension; auto converted_shape = rewriter.create( iota.getLoc(), @@ -462,7 +552,7 @@ static LogicalResult Verify(DequantizeOp op) { //===----------------------------------------------------------------------===// static LogicalResult Verify(GetTupleElementOp op) { - auto indexVal = op.index().getZExtValue(); + auto indexVal = op.index(); auto operandType = op.getOperand().getType().cast(); if (indexVal >= operandType.size()) { return op.emitOpError( @@ -481,7 +571,7 @@ static LogicalResult Verify(GetTupleElementOp op) { OpFoldResult GetTupleElementOp::fold(ArrayRef operands) { if (auto tupleOp = dyn_cast_or_null(getOperand().getDefiningOp())) { - return tupleOp.getOperand(index().getLimitedValue()); + return tupleOp.getOperand(index()); } return {}; @@ -551,8 +641,8 @@ static LogicalResult Verify(AllToAllOp op) { // count. auto type = op.getOperand().getType().dyn_cast(); if (!type) return success(); - auto split_dim_size = type.getDimSize(op.split_dimension().getSExtValue()); - auto split_count = op.split_count().getSExtValue(); + auto split_dim_size = type.getDimSize(op.split_dimension()); + auto split_count = op.split_count(); if (split_dim_size % split_count != 0) { return op.emitError() << "split dimension has size " << split_dim_size << ", expected to be a multiple of split_count " @@ -821,9 +911,10 @@ static LogicalResult Verify(ClampOp op) { // ComplexOp //===----------------------------------------------------------------------===// -void ComplexOp::build(OpBuilder& builder, OperationState& state, Value lhs, - Value rhs) { - auto type = lhs.getType(); +LogicalResult ComplexOp::inferReturnTypes( + MLIRContext*, Optional, ValueRange operands, DictionaryAttr, + RegionRange, SmallVectorImpl& inferredReturnTypes) { + auto type = operands[0].getType(); auto element_ty = ComplexType::get(getElementTypeOrSelf(type)); Type result_ty; if (auto ranked_type = type.dyn_cast()) { @@ -833,8 +924,8 @@ void ComplexOp::build(OpBuilder& builder, OperationState& state, Value lhs, } else { result_ty = element_ty; } - - build(builder, state, result_ty, lhs, rhs); + inferredReturnTypes.push_back(result_ty); + return success(); } OpFoldResult ComplexOp::fold(ArrayRef operands) { @@ -864,8 +955,11 @@ Type CreateRealType(Type type) { } } // namespace -void ImagOp::build(OpBuilder& builder, OperationState& state, Value val) { - build(builder, state, CreateRealType(val.getType()), val); +LogicalResult ImagOp::inferReturnTypes( + MLIRContext*, Optional, ValueRange operands, DictionaryAttr, + RegionRange, SmallVectorImpl& inferredReturnTypes) { + inferredReturnTypes.push_back(CreateRealType(operands[0].getType())); + return success(); } OpFoldResult ImagOp::fold(ArrayRef operands) { @@ -877,8 +971,11 @@ OpFoldResult ImagOp::fold(ArrayRef operands) { return {}; } -void RealOp::build(OpBuilder& builder, OperationState& state, Value val) { - build(builder, state, CreateRealType(val.getType()), val); +LogicalResult RealOp::inferReturnTypes( + MLIRContext*, Optional, ValueRange operands, DictionaryAttr, + RegionRange, SmallVectorImpl& inferredReturnTypes) { + inferredReturnTypes.push_back(CreateRealType(operands[0].getType())); + return success(); } OpFoldResult RealOp::fold(ArrayRef operands) { @@ -900,7 +997,7 @@ class ConcatenateOperandRemoval : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ConcatenateOp op, PatternRewriter& rewriter) const override { - auto axis = op.dimension().getLimitedValue(); + auto axis = op.dimension(); llvm::SmallVector new_operands; for (auto operand : op.getOperands()) { auto ty = operand.getType().cast(); @@ -941,13 +1038,41 @@ LogicalResult ConcatenateOp::inferReturnTypes( } } - // If an input is unranked the output shape is unranked. + // Find the first ranked input to determine the output rank. + for (auto type : operands.getTypes()) { + auto shaped_type = type.cast(); + if (shaped_type.hasRank()) { + first_type = shaped_type; + break; + } + } + + // If all inputs are unranked, the result must be unranked. if (!first_type.hasRank()) { inferredReturnTypes.push_back(UnrankedTensorType::get(out_element)); return success(); } + if (first_type.getRank() == 0) + return emitOptionalError(location, "rank-0 values cannot be concatenated"); + auto out_shape = llvm::to_vector<6>(first_type.getShape()); + + // Determine what the non-concatenate dimensions should be. + for (auto type : operands.getTypes()) { + auto shaped_ty = type.cast(); + if (!shaped_ty.hasRank()) { + continue; + } + + for (auto it : llvm::enumerate(shaped_ty.getShape())) { + // If a dimension is not dynamic, the output shape should match. + if (ShapedType::isDynamic(out_shape[it.index()])) { + out_shape[it.index()] = it.value(); + } + } + } + out_shape[dimension] = 0; for (auto operand : operands.getTypes()) { @@ -980,7 +1105,7 @@ void ConcatenateOp::getCanonicalizationPatterns( template static Attribute foldConcatenateHelper(ConcatenateOp* op, ArrayRef operands) { - auto axis = op->dimension().getLimitedValue(); + auto axis = op->dimension(); auto type = op->getType().cast(); SmallVector values; @@ -1028,7 +1153,7 @@ OpFoldResult ConcatenateOp::fold(ArrayRef operands) { ShapedType type = getResult().getType().cast(); if (!type.hasStaticShape()) return {}; - auto axis = dimension().getLimitedValue(); + auto axis = dimension(); if (auto attr = foldConcatenate(this, operands)) { return attr; } @@ -1203,6 +1328,131 @@ static LogicalResult Verify(InfeedOp op) { return success(); } +//===----------------------------------------------------------------------===// +// Logical Ops +//===----------------------------------------------------------------------===// + +OpFoldResult AndOp::fold(ArrayRef operands) { + if (lhs() == rhs()) return lhs(); + + auto rType = getType().cast(); + auto lhsVal = operands[0].dyn_cast_or_null(); + auto rhsVal = operands[1].dyn_cast_or_null(); + + if (lhsVal && lhsVal.isSplat()) { + if (lhsVal.getSplatValue() + .cast() + .getValue() + .isAllOnesValue()) { + return rhs(); + } + + if (lhsVal.getSplatValue().cast().getValue().isNullValue()) { + return lhsVal; + } + } + + if (rhsVal && rhsVal.isSplat()) { + if (rhsVal.getSplatValue() + .cast() + .getValue() + .isAllOnesValue()) { + return lhs(); + } + + if (rhsVal.getSplatValue().cast().getValue().isNullValue()) { + return rhsVal; + } + } + + if (!rhsVal || !lhsVal) return {}; + + llvm::SmallVector values; + values.reserve(rhsVal.getNumElements()); + for (auto it : llvm::zip(rhsVal.getIntValues(), lhsVal.getIntValues())) { + values.push_back(std::get<0>(it) & std::get<1>(it)); + } + + return DenseIntElementsAttr::get(rType, values); +} + +OpFoldResult OrOp::fold(ArrayRef operands) { + if (lhs() == rhs()) return lhs(); + + auto rType = getType().cast(); + auto lhsVal = operands[0].dyn_cast_or_null(); + auto rhsVal = operands[1].dyn_cast_or_null(); + + if (lhsVal && lhsVal.isSplat()) { + if (lhsVal.getSplatValue() + .cast() + .getValue() + .isAllOnesValue()) { + return lhsVal; + } + + if (lhsVal.getSplatValue().cast().getValue().isNullValue()) { + return rhs(); + } + } + + if (rhsVal && rhsVal.isSplat()) { + if (rhsVal.getSplatValue() + .cast() + .getValue() + .isAllOnesValue()) { + return rhsVal; + } + + if (rhsVal.getSplatValue().cast().getValue().isNullValue()) { + return lhs(); + } + } + + if (!rhsVal || !lhsVal) return {}; + + llvm::SmallVector values; + values.reserve(rhsVal.getNumElements()); + for (auto it : llvm::zip(rhsVal.getIntValues(), lhsVal.getIntValues())) { + values.push_back(std::get<0>(it) | std::get<1>(it)); + } + + return DenseIntElementsAttr::get(rType, values); +} + +OpFoldResult XorOp::fold(ArrayRef operands) { + auto rType = getType().cast(); + if (lhs() == rhs()) { + Builder builder(getContext()); + return builder.getZeroAttr(rType); + } + + auto lhsVal = operands[0].dyn_cast_or_null(); + auto rhsVal = operands[1].dyn_cast_or_null(); + + if (lhsVal && lhsVal.isSplat()) { + if (lhsVal.getSplatValue().cast().getValue().isNullValue()) { + return rhs(); + } + } + + if (rhsVal && rhsVal.isSplat()) { + if (rhsVal.getSplatValue().cast().getValue().isNullValue()) { + return lhs(); + } + } + + if (!rhsVal || !lhsVal) return {}; + + llvm::SmallVector values; + values.reserve(rhsVal.getNumElements()); + for (auto it : llvm::zip(rhsVal.getIntValues(), lhsVal.getIntValues())) { + values.push_back(std::get<0>(it) ^ std::get<1>(it)); + } + + return DenseIntElementsAttr::get(rType, values); +} + //===----------------------------------------------------------------------===// // MapOp //===----------------------------------------------------------------------===// @@ -1396,6 +1646,29 @@ static LogicalResult Verify(SelectOp op) { return success(); } +OpFoldResult SelectOp::fold(ArrayRef operands) { + if (on_true() == on_false()) { + return on_true(); + } + + auto predicate = operands[0].dyn_cast_or_null(); + if (!predicate) { + return {}; + } + + auto predicateTy = predicate.getType().cast(); + if (!predicateTy.getElementType().isInteger(1)) { + return {}; + } + + if (predicate.isSplat()) { + return predicate.getSplatValue().getBoolValue() ? on_true() + : on_false(); + } + + return {}; +} + // Makes it such that a SelectOp that is a non-root operation in a DRR infers // the return type based on operand type. LogicalResult SelectOp::inferReturnTypes( @@ -1437,6 +1710,20 @@ LogicalResult SelectOp::inferReturnTypes( return success(); } +LogicalResult SelectOp::inferReturnTypeComponents( + mlir::MLIRContext*, llvm::Optional, mlir::ValueRange, + mlir::DictionaryAttr, mlir::RegionRange, + llvm::SmallVectorImpl&) { + // TODO(b/168772852) + return failure(); +} + +LogicalResult SelectOp::reifyReturnTypeShapes( + OpBuilder& builder, SmallVectorImpl& reifiedReturnShapes) { + return deriveShapeFromFirstOperand(&builder, getOperation(), + &reifiedReturnShapes); +} + //===----------------------------------------------------------------------===// // PadOp //===----------------------------------------------------------------------===// @@ -1584,6 +1871,79 @@ static LogicalResult Verify(CaseOp op) { return success(); } +//===----------------------------------------------------------------------===// +// SqrtOp +//===----------------------------------------------------------------------===// + +OpFoldResult SqrtOp::fold(ArrayRef operands) { + auto val = operands[0].dyn_cast_or_null(); + if (!val) return {}; + + auto type = getElementTypeOrSelf(getType()); + if (!type.isF32() && !type.isF64()) return {}; + + auto shaped_type = getType().cast(); + if (!shaped_type.hasStaticShape()) return {}; + + int bit_width = type.getIntOrFloatBitWidth(); + llvm::SmallVector values; + values.reserve(val.getNumElements()); + for (auto it : val.getFloatValues()) { + double value = bit_width == 32 ? it.convertToFloat() : it.convertToDouble(); + if (value < 0) return {}; + value = std::sqrt(value); + if (bit_width == 32) + values.emplace_back(static_cast(value)); + else + values.emplace_back(value); + } + return DenseFPElementsAttr::get(shaped_type, values); +} + +//===----------------------------------------------------------------------===// +// UnaryOps +//===----------------------------------------------------------------------===// + +template +static Attribute UnaryFolder(Op* op, ArrayRef attrs) { + if (!attrs[0]) return {}; + + DenseElementsAttr val = attrs[0].dyn_cast(); + if (!val) return {}; + + ShapedType type = op->getType().template cast(); + if (!type.hasStaticShape()) { + return {}; + } + + Type etype = type.getElementType(); + + // Evaluate for integer values. + if (!etype.isa()) { + return {}; + } + + SmallVector values; + values.reserve(val.getNumElements()); + for (const auto v : val.getValues()) { + values.push_back(Convert()(v)); + } + + return DenseElementsAttr::get(type, values); +} + +#define UNARY_FOLDER(Op, Func) \ + OpFoldResult Op::fold(ArrayRef attrs) { \ + if (getElementTypeOrSelf(getType()).isa()) \ + return UnaryFolder>(this, attrs); \ + if (getElementTypeOrSelf(getType()).isa()) \ + return UnaryFolder>(this, attrs); \ + return {}; \ + } + +UNARY_FOLDER(NegOp, std::negate); + //===----------------------------------------------------------------------===// // BinaryOps //===----------------------------------------------------------------------===// @@ -1643,6 +2003,23 @@ struct divide { APInt operator()(const APInt& a, const APInt& b) const { return a.sdiv(b); } }; +template +struct remainder : std::modulus {}; + +template <> +struct remainder { + APInt operator()(const APInt& a, const APInt& b) const { return a.srem(b); } +}; + +template <> +struct remainder { + APFloat operator()(const APFloat& a, const APFloat& b) const { + APFloat result(a); + result.remainder(b); + return result; + } +}; + template struct max { T operator()(const T& a, const T& b) const { return std::max(a, b); } @@ -1684,6 +2061,7 @@ BINARY_FOLDER(AddOp, std::plus); BINARY_FOLDER(SubOp, std::minus); BINARY_FOLDER(MulOp, std::multiplies); BINARY_FOLDER(DivOp, divide); +BINARY_FOLDER(RemOp, remainder); BINARY_FOLDER(MaxOp, max); BINARY_FOLDER(MinOp, min); @@ -1758,11 +2136,11 @@ static Attribute FoldSlice(SliceOp* op, I values) { OpFoldResult SliceOp::fold(ArrayRef operands) { // Check if the SliceOp is a NoOp operation. - auto operand_shape = getOperand().getType().cast().getShape(); + auto operand_type = getOperand().getType().cast(); auto result_type = getResult().getType().cast(); - auto result_shape = result_type.getShape(); - if (result_type.hasStaticShape() && (operand_shape == result_shape)) { + if (operand_type.hasStaticShape() && result_type.hasStaticShape() && + (operand_type.getShape() == result_type.getShape())) { return getOperand(); } @@ -1808,7 +2186,7 @@ struct SimplifyConcatSlice : public OpRewritePattern { return failure(); } - auto dimension = concat.dimension().getSExtValue(); + auto dimension = concat.dimension(); auto start = slice.start_indices().getIntValues(); auto limit = slice.limit_indices().getIntValues(); @@ -1933,10 +2311,7 @@ void SortOp::build(OpBuilder& builder, OperationState& state, state.addAttribute("dimension", builder.getI64IntegerAttr(dimension)); state.addAttribute("is_stable", builder.getBoolAttr(dimension)); - SmallVector element_types; - element_types.reserve(operands.size()); - for (Value operand : operands) element_types.push_back(operand.getType()); - state.addTypes(builder.getTupleType(element_types)); + for (Value operand : operands) state.addTypes(operand.getType()); state.addRegion(); } @@ -1958,7 +2333,7 @@ static LogicalResult Verify(SortOp op) { return op.emitOpError("requires all inputs to have the same dimensions"); int64_t rank = input_shape.size(); - int64_t cmp_dim = op.dimension().getSExtValue(); + int64_t cmp_dim = op.dimension(); if (cmp_dim < -rank || cmp_dim >= rank) return op.emitOpError("dimension attribute value must be in range [-") << rank << ", " << rank << "), but found " << cmp_dim; @@ -2159,9 +2534,267 @@ void CompareOp::build(OpBuilder& builder, OperationState& result, Value lhs, build(builder, result, new_type, lhs, rhs, comparison_direction); } +LogicalResult CompareOp::inferReturnTypeComponents( + mlir::MLIRContext*, llvm::Optional, mlir::ValueRange, + mlir::DictionaryAttr, mlir::RegionRange, + llvm::SmallVectorImpl&) { + // TODO(b/168772852) + return failure(); +} + +LogicalResult CompareOp::reifyReturnTypeShapes( + OpBuilder& builder, SmallVectorImpl& reifiedReturnShapes) { + return deriveShapeFromFirstOperand(&builder, getOperation(), + &reifiedReturnShapes); +} + +template +struct less : std::less {}; + +template <> +struct less { + bool operator()(const APInt& a, const APInt& b) const { return a.slt(b); } +}; + +template +struct less_equal : std::less_equal {}; + +template <> +struct less_equal { + bool operator()(const APInt& a, const APInt& b) const { return a.sle(b); } +}; + +template +struct greater : std::greater {}; + +template <> +struct greater { + bool operator()(const APInt& a, const APInt& b) const { return a.sgt(b); } +}; + +template +struct greater_equal : std::greater_equal {}; + +template <> +struct greater_equal { + bool operator()(const APInt& a, const APInt& b) const { return a.sge(b); } +}; + +template +static Attribute CompareFolder(CompareOp op, ArrayRef attrs) { + if (!attrs[0] || !attrs[1]) return {}; + + DenseElementsAttr lhs = attrs[0].dyn_cast(); + DenseElementsAttr rhs = attrs[1].dyn_cast(); + if (!lhs || !rhs) return {}; + + ShapedType operand_type = + op.getOperand(0).getType().template cast(); + if (!operand_type.hasStaticShape()) { + return {}; + } + + if (!operand_type.getElementType().isa()) { + return {}; + } + + SmallVector values; + values.reserve(lhs.getNumElements()); + for (const auto zip : + llvm::zip(lhs.getValues(), rhs.getValues())) { + values.push_back(Convert()(std::get<0>(zip), std::get<1>(zip))); + } + + auto result_ty = op.getType().cast(); + return DenseElementsAttr::get(result_ty, values); +} + +OpFoldResult CompareOp::fold(ArrayRef operands) { + auto result_ty = getType().cast(); + if (!result_ty.hasStaticShape()) return {}; + + auto direction = comparison_direction(); + if (lhs() == rhs()) { + if (direction == "LE" || direction == "EQ" || direction == "GE") { + return DenseIntElementsAttr::get(result_ty, {true}); + } + + return DenseIntElementsAttr::get(result_ty, {false}); + } + + if (!operands[0] || !operands[1]) { + return {}; + } + +#define COMPARE_FOLDER(Op, comparison, Func) \ + if (direction == comparison) { \ + if (auto folded = CompareFolder>( \ + *this, operands)) \ + return folded; \ + if (auto folded = CompareFolder>( \ + *this, operands)) \ + return folded; \ + } + + COMPARE_FOLDER(CompareOp, "EQ", std::equal_to); + COMPARE_FOLDER(CompareOp, "NE", std::not_equal_to); + COMPARE_FOLDER(CompareOp, "LT", less); + COMPARE_FOLDER(CompareOp, "LE", less_equal); + COMPARE_FOLDER(CompareOp, "GT", greater); + COMPARE_FOLDER(CompareOp, "GE", greater_equal); +#undef COMPARE_FOLDER + + return {}; +} + +//===----------------------------------------------------------------------===// +// ScatterOp +//===----------------------------------------------------------------------===// + +llvm::SmallVector evaluateMhloRegion(Region& region, + ArrayRef inputs) { + if (region.getNumArguments() != inputs.size()) return {}; + + llvm::DenseMap values; + values.reserve(region.getNumArguments()); + for (auto it : llvm::zip(region.getArguments(), inputs)) { + values.try_emplace(std::get<0>(it), std::get<1>(it)); + } + + for (auto& op : region.getOps()) { + llvm::SmallVector inputs; + for (auto& operand : op.getOpOperands()) { + inputs.push_back(values.lookup(operand.get())); + } + if (isa(op)) return inputs; + + llvm::SmallVector results; + if (failed(op.fold(inputs, results))) return {}; + for (auto it : llvm::zip(op.getResults(), results)) { + if (!std::get<1>(it).is()) return {}; + values.insert({std::get<0>(it), std::get<1>(it).get()}); + } + } + return {}; +} + +OpFoldResult ScatterOp::fold(ArrayRef operands) { + auto base = operands[0].dyn_cast_or_null(); + auto index = operands[1].dyn_cast_or_null(); + auto update = operands[2].dyn_cast_or_null(); + if (!base || !index || !update) return {}; + + auto base_type = base.getType().dyn_cast(); + auto index_type = index.getType().dyn_cast(); + auto update_type = update.getType().dyn_cast(); + if (!base_type || !index_type || !update_type) return {}; + + // Add the virtual trailing dimension of size 1 if index_vector_dim equals to + // index_type.rank. + const int64_t index_vector_dim = + scatter_dimension_numbers().index_vector_dim().getInt(); + if (index_vector_dim == index_type.getRank()) { + auto index_shape = index_type.getShape().vec(); + index_shape.push_back(1); + index_type = + RankedTensorType::get(index_shape, index_type.getElementType()); + index = index.reshape(index_type).cast(); + } + + // Increment the multi-dimensional index vector based on the limits for each + // dimension specified by shape and returns false if the index rolled around + // with true otherwise. + auto next_index = [](llvm::SmallVector& index, + llvm::ArrayRef shape) { + for (int64_t i = index.size() - 1; i >= 0; --i) { + ++index[i]; + if (index[i] < shape[i]) return true; + index[i] = 0; + } + return false; + }; + + // Iterate over all elements of the update tensor, then find the corresponding + // value in the indices tensor to determine which location we have to update + // in the base/result tensor. + llvm::SmallVector results(base.getValues()); + llvm::SmallVector update_index(update_type.getRank(), 0); + llvm::SmallVector index_index; + index_index.reserve(index_type.getRank()); + llvm::SmallVector base_index; + base_index.reserve(base_type.getRank()); + do { + // Compute the index for the slice of the indices tensor for this update + // value. + index_index.clear(); + if (index_vector_dim == 0) index_index.push_back(0); + for (int64_t i = 0; i < update_index.size(); ++i) { + if (llvm::count(scatter_dimension_numbers().update_window_dims(), i) == 0) + index_index.push_back(update_index[i]); + if (index_index.size() == index_vector_dim) index_index.push_back(0); + } + + // Compute the index for the given update value in the base tensor. + base_index.assign(base_type.getRank(), 0); + uint64_t index_count = index_type.getShape()[index_vector_dim]; + for (uint64_t i = 0; i < index_count; ++i) { + uint64_t operand_dim = scatter_dimension_numbers() + .scatter_dims_to_operand_dims() + .getValue({i}) + .getSExtValue(); + index_index[index_vector_dim] = i; + base_index[operand_dim] += + index.getValue(index_index).getSExtValue(); + } + uint64_t update_window_dim_index = 0; + for (uint64_t i = 0; i < base_index.size(); ++i) { + if (llvm::count(scatter_dimension_numbers().inserted_window_dims(), i)) + continue; + base_index[i] += + update_index[scatter_dimension_numbers() + .update_window_dims() + .getValue({update_window_dim_index}) + .getSExtValue()]; + update_window_dim_index++; + } + + // Compute the linear index for the index into the base tensor. + int64_t linear_base_index = 0; + int64_t linear_base_index_multiplyer = 1; + for (int64_t i = base_index.size() - 1; i >= 0; --i) { + // Out of bound index have backend specific behaviour so avoid folding it. + if (base_index[i] < 0 || base_index[i] >= base_type.getShape()[i]) + return {}; + linear_base_index += base_index[i] * linear_base_index_multiplyer; + linear_base_index_multiplyer *= base_type.getShape()[i]; + } + + // Evaluate update computation and update the value with the newly computed + // attribute in the base tensor. + auto lhs = DenseElementsAttr::get( + RankedTensorType::get({}, base_type.getElementType()), + results[linear_base_index]); + auto rhs = DenseElementsAttr::get( + RankedTensorType::get({}, base_type.getElementType()), + update.getValue(update_index)); + auto new_value = evaluateMhloRegion(update_computation(), {lhs, rhs}); + if (new_value.size() != 1 || !new_value[0]) return {}; + results[linear_base_index] = + new_value[0].cast().getValue({}); + } while (next_index(update_index, update_type.getShape())); + + return DenseElementsAttr::get(base_type, results); +} + +} // namespace mhlo +} // namespace mlir + #define GET_OP_CLASSES #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.cc.inc" +namespace mlir { +namespace mhlo { + //===----------------------------------------------------------------------===// // mhlo Dialect Interfaces //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/xla/service/gpu/ir/dialect_registration.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops_base_structs.cc similarity index 75% rename from tensorflow/compiler/xla/service/gpu/ir/dialect_registration.cc rename to tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops_base_structs.cc index 2e3461951d8..90da1251ea0 100644 --- a/tensorflow/compiler/xla/service/gpu/ir/dialect_registration.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops_base_structs.cc @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/ir/xla_thunks_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h" -// Static initialization for GPU thunks op registration. -static mlir::DialectRegistration - xla_thunks_ops; +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.cc.inc" diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/init.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/init.cc index cf8bd257d20..ca8c6a8d150 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/init.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/init.cc @@ -15,27 +15,15 @@ limitations under the License. #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/register.h" -// Static initialization for *HLO dialects registration. - -void mlir::mhlo::registerAllDialects() { - static bool init_once = []() { - registerDialect(); - registerDialect(); - registerDialect(); - return true; - }(); - (void)init_once; - - // Dependent dialects -} - void mlir::mhlo::registerAllMhloDialects(mlir::DialectRegistry ®istry) { // clang-format off registry.insert(); + mlir::lmhlo_gpu::LmhloGpuDialect>(); // clang-format on } diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_gpu_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_gpu_ops.cc new file mode 100644 index 00000000000..10c5c0c2f9d --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_gpu_ops.cc @@ -0,0 +1,64 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines the operations used in the LMHLO GPU dialect. + +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h" + +#include +#include +#include + +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/FormatVariadic.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" + +namespace mlir { +namespace lmhlo_gpu { + +LmhloGpuDialect::LmhloGpuDialect(MLIRContext *context) + : Dialect(getDialectNamespace(), context, TypeID::get()) { + addOperations< +#define GET_OP_LIST +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.cc.inc" + >(); +} + +// TODO(jurahul): Add verification for operand shapes and ranks. + +} // namespace lmhlo_gpu +} // namespace mlir + +#define GET_OP_CLASSES +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.cc.inc" diff --git a/tensorflow/python/util/tf32.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_gpu_ops_structs.cc similarity index 73% rename from tensorflow/python/util/tf32.cc rename to tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_gpu_ops_structs.cc index 7dece6ccdae..cd2cfc58836 100644 --- a/tensorflow/python/util/tf32.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_gpu_ops_structs.cc @@ -13,10 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "pybind11/pybind11.h" -#include "tensorflow/core/platform/tf32_utils.h" +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h" -PYBIND11_MODULE(_pywrap_tf32_execution, m) { - m.def("allow", &tensorflow::allow_tf32_execution); - m.def("is_allowed", &tensorflow::tf32_execution_allowed); -} +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.cc.inc" diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_ops.cc index 81407c89204..4524cf3ec1f 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_ops.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_ops.cc @@ -46,7 +46,6 @@ limitations under the License. #include "mlir/IR/Value.h" namespace mlir { -#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.cc.inc" namespace lmhlo { LmhloDialect::LmhloDialect(MLIRContext *context) @@ -159,9 +158,15 @@ static LogicalResult Verify(ReshapeMemRefCastOp op) { return success(); } +} // namespace lmhlo +} // namespace mlir + #define GET_OP_CLASSES #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.cc.inc" +namespace mlir { +namespace lmhlo { + // TODO(cheshire): Support folding, reuse code from hlo_ops.cc. void FusionOp::build(OpBuilder &builder, OperationState &result, diff --git a/tensorflow/compiler/mlir/lite/ir/dialect_registration.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_ops_structs.cc similarity index 70% rename from tensorflow/compiler/mlir/lite/ir/dialect_registration.cc rename to tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_ops_structs.cc index fae20437811..83dd4e62b47 100644 --- a/tensorflow/compiler/mlir/lite/ir/dialect_registration.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_ops_structs.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,7 +13,5 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" - -// Static initialization for TensorFlow Lite op registration. -static mlir::DialectRegistration tfl_ops; +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.cc.inc" +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h" diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/CMakeLists.txt b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/CMakeLists.txt index bb9f98d32d3..354913264bb 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/CMakeLists.txt +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/CMakeLists.txt @@ -25,6 +25,10 @@ set(LLVM_TARGET_DEFINITIONS legalize_to_standard_patterns.td) mlir_tablegen(generated_legalize_to_standard.inc -gen-rewriters) add_public_tablegen_target(MLIRMhloLegalizeToStandardIncGen) +set(LLVM_TARGET_DEFINITIONS chlo_legalize_to_hlo_patterns.td) +mlir_tablegen(generated_chlo_legalize_to_hlo.inc -gen-rewriters) +add_public_tablegen_target(MLIRChloLegalizeToHloIncGen) + add_mlir_library(ChloPasses chlo_legalize_to_hlo.cc @@ -32,6 +36,7 @@ add_mlir_library(ChloPasses DEPENDS MLIRhlo_opsIncGen + MLIRChloLegalizeToHloIncGen LINK_COMPONENTS Core @@ -44,7 +49,7 @@ add_mlir_library(ChloPasses add_mlir_library(MhloPasses legalize_gather_to_torch_index_select.cc - legalize_tanh_to_approximation.cc + legalize_trigonometric_to_approximation.cc lower_complex.cc lower_complex_patterns.td lower_general_dot.cc @@ -93,6 +98,7 @@ add_mlir_library(MhloToLhloConversion add_mlir_library(MhloToStandard legalize_control_flow.cc legalize_to_standard.cc + mhlo_control_flow_to_scf.cc DEPENDS MLIRhlo_opsIncGen @@ -124,7 +130,6 @@ add_mlir_library(MhloLhloToLinalg ) add_mlir_library(LmhloPasses - lhlo_copy_removal.cc lhlo_fuse_linalg.cc lhlo_legalize_to_affine.cc lhlo_legalize_to_gpu.cc diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc index c2db4880632..42d6d70b524 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" @@ -31,6 +33,39 @@ namespace mlir { namespace chlo { namespace { +struct ConvertConstantLikeOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + ConstantLikeOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto result_ty = op.getType().cast(); + + // Unranked uses are not supported. Consider `transform-unranked-hlo`. + if (!result_ty.hasRank()) return failure(); + + // Lower to MHLO constant if statically shaped. + if (result_ty.hasStaticShape()) { + rewriter.replaceOpWithNewOp( + op, DenseElementsAttr::get(result_ty, op.value())); + return success(); + } + + // Lower to broadcasted constant. + ConstantLikeOp::Adaptor transformed(operands); + auto loc = op.getLoc(); + Type extent_tensor_type = shape::getExtentTensorType(op.getContext()); + Value constant = rewriter.create(loc, op.value()); + Value uncasted_shape = rewriter.create( + loc, extent_tensor_type, transformed.operand()); + Type shape_ty = + RankedTensorType::get({result_ty.getRank()}, rewriter.getIndexType()); + Value shape = rewriter.create(loc, shape_ty, uncasted_shape); + rewriter.replaceOpWithNewOp( + op, result_ty, constant, shape, rewriter.getI64TensorAttr({})); + return success(); + } +}; + // Converts binary ops that statically are determined to not broadcast directly // to the corresponding mhlo non-broadcasting op. template @@ -248,7 +283,7 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp auto if_op = rewriter.create( loc, result_type, IsScalarTensor(rewriter, op, lhs), true); OpBuilder if_lhs_scalar_builder = if_op.getThenBodyBuilder(); - Value reshaped_lhs = if_lhs_scalar_builder.create( + Value reshaped_lhs = if_lhs_scalar_builder.create( loc, RankedTensorType::get({}, lhs_type.getElementType()), lhs); Value if_lhs_scalar_result = if_lhs_scalar_builder.create( loc, ArrayRef{result_type}, ArrayRef{reshaped_lhs, rhs}, @@ -265,7 +300,7 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp else_lhs_scalar_builder.create(loc, if_rhs_scalar_op.getResult(0)); OpBuilder if_rhs_scalar_builder = if_rhs_scalar_op.getThenBodyBuilder(); - Value reshaped_rhs = if_rhs_scalar_builder.create( + Value reshaped_rhs = if_rhs_scalar_builder.create( loc, RankedTensorType::get({}, lhs_type.getElementType()), rhs); Value if_rhs_scalar_result = if_rhs_scalar_builder.create( loc, ArrayRef{result_type}, ArrayRef{lhs, reshaped_rhs}, @@ -338,30 +373,37 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp Value lhs_shape = if_builder.create(loc, lhs); Value rhs_shape = if_builder.create(loc, rhs); SmallVector ranked_shape(targeted_rank, 1); - auto extent_tensor_type = + auto unknown_rank_extent_tensor_type = RankedTensorType::get( + {RankedTensorType::kDynamicSize}, builder.getIndexType()); + auto known_rank_extent_tensor_type = RankedTensorType::get({targeted_rank}, builder.getIndexType()); auto reshaped_type = RankedTensorType::get( llvm::SmallVector(targeted_rank, RankedTensorType::kDynamicSize), lhs.getType().template dyn_cast().getElementType()); Value ranked_shape_val = if_builder.create( - loc, extent_tensor_type, - mlir::DenseIntElementsAttr::get(extent_tensor_type, ranked_shape)); - // TODO(tpopp): Return extent tensors when possible to signal that this is a - // guaranteed safe broadcast by construction. + loc, known_rank_extent_tensor_type, + mlir::DenseIntElementsAttr::get(known_rank_extent_tensor_type, + ranked_shape)); Value extended_lhs = if_builder.create( - loc, extent_tensor_type, lhs_shape, ranked_shape_val, nullptr); + loc, unknown_rank_extent_tensor_type, lhs_shape, ranked_shape_val, + nullptr); + Value extended_lhs_casted = if_builder.create( + loc, known_rank_extent_tensor_type, extended_lhs); Value extended_rhs = if_builder.create( - loc, extent_tensor_type, rhs_shape, ranked_shape_val, nullptr); + loc, unknown_rank_extent_tensor_type, rhs_shape, ranked_shape_val, + nullptr); + Value extended_rhs_casted = if_builder.create( + loc, known_rank_extent_tensor_type, extended_rhs); // 1. Reshape operands to the given rank (with the same number of elements) // 2. Compute the ranked-broadcasted ChloOp (which will assert that the ops // can be broadcasted and do the actual broadcasting) // 3. Type erase the output back to unranked Value reshaped_lhs = if_builder.create( - loc, reshaped_type, lhs, extended_lhs); + loc, reshaped_type, lhs, extended_lhs_casted); Value reshaped_rhs = if_builder.create( - loc, reshaped_type, rhs, extended_rhs); + loc, reshaped_type, rhs, extended_rhs_casted); Value result = if_builder.create( loc, ArrayRef{reshaped_type}, ArrayRef{reshaped_lhs, reshaped_rhs}, op.getAttrs()); @@ -469,10 +511,13 @@ struct HloCompareAdaptor { } }; +#include "generated_chlo_legalize_to_hlo.inc" } // namespace void PopulateLegalizeChloToHloPatterns(MLIRContext *context, OwningRewritePatternList *patterns) { + populateWithGenerated(context, *patterns); + // Instantiate conversion templates for conforming binary elementwise ops // that do not have different dtypes between operands and results and do // not have special attributes that need to be preserved. @@ -502,6 +547,9 @@ void PopulateLegalizeChloToHloPatterns(MLIRContext *context, context, patterns); PopulateForBinaryOp( context, patterns); + + // Other patterns. + patterns->insert(context); } } // namespace chlo diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc index 50cd6df5c99..d2f415d91f9 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc @@ -27,16 +27,21 @@ namespace mhlo { namespace { -struct TestChloLegalizeToHloPass - : public PassWrapper { +struct ChloLegalizeToHloPass + : public PassWrapper { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnFunction() override { ConversionTarget conversionTarget(getContext()); OwningRewritePatternList conversionPatterns; - conversionTarget.addIllegalDialect(); + // Consider the mhlo dialect legal for tests. conversionTarget.addLegalDialect(); - // The conversion uses helpers from the Standard dialect. + + // The conversion uses helpers from the standard dialect. conversionTarget.addLegalDialect(); conversionTarget.addLegalDialect(); conversionTarget.addLegalDialect(); @@ -52,8 +57,8 @@ struct TestChloLegalizeToHloPass } // namespace -std::unique_ptr createTestChloLegalizeToHloPass() { - return std::make_unique(); +std::unique_ptr createChloLegalizeToHloPass() { + return std::make_unique(); } } // namespace mhlo diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td new file mode 100644 index 00000000000..a48abb6190c --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td @@ -0,0 +1,107 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This is the legalization pattern definition file for CHLO to MHLO. + +include "mlir/IR/OpBase.td" +include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td" +include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.td" + +//===----------------------------------------------------------------------===// +// Unary op patterns. +//===----------------------------------------------------------------------===// + +// Expand acos to MHLO dialect as follows: +// acos(x) = 2 * atan2(sqrt(1 - x^2), (1 + x)) if x != -1 +// = pi if x == -1 +def : Pat<(HLOClient_AcosOp $input), + (HLO_SelectOp + (HLO_CompareOp + $input, + (HLO_ConstantLike<"-1"> $input), + HLO_COMPARISON_DIRECTION_NE + ), + (HLO_MulOp + (HLO_ConstantLike<"2"> $input), + (HLO_Atan2Op + (HLO_SqrtOp + (HLO_SubOp + (HLO_ConstantLike<"1"> $input), + (HLO_MulOp $input, $input) + ) + ), + (HLO_AddOp + (HLO_ConstantLike<"1"> $input), + $input + ) + ) + ), + (HLO_ConstantLike<"M_PI"> $input) + )>; + +// Express `atan` as +// atan(x) = atan2(x, 1) +def : Pat<(HLOClient_AtanOp $input), + (HLO_Atan2Op + $input, + (HLO_ConstantLike<"1"> $input) + )>; + +// Express `sinh` as +// sinh(x) = (e^x - e^-x) / 2 if |x| < 1 +// = e^(x + log(1/2)) - e^(-x + log(1/2)) otherwise. +def : Pat<(HLOClient_SinhOp $input), + (HLO_SelectOp + (HLO_CompareOp + (HLO_AbsOp $input), + (HLO_ConstantLike<"1"> $input), + HLO_COMPARISON_DIRECTION_LT + ), + (HLO_DivOp + (HLO_SubOp + (HLO_ExpOp $input), + (HLO_ExpOp + (HLO_NegOp $input) + ) + ), + (HLO_ConstantLike<"2"> $input) + ), + (HLO_SubOp + (HLO_ExpOp + (HLO_AddOp + $input, + (HLO_LogOp + (HLO_ConstantLike<"0.5"> $input) + ) + ) + ), + (HLO_ExpOp + (HLO_SubOp + (HLO_LogOp + (HLO_ConstantLike<"0.5"> $input) + ), + $input + ) + ) + ) + )>; + +// Express tan in MHLO dialect as +// tan(x) = sin(x) / cos(x). +def : Pat<(HLOClient_TanOp $input), + (HLO_DivOp + (HLO_SinOp $input), + (HLO_CosOp $input) + )>; diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc index a8c3ad17ebb..7b401d56e8c 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc @@ -20,6 +20,8 @@ limitations under the License. #include "mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h" #include "mlir-hlo/Dialect/mhlo/transforms/passes.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/Dialect/Shape/Transforms/Passes.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" @@ -32,7 +34,7 @@ limitations under the License. #include "mlir/IR/PatternMatch.h" #include "mlir/IR/StandardTypes.h" #include "mlir/Pass/Pass.h" -#include "mlir/Transforms/BufferPlacement.h" +#include "mlir/Transforms/Bufferize.h" #include "mlir/Transforms/DialectConversion.h" namespace mlir { @@ -40,12 +42,12 @@ namespace mhlo { namespace { template -using BaseOpConversion = BufferAssignmentOpConversionPattern; +using BaseOpConversion = BufferizeOpConversionPattern; Value InsertDynamicAllocAndDealloc(Location loc, Value result, Value shape_operand, ConversionPatternRewriter* rewriter) { - auto result_type = result.getType().dyn_cast(); + auto result_type = result.getType().dyn_cast(); if (!result_type) { result.getDefiningOp()->emitOpError() << "tensor to buffer conversion expects ranked results"; @@ -53,17 +55,13 @@ Value InsertDynamicAllocAndDealloc(Location loc, Value result, auto memref_type = MemRefType::get(result_type.getShape(), result_type.getElementType()); - Operation* op = result.getDefiningOp(); - // Extract the required element out of the vector. SmallVector dynamic_operands; for (auto shape_element : llvm::enumerate(result_type.getShape())) { if (shape_element.value() != ShapedType::kDynamicSize) continue; - Value index = rewriter->create( - loc, rewriter->getIntegerAttr(rewriter->getIndexType(), - shape_element.index())); - Value alloc_operand = rewriter->create(loc, shape_operand, - ValueRange{index}); + Value index = rewriter->create(loc, shape_element.index()); + Value alloc_operand = + rewriter->create(loc, shape_operand, index); if (!alloc_operand.getType().isIndex()) { alloc_operand = rewriter->create(loc, alloc_operand, rewriter->getIndexType()); @@ -71,16 +69,12 @@ Value InsertDynamicAllocAndDealloc(Location loc, Value result, dynamic_operands.push_back(alloc_operand); } - // Insert in front of op to ensure sizes are available. - OpBuilder allocBuilder(op); - auto alloc = allocBuilder.create(loc, memref_type, dynamic_operands); - return alloc; + return rewriter->create(loc, memref_type, dynamic_operands); } Value InsertAlloc(Location loc, OpResult result, - BufferAssignmentPlacer* bufferAssignment, ConversionPatternRewriter* rewriter) { - auto result_type = result.getType().dyn_cast(); + auto result_type = result.getType().dyn_cast(); if (!result_type || !result_type.hasStaticShape()) { result.getDefiningOp()->emitOpError() << "tensor to buffer conversion expects statically shaped results"; @@ -88,8 +82,7 @@ Value InsertAlloc(Location loc, OpResult result, auto memref_type = MemRefType::get(result_type.getShape(), result_type.getElementType()); OpBuilder::InsertionGuard guard(*rewriter); - rewriter->restoreInsertionPoint( - bufferAssignment->computeAllocPosition(result)); + rewriter->setInsertionPoint(result.getDefiningOp()); auto alloc = rewriter->create(loc, memref_type); return alloc; } @@ -111,8 +104,52 @@ class HloToLhloOpConverter : public BaseOpConversion { return failure(); } if (resultType.hasStaticShape()) { - buffer_args.push_back(InsertAlloc(op->getLoc(), result.value(), - this->bufferAssignment, &rewriter)); + buffer_args.push_back( + InsertAlloc(op->getLoc(), result.value(), &rewriter)); + } else { + auto shape_type_op = dyn_cast(op); + if (!shape_type_op) return failure(); + + SmallVector results_shape; + auto status = + shape_type_op.reifyReturnTypeShapes(rewriter, results_shape); + if (failed(status)) return failure(); + buffer_args.push_back(InsertDynamicAllocAndDealloc( + op->getLoc(), result.value(), results_shape.front(), &rewriter)); + } + } + rewriter.create>(op->getLoc(), llvm::None, + buffer_args, op->getAttrs()); + rewriter.replaceOp( + op, llvm::makeArrayRef(buffer_args).drop_front(operands.size())); + return success(); + } +}; + +// This specialization exists so that LMHLO's Dot can be given a specific set of +// dimension numbers, when lowering from MHLO's Dot, which does not have +// dimension numbers (it uses DotGeneral for this generalized notion of dot +// products). When these two dialects are in sync with respect to the +// Dot/DotGeneral issue, this specialization should be deleted. +template <> +class HloToLhloOpConverter : public BaseOpConversion { + public: + using BaseOpConversion::BaseOpConversion; + LogicalResult matchAndRewrite( + mhlo::DotOp hloOp, ArrayRef operands, + ConversionPatternRewriter& rewriter) const final { + Operation* op = hloOp.getOperation(); + const auto& original_results = op->getResults(); + SmallVector buffer_args(operands.begin(), operands.end()); + for (auto result : llvm::enumerate(original_results)) { + RankedTensorType resultType = + result.value().getType().dyn_cast(); + if (!resultType) { + return failure(); + } + if (resultType.hasStaticShape()) { + buffer_args.push_back( + InsertAlloc(op->getLoc(), result.value(), &rewriter)); } else { SmallVector results_shape; auto shape_type_op = dyn_cast(op); @@ -124,8 +161,20 @@ class HloToLhloOpConverter : public BaseOpConversion { op->getLoc(), result.value(), results_shape.front(), &rewriter)); } } - rewriter.create>(op->getLoc(), llvm::None, - buffer_args, op->getAttrs()); + + // TODO(silvasean): Move this helper to MLIR core. + auto make_elements_attr = [&rewriter](ArrayRef integers) { + auto type = RankedTensorType::get({static_cast(integers.size())}, + rewriter.getIntegerType(64)); + return DenseIntElementsAttr::get(type, integers); + }; + auto dotOp = rewriter.create(op->getLoc(), llvm::None, + buffer_args, op->getAttrs()); + // MHLO's Dot uses rank-2 operands, of the form ([N, M], [M, O]) -> [N, O]. + auto dimension_numbers = mhlo::DotDimensionNumbers::get( + make_elements_attr({}), make_elements_attr({}), make_elements_attr({1}), + make_elements_attr({0}), rewriter.getContext()); + dotOp.dot_dimension_numbersAttr(dimension_numbers); rewriter.replaceOp(op, ArrayRef(buffer_args).slice(operands.size())); return success(); } @@ -241,6 +290,43 @@ struct HloToLhloDynamicReshapeConverter } }; +struct HloToLhloDotGeneralOpConverter + : public BaseOpConversion { + using BaseOpConversion::BaseOpConversion; + LogicalResult matchAndRewrite( + mhlo::DotGeneralOp dotGeneralOp, ArrayRef operands, + ConversionPatternRewriter& rewriter) const final { + Operation* op = dotGeneralOp.getOperation(); + + if (op->getResults().empty()) return failure(); + OpResult result = op->getResults()[0]; + RankedTensorType resultType = result.getType().dyn_cast(); + if (!resultType) return failure(); + + // The third buffer argument will be filled with what used to be the return + // type of the DotGeneral. + if (operands.size() != 2) return failure(); + std::array bufferArgs = {operands[0], operands[1], {}}; + + if (resultType.hasStaticShape()) { + bufferArgs[2] = InsertAlloc(op->getLoc(), result, &rewriter); + } else { + SmallVector results_shape; + auto shape_type_op = dyn_cast(op); + if (failed(shape_type_op.reifyReturnTypeShapes(rewriter, results_shape))) + return failure(); + + bufferArgs[2] = InsertDynamicAllocAndDealloc( + op->getLoc(), result, results_shape.front(), &rewriter); + } + + rewriter.create(op->getLoc(), llvm::None, bufferArgs, + op->getAttrs()); + rewriter.replaceOp(op, bufferArgs[2]); + return success(); + } +}; + struct HloToLhloReduceOpConverter : public BaseOpConversion { public: using BaseOpConversion::BaseOpConversion; @@ -259,8 +345,7 @@ struct HloToLhloReduceOpConverter : public BaseOpConversion { const auto& original_results = op.getResults(); SmallVector buffer_args(operands.begin(), operands.end()); for (auto result : original_results) { - buffer_args.push_back( - InsertAlloc(loc, result, this->bufferAssignment, &rewriter)); + buffer_args.push_back(InsertAlloc(loc, result, &rewriter)); } auto new_op = rewriter.create(loc, llvm::None, buffer_args, op.getAttrs()); @@ -290,11 +375,36 @@ struct HloToLhloReduceOpConverter : public BaseOpConversion { } }; -// Legalize mhlo.return to a lmhlo.copy and lmhlo.terminator. This functionality -// is provided by mlir buffer assignment, so use the pattern from there. -// TODO(DFKI): Move this out of detail. -using HloToLhloReturnOpConverter = detail::BufferAssignmentReturnOpConverter< - mhlo::ReturnOp, lmhlo::TerminatorOp, lmhlo::CopyOp, false>; +// Legalize mhlo.return to a lmhlo.copy and lmhlo.terminator. +struct HloToLhloReturnOpConverter : public BaseOpConversion { + public: + using BaseOpConversion::BaseOpConversion; + + LogicalResult matchAndRewrite( + mhlo::ReturnOp op, ArrayRef operands, + ConversionPatternRewriter& rewriter) const final { + auto loc = op.getLoc(); + auto& entry_block = op.getParentRegion()->front(); + auto num_arguments = entry_block.getNumArguments(); + if (operands.size() > num_arguments) { + return op.emitError( + "The number of operands that need Copy operations is more " + "than the number of target function arguments."); + } + + // The index of the first output block argument. + auto dest_arg_idx = num_arguments - operands.size(); + + // Create a lmhlo.copy for each operand of mhlo.return. + for (Value operand : operands) { + rewriter.create(loc, operand, + entry_block.getArgument(dest_arg_idx)); + ++dest_arg_idx; + } + rewriter.replaceOpWithNewOp(op); + return success(); + } +}; class HloToLhloTensorLoadOpConverter : public BaseOpConversion { @@ -388,6 +498,10 @@ class HloToLhloTensorStoreOpConverter struct HloLegalizeToLhlo : public PassWrapper> { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + public: HloLegalizeToLhlo() = default; HloLegalizeToLhlo(const HloLegalizeToLhlo& o) { @@ -410,7 +524,7 @@ struct HloLegalizeToLhlo target.addLegalOp(); target.addIllegalDialect(); - BufferAssignmentTypeConverter converter; + BufferizeTypeConverter converter; auto isMemRefType = [](Type type) { return type.isa(); }; target.addDynamicallyLegalOp([&](FuncOp op) { auto inputs = op.getType().getInputs(); @@ -427,29 +541,25 @@ struct HloLegalizeToLhlo return std::all_of(op.operand_type_begin(), op.operand_type_end(), isMemRefType); }); - - auto module = getOperation(); - WalkResult result = module.walk([&](FuncOp func) -> WalkResult { - BufferAssignmentPlacer bufferAssignment(func); - OwningRewritePatternList patterns; - populateHLOToLHLOConversionPattern(func.getContext(), &bufferAssignment, - &converter, &patterns); - if (results_escape_function) { - populateWithBufferAssignmentOpConversionPatterns< - mlir::ReturnOp, mlir::ReturnOp, lmhlo::CopyOp, - /*allowMemrefFunctionResults=*/true>(&context, &bufferAssignment, - &converter, &patterns); - } else { - populateWithBufferAssignmentOpConversionPatterns< - mlir::ReturnOp, mlir::ReturnOp, lmhlo::CopyOp, - /*allowMemrefFunctionResults=*/false>(&context, &bufferAssignment, - &converter, &patterns); - } - return applyPartialConversion(func, target, patterns); + target.addDynamicallyLegalOp([&](shape::AssumingOp op) { + return std::all_of(op.result_type_begin(), op.result_type_end(), + isMemRefType); }); - if (result.wasInterrupted()) { + + auto kind = results_escape_function + ? BufferizeTypeConverter::KeepAsFunctionResult + : BufferizeTypeConverter::AppendToArgumentsList; + converter.setResultConversionKind( + kind); + converter.setResultConversionKind(kind); + + populateHLOToLHLOConversionPattern(&context, &converter, &patterns); + populateWithBufferizeOpConversionPatterns( + &context, converter, patterns); + populateShapeTypeConversionPatterns(&context, converter, patterns); + if (failed(applyPartialConversion(getOperation(), target, patterns))) signalPassFailure(); - } } private: @@ -461,16 +571,18 @@ struct HloLegalizeToLhlo }; } // namespace -void populateHLOToLHLOConversionPattern( - MLIRContext* context, BufferAssignmentPlacer* bufferAssignment, - TypeConverter* converter, OwningRewritePatternList* patterns) { +void populateHLOToLHLOConversionPattern(MLIRContext* context, + BufferizeTypeConverter* converter, + OwningRewritePatternList* patterns) { // clang-format off patterns->insert< + HloToLhloDotGeneralOpConverter, HloToLhloDynamicBroadcastInDimOpConverter, HloToLhloDynamicReshapeConverter, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, + HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, @@ -480,31 +592,38 @@ void populateHLOToLHLOConversionPattern( HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, + HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, + HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, + HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, + HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, + HloToLhloOpConverter, HloToLhloReduceOpConverter, HloToLhloReturnOpConverter, HloToLhloTensorLoadOpConverter, HloToLhloTensorStoreOpConverter - >(context, bufferAssignment, converter); + >(context, *converter); // clang-format on } diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_control_flow.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_control_flow.cc index b6e23a6b131..adf2a398a00 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_control_flow.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_control_flow.cc @@ -32,8 +32,6 @@ limitations under the License. #include "mlir/Pass/PassRegistry.h" #include "mlir/Support/LogicalResult.h" -using mlir::PassRegistration; - namespace mlir { namespace mhlo { namespace { diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_tanh_to_approximation.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_tanh_to_approximation.cc deleted file mode 100644 index 57c494f536b..00000000000 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_tanh_to_approximation.cc +++ /dev/null @@ -1,152 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This file implements logic for lowering the tanh standard ops to an -// approximation. - -#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" -#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/IR/Function.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Pass/Pass.h" - -namespace mlir { -namespace mhlo { -namespace { - -/// Emits the fast tanh approximation that is also used by XLA. -Value EmitTanhApproximation(Value input, Location loc, - PatternRewriter &rewriter) { - // For small values of x, we can approximate tanh(x)=x. For extremely small - // values of x (|x| < 1e-37), the other approximation would evaluate - // tanh(x) = 0. - constexpr float kCanUseApprox = 0.0004; - Value abs_value = rewriter.create(loc, input); - Value can_use_approx = - rewriter.create(loc, rewriter.getF32FloatAttr(kCanUseApprox)); - Value return_input = rewriter.create(loc, CmpFPredicate::OLT, - abs_value, can_use_approx); - // Clamp the input to [-c, c]. - Value max_clamp = rewriter.create( - loc, rewriter.getF32FloatAttr(7.90531110763549805f)); - Value smaller_than_max = - rewriter.create(loc, CmpFPredicate::ULE, input, max_clamp); - Value clamped_half = - rewriter.create(loc, smaller_than_max, input, max_clamp); - Value min_clamp = rewriter.create( - loc, rewriter.getF32FloatAttr(-7.90531110763549805f)); - Value larger_than_min = - rewriter.create(loc, CmpFPredicate::UGE, clamped_half, min_clamp); - Value input_clamped = - rewriter.create(loc, larger_than_min, clamped_half, min_clamp); - - static constexpr std::array numerator_coeffs{ - -2.76076847742355e-16f, 2.00018790482477e-13f, -8.60467152213735e-11f, - 5.12229709037114e-08f, 1.48572235717979e-05f, 6.37261928875436e-04f, - 4.89352455891786e-03f}; - - static constexpr std::array denominator_coeffs{ - 1.19825839466702e-06f, 1.18534705686654e-04f, 2.26843463243900e-03f, - 4.89352518554385e-03f}; - - Value input_squared = - rewriter.create(loc, input_clamped, input_clamped); - Value numerator = rewriter.create( - loc, rewriter.getF32FloatAttr(numerator_coeffs[0])); - for (int i = 1; i < numerator_coeffs.size(); i++) { - numerator = rewriter.create( - loc, rewriter.create(loc, input_squared, numerator), - rewriter.create( - loc, rewriter.getF32FloatAttr(numerator_coeffs[i]))); - } - - numerator = rewriter.create(loc, input_clamped, numerator); - - Value denominator = rewriter.create( - loc, rewriter.getF32FloatAttr(denominator_coeffs[0])); - for (int i = 1; i < denominator_coeffs.size(); i++) { - denominator = rewriter.create( - loc, rewriter.create(loc, input_squared, denominator), - rewriter.create( - loc, rewriter.getF32FloatAttr(denominator_coeffs[i]))); - } - - Value approx = rewriter.create(loc, numerator, denominator); - - return rewriter.create(loc, return_input, input, approx); -} - -class ApproximateTanhLowering : public OpRewritePattern { - public: - explicit ApproximateTanhLowering(MLIRContext *ctx) - : OpRewritePattern(ctx, 100) {} - - LogicalResult matchAndRewrite(TanhOp tanhOp, - PatternRewriter &rewriter) const override { - Type operand_type = tanhOp.getType(); - - if (operand_type.isF64()) { - // Similar to XLA, do not rewrite f64 as precision might matter. - return failure(); - } - - Location loc = tanhOp.getLoc(); - Value input = tanhOp.operand(); - if (operand_type.isF16()) { - input = rewriter.create(loc, input, rewriter.getF32Type()); - } - - // If we still do not have f32, fail. - if (!input.getType().isF32()) { - return failure(); - } - - Value result = EmitTanhApproximation(input, loc, rewriter); - - // Truncate back if needed. - if (operand_type.isF16()) { - result = rewriter.create(loc, result, rewriter.getF16Type()); - } - - rewriter.replaceOp(tanhOp, {result}); - return success(); - } -}; - -struct LegalizeTanhToApproximationPass - : public PassWrapper { - /// Perform the lowering of standard dialect operations to approximations. - void runOnFunction() override { - OwningRewritePatternList patterns; - PopulateTanhToApproximationPatterns(&getContext(), &patterns); - applyPatternsAndFoldGreedily(getFunction(), patterns); - } -}; - -} // anonymous namespace - -std::unique_ptr> -createLegalizeTanhToApproximationPass() { - return std::make_unique(); -} - -void PopulateTanhToApproximationPatterns(mlir::MLIRContext *context, - OwningRewritePatternList *patterns) { - patterns->insert(context); -} - -} // namespace mhlo -} // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index 033021c36ac..b64d66200cf 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "llvm/ADT/STLExtras.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h" @@ -32,8 +33,10 @@ limitations under the License. #include "mlir/IR/Location.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeUtilities.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" @@ -75,69 +78,69 @@ class PointwiseToLinalgConverter : public OpConversionPattern { OpTy op, ArrayRef args, ConversionPatternRewriter& rewriter) const final { auto loc = op.getLoc(); - auto argType = - op.getOperation()->getOperand(0).getType().template cast(); - if (!argType.hasRank()) { - emitError(loc, "lhlo to linalg conversion expects ranked args"); - return failure(); - } - auto elemTy = argType.getElementType(); - if (!elemTy.isSignlessIntOrFloat() && !elemTy.template isa()) { - return failure(); - } + ShapedType t0 = args[0].getType().template dyn_cast(); + if (!t0) return failure(); + + unsigned nloops = t0.getRank(); + auto fail = [&](ShapedType t) { + return !t || !t.hasRank() || t.getRank() != nloops || + !(t.getElementType().isSignlessIntOrFloat() || + t.getElementType().isa()); + }; + if (llvm::any_of(args, + [&](Value v) { + return fail(v.getType().dyn_cast()); + }) || + llvm::any_of(op.getOperation()->getResultTypes(), + [&](Type t) { return fail(t.dyn_cast()); })) + return emitError(loc, + "lhlo to linalg conversion expects ranked args of " + "signless int, float or complex element type with ") + << nloops << " parallel iterators: " << *(op.getOperation()); // Construct the indexing maps needed for linalg.generic ops. - SmallVector indexing_maps; SmallVector bodyArgTypes, bodyResultTypes, opResultTypes; // This doesnt account for implicit broadcast, but the working assumption - // here is that are broadcasts have been made explicit. - unsigned nloops = argType.getRank(); + // in HLO/LHLO is that are broadcasts are made explicit. if (isLHLO && !nloops) return failure(); - int operandCount = (isLHLO ? args.size() - 1 : args.size()); - auto verifyArgOrResultType = [&](Value val) -> ShapedType { - auto shapedType = val.getType().dyn_cast(); - if (!shapedType || - (!shapedType.isa() && - !shapedType.isa()) || - shapedType.getRank() != nloops) - return nullptr; - indexing_maps.emplace_back( - nloops ? rewriter.getMultiDimIdentityMap(nloops) - : AffineMap::get(nloops, 0, rewriter.getContext())); - return shapedType; - }; - for (const auto& arg : llvm::enumerate(args)) { - auto shapedType = verifyArgOrResultType(arg.value()); - if (!shapedType) return failure(); - auto& result_or_body_arg = - arg.index() < operandCount ? bodyArgTypes : bodyResultTypes; - result_or_body_arg.emplace_back(shapedType.getElementType()); - } + int numInputs = (isLHLO ? args.size() - 1 : args.size()); + + ValueRange inputs(args.take_front(numInputs)); + for (Value in : inputs) + bodyArgTypes.emplace_back(getElementTypeOrSelf(in.getType())); + + ValueRange outputBuffers(args.take_back(args.size() - numInputs)); + for (Value out : outputBuffers) + bodyResultTypes.emplace_back(getElementTypeOrSelf(out.getType())); + if (!isLHLO) { // HLO operations have return as tensor types. assert(bodyResultTypes.empty() && "When lowering HLO ops result can't be part of arguments"); Value result = op.getOperation()->getResult(0); - auto shapedType = verifyArgOrResultType(result); - if (!shapedType) return failure(); - bodyResultTypes.push_back(shapedType.getElementType()); - opResultTypes.push_back(shapedType); + bodyResultTypes.push_back(getElementTypeOrSelf(result)); + opResultTypes.push_back(result.getType()); } - int64_t args_count = bodyArgTypes.size(); - int64_t results_count = bodyResultTypes.size(); + AffineMap commonIndexingMap = + nloops ? rewriter.getMultiDimIdentityMap(nloops) + : AffineMap::get(nloops, 0, rewriter.getContext()); + SmallVector indexing_maps(args.size() + (isLHLO ? 0 : 1), + commonIndexingMap); + auto linalgOp = rewriter.create( - loc, opResultTypes, args, args_count, results_count, indexing_maps, + loc, opResultTypes, inputs, outputBuffers, + /*initTensors=*/ValueRange{}, indexing_maps, GetNParallelLoopsAttrs(nloops), [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) { // TODO(ravishankarm) : For now use the method in lmhlo namespace. // That method needs to be moved out of there. Value opResult = lmhlo::HloOpToStdScalarOp::map( op, bodyResultTypes, - llvm::to_vector<2>(args.take_front(args_count)), &rewriter); + llvm::to_vector<2>(args.take_front(inputs.size())), &rewriter); nestedBuilder.create(loc, opResult); }); rewriter.replaceOp(op, linalgOp.getOperation()->getResults()); @@ -189,7 +192,7 @@ struct ConvToLinalgConverter : public OpConversionPattern { lmhlo::ConvOp op, ArrayRef args, ConversionPatternRewriter& rewriter) const final { // Check validity of dimension information. - if (const lmhlo::ConvDimensionNumbers& dimensionNumbers = + if (const mhlo::ConvDimensionNumbers& dimensionNumbers = op.dimension_numbers()) { const int inputSpatialRank = llvm::size(dimensionNumbers.input_spatial_dimensions()); @@ -299,12 +302,15 @@ class DataMovementOpConverter : public OpConversionPattern { auto nloops = resultType.getRank(); auto loc = op.getLoc(); auto linalgOp = rewriter.create( - loc, isLHLO ? ArrayRef{} : resultType, args, /*argsIn=*/1, - /*argsOut=*/1, indexing_maps, GetNParallelLoopsAttrs(nloops), + loc, + /*resultTensorTypes=*/isLHLO ? ArrayRef{} : resultType, + /*inputs=*/args.front(), + /*outputBuffers=*/isLHLO ? ValueRange{args.back()} : ValueRange{}, + /*initTensor=*/ValueRange{}, indexing_maps, + GetNParallelLoopsAttrs(nloops), [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) { nestedBuilder.create(loc, *args.begin()); }); - rewriter.replaceOp(op, linalgOp.getOperation()->getResults()); return success(); } @@ -420,8 +426,8 @@ class LhloBroadcastInDimConverter Value val = rewriter.create(loc, operand, llvm::makeArrayRef({zero})); rewriter.create( - loc, llvm::None, llvm::makeArrayRef(operand_adaptor.output()), - /*argsIn=*/0, /*argsOut=*/1, + loc, /*inputs=*/ValueRange{}, + /*outputBuffers=*/ValueRange{operand_adaptor.output()}, llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)), GetNParallelLoopsAttrs(nloops), [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) { @@ -432,9 +438,8 @@ class LhloBroadcastInDimConverter auto indexing_maps = getIndexingMaps(op, broadcast_dims, result_shape, operand_type, &rewriter); rewriter.create( - loc, llvm::None, - llvm::makeArrayRef({operand, operand_adaptor.output()}), - /*argsIn=*/1, /*argsOut=*/1, indexing_maps, + loc, /*inputs=*/ValueRange{operand}, + /*outputBuffers=*/ValueRange{operand_adaptor.output()}, indexing_maps, GetNParallelLoopsAttrs(nloops), [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) { nestedBuilder.create(loc, *args.begin()); @@ -627,7 +632,8 @@ class ReshapeOpConverter : public OpConversionPattern { } currDstDim++; } - if (currSrcDim != srcShape.size()) isExpandingOrCollapsing = false; + if (currSrcDim != srcShape.size() || currDstDim != dstShape.size()) + isExpandingOrCollapsing = false; if (!isExpandingOrCollapsing) { auto getIdentityExprs = [&rewriter](int n) { @@ -696,15 +702,18 @@ class IotaConverter : public OpConversionPattern { unsigned nloops = resultShapedType.getRank(); auto linalgOp = rewriter.create( - iotaOp.getLoc(), isLHLO ? ArrayRef{} : resultShapedType, args, - 0, // args_in - 1, // args_out + iotaOp.getLoc(), + /*resultTensorTypes=*/ + isLHLO ? ArrayRef{} : ArrayRef{resultShapedType}, + /*inputs=*/ValueRange{}, + /*outputBuffers=*/isLHLO ? ValueRange{args} : ValueRange{}, + /*initTensors=*/ValueRange{}, llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)), GetNParallelLoopsAttrs(nloops), [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange ivs, ValueRange args) { Value castOp = nestedBuilder.create( - nestedLoc, ivs[iotaOp.iota_dimension().getZExtValue()], + nestedLoc, ivs[iotaOp.iota_dimension()], nestedBuilder.getIntegerType( resultElementType.getIntOrFloatBitWidth())); if (resultElementType.template isa()) { @@ -716,7 +725,7 @@ class IotaConverter : public OpConversionPattern { if (isLHLO) rewriter.replaceOp(iotaOp, llvm::None); else - rewriter.replaceOp(iotaOp, linalgOp.output_tensors()); + rewriter.replaceOp(iotaOp, linalgOp.result_tensors()); return success(); } }; @@ -813,6 +822,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, @@ -822,12 +832,14 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, @@ -837,10 +849,12 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, ReshapeOpConverter, ReverseConverter, ScalarPointwiseToStandardConverter, - SliceConverter + SliceConverter, + TransposeConverter >(context); // clang-format on } @@ -859,13 +873,15 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context, // %0 = addf %arg4, %arg5 : f32 // "linalg.yield"(%0) : (f32) -> () // }) { -// args_in = 2, -// args_out = 1, // indexing_maps = [#map0, #map0, #map0], // iterator_types = ["parallel", "parallel"], // } : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () struct LhloLegalizeToLinalgPass : public PassWrapper { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + void runOnFunction() override { OwningRewritePatternList patterns; ConversionTarget target(getContext()); @@ -882,6 +898,10 @@ struct LhloLegalizeToLinalgPass struct HloLegalizeToLinalgPass : public PassWrapper { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + void runOnFunction() override { OwningRewritePatternList patterns; ConversionTarget target(getContext()); @@ -913,6 +933,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, @@ -921,12 +942,14 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, @@ -935,6 +958,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, ReshapeOpConverter, ReverseConverter, TransposeConverter>(context); diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_standard.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_standard.cc index cc574e008d5..84255c2810e 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_standard.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_standard.cc @@ -117,7 +117,7 @@ class ConvertIotaOp : public OpRewritePattern { PatternRewriter &rewriter) const override { auto output_type = op.getType().cast(); auto output_size = output_type.getNumElements(); - auto dimension = op.iota_dimension().getSExtValue(); + auto dimension = op.iota_dimension(); auto max_dim_size = output_type.getDimSize(dimension); auto element_type = output_type.getElementType(); @@ -178,6 +178,10 @@ class ConvertIotaOp : public OpRewritePattern { namespace { struct LegalizeToStandardPass : public PassWrapper { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + /// Perform the lowering to Standard dialect. void runOnFunction() override; }; @@ -189,7 +193,7 @@ std::unique_ptr> createLegalizeToStdPass() { void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns, mlir::MLIRContext *ctx) { - mlir::populateWithGenerated(ctx, patterns); + mlir::populateWithGenerated(ctx, *patterns); patterns->insert(ctx); } diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_trigonometric_to_approximation.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_trigonometric_to_approximation.cc new file mode 100644 index 00000000000..10030866d0f --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_trigonometric_to_approximation.cc @@ -0,0 +1,284 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file implements the lowering for trigonometric standard ops to +// approximations. + +#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace mhlo { +namespace { + +template +class ApproximateOnExtendedF32Lowering : public OpRewritePattern { + public: + explicit ApproximateOnExtendedF32Lowering(MLIRContext *ctx) + : OpRewritePattern(ctx, /*benefit=*/100) {} + + virtual Value emitApproximation(ValueRange, Location, + PatternRewriter &) const = 0; + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto raw_args = op.getOperation()->getOperands(); + + // Supports only f16 and f32 for now. + if (!op.getType().isF16() && !op.getType().isF32()) return failure(); + + // Extend operands to f32 if needed and possible. + SmallVector f32_args; + f32_args.reserve(raw_args.size()); + for (Value arg : raw_args) { + // Similar to XLA, do not rewrite f64 as precision might matter. + Type arg_ty = arg.getType(); + if (arg_ty.isF64()) return failure(); + + if (arg_ty.isF16()) + arg = rewriter.create(loc, arg, rewriter.getF32Type()); + + // If we still do not have f32, fail. + if (!arg.getType().isF32()) return failure(); + + f32_args.push_back(arg); + } + + Value result = emitApproximation(f32_args, loc, rewriter); + assert(result.getType().isF32() && "Expect f32 intermediate result."); + + // Truncate back if needed. + if (op.getType().isF16()) + result = rewriter.create(loc, result, rewriter.getF16Type()); + + rewriter.replaceOp(op, {result}); + return success(); + } +}; + +class ApproximateTanhLowering + : public ApproximateOnExtendedF32Lowering { + public: + explicit ApproximateTanhLowering(MLIRContext *ctx) + : ApproximateOnExtendedF32Lowering(ctx) {} + + // Emits the fast tanh approximation that is also used by XLA. + Value emitApproximation(ValueRange args, Location loc, + PatternRewriter &rewriter) const override { + // For small values of x, we can approximate tanh(x) = x. For extremely + // small values of x (|x| < 1e-37), the other approximation would evaluate + // tanh(x) = 0. + Value input = args.front(); + assert(input.getType().isF32()); + constexpr float kCanUseApprox = 0.0004; + Value abs_value = rewriter.create(loc, input); + Value can_use_approx = rewriter.create( + loc, rewriter.getF32FloatAttr(kCanUseApprox)); + Value return_input = rewriter.create(loc, CmpFPredicate::OLT, + abs_value, can_use_approx); + // Clamp the input to [-c, c]. + Value max_clamp = rewriter.create( + loc, rewriter.getF32FloatAttr(7.90531110763549805f)); + Value smaller_than_max = + rewriter.create(loc, CmpFPredicate::ULE, input, max_clamp); + Value clamped_half = + rewriter.create(loc, smaller_than_max, input, max_clamp); + Value min_clamp = rewriter.create( + loc, rewriter.getF32FloatAttr(-7.90531110763549805f)); + Value larger_than_min = rewriter.create(loc, CmpFPredicate::UGE, + clamped_half, min_clamp); + Value input_clamped = rewriter.create(loc, larger_than_min, + clamped_half, min_clamp); + + static constexpr std::array numerator_coeffs{ + -2.76076847742355e-16f, 2.00018790482477e-13f, -8.60467152213735e-11f, + 5.12229709037114e-08f, 1.48572235717979e-05f, 6.37261928875436e-04f, + 4.89352455891786e-03f}; + + static constexpr std::array denominator_coeffs{ + 1.19825839466702e-06f, 1.18534705686654e-04f, 2.26843463243900e-03f, + 4.89352518554385e-03f}; + + Value input_squared = + rewriter.create(loc, input_clamped, input_clamped); + Value numerator = rewriter.create( + loc, rewriter.getF32FloatAttr(numerator_coeffs[0])); + for (int i = 1; i < numerator_coeffs.size(); i++) { + numerator = rewriter.create( + loc, rewriter.create(loc, input_squared, numerator), + rewriter.create( + loc, rewriter.getF32FloatAttr(numerator_coeffs[i]))); + } + + numerator = rewriter.create(loc, input_clamped, numerator); + + Value denominator = rewriter.create( + loc, rewriter.getF32FloatAttr(denominator_coeffs[0])); + for (int i = 1; i < denominator_coeffs.size(); i++) { + denominator = rewriter.create( + loc, rewriter.create(loc, input_squared, denominator), + rewriter.create( + loc, rewriter.getF32FloatAttr(denominator_coeffs[i]))); + } + + Value approx = rewriter.create(loc, numerator, denominator); + + return rewriter.create(loc, return_input, input, approx); + } +}; + +class ApproximateAtan2Lowering + : public ApproximateOnExtendedF32Lowering { + public: + explicit ApproximateAtan2Lowering(MLIRContext *ctx) + : ApproximateOnExtendedF32Lowering(ctx) {} + + // Reduces atan2 to atan in the same way XLA does it. + Value emitApproximation(ValueRange args, Location loc, + PatternRewriter &rewriter) const override { + Value y = args[0]; + Value x = args[1]; + assert(x.getType().isF32() && y.getType().isF32() && + "expect f32 arguments"); + Value ax = rewriter.create(loc, x); + Value ay = rewriter.create(loc, y); + Value le_ax_ay = rewriter.create(loc, CmpFPredicate::OLE, ax, ay); + Value min_ax_ay = rewriter.create(loc, le_ax_ay, ax, ay); + Value max_ax_ay = rewriter.create(loc, le_ax_ay, ay, ax); + Value zero_to_one = rewriter.create(loc, min_ax_ay, max_ax_ay); + Value a = emitAtanCoreApproximation(zero_to_one, loc, rewriter); + + Value pi_over_2 = + rewriter.create(loc, rewriter.getF32FloatAttr(1.57079637f)); + a = rewriter.create( + loc, le_ax_ay, rewriter.create(loc, pi_over_2, a), a); + + Value zero = rewriter.create(loc, rewriter.getF32FloatAttr(0)); + Value lt_x_0 = rewriter.create(loc, CmpFPredicate::OLT, x, zero); + Value pi = + rewriter.create(loc, rewriter.getF32FloatAttr(3.14159274f)); + a = rewriter.create(loc, lt_x_0, + rewriter.create(loc, pi, a), a); + + Value t = rewriter.create(loc, lt_x_0, pi, zero); + Value eq_y_0 = rewriter.create(loc, CmpFPredicate::OEQ, y, zero); + a = rewriter.create(loc, eq_y_0, t, a); + + // Propagate nan. + Value is_nan = rewriter.create(loc, CmpFPredicate::UNO, y, x); + Value nan = rewriter.create( + loc, rewriter.getF32FloatAttr(std::numeric_limits::quiet_NaN())); + a = rewriter.create(loc, is_nan, nan, a); + + // x and y are +- inf. + Value three_pi_over_4 = + rewriter.create(loc, rewriter.getF32FloatAttr(2.3561945f)); + Value pi_over_4 = rewriter.create( + loc, rewriter.getF32FloatAttr(0.785398185f)); + t = rewriter.create(loc, lt_x_0, three_pi_over_4, + pi_over_4); + Value inf = rewriter.create( + loc, rewriter.getF32FloatAttr(std::numeric_limits::infinity())); + Value eq_x_inf = rewriter.create(loc, CmpFPredicate::OEQ, x, inf); + Value eq_y_inf = rewriter.create(loc, CmpFPredicate::OEQ, y, inf); + Value all_inf = rewriter.create(loc, eq_x_inf, eq_y_inf); + a = rewriter.create(loc, all_inf, t, a); + + return rewriter.create(loc, a, y); + } + + private: + // The core atan reduction derives from the heuristic described in + // https://arxiv.org/abs/1508.03211 and has a < 0.95 ulp error in the [-1, 1] + // range (though that assumed FMA was available, and it is not here). This is + // the same approximation that is also used by XLA. + Value emitAtanCoreApproximation(Value x, Location loc, + PatternRewriter &rewriter) const { + auto constant = [&](float c) { + return rewriter.create(loc, rewriter.getF32FloatAttr(c)); + }; + + // Computes ab + c. + auto mul_add = [&](Value a, Value b, Value c) { + Value prod = rewriter.create(loc, a, b); + return rewriter.create(loc, prod, c); + }; + + Value s = rewriter.create(loc, x, x); + Value r = constant(0.0027856871f); + r = mul_add(r, s, constant(-0.0158660002f)); + r = mul_add(r, s, constant(0.042472221f)); + r = mul_add(r, s, constant(-0.0749753043f)); + r = mul_add(r, s, constant(0.106448799f)); + r = mul_add(r, s, constant(-0.142070308f)); + r = mul_add(r, s, constant(0.199934542f)); + r = mul_add(r, s, constant(-0.333331466f)); + r = rewriter.create(loc, r, s); + return mul_add(r, x, x); + } +}; + +class ApproximateAtanLowering + : public ApproximateOnExtendedF32Lowering { + public: + explicit ApproximateAtanLowering(MLIRContext *ctx) + : ApproximateOnExtendedF32Lowering(ctx) {} + + // Reduce atan(x) to atan2(x, 1) to subsequently rely on an atan approximation + // for the argument range [-1, 1]. + Value emitApproximation(ValueRange args, Location loc, + PatternRewriter &rewriter) const override { + Value x = args.front(); + assert(x.getType().isF32()); + Value one = rewriter.create(loc, rewriter.getF32FloatAttr(1)); + return rewriter.create(loc, x, one); + } +}; + +struct LegalizeTrigonometricToApproximationPass + : public PassWrapper { + /// Perform the lowering of standard dialect operations to approximations. + void runOnFunction() override { + OwningRewritePatternList patterns; + PopulateTrigonometricToApproximationPatterns(&getContext(), &patterns); + applyPatternsAndFoldGreedily(getFunction(), patterns); + } +}; + +} // anonymous namespace + +std::unique_ptr> +createLegalizeTrigonometricToApproximationPass() { + return std::make_unique(); +} + +void PopulateTrigonometricToApproximationPatterns( + mlir::MLIRContext *context, OwningRewritePatternList *patterns) { + // clang-format off + patterns->insert< + ApproximateAtanLowering, + ApproximateAtan2Lowering, + ApproximateTanhLowering>(context); + // clang-format on +} + +} // namespace mhlo +} // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_copy_removal.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_copy_removal.cc deleted file mode 100644 index 7a4418466b5..00000000000 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_copy_removal.cc +++ /dev/null @@ -1,102 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This file implements a pass to remove redundant LHLO copy operations. - -#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" -#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/IR/Operation.h" -#include "mlir/Pass/Pass.h" - -namespace mlir { -namespace lmhlo { -namespace { - -// Removes LHLO copy operations that copy from allocated buffers to block -// arguments. All uses of each buffer are replaced with the corresponding block -// argument and the buffer is freed. Note that this pass only works in regions -// with a single block. -struct LhloCopyRemovalPass - : mlir::PassWrapper> { - void runOnOperation() override { - llvm::SmallVector eraseList; - auto operation = getOperation(); - operation->walk([&](mlir::lmhlo::CopyOp copyOp) { - // If this region contains more than one block, then ignore this copy - // operation. - if (copyOp.getParentRegion()->getBlocks().size() > 1) { - return; - } - - mlir::Value fromOperand = copyOp.operand(); - mlir::Value toOperand = copyOp.output(); - - // If the fromOperand value is a block argument or the toOperand - // value is not a block argument, then ignore this copy operation. - if (!fromOperand.getDefiningOp() || toOperand.getDefiningOp()) { - return; - } - - // The copy operation removal is illegal if there is at least a single use - // of toOperand value that lies between the first use of fromOperand value - // and the copy operation. - auto fromOperandUsers = fromOperand.getUsers(); - auto firstUser = *fromOperandUsers.begin(); - for (auto op : fromOperandUsers) { - if (op->isBeforeInBlock(firstUser)) firstUser = op; - } - for (auto op : toOperand.getUsers()) { - if (op->isBeforeInBlock(copyOp) && firstUser->isBeforeInBlock(op)) { - return; - } - } - - // TODO(DFKI): Use live variable analysis to solve aliasing issues among - // block arguments. - - // Remove the associated alloc operation. - auto allocOp = fromOperand.getDefiningOp(); - eraseList.push_back(allocOp); - - // Iterate over all uses of the fromOperand to find the associated - // deallocOp (if any). - for (auto op : fromOperandUsers) { - if (isa(op)) { - eraseList.push_back(op); - break; - } - } - - // Replace all uses of the fromOperand with the toOperand. This rewires - // all references pointing to the original alloc operation to the new - // target operation in order to safely remove the copy op. - fromOperand.replaceAllUsesWith(toOperand); - copyOp.erase(); - }); - for (auto op : eraseList) { - op->erase(); - } - }; -}; - -} // namespace - -std::unique_ptr createLhloCopyRemovalPass() { - return std::make_unique(); -} - -} // namespace lmhlo -} // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_fuse_linalg.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_fuse_linalg.cc index 1467f015dc9..8f50ad0667f 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_fuse_linalg.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_fuse_linalg.cc @@ -19,9 +19,12 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Interfaces/ViewLikeInterface.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/FoldUtils.h" @@ -33,6 +36,10 @@ using linalg::LinalgOp; class LhloFuseLinalgPass : public PassWrapper { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + public: LhloFuseLinalgPass() = default; LhloFuseLinalgPass(const LhloFuseLinalgPass&) {} @@ -67,6 +74,24 @@ class LhloFuseLinalgPass result_buffers.insert(operand); } } + // Resolve aliasing operations (like casts) on the result to identify + // results. This only handles escaping results. + // TODO(herhut): Use BufferizeAliasAnalysis for this. + llvm::SmallVector worklist(result_buffers.begin(), + result_buffers.end()); + while (!worklist.empty()) { + Value result = worklist.pop_back_val(); + auto definingOp = result.getDefiningOp(); + if (!definingOp) { + continue; + } + if (auto viewLike = dyn_cast(definingOp)) { + auto alias = viewLike.getViewSource(); + if (result_buffers.insert(alias).second) { + worklist.push_back(alias); + } + } + } MLIRContext* ctx = func.getContext(); OpBuilder b(func); OperationFolder folder(ctx); diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc index 07891327775..2041d22c62b 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc @@ -59,6 +59,20 @@ struct DotOpConverter : public OpRewritePattern { return failure(); } + // We don't currently support batching dimensions, or multiple contraction + // dimensions. + mhlo::DotDimensionNumbers dot_dimension_numbers = + op.dot_dimension_numbers(); + if (dot_dimension_numbers.lhs_batching_dimensions().size() > 0 || + dot_dimension_numbers.rhs_batching_dimensions().size() > 0) + return failure(); + if (dot_dimension_numbers.lhs_contracting_dimensions().size() != 1 || + *dot_dimension_numbers.lhs_contracting_dimensions().begin() != 1 || + dot_dimension_numbers.rhs_contracting_dimensions().size() != 1 || + *dot_dimension_numbers.rhs_contracting_dimensions().begin() != 0) { + return failure(); + } + LogicalResult map_status = success(); auto body_builder = [&](OpBuilder& builder, Location loc, ValueRange ivs) { SmallVector lhs_indices{ivs[0], ivs[2]}, @@ -139,6 +153,9 @@ void populateLHLOToAffineConversionPattern(MLIRContext* context, struct LhloLegalizeToAffinePass : public PassWrapper { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } void runOnFunction() override { OwningRewritePatternList patterns; auto func = getFunction(); diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc index cffb58b37de..fbade8f7387 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc @@ -20,8 +20,10 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Attributes.h" @@ -169,6 +171,11 @@ class LhloReduceToGPULaunchConverter : public OpConversionPattern { struct LhloLegalizeToGpuPass : public PassWrapper { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + void runOnFunction() override { OwningRewritePatternList patterns; ConversionTarget target(getContext()); diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm.cc index 42b71543543..57ea947c473 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm.cc @@ -45,7 +45,7 @@ struct StaticMemRefCastOpConverter return failure(); // Create descriptor. auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); - Type llvmTargetElementTy = desc.getElementType(); + Type llvmTargetElementTy = desc.getElementPtrType(); // Set allocated ptr. Value allocated = sourceMemRef.allocatedPtr(rewriter, loc); allocated = @@ -96,7 +96,7 @@ struct DynamicMemRefCastOpConverter return failure(); // Create descriptor. auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); - Type llvmTargetElementTy = desc.getElementType(); + Type llvmTargetElementTy = desc.getElementPtrType(); // Set allocated ptr. Value allocated = sourceMemRef.allocatedPtr(rewriter, loc); allocated = diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm_pass.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm_pass.cc index 8493a1feb5d..3d49027bb50 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm_pass.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm_pass.cc @@ -29,6 +29,10 @@ namespace { class TestLhloToLLVMPass : public ::mlir::PassWrapper> { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + public: void runOnOperation() override { ModuleOp m = getOperation(); diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc index 19f47d08c0d..d9a2d993496 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc @@ -691,6 +691,10 @@ class SelectAndScatterOpConverter struct LhloLegalizeToParallelLoopsPass : public PassWrapper { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + void runOnFunction() override { auto func = getFunction(); diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lower_complex.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lower_complex.cc index 9f7c946577d..491f1c01cf7 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lower_complex.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lower_complex.cc @@ -37,7 +37,6 @@ limitations under the License. using mlir::FunctionPass; using mlir::OwningRewritePatternList; -using mlir::PassRegistration; using mlir::PassWrapper; namespace { @@ -60,7 +59,7 @@ namespace { void PopulateComplexLoweringPatterns(MLIRContext* context, OwningRewritePatternList* patterns) { - populateWithGenerated(context, patterns); + populateWithGenerated(context, *patterns); } } // end namespace mhlo } // end namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lower_general_dot.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lower_general_dot.cc index 2bbd4691f95..ada30a289a4 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lower_general_dot.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lower_general_dot.cc @@ -38,7 +38,6 @@ using mlir::LogicalResult; using mlir::MLIRContext; using mlir::OpRewritePattern; using mlir::OwningRewritePatternList; -using mlir::PassRegistration; using mlir::PassWrapper; using mlir::PatternRewriter; using mlir::RankedTensorType; @@ -155,9 +154,16 @@ struct GeneralDotConvert : public OpRewritePattern { dot_numbers.rhs_contracting_dimensions(), /*outer_dims_first=*/false, &rewriter); + // Accept only static shaped types. + auto lhs_shape_type = lhs.getType().dyn_cast_or_null(); + auto rhs_shape_type = rhs.getType().dyn_cast_or_null(); + if (!lhs_shape_type || !rhs_shape_type) return failure(); + if (!lhs_shape_type.hasStaticShape() || !rhs_shape_type.hasStaticShape()) + return failure(); + // Dot resulting shape. - auto lhs_shape = lhs.getType().cast().getShape(); - auto rhs_shape = rhs.getType().cast().getShape(); + auto lhs_shape = lhs_shape_type.getShape(); + auto rhs_shape = rhs_shape_type.getShape(); auto new_dot_type = RankedTensorType::get({lhs_shape[0], rhs_shape[1]}, dot_element_type); diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/mhlo_control_flow_to_scf.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/mhlo_control_flow_to_scf.cc new file mode 100644 index 00000000000..dba3cab6956 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/mhlo_control_flow_to_scf.cc @@ -0,0 +1,199 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "llvm/Support/Casting.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/Value.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" + +#define DEBUG_TYPE "mhlo-control-flow-to-scf" + +namespace mlir { +namespace mhlo { + +namespace { + +/// Convert MHLO While to SCF. +void MatchAndRewrite(WhileOp whileOp); + +/// Pass that converts MHLO control flow to SCF. +class ControlFlowToScfPass + : public mlir::PassWrapper { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + void runOnFunction() override { + getFunction().walk([&](WhileOp whileOp) { MatchAndRewrite(whileOp); }); + } +}; + +// TODO(jpienaar): Look into reformulating as a pattern. +void MatchAndRewrite(WhileOp whileOp) { + // Handle pattern: + // x = start + // step = ... + // limit = ... + // while (x < limit) { ... x += step; } + + // Only handling multi value while loops at the moment. + auto tupleOp = whileOp.getOperand().getDefiningOp(); + if (!tupleOp) return; + auto bodyReturn = whileOp.body() + .front() + .getTerminator() + ->getOperand(0) + .getDefiningOp(); + // Note: due to the shape restrictions on While, if the operand to While is a + // tuple, then so is the return type of the body. But the verifier isn't + // checking that at the moment, so just bail out here if this doesn't hold. + if (!bodyReturn) return; + + Value result = whileOp.cond().front().getTerminator()->getOperand(0); + // TODO(jpienaar): Expand to handle more than simple case with LT compare and + // constant step. + auto cmp = result.getDefiningOp(); + if (!cmp || cmp.comparison_direction() != "LT") return; + + const int kConstant = -1; + auto getValueAndIndex = [&](Value val) -> std::pair { + if (matchPattern(val, m_Constant())) return {val, kConstant}; + // If it is defined by a tuple, then the tuple has to have been fed in and + // the external value is captured. + if (auto gte = val.getDefiningOp()) { + if (!gte.getOperand().isa()) return {nullptr, 0}; + int index = gte.index(); + return {tupleOp.getOperand(index), index}; + } + return {nullptr, 0}; + }; + + using ValueIndex = std::pair; + ValueIndex loopIndVar = getValueAndIndex(cmp.lhs()); + ValueIndex max = getValueAndIndex(cmp.rhs()); + if (!loopIndVar.first || !max.first) return; + auto add = + bodyReturn.getOperand(loopIndVar.second).getDefiningOp(); + if (!add) return; + ValueIndex step = getValueAndIndex(add.rhs()); + if (step.second != kConstant || !step.first) return; + + // Only handle case where tuple isn't propagated as is for now. + // TODO(jpienaar): Remove this when a tuple is also created inside the loop + // to propagate. + for (auto* use : whileOp.body().front().getArgument(0).getUsers()) + if (!isa(use)) return; + + LLVM_DEBUG(llvm::dbgs() << "Found for (" << whileOp.getLoc() << "):\n"; + llvm::dbgs() << " loopIndVar = " << loopIndVar.second << " max = " + << max.second << " step = " << step.second << "\n"; + llvm::dbgs() << " loopIndVar = " << loopIndVar.first << " max = " + << max.first << " step = " << step.first << "\n";); + OpBuilder b(whileOp); + // Inputs to new for loop. + llvm::SmallVector input; + input.reserve(tupleOp.getNumOperands()); + for (auto r : tupleOp.getOperands().take_front(loopIndVar.second)) + input.push_back(r); + for (auto r : tupleOp.getOperands().drop_front(loopIndVar.second + 1)) + input.push_back(r); + + auto tensorIndexType = RankedTensorType::get({}, b.getIndexType()); + auto getAsIndex = [&](Value val) { + auto loc = whileOp.getLoc(); + return b.create( + loc, b.create(loc, tensorIndexType, val), ValueRange()); + }; + + // SCF for uses index type, so converted these. + auto forloopIndVar = getAsIndex(loopIndVar.first); + auto forMax = getAsIndex(max.first); + auto forStep = getAsIndex(step.first); + auto forOp = b.create(whileOp.getLoc(), forloopIndVar, + forMax, forStep, input); + // Transfer the body without the block arguments. + forOp.getLoopBody().front().getOperations().splice( + forOp.getLoopBody().front().getOperations().end(), + whileOp.body().front().getOperations()); + + b.setInsertionPointToStart(&forOp.getLoopBody().front()); + auto loopIndVarElType = + loopIndVar.first.getType().cast().getElementType(); + Value indVar = b.create( + whileOp.getLoc(), RankedTensorType::get({}, loopIndVarElType), + b.create(whileOp.getLoc(), loopIndVarElType, + forOp.getInductionVar())); + // Update all block argument users to the SCF For args. + for (auto* use : + llvm::make_early_inc_range(whileOp.body().getArgument(0).getUsers())) { + // TODO(jpienaar): Expand here too when we allow using the tuple in the + // loop. + auto gte = cast(use); + // If the loop induction var, then refer to the loop induction variable as + // this operand is not updated. + if (gte.index() == loopIndVar.second) { + use->getResult(0).replaceAllUsesWith(indVar); + use->erase(); + continue; + } + int index = gte.index(); + // If after the loop induction variable, then decrement as we don't include + // the loop induction variable in the for iter operands. + if (index > loopIndVar.second) --index; + use->getResult(0).replaceAllUsesWith(forOp.getIterOperands()[index]); + use->erase(); + } + + // Create new yield op without induction var update. + SmallVector newYieldOps; + newYieldOps.reserve(bodyReturn.getNumOperands() - 1); + for (auto r : bodyReturn.getOperands().take_front(loopIndVar.second)) + newYieldOps.push_back(r); + for (auto r : bodyReturn.getOperands().drop_front(loopIndVar.second + 1)) + newYieldOps.push_back(r); + // Delete return & tuple op. + forOp.getLoopBody().front().back().erase(); + forOp.getLoopBody().front().back().erase(); + b.setInsertionPointToEnd(&forOp.getLoopBody().front()); + b.create(whileOp.getLoc(), newYieldOps); + + // Recombine output tuple with max value of induction variable. + llvm::SmallVector loopOut; + loopOut.reserve(forOp.getNumResults() + 1); + for (auto r : forOp.getResults().take_front(loopIndVar.second)) + loopOut.push_back(r); + loopOut.push_back(max.first); + for (auto r : forOp.getResults().drop_front(loopIndVar.second)) + loopOut.push_back(r); + b.setInsertionPoint(whileOp); + auto newRes = b.create(whileOp.getLoc(), loopOut); + whileOp.replaceAllUsesWith(newRes.getOperation()); + whileOp.erase(); +} + +} // anonymous namespace + +std::unique_ptr> createControlFlowToScfPass() { + return std::make_unique(); +} + +} // namespace mhlo +} // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/optimize_mhlo_pass.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/optimize_mhlo_pass.cc index 32a846e79ef..febd4423bf2 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/optimize_mhlo_pass.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/optimize_mhlo_pass.cc @@ -24,7 +24,6 @@ limitations under the License. #include "mlir/Transforms/DialectConversion.h" using mlir::FunctionPass; -using mlir::PassRegistration; using mlir::PassWrapper; namespace { diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/sink_constants_to_control_flow.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/sink_constants_to_control_flow.cc index 8d677f45c19..d863d825bcb 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/sink_constants_to_control_flow.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/sink_constants_to_control_flow.cc @@ -16,6 +16,7 @@ limitations under the License. #include "llvm/ADT/DenseMap.h" #include "llvm/Support/Casting.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Operation.h" #include "mlir/Pass/Pass.h" @@ -39,7 +40,8 @@ namespace { // those within internally. Note that doing so is the only option in case of // values defined outside that are BlockArguments of any of the parent region. class SinkConstantsToControlFlowPass - : public mlir::PassWrapper { + : public SinkConstantsToControlFlowPassBase< + SinkConstantsToControlFlowPass> { void runOnFunction() override { getFunction().walk([](Operation* op) { if (auto while_op = llvm::dyn_cast(op)) { diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc index 7c985ea7535..7c01fa22372 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ +#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" #include "mlir/Dialect/Shape/IR/Shape.h" @@ -27,7 +28,6 @@ limitations under the License. #include "mlir/Transforms/DialectConversion.h" namespace mlir { -namespace mhlo { namespace { // TODO(herhut): Generate these out of op definitions. @@ -46,106 +46,81 @@ namespace { sep fn(ShiftLeftOp) sep fn(ShiftRightArithmeticOp) \ sep fn(ShiftRightLogicalOp) sep fn(SubOp) +// TODO(herhut): Generate these out of op definitions. +#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \ + fn(AcosOp) sep fn(AtanOp) sep fn(SinhOp) sep fn(TanOp) + template inline void AddLegalOpOnRankedTensor(ConversionTarget *target) { target->addDynamicallyLegalOp([](OpTy op) { - return llvm::all_of((op.getOperation())->getOperandTypes(), + return llvm::all_of(op.getOperation()->getOperandTypes(), [&](Type t) { return t.isa(); }); }); } -/// Unary element-wise operations on unranked tensors can be applied to the -/// flattened tensor with the same effect. -/// This pattern rewrites every such operation to +/// Element-wise operations on unranked tensors can be applied to the flattened +/// tensor operands with the same effect. This pattern rewrites every such +/// operation to /// (i) flatten the input tensor, -/// (ii) apply the unary operation, and +/// (ii) apply the operation, and /// (iii) restore the original shape. template -struct UnaryElementwiseOpConversion : public OpRewritePattern { - explicit UnaryElementwiseOpConversion(MLIRContext *context) +struct ElementwiseOpConversion : public OpRewritePattern { + explicit ElementwiseOpConversion(MLIRContext *context) : OpRewritePattern(context) {} LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { - // Don't apply conversion to ops with statically shaped operands. - Value operand = op.getOperand(); - auto operandTy = operand.getType().dyn_cast(); - if (operandTy.hasRank()) return failure(); - - // Generate IR to flatten the operand. - auto loc = op.getLoc(); - Type extentTensorTy = shape::getExtentTensorType(rewriter.getContext()); - Value shape = - rewriter.create(loc, extentTensorTy, operand); - Type indexTy = rewriter.getIndexType(); - Value numElements = - rewriter.create(loc, indexTy, shape); - Value flatShape = rewriter.create(loc, numElements); - auto flatTensorTy = RankedTensorType::get({ShapedType::kDynamicSize}, - operandTy.getElementType()); - Value flatOperand = rewriter.create( - loc, flatTensorTy, operand, flatShape); - - // Generate IR for the actual operation. - Value flatResult = rewriter.create(loc, flatTensorTy, flatOperand); - - // Generate IR to restore the original shape. - rewriter.replaceOpWithNewOp(op, operandTy, - flatResult, shape); - - return success(); - } -}; - -/// Binary element-wise operation on unranked tensors can be applied to the -/// flattened operand tensors with the same effect. -/// This pattern rewrites every such operation to -/// (i) flatten the operand tensors, -/// (ii) apply the binary operation, and -// (iii) restore the original shape. -template -struct BinaryElementwiseOpConversion : public OpRewritePattern { - explicit BinaryElementwiseOpConversion(MLIRContext *context) - : OpRewritePattern(context) {} - - LogicalResult matchAndRewrite(OpTy op, - PatternRewriter &rewriter) const override { - // Don't apply conversion unless both operands are unranked. - if (op.lhs().getType().template isa() || - op.rhs().getType().template isa()) { + // Don't apply conversion unless all operands are unranked. + if (!llvm::all_of(op.getOperation()->getOperands(), [&](Value operand) { + return operand.getType().isa(); + })) { return failure(); } - // Flatten operands. + // Get operands' shape. auto loc = op.getLoc(); Type extentTensorTy = shape::getExtentTensorType(rewriter.getContext()); - Value shapeLhs = - rewriter.create(loc, extentTensorTy, op.lhs()); - Value shapeRhs = - rewriter.create(loc, extentTensorTy, op.rhs()); - Value shape = rewriter.create(loc, extentTensorTy, - ValueRange{shapeLhs, shapeRhs}); + SmallVector operandShapes; + for (Value operand : op.getOperation()->getOperands()) { + Value shape = + rewriter.create(loc, extentTensorTy, operand); + operandShapes.push_back(shape); + } + Value shape = + operandShapes.size() == 1 + ? operandShapes.front() + : rewriter.create(loc, extentTensorTy, operandShapes); + + // Derive flat shape. Type indexTy = rewriter.getIndexType(); Value numElements = rewriter.create(loc, indexTy, shape); Value flatShape = rewriter.create(loc, numElements); - TensorType lhsTy = op.lhs().getType().template cast(); - Type flatLhsTy = RankedTensorType::get({ShapedType::kDynamicSize}, - lhsTy.getElementType()); - Value flatLhs = - rewriter.create(loc, flatLhsTy, op.lhs(), flatShape); - TensorType rhsTy = op.rhs().getType().template cast(); - Type flatRhsTy = RankedTensorType::get({ShapedType::kDynamicSize}, - rhsTy.getElementType()); - Value flatRhs = - rewriter.create(loc, flatRhsTy, op.rhs(), flatShape); - // Apply actual operation to flattened operands. - Value flatResult = rewriter.create(loc, flatLhs, flatRhs); + // Flatten operands. + SmallVector flatOperands; + for (Value operand : op.getOperation()->getOperands()) { + Type operandElementTy = + operand.getType().template cast().getElementType(); + Type flatTy = + RankedTensorType::get({ShapedType::kDynamicSize}, operandElementTy); + Value flat = rewriter.create(loc, flatTy, operand, + flatShape); + flatOperands.push_back(flat); + } + + // Apply operation to flattened operands. + Type resultElementTy = + op.getType().template cast().getElementType(); + Type flatResultTy = + RankedTensorType::get({ShapedType::kDynamicSize}, resultElementTy); + Value flatResult = + rewriter.create(loc, flatResultTy, flatOperands, op.getAttrs()); // Restore original shape. - rewriter.replaceOpWithNewOp(op, op.getType(), flatResult, - shape); + rewriter.replaceOpWithNewOp(op, op.getType(), + flatResult, shape); return success(); } @@ -153,17 +128,26 @@ struct BinaryElementwiseOpConversion : public OpRewritePattern { struct TransformUnrankedHloPass : public PassWrapper { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnFunction() override { // Setup conversion target. MLIRContext &ctx = getContext(); ConversionTarget target(ctx); - target.addLegalDialect(); target.addLegalOp(); -#define ADD_LEGAL(op) AddLegalOpOnRankedTensor(&target) - MAP_XLA_OPERATION_CWISE_UNARY(ADD_LEGAL, ;); - MAP_XLA_OPERATION_CWISE_BINARY(ADD_LEGAL, ;); -#undef ADD_LEGAL +#define ADD_LEGAL_MHLO(op) AddLegalOpOnRankedTensor(&target) +#define ADD_LEGAL_CHLO(op) AddLegalOpOnRankedTensor(&target) + MAP_XLA_OPERATION_CWISE_UNARY(ADD_LEGAL_MHLO, ;); + MAP_XLA_OPERATION_CWISE_BINARY(ADD_LEGAL_MHLO, ;); + MAP_CHLO_OPERATION_CWISE_UNARY(ADD_LEGAL_CHLO, ;); +#undef ADD_LEGAL_MHLO +#undef ADD_LEGAL_CHLO + AddLegalOpOnRankedTensor(&target); + AddLegalOpOnRankedTensor(&target); // Populate rewrite patterns. OwningRewritePatternList patterns; @@ -179,24 +163,26 @@ struct TransformUnrankedHloPass void PopulateTransformUnrankedHloPatterns(MLIRContext *context, OwningRewritePatternList *patterns) { - // TODO(frgossen): Populate all unary and binary operations. - // clang-format off -#define MAP_UNARY(op) UnaryElementwiseOpConversion -#define MAP_BINARY(op) BinaryElementwiseOpConversion +#define MAP_UNARY(op) ElementwiseOpConversion +#define MAP_BINARY(op) ElementwiseOpConversion +#define MAP_CHLO_UNARY(op) ElementwiseOpConversion #define COMMA , + // clang-format off patterns->insert< MAP_XLA_OPERATION_CWISE_UNARY(MAP_UNARY, COMMA), - MAP_XLA_OPERATION_CWISE_BINARY(MAP_BINARY, COMMA) - >(context); + MAP_XLA_OPERATION_CWISE_BINARY(MAP_BINARY, COMMA), + MAP_CHLO_OPERATION_CWISE_UNARY(MAP_CHLO_UNARY, COMMA), + ElementwiseOpConversion, + ElementwiseOpConversion>(context); + // clang-format on #undef MAP_UNARY #undef MAP_BINARY +#undef MAP_CHLO_UNARY #undef COMMA - // clang-format on } -std::unique_ptr<::mlir::Pass> createTransformUnrankedHloPass() { +std::unique_ptr createTransformUnrankedHloPass() { return std::make_unique(); } -} // namespace mhlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/unfuse_batch_norm.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/unfuse_batch_norm.cc index 1458e5f3d63..9d072488389 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/unfuse_batch_norm.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/unfuse_batch_norm.cc @@ -122,7 +122,7 @@ class UnfuseBatchNormInferencePattern if (!fp_type) { return failure(); } - int64_t feature_dim = bn_op.feature_index().getSExtValue(); + int64_t feature_dim = bn_op.feature_index(); // Add epsilon to the variance and sqrt to get stddev: // stddev = sqrt(variance + epsilon) diff --git a/tensorflow/compiler/mlir/hlo/tests/BUILD b/tensorflow/compiler/mlir/hlo/tests/BUILD index 2c3150a217a..df74de64d7f 100644 --- a/tensorflow/compiler/mlir/hlo/tests/BUILD +++ b/tensorflow/compiler/mlir/hlo/tests/BUILD @@ -1,3 +1,4 @@ +load("//tensorflow:tensorflow.bzl", "filegroup") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") package(licenses = ["notice"]) diff --git a/tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir b/tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir index 0d20c3f517b..4effdc14ed6 100644 --- a/tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir @@ -63,6 +63,24 @@ func @divide_fold_float() -> tensor<4xf64> { return %2 : tensor<4xf64> } +// CHECK-LABEL: remainder_fold_int +func @remainder_fold_int() -> tensor<4xi32> { + %0 = mhlo.constant dense<[5, 66, 5, 1]> : tensor<4xi32> + %1 = mhlo.constant dense<[3, 5, 1, 2]> : tensor<4xi32> + // CHECK: mhlo.constant dense<[2, 1, 0, 1]> + %2 = "mhlo.remainder"(%0, %1) : (tensor<4xi32>, tensor<4xi32>) -> (tensor<4xi32>) + return %2 : tensor<4xi32> +} + +// CHECK-LABEL: remainder_fold_float +func @remainder_fold_float() -> tensor<4xf32> { + %0 = mhlo.constant dense<[7.0, 66.5, 5.0, 3.1]> : tensor<4xf32> + %1 = mhlo.constant dense<[3.0, 5.0, 1.0, 2.6]> : tensor<4xf32> + // CHECK: mhlo.constant dense<[1.000000e+00, 1.500000e+00, 0.000000e+00, 5.000000e-01]> + %2 = "mhlo.remainder"(%0, %1) : (tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>) + return %2 : tensor<4xf32> +} + // CHECK-LABEL: max_scalar_fold func @max_scalar_fold() -> tensor<4xi64> { %0 = mhlo.constant dense<7> : tensor<4xi64> @@ -301,6 +319,13 @@ func @slice_2D_fold_vertical() -> tensor<4x1xi64> { return %1 : tensor<4x1xi64> } +// CHECK-LABEL: slice_unknown_shape +func @slice_unknown_shape(%arg0: tensor<*xf32>) -> tensor<*xf32> { + // CHECK: "mhlo.slice"(%arg0) {limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<*xf32>) -> tensor<*xf32> + %0 = "mhlo.slice"(%arg0) {limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<*xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + // CHECK-LABEL: slice_concat_fold_first func @slice_concat_fold_first(%arg0: tensor<1x5xf32>, %arg1: tensor<1x5xf32>) -> tensor<1x5xf32> { %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32> @@ -576,6 +601,262 @@ func @dce_while_without_side_effect(%arg0: tensor) -> tensor { return %arg0 : tensor } +// CHECK-LABEL: fold_compare_same_eq +func @fold_compare_same_eq(%arg0: tensor) -> tensor { + // CHECK: %0 = mhlo.constant dense : tensor + %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor, tensor) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: fold_compare_same_le +func @fold_compare_same_le(%arg0: tensor) -> tensor { + // CHECK: %0 = mhlo.constant dense : tensor + %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor, tensor) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: fold_compare_same_ge +func @fold_compare_same_ge(%arg0: tensor) -> tensor { + // CHECK: %0 = mhlo.constant dense : tensor + %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor, tensor) -> tensor + return %0 : tensor +} +// CHECK-LABEL: fold_compare_same_ne +func @fold_compare_same_ne(%arg0: tensor) -> tensor { + // CHECK: %0 = mhlo.constant dense : tensor + %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor, tensor) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: fold_compare_same_lt +func @fold_compare_same_lt(%arg0: tensor) -> tensor { + // CHECK: %0 = mhlo.constant dense : tensor + %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: fold_compare_same_gt +func @fold_compare_same_gt(%arg0: tensor) -> tensor { + // CHECK: %0 = mhlo.constant dense : tensor + %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: fold_compare_false_eq +func @fold_compare_false_eq() -> tensor { + %0 = mhlo.constant dense<0> : tensor + %1 = mhlo.constant dense<1> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = "EQ"} : (tensor, tensor) -> tensor + return %2 : tensor +} +// CHECK-LABEL: fold_compare_true_eq +func @fold_compare_true_eq() -> tensor { + %0 = mhlo.constant dense<1> : tensor + %1 = mhlo.constant dense<1> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = "EQ"} : (tensor, tensor) -> tensor + return %2 : tensor +} + +// CHECK-LABEL: fold_compare_false_eq_float +func @fold_compare_false_eq_float() -> tensor { + %0 = mhlo.constant dense<0.> : tensor + %1 = mhlo.constant dense<1.> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = "EQ"} : (tensor, tensor) -> tensor + return %2 : tensor +} + +// CHECK-LABEL: fold_compare_true_eq_float +func @fold_compare_true_eq_float() -> tensor { + %0 = mhlo.constant dense<1.> : tensor + %1 = mhlo.constant dense<1.> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = "EQ"} : (tensor, tensor) -> tensor + return %2 : tensor +} + +// CHECK-LABEL: fold_compare_false_ne +func @fold_compare_false_ne() -> tensor { + %0 = mhlo.constant dense<1> : tensor + %1 = mhlo.constant dense<1> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = "NE"} : (tensor, tensor) -> tensor + return %2 : tensor +} + +// CHECK-LABEL: fold_compare_true_ne +func @fold_compare_true_ne() -> tensor { + %0 = mhlo.constant dense<1> : tensor + %1 = mhlo.constant dense<0> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = "NE"} : (tensor, tensor) -> tensor + return %2 : tensor +} + +// CHECK-LABEL: fold_compare_false_ne_float +func @fold_compare_false_ne_float() -> tensor { + %0 = mhlo.constant dense<1.> : tensor + %1 = mhlo.constant dense<1.> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = "NE"} : (tensor, tensor) -> tensor + return %2 : tensor +} + +// CHECK-LABEL: fold_compare_true_ne_float +func @fold_compare_true_ne_float() -> tensor { + %0 = mhlo.constant dense<0.> : tensor + %1 = mhlo.constant dense<1.> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = "NE"} : (tensor, tensor) -> tensor + return %2 : tensor +} + +// CHECK-LABEL: fold_compare_false_lt +func @fold_compare_false_lt() -> tensor { + %0 = mhlo.constant dense<1> : tensor + %1 = mhlo.constant dense<1> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + return %2 : tensor +} + +// CHECK-LABEL: fold_compare_true_lt +func @fold_compare_true_lt() -> tensor { + %0 = mhlo.constant dense<0> : tensor + %1 = mhlo.constant dense<1> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + return %2 : tensor +} + +// CHECK-LABEL: fold_compare_false_lt_float +func @fold_compare_false_lt_float() -> tensor { + %0 = mhlo.constant dense<1.> : tensor + %1 = mhlo.constant dense<1.> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + return %2 : tensor +} + +// CHECK-LABEL: fold_compare_true_lt_float +func @fold_compare_true_lt_float() -> tensor { + %0 = mhlo.constant dense<0.> : tensor + %1 = mhlo.constant dense<1.> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + return %2 : tensor +} + +// CHECK-LABEL: fold_compare_false_le +func @fold_compare_false_le() -> tensor { + %0 = mhlo.constant dense<1> : tensor + %1 = mhlo.constant dense<0> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = "LE"} : (tensor, tensor) -> tensor + return %2 : tensor +} + +// CHECK-LABEL: fold_compare_true_le +func @fold_compare_true_le() -> tensor { + %0 = mhlo.constant dense<1> : tensor + %1 = mhlo.constant dense<1> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = "LE"} : (tensor, tensor) -> tensor + return %2 : tensor +} + +// CHECK-LABEL: fold_compare_false_le_float +func @fold_compare_false_le_float() -> tensor { + %0 = mhlo.constant dense<1.> : tensor + %1 = mhlo.constant dense<0.> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = "LE"} : (tensor, tensor) -> tensor + return %2 : tensor +} + +// CHECK-LABEL: fold_compare_true_le_float +func @fold_compare_true_le_float() -> tensor { + %0 = mhlo.constant dense<1.> : tensor + %1 = mhlo.constant dense<1.> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = "LE"} : (tensor, tensor) -> tensor + return %2 : tensor +} + +// CHECK-LABEL: fold_compare_false_gt +func @fold_compare_false_gt() -> tensor { + %0 = mhlo.constant dense<0> : tensor + %1 = mhlo.constant dense<0> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + return %2 : tensor +} + +// CHECK-LABEL: fold_compare_true_gt +func @fold_compare_true_gt() -> tensor { + %0 = mhlo.constant dense<1> : tensor + %1 = mhlo.constant dense<0> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + return %2 : tensor +} + +// CHECK-LABEL: fold_compare_false_gt_float +func @fold_compare_false_gt_float() -> tensor { + %0 = mhlo.constant dense<0.> : tensor + %1 = mhlo.constant dense<0.> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + return %2 : tensor +} + +// CHECK-LABEL: fold_compare_true_gt_float +func @fold_compare_true_gt_float() -> tensor { + %0 = mhlo.constant dense<1.> : tensor + %1 = mhlo.constant dense<0.> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + return %2 : tensor +} + +// CHECK-LABEL: fold_compare_false_ge +func @fold_compare_false_ge() -> tensor { + %0 = mhlo.constant dense<0> : tensor + %1 = mhlo.constant dense<1> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = "GE"} : (tensor, tensor) -> tensor + return %2 : tensor +} + +// CHECK-LABEL: fold_compare_true_ge +func @fold_compare_true_ge() -> tensor { + %0 = mhlo.constant dense<0> : tensor + %1 = mhlo.constant dense<0> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = "GE"} : (tensor, tensor) -> tensor + return %2 : tensor +} + +// CHECK-LABEL: fold_compare_false_ge_float +func @fold_compare_false_ge_float() -> tensor { + %0 = mhlo.constant dense<0.> : tensor + %1 = mhlo.constant dense<1.> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = "GE"} : (tensor, tensor) -> tensor + return %2 : tensor +} + +// CHECK-LABEL: fold_compare_true_ge_float +func @fold_compare_true_ge_float() -> tensor { + %0 = mhlo.constant dense<0.> : tensor + %1 = mhlo.constant dense<0.> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = "GE"} : (tensor, tensor) -> tensor + return %2 : tensor +} + // CHECK-LABEL: unpack_repack_same_tuple // CHECK-SAME: ([[ARG0:%.*]]: tuple, !mhlo.token, tensor>) func @unpack_repack_same_tuple(%arg0: tuple, !mhlo.token, tensor>) -> tuple, !mhlo.token, tensor> { @@ -618,3 +899,533 @@ func @erase_dead_lhlo_constant_negative(%M : memref<4xf32>) -> memref<256x1024xf "lmhlo.constant"(%N) {value = dense<0.0> : tensor} : (memref<256x1024xf32>) -> () return %N : memref<256x1024xf32> } + +// CHECK-LABEL: func @fold_get_dimension_size +func @fold_get_dimension_size(%I : tensor<1x128x512xf32>) -> tensor { + %size = "mhlo.get_dimension_size"(%I) {dimension = 2 : i32} : (tensor<1x128x512xf32>) -> tensor + return %size : tensor + // CHECK-NEXT: %[[C:.*]] = mhlo.constant dense<512> : tensor + // CHECK-NEXT: return %[[C]] +} + +// CHECK-LABEL: func @fold_select_same +func @fold_select_same(%arg0 : tensor, %arg1 : tensor) -> tensor { + %1 = "mhlo.select"(%arg1, %arg0, %arg0) : (tensor, tensor, tensor) -> tensor + // CHECK: return %arg0 + return %1 : tensor +} + +// CHECK-LABEL: func @fold_select_first +func @fold_select_first(%arg0 : tensor, %arg1 : tensor) -> tensor { + %0 = mhlo.constant dense<1> : tensor + %1 = "mhlo.select"(%0, %arg0, %arg1) : (tensor, tensor, tensor) -> tensor + // CHECK: return %arg0 + return %1 : tensor +} + +// CHECK-LABEL: func @fold_select_second +func @fold_select_second(%arg0 : tensor, %arg1 : tensor) -> tensor { + %0 = mhlo.constant dense<0> : tensor + %1 = "mhlo.select"(%0, %arg0, %arg1) : (tensor, tensor, tensor) -> tensor + // CHECK: return %arg1 + return %1 : tensor +} + +// CHECK-LABEL: func @fold_select_vector +func @fold_select_vector(%arg0 : tensor<4xf32>, %arg1 : tensor<4xf32>) -> tensor<4xf32> { + %0 = mhlo.constant dense<1> : tensor<4xi1> + %1 = "mhlo.select"(%0, %arg0, %arg1) : (tensor<4xi1>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + // CHECK: return %arg0 + return %1 : tensor<4xf32> +} + +// CHECK-LABEL: gather_to_slice +func @gather_to_slice(%arg0: tensor<5x6x7xf32>) -> tensor<3x6x5xf32> { + %0 = constant dense<[1, 2]> : tensor<2xi32> + %1 = "mhlo.gather"(%arg0, %0) { + dimension_numbers = {collapsed_slice_dims = dense<> : tensor<0xi64>, + index_vector_dim = 0 : i64, + offset_dims = dense<[0, 1, 2]> : tensor<3xi64>, + start_index_map = dense<[0, 2]> : tensor<2xi64>}, + indices_are_sorted = false, + slice_sizes = dense<[3, 6, 5]> : tensor<3xi64>} : (tensor<5x6x7xf32>, tensor<2xi32>) -> tensor<3x6x5xf32> + return %1 : tensor<3x6x5xf32> + // CHECK: %[[RET:.*]] = "mhlo.slice"(%arg0) {limit_indices = dense<[4, 6, 7]> : tensor<3xi64>, start_indices = dense<[1, 0, 2]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<5x6x7xf32>) -> tensor<3x6x5xf32> + // CHECK: return %[[RET]] : tensor<3x6x5xf32> +} + +// CHECK-LABEL: gather_scalar_index_to_slice +func @gather_scalar_index_to_slice(%arg0: tensor<5x6x7xf32>) -> tensor<5x6x4xf32> { + %0 = constant dense<1> : tensor + %1 = "mhlo.gather"(%arg0, %0) { + dimension_numbers = {collapsed_slice_dims = dense<> : tensor<0xi64>, + index_vector_dim = 0 : i64, + offset_dims = dense<[0, 1, 2]> : tensor<3xi64>, + start_index_map = dense<[2]> : tensor<1xi64>}, + indices_are_sorted = false, + slice_sizes = dense<[5, 6, 4]> : tensor<3xi64>} : (tensor<5x6x7xf32>, tensor) -> tensor<5x6x4xf32> + return %1 : tensor<5x6x4xf32> + // CHECK: %[[RET:.*]] = "mhlo.slice"(%arg0) {limit_indices = dense<[5, 6, 5]> : tensor<3xi64>, start_indices = dense<[0, 0, 1]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<5x6x7xf32>) -> tensor<5x6x4xf32> + // CHECK: return %[[RET]] : tensor<5x6x4xf32> +} + +// CHECK-LABEL: gather_to_slice_reshape +func @gather_to_slice_reshape(%arg0: tensor<5x6x7xf32>) -> tensor<3x6xf32> { + %0 = constant dense<[1, 2]> : tensor<2xi32> + %1 = "mhlo.gather"(%arg0, %0) { + dimension_numbers = {collapsed_slice_dims = dense<[2]> : tensor<1xi64>, + index_vector_dim = 0 : i64, + offset_dims = dense<[0, 1, 2]> : tensor<3xi64>, + start_index_map = dense<[0, 2]> : tensor<2xi64>}, + indices_are_sorted = false, + slice_sizes = dense<[3, 6, 1]> : tensor<3xi64>} : (tensor<5x6x7xf32>, tensor<2xi32>) -> tensor<3x6xf32> + return %1 : tensor<3x6xf32> + // CHECK: %[[V0:.*]] = "mhlo.slice"(%arg0) {limit_indices = dense<[4, 6, 3]> : tensor<3xi64>, start_indices = dense<[1, 0, 2]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<5x6x7xf32>) -> tensor<3x6x1xf32> + // CHECK: %[[V1:.*]] = "mhlo.reshape"(%[[V0]]) : (tensor<3x6x1xf32>) -> tensor<3x6xf32> + // CHECK: return %[[V1]] : tensor<3x6xf32> +} + +// CHECK-LABEL: func @fold_and_same +func @fold_and_same(%arg0 : tensor<4xi32>) -> tensor<4xi32> { + %0 = "mhlo.and"(%arg0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + // CHECK: return %arg0 + return %0 : tensor<4xi32> +} + +// CHECK-LABEL: func @fold_and_ones +func @fold_and_ones(%arg0 : tensor<4xi32>) -> tensor<4xi32> { + %0 = mhlo.constant dense<-1> : tensor<4xi32> + %1 = "mhlo.and"(%0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + // CHECK: return %arg0 + return %1 : tensor<4xi32> +} + +// CHECK-LABEL: func @fold_and_zeros +func @fold_and_zeros(%arg0 : tensor<4xi32>) -> tensor<4xi32> { + %0 = mhlo.constant dense<0> : tensor<4xi32> + %1 = "mhlo.and"(%0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + // CHECK: return %0 + return %1 : tensor<4xi32> +} + +// CHECK-LABEL: func @fold_and_constant +func @fold_and_constant(%arg0 : tensor<4xi32>) -> tensor<4xi32> { + %0 = mhlo.constant dense<7> : tensor<4xi32> + // CHECK: mhlo.and + %1 = "mhlo.and"(%0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + return %1 : tensor<4xi32> +} + +// CHECK-LABEL: func @fold_and_constants +func @fold_and_constants() -> tensor<4xi32> { + %0 = mhlo.constant dense<[0, 1, 6, 3]> : tensor<4xi32> + %1 = mhlo.constant dense<[7, 3, 7, 2]> : tensor<4xi32> + %2 = "mhlo.and"(%0, %1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + // CHECK: %0 = mhlo.constant dense<[0, 1, 6, 2]> : tensor<4xi32> + // CHECK: return %0 + return %2 : tensor<4xi32> +} + +// CHECK-LABEL: func @fold_or_same +func @fold_or_same(%arg0 : tensor<4xi32>) -> tensor<4xi32> { + %0 = "mhlo.or"(%arg0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + // CHECK: return %arg0 + return %0 : tensor<4xi32> +} + +// CHECK-LABEL: func @fold_or_ones +func @fold_or_ones(%arg0 : tensor<4xi32>) -> tensor<4xi32> { + %0 = mhlo.constant dense<-1> : tensor<4xi32> + %1 = "mhlo.or"(%0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + // CHECK: return %0 + return %1 : tensor<4xi32> +} + +// CHECK-LABEL: func @fold_or_zeros +func @fold_or_zeros(%arg0 : tensor<4xi32>) -> tensor<4xi32> { + %0 = mhlo.constant dense<0> : tensor<4xi32> + %1 = "mhlo.or"(%0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + // CHECK: return %arg0 + return %1 : tensor<4xi32> +} + +// CHECK-LABEL: func @fold_or_constant +func @fold_or_constant(%arg0 : tensor<4xi32>) -> tensor<4xi32> { + %0 = mhlo.constant dense<7> : tensor<4xi32> + // CHECK: mhlo.or + %1 = "mhlo.or"(%0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + return %1 : tensor<4xi32> +} + +// CHECK-LABEL: func @fold_or_zeros_right +func @fold_or_zeros_right(%arg0 : tensor<4xi32>) -> tensor<4xi32> { + %0 = mhlo.constant dense<0> : tensor<4xi32> + %1 = "mhlo.or"(%arg0, %0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + // CHECK: return %arg0 + return %1 : tensor<4xi32> +} + +// CHECK-LABEL: func @fold_or_zeros_constants +func @fold_or_zeros_constants() -> tensor<4xi32> { + %0 = mhlo.constant dense<[0, 1, 6, 3]> : tensor<4xi32> + %1 = mhlo.constant dense<[7, 3, 7, 2]> : tensor<4xi32> + %2 = "mhlo.or"(%0, %1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + // CHECK: %0 = mhlo.constant dense<[7, 3, 7, 3]> : tensor<4xi32> + // CHECK: return %0 + return %2 : tensor<4xi32> +} + +// CHECK-LABEL: func @fold_xor_same +func @fold_xor_same(%arg0 : tensor<4xi32>) -> tensor<4xi32> { + %0 = "mhlo.xor"(%arg0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + // CHECK: %0 = mhlo.constant dense<0> : tensor<4xi32> + // CHECK: return %0 + return %0 : tensor<4xi32> +} + +// CHECK-LABEL: func @fold_xor_ones_left +func @fold_xor_ones_left(%arg0 : tensor<4xi32>) -> tensor<4xi32> { + %0 = mhlo.constant dense<-1> : tensor<4xi32> + // CHECK: mhlo.xor + %1 = "mhlo.xor"(%0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + return %1 : tensor<4xi32> +} + +// CHECK-LABEL: func @fold_xor_ones_right +func @fold_xor_ones_right(%arg0 : tensor<4xi32>) -> tensor<4xi32> { + %0 = mhlo.constant dense<-1> : tensor<4xi32> + // CHECK: mhlo.xor + %1 = "mhlo.xor"(%arg0, %0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + return %1 : tensor<4xi32> +} + +// CHECK-LABEL: func @fold_xor_zeros_left +func @fold_xor_zeros_left(%arg0 : tensor<4xi32>) -> tensor<4xi32> { + %0 = mhlo.constant dense<0> : tensor<4xi32> + %1 = "mhlo.xor"(%0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + // CHECK: return %arg0 + return %1 : tensor<4xi32> +} + +// CHECK-LABEL: func @fold_xor_zeros_right +func @fold_xor_zeros_right(%arg0 : tensor<4xi32>) -> tensor<4xi32> { + %0 = mhlo.constant dense<0> : tensor<4xi32> + %1 = "mhlo.xor"(%arg0, %0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + // CHECK: return %arg0 + return %1 : tensor<4xi32> +} + +// CHECK-LABEL: func @fold_xor_zeros_constants +func @fold_xor_zeros_constants() -> tensor<4xi32> { + %0 = mhlo.constant dense<[0, 1, 6, 3]> : tensor<4xi32> + %1 = mhlo.constant dense<[7, 3, 7, 2]> : tensor<4xi32> + %2 = "mhlo.xor"(%0, %1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + // CHECK: %0 = mhlo.constant dense<[7, 2, 1, 1]> : tensor<4xi32> + // CHECK: return %0 + return %2 : tensor<4xi32> +} + +// CHECK-LABEL: func @fold_negate_int +func @fold_negate_int() -> tensor<4xi32> { + %0 = mhlo.constant dense<[0, 1, 6, -3]> : tensor<4xi32> + // CHECK: mhlo.constant dense<[0, -1, -6, 3]> + %1 = "mhlo.negate"(%0) : (tensor<4xi32>) -> tensor<4xi32> + return %1 : tensor<4xi32> +} + +// CHECK-LABEL: func @fold_negate_float +func @fold_negate_float() -> tensor<4xf32> { + %0 = mhlo.constant dense<[0., 1., 6., -3.]> : tensor<4xf32> + // CHECK: mhlo.constant dense<[-0.000000e+00, -1.000000e+00, -6.000000e+00, 3.000000e+00]> + %1 = "mhlo.negate"(%0) : (tensor<4xf32>) -> tensor<4xf32> + return %1 : tensor<4xf32> +} + +// CHECK-LABEL: func @fold_sqrt_f32_constants +func @fold_sqrt_f32_constants() -> tensor<4xf32> { + %0 = mhlo.constant dense<1.0> : tensor<4xf32> + %1 = "mhlo.sqrt"(%0) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: mhlo.constant dense<1.000000e+00> : tensor<4xf32> + // CHECK-NOT: mhlo.sqrt + return %1 : tensor<4xf32> +} + +// CHECK-LABEL: func @fold_sqrt_f64_constants +func @fold_sqrt_f64_constants() -> tensor<4xf64> { + %0 = mhlo.constant dense<[1.0, 4.0, 9.0, 16.0]> : tensor<4xf64> + %1 = "mhlo.sqrt"(%0) : (tensor<4xf64>) -> tensor<4xf64> + // CHECK: mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf64> + // CHECK-NOT: mhlo.sqrt + return %1 : tensor<4xf64> +} + +// CHECK-LABEL: func @not_fold_sqrt_neg_constants +func @not_fold_sqrt_neg_constants() -> tensor<4xf32> { + %0 = mhlo.constant dense<-1.0> : tensor<4xf32> + %1 = "mhlo.sqrt"(%0) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: mhlo.constant dense<-1.000000e+00> : tensor<4xf32> + // CHECK: mhlo.sqrt + return %1 : tensor<4xf32> +} + +// CHECK-LABEL: @tensor_flow_scatter_v1_update +func @tensor_flow_scatter_v1_update() -> tensor<3x3xi32> { + %0 = constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32> + %1 = constant dense<[0, 2]> : tensor<2xi32> + %2 = constant dense<[[10, 20, 30], [70, 80, 90]]> : tensor<2x3xi32> + %3 = "mhlo.scatter"(%0, %1, %2) ( { + ^bb0(%arg0: tensor, %arg1: tensor): + "mhlo.return"(%arg1) : (tensor) -> () + }) {indices_are_sorted = false, + scatter_dimension_numbers = { + index_vector_dim = 1 : i64, + inserted_window_dims = dense<0> : tensor<1xi64>, + scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>, + update_window_dims = dense<[1]> : tensor<1xi64> + }, + unique_indices = false + } : (tensor<3x3xi32>, tensor<2xi32>, tensor<2x3xi32>) -> tensor<3x3xi32> + return %3 : tensor<3x3xi32> + // CHECK: mhlo.constant dense<[ + // CHECK-SAME: [10, 20, 30], [4, 5, 6], [70, 80, 90] + // CHECK-SAME: ]> : tensor<3x3xi32> +} + +// CHECK-LABEL: @tensor_flow_scatter_v2_update +func @tensor_flow_scatter_v2_update() -> tensor<3x3xi32> { + %0 = constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32> + %1 = constant dense<[0, 2]> : tensor<2xi32> + %2 = constant dense<[[10, 30], [40, 60], [70, 90]]> : tensor<3x2xi32> + %3 = "mhlo.scatter"(%0, %1, %2) ( { + ^bb0(%arg0: tensor, %arg1: tensor): + "mhlo.return"(%arg1) : (tensor) -> () + }) {indices_are_sorted = false, + scatter_dimension_numbers = { + index_vector_dim = 1 : i64, + inserted_window_dims = dense<1> : tensor<1xi64>, + scatter_dims_to_operand_dims = dense<1> : tensor<1xi64>, + update_window_dims = dense<[0]> : tensor<1xi64> + }, + unique_indices = false + } : (tensor<3x3xi32>, tensor<2xi32>, tensor<3x2xi32>) -> tensor<3x3xi32> + return %3 : tensor<3x3xi32> + // CHECK: mhlo.constant dense<[ + // CHECK-SAME: [10, 2, 30], [40, 5, 60], [70, 8, 90] + // CHECK-SAME: ]> : tensor<3x3xi32> +} + +// CHECK-LABEL: @tensor_flow_scatter_add +func @tensor_flow_scatter_add() -> tensor<3x3xi32> { + %0 = constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32> + %1 = constant dense<[0, 2]> : tensor<2xi32> + %2 = constant dense<[[10, 20, 30], [70, 80, 90]]> : tensor<2x3xi32> + %3 = "mhlo.scatter"(%0, %1, %2) ( { + ^bb0(%arg0: tensor, %arg1: tensor): + %4 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> (tensor) + "mhlo.return"(%4) : (tensor) -> () + }) {indices_are_sorted = false, + scatter_dimension_numbers = { + index_vector_dim = 1 : i64, + inserted_window_dims = dense<0> : tensor<1xi64>, + scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>, + update_window_dims = dense<[1]> : tensor<1xi64> + }, + unique_indices = false + } : (tensor<3x3xi32>, tensor<2xi32>, tensor<2x3xi32>) -> tensor<3x3xi32> + return %3 : tensor<3x3xi32> + // CHECK: mhlo.constant dense<[ + // CHECK-SAME: [11, 22, 33], [4, 5, 6], [77, 88, 99] + // CHECK-SAME: ]> : tensor<3x3xi32> +} + +// CHECK-LABEL: @tensor_flow_scatter_repeated +func @tensor_flow_scatter_repeated() -> tensor<3x3xi32> { + %0 = constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32> + %1 = constant dense<[1, 1]> : tensor<2xi32> + %2 = constant dense<[[10, 20, 30], [70, 80, 90]]> : tensor<2x3xi32> + %3 = "mhlo.scatter"(%0, %1, %2) ( { + ^bb0(%arg0: tensor, %arg1: tensor): + %4 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> (tensor) + "mhlo.return"(%4) : (tensor) -> () + }) {indices_are_sorted = false, + scatter_dimension_numbers = { + index_vector_dim = 1 : i64, + inserted_window_dims = dense<0> : tensor<1xi64>, + scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>, + update_window_dims = dense<[1]> : tensor<1xi64> + }, + unique_indices = false + } : (tensor<3x3xi32>, tensor<2xi32>, tensor<2x3xi32>) -> tensor<3x3xi32> + return %3 : tensor<3x3xi32> + // CHECK: mhlo.constant dense<[ + // CHECK-SAME: [1, 2, 3], [84, 105, 126], [7, 8, 9] + // CHECK-SAME: ]> : tensor<3x3xi32> +} + +// CHECK-LABEL: @tensor_flow_scatter_multiple_batch +func @tensor_flow_scatter_multiple_batch() -> tensor<3x3xi32> { + %0 = constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32> + %1 = constant dense<[[0, 2], [2, 1]]> : tensor<2x2xi32> + %2 = constant dense<[[[10, 30], [40, 60], [70, 90]], [[5, 5], [5, 5], [5, 5]]]> : tensor<2x3x2xi32> + %3 = "mhlo.scatter"(%0, %1, %2) ( { + ^bb0(%arg0: tensor, %arg1: tensor): + %4 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> (tensor) + "mhlo.return"(%4) : (tensor) -> () + }) {indices_are_sorted = false, + scatter_dimension_numbers = { + index_vector_dim = 2 : i64, + inserted_window_dims = dense<1> : tensor<1xi64>, + scatter_dims_to_operand_dims = dense<1> : tensor<1xi64>, + update_window_dims = dense<[1]> : tensor<1xi64> + }, + unique_indices = false + } : (tensor<3x3xi32>, tensor<2x2xi32>, tensor<2x3x2xi32>) -> tensor<3x3xi32> + return %3 : tensor<3x3xi32> + // CHECK: mhlo.constant dense<[ + // CHECK-SAME: [11, 7, 38], [44, 10, 71], [77, 13, 104] + // CHECK-SAME: ]> : tensor<3x3xi32> +} + +// CHECK-LABEL: @tensor_flow_scatter_nd +func @tensor_flow_scatter_nd() -> tensor<3x3x2xi32> { + %0 = constant dense<[[[-1, 1], [-2, 2], [-3, 3]], [[-4, 4], [-5, 5], [-6, 6]], [[-7, 7], [-8, 8], [-9, 9]]]> : tensor<3x3x2xi32> + %1 = constant dense<[[0, 0], [1, 0]]> : tensor<2x2xi32> + %2 = constant dense<[[-10, 10], [-40, 40]]> : tensor<2x2xi32> + %3 = "mhlo.scatter"(%0, %1, %2) ( { + ^bb0(%arg0: tensor, %arg1: tensor): + "mhlo.return"(%arg1) : (tensor) -> () + }) {indices_are_sorted = false, + scatter_dimension_numbers = { + index_vector_dim = 1 : i64, + inserted_window_dims = dense<[0, 1]> : tensor<2xi64>, + scatter_dims_to_operand_dims = dense<[0, 1]> : tensor<2xi64>, + update_window_dims = dense<[1]> : tensor<1xi64> + }, + unique_indices = false + } : (tensor<3x3x2xi32>, tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<3x3x2xi32> + return %3 : tensor<3x3x2xi32> + // CHECK: mhlo.constant dense<[ + // CHECK-SAME: [-10, 10], [-2, 2], [-3, 3] + // CHECK-SAME: [-40, 40], [-5, 5], [-6, 6] + // CHECK-SAME: [-7, 7], [-8, 8], [-9, 9] + // CHECK-SAME: ]> : tensor<3x3x2xi32> +} + +// CHECK-LABEL: @tensor_flow_scatter_nd_index_vector +func @tensor_flow_scatter_nd_index_vector() -> tensor<3x3x2xi32> { + %0 = constant dense<[[[-1, 1], [-2, 2], [-3, 3]], [[-4, 4], [-5, 5], [-6, 6]], [[-7, 7], [-8, 8], [-9, 9]]]> : tensor<3x3x2xi32> + %1 = constant dense<[[0, 0], [1, 0]]> : tensor<2x2xi32> + %2 = constant dense<[[-10, 10], [-20, 20]]> : tensor<2x2xi32> + %3 = "mhlo.scatter"(%0, %1, %2) ( { + ^bb0(%arg0: tensor, %arg1: tensor): + "mhlo.return"(%arg1) : (tensor) -> () + }) {indices_are_sorted = false, + scatter_dimension_numbers = { + index_vector_dim = 0 : i64, + inserted_window_dims = dense<[0, 1]> : tensor<2xi64>, + scatter_dims_to_operand_dims = dense<[0, 1]> : tensor<2xi64>, + update_window_dims = dense<[1]> : tensor<1xi64> + }, + unique_indices = false + } : (tensor<3x3x2xi32>, tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<3x3x2xi32> + return %3 : tensor<3x3x2xi32> + // CHECK: mhlo.constant dense<[ + // CHECK-SAME: [-20, 20], [-10, 10], [-3, 3] + // CHECK-SAME: [-4, 4], [-5, 5], [-6, 6] + // CHECK-SAME: [-7, 7], [-8, 8], [-9, 9] + // CHECK-SAME: ]> : tensor<3x3x2xi32> +} + +// CHECK-LABEL: @scatter_batch_dus +func @scatter_batch_dus() -> tensor<3x3xi32> { + %0 = constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32> + %1 = constant dense<[[2, 1], [1, 1]]> : tensor<2x2xi32> + %2 = constant dense<[[[10]], [[20]]]> : tensor<2x1x1xi32> + %3 = "mhlo.scatter"(%0, %1, %2) ( { + ^bb0(%arg0: tensor, %arg1: tensor): + "mhlo.return"(%arg1) : (tensor) -> () + }) {indices_are_sorted = false, + scatter_dimension_numbers = { + index_vector_dim = 0 : i64, + inserted_window_dims = dense<> : tensor<0xi64>, + scatter_dims_to_operand_dims = dense<[0, 1]> : tensor<2xi64>, + update_window_dims = dense<[1, 2]> : tensor<2xi64> + }, + unique_indices = false + } : (tensor<3x3xi32>, tensor<2x2xi32>, tensor<2x1x1xi32>) -> tensor<3x3xi32> + return %3 : tensor<3x3xi32> + // CHECK: mhlo.constant dense<[ + // CHECK-SAME: [1, 2, 3], [4, 20, 6], [7, 10, 9] + // CHECK-SAME: ]> : tensor<3x3xi32> +} + +// CHECK-LABEL: @scatter_no_update_window_dim +func @scatter_no_update_window_dim() -> tensor<3xi32> { + %0 = constant dense<[0, 1, 2]> : tensor<3xi32> + %1 = constant dense<[[[0], [1]], [[2], [1]]]> : tensor<2x2x1xi32> + %2 = constant dense<[[10, 20], [30, 40]]> : tensor<2x2xi32> + %3 = "mhlo.scatter"(%0, %1, %2) ( { + ^bb0(%arg0: tensor, %arg1: tensor): + %4 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> (tensor) + "mhlo.return"(%4) : (tensor) -> () + }) {indices_are_sorted = false, + scatter_dimension_numbers = { + index_vector_dim = 2 : i64, + inserted_window_dims = dense<0> : tensor<1xi64>, + scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>, + update_window_dims = dense<> : tensor<0xi64> + }, + unique_indices = false + } : (tensor<3xi32>, tensor<2x2x1xi32>, tensor<2x2xi32>) -> tensor<3xi32> + return %3 : tensor<3xi32> + // CHECK: mhlo.constant dense<[10, 61, 32]> : tensor<3xi32> +} + +// CHECK-LABEL: @scatter_negative_index +func @scatter_negative_index() -> tensor<3x3xi32> { + %0 = constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32> + %1 = constant dense<[0, -1]> : tensor<2xi32> + %2 = constant dense<[[10, 20, 30], [70, 80, 90]]> : tensor<2x3xi32> + %3 = "mhlo.scatter"(%0, %1, %2) ( { + ^bb0(%arg0: tensor, %arg1: tensor): + "mhlo.return"(%arg1) : (tensor) -> () + }) {indices_are_sorted = false, + scatter_dimension_numbers = { + index_vector_dim = 1 : i64, + inserted_window_dims = dense<0> : tensor<1xi64>, + scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>, + update_window_dims = dense<[1]> : tensor<1xi64> + }, + unique_indices = false + } : (tensor<3x3xi32>, tensor<2xi32>, tensor<2x3xi32>) -> tensor<3x3xi32> + return %3 : tensor<3x3xi32> + // CHECK: constant dense<[ + // CHECK-SAME: [1, 2, 3], [4, 5, 6], [7, 8, 9] + // CHECK-SAME: ]> : tensor<3x3xi32> + // CHECK: "mhlo.scatter" +} + +// CHECK-LABEL: @scatter_out_of_bound +func @scatter_out_of_bound() -> tensor<3x3xi32> { + %0 = constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32> + %1 = constant dense<[1, 5]> : tensor<2xi32> + %2 = constant dense<[[10, 20, 30], [70, 80, 90]]> : tensor<2x3xi32> + %3 = "mhlo.scatter"(%0, %1, %2) ( { + ^bb0(%arg0: tensor, %arg1: tensor): + "mhlo.return"(%arg1) : (tensor) -> () + }) {indices_are_sorted = false, + scatter_dimension_numbers = { + index_vector_dim = 1 : i64, + inserted_window_dims = dense<0> : tensor<1xi64>, + scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>, + update_window_dims = dense<[1]> : tensor<1xi64> + }, + unique_indices = false + } : (tensor<3x3xi32>, tensor<2xi32>, tensor<2x3xi32>) -> tensor<3x3xi32> + return %3 : tensor<3x3xi32> + // CHECK: constant dense<[ + // CHECK-SAME: [1, 2, 3], [4, 5, 6], [7, 8, 9] + // CHECK-SAME: ]> : tensor<3x3xi32> + // CHECK: "mhlo.scatter" +} + diff --git a/tensorflow/compiler/mlir/hlo/tests/chlo_infer_shape_type_methods.mlir b/tensorflow/compiler/mlir/hlo/tests/chlo_infer_shape_type_methods.mlir index d226c92858a..0738459f8b6 100644 --- a/tensorflow/compiler/mlir/hlo/tests/chlo_infer_shape_type_methods.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/chlo_infer_shape_type_methods.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt -mhlo-test-infer-shaped-type-methods -allow-unregistered-dialect -split-input-file -verify-diagnostics %s -o - | FileCheck %s +// RUN: mlir-hlo-opt --mhlo-test-infer-shaped-type-methods --allow-unregistered-dialect --split-input-file %s | FileCheck %s // CHECK-LABEL: @broadcast_add // Note that all broadcast_ops are expanded from the same template, so diff --git a/tensorflow/compiler/mlir/hlo/tests/chlo_legalize_to_hlo_broadcasts.mlir b/tensorflow/compiler/mlir/hlo/tests/chlo_legalize_to_hlo_broadcasts.mlir index 9670372a864..60ec26f48a1 100644 --- a/tensorflow/compiler/mlir/hlo/tests/chlo_legalize_to_hlo_broadcasts.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/chlo_legalize_to_hlo_broadcasts.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt -mhlo-test-chlo-legalize-to-hlo -cse -split-input-file -verify-diagnostics %s -o - | FileCheck %s +// RUN: mlir-hlo-opt -chlo-legalize-to-hlo -cse -split-input-file -verify-diagnostics %s -o - | FileCheck %s // Check the non-broadcast case for each registered op, then just check a // representative op for detailed broadcast semantics. @@ -253,7 +253,7 @@ func @addScalarUnranked(%arg0: tensor, %arg1: tensor<*xf32>) -> tensor<*xf3 // to a 1D tensor. // CHECK: %[[SHAPE_1:.*]] = shape.shape_of %[[ARG_1]] : tensor<*xf32> // CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_1]] : tensor -> index -// CHECK: %[[SIZE_TENSOR:.*]] = tensor_from_elements(%[[NUM_ELEMENTS]]) : tensor<1xindex> +// CHECK: %[[SIZE_TENSOR:.*]] = tensor_from_elements %[[NUM_ELEMENTS]] : tensor<1xindex> // CHECK: %[[RESHAPED:.*]] = "mhlo.dynamic_reshape"(%[[ARG_1]], %[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor // The assuming region is part of the second stage of lowering // with ranked broadcasting logic. @@ -288,7 +288,7 @@ func @addUnrankedScalar(%arg0: tensor<*xf32>, %arg1: tensor) -> tensor<*xf3 // to a 1D tensor. // CHECK: %[[SHAPE_0:.*]] = shape.shape_of %[[ARG_0]] : tensor<*xf32> // CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_0]] : tensor -> index -// CHECK: %[[SIZE_TENSOR:.*]] = tensor_from_elements(%[[NUM_ELEMENTS]]) : tensor<1xindex> +// CHECK: %[[SIZE_TENSOR:.*]] = tensor_from_elements %[[NUM_ELEMENTS]] : tensor<1xindex> // CHECK: %[[RESHAPED:.*]] = "mhlo.dynamic_reshape"(%[[ARG_0]], %[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor // The assuming region is part of the second stage of lowering // with ranked broadcasting logic. @@ -325,7 +325,7 @@ func @addUnrankedUnranked( // CHECK: %[[LHS_IS_SCALAR:.*]] = cmpi "eq", %[[RANK_LHS]], %[[C0]] : index // Handle scalar LHS case // CHECK: %[[VAL_8:.*]] = scf.if %[[LHS_IS_SCALAR]] -> (tensor<*xf32>) { -// CHECK: %[[SCALAR_LHS:.*]] = "mhlo.reshape"(%[[LHS]]) : (tensor<*xf32>) -> tensor +// CHECK: %[[SCALAR_LHS:.*]] = tensor_cast %[[LHS]] : tensor<*xf32> to tensor // CHECK: %[[VAL_10:.*]] = chlo.broadcast_add %[[SCALAR_LHS]], %[[RHS]] : (tensor, tensor<*xf32>) -> tensor<*xf32> // CHECK: scf.yield %[[VAL_10]] : tensor<*xf32> // CHECK: } else { @@ -334,7 +334,7 @@ func @addUnrankedUnranked( // CHECK: %[[RHS_IS_SCALAR:.*]] = cmpi "eq", %[[RANK_RHS]], %[[C0]] : index // Handle scalar RHS case // CHECK: %[[VAL_14:.*]] = scf.if %[[RHS_IS_SCALAR]] -> (tensor<*xf32>) { -// CHECK: %[[SCALAR_RHS:.*]] = "mhlo.reshape"(%[[RHS]]) : (tensor<*xf32>) -> tensor +// CHECK: %[[SCALAR_RHS:.*]] = tensor_cast %[[RHS]] : tensor<*xf32> to tensor // CHECK: %[[VAL_16:.*]] = chlo.broadcast_add %[[LHS]], %[[SCALAR_RHS]] : (tensor<*xf32>, tensor) -> tensor<*xf32> // CHECK: scf.yield %[[VAL_16]] : tensor<*xf32> // CHECK: } else { @@ -353,10 +353,12 @@ func @addUnrankedUnranked( // Handle rank 2 specialization // CHECK: %[[VAL_26:.*]] = scf.if %[[GREATEST_RANK_IS_2]] -> (tensor<*xf32>) { // CHECK: %[[CONST_SHAPE_2:.*]] = shape.const_shape [1, 1] -// CHECK: %[[BROADCASTED_LHS_2:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor, tensor<2xindex> -> tensor<2xindex> -// CHECK: %[[BROADCASTED_RHS_2:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor, tensor<2xindex> -> tensor<2xindex> -// CHECK: %[[RESHAPED_LHS_2:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[BROADCASTED_LHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor -// CHECK: %[[RESHAPED_RHS_2:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[BROADCASTED_RHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor +// CHECK: %[[BROADCASTED_LHS_2:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor, tensor<2xindex> -> tensor +// CHECK: %[[CASTED_LHS_2:.*]] = tensor_cast %[[BROADCASTED_LHS_2]] : tensor to tensor<2xindex> +// CHECK: %[[BROADCASTED_RHS_2:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor, tensor<2xindex> -> tensor +// CHECK: %[[CASTED_RHS_2:.*]] = tensor_cast %[[BROADCASTED_RHS_2]] : tensor to tensor<2xindex> +// CHECK: %[[RESHAPED_LHS_2:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor +// CHECK: %[[RESHAPED_RHS_2:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor // CHECK: %[[RESULT_RANK_2:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_2]], %[[RESHAPED_RHS_2]] : (tensor, tensor) -> tensor // CHECK: %[[RESULT_2:.*]] = tensor_cast %[[RESULT_RANK_2]] : tensor to tensor<*xf32> // CHECK: scf.yield %[[RESULT_2]] : tensor<*xf32> @@ -366,10 +368,12 @@ func @addUnrankedUnranked( // Handle rank 3 specialization // CHECK: %[[VAL_34:.*]] = scf.if %[[GREATEST_RANK_IS_3]] -> (tensor<*xf32>) { // CHECK: %[[CONST_SHAPE_3:.*]] = shape.const_shape [1, 1, 1] -// CHECK: %[[BROADCASTED_LHS_3:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor, tensor<3xindex> -> tensor<3xindex> -// CHECK: %[[BROADCASTED_RHS_3:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor, tensor<3xindex> -> tensor<3xindex> -// CHECK: %[[RESHAPED_LHS_3:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[BROADCASTED_LHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor -// CHECK: %[[RESHAPED_RHS_3:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[BROADCASTED_RHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor +// CHECK: %[[BROADCASTED_LHS_3:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor, tensor<3xindex> -> tensor +// CHECK: %[[CASTED_LHS_3:.*]] = tensor_cast %[[BROADCASTED_LHS_3]] : tensor to tensor<3xindex> +// CHECK: %[[BROADCASTED_RHS_3:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor, tensor<3xindex> -> tensor +// CHECK: %[[CASTED_RHS_3:.*]] = tensor_cast %[[BROADCASTED_RHS_3]] : tensor to tensor<3xindex> +// CHECK: %[[RESHAPED_LHS_3:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor +// CHECK: %[[RESHAPED_RHS_3:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor // CHECK: %[[RESULT_RANK_3:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_3]], %[[RESHAPED_RHS_3]] : (tensor, tensor) -> tensor // CHECK: %[[RESULT_3:.*]] = tensor_cast %[[RESULT_RANK_3]] : tensor to tensor<*xf32> // CHECK: scf.yield %[[RESULT_3]] : tensor<*xf32> @@ -379,10 +383,12 @@ func @addUnrankedUnranked( // Handle rank 4 specialization // CHECK: %[[VAL_42:.*]] = scf.if %[[GREATEST_RANK_IS_4]] -> (tensor<*xf32>) { // CHECK: %[[CONST_SHAPE_4:.*]] = shape.const_shape [1, 1, 1, 1] -// CHECK: %[[BROADCASTED_LHS_4:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor, tensor<4xindex> -> tensor<4xindex> -// CHECK: %[[BROADCASTED_RHS_4:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor, tensor<4xindex> -> tensor<4xindex> -// CHECK: %[[RESHAPED_LHS_4:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[BROADCASTED_LHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor -// CHECK: %[[RESHAPED_RHS_4:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[BROADCASTED_RHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor +// CHECK: %[[BROADCASTED_LHS_4:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor, tensor<4xindex> -> tensor +// CHECK: %[[CASTED_LHS_4:.*]] = tensor_cast %[[BROADCASTED_LHS_4]] : tensor to tensor<4xindex> +// CHECK: %[[BROADCASTED_RHS_4:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor, tensor<4xindex> -> tensor +// CHECK: %[[CASTED_RHS_4:.*]] = tensor_cast %[[BROADCASTED_RHS_4]] : tensor to tensor<4xindex> +// CHECK: %[[RESHAPED_LHS_4:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor +// CHECK: %[[RESHAPED_RHS_4:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor // CHECK: %[[RESULT_RANK_4:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_4]], %[[RESHAPED_RHS_4]] : (tensor, tensor) -> tensor // CHECK: %[[RESULT_4:.*]] = tensor_cast %[[RESULT_RANK_4]] : tensor to tensor<*xf32> // CHECK: scf.yield %[[RESULT_4]] : tensor<*xf32> @@ -392,10 +398,12 @@ func @addUnrankedUnranked( // Handle rank 5 specialization // CHECK: %[[VAL_50:.*]] = scf.if %[[GREATEST_RANK_IS_5]] -> (tensor<*xf32>) { // CHECK: %[[CONST_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1] -// CHECK: %[[BROADCASTED_LHS_5:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor, tensor<5xindex> -> tensor<5xindex> -// CHECK: %[[BROADCASTED_RHS_5:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor, tensor<5xindex> -> tensor<5xindex> -// CHECK: %[[RESHAPED_LHS_5:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[BROADCASTED_LHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor -// CHECK: %[[RESHAPED_RHS_5:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[BROADCASTED_RHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor +// CHECK: %[[BROADCASTED_LHS_5:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor, tensor<5xindex> -> tensor +// CHECK: %[[CASTED_LHS_5:.*]] = tensor_cast %[[BROADCASTED_LHS_5]] : tensor to tensor<5xindex> +// CHECK: %[[BROADCASTED_RHS_5:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor, tensor<5xindex> -> tensor +// CHECK: %[[CASTED_RHS_5:.*]] = tensor_cast %[[BROADCASTED_RHS_5]] : tensor to tensor<5xindex> +// CHECK: %[[RESHAPED_LHS_5:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor +// CHECK: %[[RESHAPED_RHS_5:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor // CHECK: %[[RESULT_RANK_5:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_5]], %[[RESHAPED_RHS_5]] : (tensor, tensor) -> tensor // CHECK: %[[RESULT_5:.*]] = tensor_cast %[[RESULT_RANK_5]] : tensor to tensor<*xf32> // CHECK: scf.yield %[[RESULT_5]] : tensor<*xf32> @@ -405,10 +413,12 @@ func @addUnrankedUnranked( // Handle rank 6 specialization // CHECK: %[[VAL_58:.*]] = scf.if %[[GREATEST_RANK_IS_6]] -> (tensor<*xf32>) { // CHECK: %[[CONST_SHAPE_6:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1] -// CHECK: %[[BROADCASTED_LHS_6:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor, tensor<6xindex> -> tensor<6xindex> -// CHECK: %[[BROADCASTED_RHS_6:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor, tensor<6xindex> -> tensor<6xindex> -// CHECK: %[[RESHAPED_LHS_6:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[BROADCASTED_LHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor -// CHECK: %[[RESHAPED_RHS_6:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[BROADCASTED_RHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor +// CHECK: %[[BROADCASTED_LHS_6:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor, tensor<6xindex> -> tensor +// CHECK: %[[CASTED_LHS_6:.*]] = tensor_cast %[[BROADCASTED_LHS_6]] : tensor to tensor<6xindex> +// CHECK: %[[BROADCASTED_RHS_6:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor, tensor<6xindex> -> tensor +// CHECK: %[[CASTED_RHS_6:.*]] = tensor_cast %[[BROADCASTED_RHS_6]] : tensor to tensor<6xindex> +// CHECK: %[[RESHAPED_LHS_6:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor +// CHECK: %[[RESHAPED_RHS_6:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor // CHECK: %[[RESULT_RANK_6:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_6]], %[[RESHAPED_RHS_6]] : (tensor, tensor) -> tensor // CHECK: %[[RESULT_6:.*]] = tensor_cast %[[RESULT_RANK_6]] : tensor to tensor<*xf32> // CHECK: scf.yield %[[RESULT_6]] : tensor<*xf32> diff --git a/tensorflow/compiler/mlir/hlo/tests/chlo_legalize_to_mhlo.mlir b/tensorflow/compiler/mlir/hlo/tests/chlo_legalize_to_mhlo.mlir new file mode 100644 index 00000000000..2bec91203f9 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/tests/chlo_legalize_to_mhlo.mlir @@ -0,0 +1,26 @@ +// RUN: mlir-hlo-opt --chlo-legalize-to-hlo --split-input-file %s | FileCheck %s + +// Lower statically shaped `constant_like` to constant. +// CHECK-LABEL: @constant_like_static_shape +func @constant_like_static_shape(%arg : tensor<1x2xi64>) -> tensor<1x2xf32> { + // CHECK: %[[RESULT:.*]] = mhlo.constant dense<3.200000e+00> : tensor<1x2xf32> + // CHECK: return %[[RESULT]] + %result = "chlo.constant_like"(%arg) { value = 3.2 : f32 } + : (tensor<1x2xi64>) -> tensor<1x2xf32> + return %result : tensor<1x2xf32> +} + +// Lower dynamically shaped `constant_like` to broadcasted constant. +// CHECK-LABEL: constant_like_dynamic_shape +// CHECK-SAME: (%[[ARG:.*]]: tensor) +func @constant_like_dynamic_shape(%arg : tensor) -> tensor { + // CHECK: %[[CONSTANT:.*]] = mhlo.constant dense<3.200000e+00> : tensor + // CHECK: %[[UNCASTED_SHAPE:.*]] = shape.shape_of %[[ARG]] : tensor -> tensor + // CHECK: %[[SHAPE:.*]] = tensor_cast %[[UNCASTED_SHAPE]] : tensor to tensor<2xindex> + // CHECK: %[[BROADCASTED_CONSTANT:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[CONSTANT]], %[[SHAPE]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<2xindex>) -> tensor + // CHECK: return %[[BROADCASTED_CONSTANT]] : tensor + %result = "chlo.constant_like"(%arg) { value = 3.2 : f32 } + : (tensor) -> tensor + return %result : tensor +} + diff --git a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo.mlir b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo.mlir index 018711e33cb..f6fdc4439bb 100644 --- a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo.mlir @@ -170,7 +170,7 @@ func @dyn_broadcast(%operand: memref) { // BOTH-SAME: (%[[OPERAND:.*]]: memref) %tensor_operand = tensor_load %operand : memref %c1 = constant 1 : i64 - %shape = tensor_from_elements(%c1, %c1, %c1) : tensor<3xi64> + %shape = tensor_from_elements %c1, %c1, %c1 : tensor<3xi64> %tensor_result = "mhlo.dynamic_broadcast_in_dim"(%tensor_operand, %shape) { broadcast_dimensions = dense<[1, 2]> : tensor<2xi64> } : (tensor, tensor<3xi64>) -> tensor @@ -236,6 +236,21 @@ func @complex(%real: memref<2x2xf32>, // ----- +// BOTH-LABEL: func @complex_dyn +func @complex_dyn(%real: memref, + %imag: memref, + %result: memref>) { + %tensor_real = tensor_load %real : memref + %tensor_imag = tensor_load %imag : memref + %tensor_result = "mhlo.complex"(%tensor_real, %tensor_imag) + : (tensor, tensor) -> tensor> + // BOTH: "lmhlo.complex"(%{{.*}}, %{{.*}}) + tensor_store %tensor_result, %result : memref> + return +} + +// ----- + // BOTH-LABEL: func @real func @real(%operand: memref<2x2xcomplex>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xcomplex> @@ -248,6 +263,18 @@ func @real(%operand: memref<2x2xcomplex>, %result: memref<2x2xf32>) { // ----- +// BOTH-LABEL: func @real_dyn +func @real_dyn(%operand: memref>, %result: memref) { + %tensor_operand = tensor_load %operand : memref> + %tensor_result = "mhlo.real"(%tensor_operand) + : (tensor>) -> tensor + // BOTH: "lmhlo.real"(%{{.*}}, %{{.*}}) + tensor_store %tensor_result, %result : memref + return +} + +// ----- + // BOTH-LABEL: func @imag func @imag(%operand: memref<2x2xcomplex>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xcomplex> @@ -260,6 +287,18 @@ func @imag(%operand: memref<2x2xcomplex>, %result: memref<2x2xf32>) { // ----- +// BOTH-LABEL: func @imag_dyn +func @imag_dyn(%operand: memref>, %result: memref) { + %tensor_operand = tensor_load %operand : memref> + %tensor_result = "mhlo.imag"(%tensor_operand) + : (tensor>) -> tensor + // BOTH: "lmhlo.imag"(%{{.*}}, %{{.*}}) + tensor_store %tensor_result, %result : memref + return +} + +// ----- + // BOTH-LABEL: func @iota func @iota(%result: memref<10xi32>) { %tensor_result = "mhlo.iota"() @@ -320,6 +359,18 @@ func @cos(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { // ----- +// BOTH-LABEL: func @floor +func @floor(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { + %tensor_operand = tensor_load %operand : memref<2x2xf32> + %tensor_result = "mhlo.floor"(%tensor_operand) + : (tensor<2x2xf32>) -> tensor<2x2xf32> + // BOTH: "lmhlo.floor"(%{{.*}}, %{{.*}}) + tensor_store %tensor_result, %result : memref<2x2xf32> + return +} + +// ----- + // BOTH-LABEL: func @neg func @neg(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> @@ -332,6 +383,18 @@ func @neg(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { // ----- +// BOTH-LABEL: func @not +func @not(%operand: memref<2x2xi32>, %result: memref<2x2xi32>) { + %tensor_operand = tensor_load %operand : memref<2x2xi32> + %tensor_result = "mhlo.not"(%tensor_operand) + : (tensor<2x2xi32>) -> tensor<2x2xi32> + // BOTH: "lmhlo.not"(%{{.*}}, %{{.*}}) + tensor_store %tensor_result, %result : memref<2x2xi32> + return +} + +// ----- + // BOTH-LABEL: func @rsqrt func @rsqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> @@ -404,7 +467,7 @@ func @add_dyn(%lhs: tensor, %rhs: tensor) { // BOTH: %[[C1:.*]] = constant 1 : index // BOTH: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref // BOTH: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64 - // BOTH: %[[SHAPE:.*]] = tensor_from_elements(%[[IC0]], %[[IC1]]) : tensor<2xi64> + // BOTH: %[[SHAPE:.*]] = tensor_from_elements %[[IC0]], %[[IC1]] : tensor<2xi64> // BOTH: %[[C0_:.*]] = constant 0 : index // BOTH: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0_]]] : tensor<2xi64> // BOTH: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index @@ -429,7 +492,7 @@ func @tanh_dyn(%arg0: tensor) { // BOTH: %[[C1:.*]] = constant 1 : index // BOTH: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref // BOTH: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64 - // BOTH: %[[SHAPE:.*]] = tensor_from_elements(%[[IC0]], %[[IC1]]) : tensor<2xi64> + // BOTH: %[[SHAPE:.*]] = tensor_from_elements %[[IC0]], %[[IC1]] : tensor<2xi64> // BOTH: %[[C0_:.*]] = constant 0 : index // BOTH: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0_]]] : tensor<2xi64> // BOTH: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index @@ -448,7 +511,13 @@ func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> { // PRE-SAME: (%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[TYPE]]) // ESC-SAME: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]] // BOTH-NEXT: %[[ALLOC:.*]] = alloc -// BOTH: "lmhlo.dot"(%[[ARG0]], %[[ARG0]], %[[ALLOC]]) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> () +// BOTH: "lmhlo.dot"(%[[ARG0]], %[[ARG0]], %[[ALLOC]]) { +// dot_dimension_numbers = { +// lhs_batching_dimensions = dense<> : tensor<0xi64>, +// lhs_contracting_dimensions = dense<1> : tensor<1xi64>, +// rhs_batching_dimensions = dense<> : tensor<0xi64>, +// rhs_contracting_dimensions = dense<0> : tensor<1xi64>}} +// : ([[TYPE]], [[TYPE]], [[TYPE]]) -> () %dot = "mhlo.dot"(%arg0, %arg0) : (tensor<1024x1024xf32>, tensor<1024x1024xf32>) -> tensor<1024x1024xf32> // PRE: "lmhlo.copy"(%[[ALLOC]], %[[RESULT]]) @@ -510,3 +579,63 @@ func @reduce(%arg0: tensor<1x8xf32>, %arg1: tensor) -> tensor<1xf32> { : (tensor<1x8xf32>, tensor) -> tensor<1xf32> return %0 : tensor<1xf32> } + +// ----- + +// BOTH-LABEL: func @transpose +func @transpose(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { + %tensor_operand = tensor_load %operand : memref<2x2xf32> + %tensor_result = "mhlo.transpose"(%tensor_operand) {permutation = dense<[1, 0]> : tensor<2xi64>} + : (tensor<2x2xf32>) -> tensor<2x2xf32> + // BOTH: "lmhlo.transpose"(%{{.*}}, %{{.*}}) {permutation = dense<[1, 0]> : tensor<2xi64>} + // BOTH-NOT: tensor_store + tensor_store %tensor_result, %result : memref<2x2xf32> + return +} + +// ----- + +// BOTH-LABEL: func @custom_call +// BOTH-SAME:([[ARG0:%.*]]: memref<2x2xf32>, [[ARG1:%.*]]: memref<2x3xf32>, [[RESULT:%.*]]: memref<4x4xf16>) +func @custom_call(%arg0: memref<2x2xf32>, %arg1: memref<2x3xf32>, %result: memref<4x4xf16>) { + %arg0_tensor = tensor_load %arg0 : memref<2x2xf32> + %arg1_tensor = tensor_load %arg1 : memref<2x3xf32> + // BOTH: "lmhlo.custom_call"([[ARG0]], [[ARG1]], %{{.*}}) {backend_config = "", call_target_name = "foo", has_side_effect = false} + %result_tensor = "mhlo.custom_call"(%arg0_tensor, %arg1_tensor) + {backend_config = "", call_target_name = "foo", has_side_effect = false} + : (tensor<2x2xf32>, tensor<2x3xf32>) -> tensor<4x4xf16> + tensor_store %result_tensor, %result: memref<4x4xf16> + return +} + +// ---- + +// BOTH-LABEL: func @isfinite +func @isfinite(%arg0: memref<2x2xf32>, %result: memref<2x2xi1>) { + %arg0_tensor = tensor_load %arg0 : memref<2x2xf32> + // BOTH: "lmhlo.is_finite"(%{{.*}}, %{{.*}}) + %result_tensor = "mhlo.is_finite"(%arg0_tensor) : (tensor<2x2xf32>) -> tensor<2x2xi1> + tensor_store %result_tensor, %result: memref<2x2xi1> + return +} + +// ----- + +// Test that assuming ops propagate memref types. +// BOTH-LABEL: func @shape_assuming_memref +func @shape_assuming_memref(%arg0: tensor) -> tensor { + %0 = mhlo.constant dense<0.000000e+00> : tensor + %1 = shape.const_witness true + // BOTH: shape.assuming %{{.*}} -> (memref) + %2 = shape.assuming %1 -> (tensor) { + %3 = shape.shape_of %arg0 : tensor -> tensor + %4 = tensor_cast %3 : tensor to tensor<1xindex> + %5 = "mhlo.dynamic_broadcast_in_dim"(%0, %4) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<1xindex>) -> tensor + %6 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %4) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor, tensor<1xindex>) -> tensor + // BOTH: "lmhlo.maximum"(%6, %9, %20) : (memref, memref, memref) -> () + %7 = mhlo.maximum %5, %6 : tensor + // BOTH: shape.assuming_yield %{{.*}} : memref + shape.assuming_yield %7 : tensor + } + return %2 : tensor +} diff --git a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir index aecf612962a..91490b43f95 100644 --- a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir @@ -152,6 +152,16 @@ func @float_ceil(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { // ----- +// CHECK-LABEL: func @floor +func @floor(%input: tensor<2x2xf32>) -> tensor<2x2xf32> { + // CHECK: linalg.generic + // CHECK: floorf + %0 = "mhlo.floor"(%input) : (tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} + +// ----- + // CHECK-LABEL: func @float_neg func @float_neg(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: linalg.generic @@ -242,6 +252,20 @@ func @copy(%input: tensor<2x4x8xf32>) -> tensor<2x4x8xf32> { // ----- +// CHECK-LABEL: func @is_finte +func @is_finte(%input: tensor<2x2xf32>) -> tensor<2x2xi1> { + %0 = "mhlo.is_finite"(%input) : (tensor<2x2xf32>) -> tensor<2x2xi1> + return %0 : tensor<2x2xi1> +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32 +// CHECK-NEXT: %[[POS_INF:.+]] = constant 0x7F800000 : f32 +// CHECK-NEXT: %[[ABS_X:.+]] = absf %[[OPERAND_IN]] : f32 +// CHECK-NEXT: %[[RESULT:.+]] = cmpf "one", %[[ABS_X]], %[[POS_INF]] : f32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : i1 + +// ----- + // CHECK-LABEL: func @select func @select(%pred: tensor<2x2xi1>, %lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { @@ -385,6 +409,28 @@ func @reshape_3D_4D(%arg0: tensor<1x49x16xf32>) -> tensor<1x784x1x1xf32> { // ----- +// CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-LABEL: func @reshape1_4D_4D +func @reshape1_4D_4D(%arg0: tensor<4x512x1x1xi32>) -> tensor<1x4x1x512xi32> { + %0 = "mhlo.reshape"(%arg0) : (tensor<4x512x1x1xi32>) -> tensor<1x4x1x512xi32> + return %0 : tensor<1x4x1x512xi32> +} +// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP]]] +// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP]]] + +// ----- + +// CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-LABEL: func @reshape2_4D_4D +func @reshape2_4D_4D(%arg0: tensor<4x1x1x1024xi32>) -> tensor<4x1024x1x1xi32> { + %0 = "mhlo.reshape"(%arg0) : (tensor<4x1x1x1024xi32>) -> tensor<4x1024x1x1xi32> + return %0 : tensor<4x1024x1x1xi32> +} +// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP]]] +// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP]]] + +// ----- + // CHECK-LABEL: func @minf func @minf(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { %0 = "mhlo.minimum"(%lhs, %rhs) diff --git a/tensorflow/compiler/mlir/hlo/tests/mhlo-transform-unranked.mlir b/tensorflow/compiler/mlir/hlo/tests/hlo-transform-unranked.mlir similarity index 74% rename from tensorflow/compiler/mlir/hlo/tests/mhlo-transform-unranked.mlir rename to tensorflow/compiler/mlir/hlo/tests/hlo-transform-unranked.mlir index 01ef250efd0..ae61fc8477e 100644 --- a/tensorflow/compiler/mlir/hlo/tests/mhlo-transform-unranked.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/hlo-transform-unranked.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt -transform-unranked-hlo -split-input-file %s | FileCheck %s +// RUN: mlir-hlo-opt --transform-unranked-hlo --split-input-file %s | FileCheck %s // Check the validity of expected IR. // CHECK-LABEL: @sqr_transform_result @@ -7,7 +7,7 @@ func @sqr_transform_result(%a: tensor<*xf32>) -> tensor<*xf32> { // Flatten operand shape. %shape = shape.shape_of %a : tensor<*xf32> -> tensor %num_elements = shape.num_elements %shape : tensor -> index - %flat_shape = tensor_from_elements(%num_elements) : tensor<1xindex> + %flat_shape = tensor_from_elements %num_elements : tensor<1xindex> %flat_a = "mhlo.dynamic_reshape"(%a, %flat_shape) : (tensor<*xf32>, tensor<1xindex>) -> tensor @@ -29,7 +29,7 @@ func @sqr_transform_result(%a: tensor<*xf32>) -> tensor<*xf32> { func @sqrt(%a: tensor<*xf32>) -> tensor<*xf32> { // CHECK-NEXT: %[[SHAPE:.*]] = shape.shape_of %[[A]] : tensor<*xf32> -> tensor // CHECK-NEXT: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]] - // CHECK-NEXT: %[[FLAT_SHAPE:.*]] = tensor_from_elements(%[[NUM_ELEMENTS]]) : tensor<1xindex> + // CHECK-NEXT: %[[FLAT_SHAPE:.*]] = tensor_from_elements %[[NUM_ELEMENTS]] : tensor<1xindex> // CHECK-NEXT: %[[FLAT_A:.*]] = "mhlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor // CHECK-NEXT: %[[FLAT_B:.*]] = "mhlo.sqrt"(%[[FLAT_A]]) : (tensor) -> tensor // CHECK-NEXT: %[[B:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_B]], %[[SHAPE]]) : (tensor, tensor) -> tensor<*xf32> @@ -71,7 +71,7 @@ func @add_unranked(%a : tensor<*xf32>, %b : tensor<*xf32>) -> tensor<*xf32> { // CHECK: %[[SHAPE_B:.*]] = shape.shape_of %[[B]] // CHECK: %[[SHAPE:.*]] = shape.any %[[SHAPE_A]], %[[SHAPE_B]] // CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]] - // CHECK: %[[FLAT_SHAPE:.*]] = tensor_from_elements(%[[NUM_ELEMENTS]]) : tensor<1xindex> + // CHECK: %[[FLAT_SHAPE:.*]] = tensor_from_elements %[[NUM_ELEMENTS]] : tensor<1xindex> // CHECK: %[[FLAT_A:.*]] = "mhlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor // CHECK: %[[FLAT_B:.*]] = "mhlo.dynamic_reshape"(%[[B]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor // CHECK: %[[FLAT_RESULT:.*]] = mhlo.add %[[FLAT_A]], %[[FLAT_B]] : tensor @@ -80,3 +80,19 @@ func @add_unranked(%a : tensor<*xf32>, %b : tensor<*xf32>) -> tensor<*xf32> { %result = mhlo.add %a, %b : tensor<*xf32> return %result : tensor<*xf32> } + +// ----- + +// CHECK-LABEL: @tan +// CHECK-SAME: (%[[A:.*]]: tensor<*xf32>) -> tensor<*xf32> +func @tan(%a : tensor<*xf32>) -> tensor<*xf32> { + // CHECK: %[[SHAPE:.*]] = shape.shape_of %[[A]] : tensor<*xf32> -> tensor + // CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]] + // CHECK: %[[FLAT_SHAPE:.*]] = tensor_from_elements %[[NUM_ELEMENTS]] : tensor<1xindex> + // CHECK: %[[FLAT_A:.*]] = "mhlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor + // CHECK: %[[FLAT_B:.*]] = chlo.tan %[[FLAT_A]] : tensor + // CHECK: %[[B:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_B]], %[[SHAPE]]) : (tensor, tensor) -> tensor<*xf32> + // CHECK: return %[[B]] : tensor<*xf32> + %result = chlo.tan %a : tensor<*xf32> + return %result : tensor<*xf32> +} diff --git a/tensorflow/compiler/mlir/hlo/tests/legalize-to-std.mlir b/tensorflow/compiler/mlir/hlo/tests/legalize-to-std.mlir index abe4e872b73..404be85e05e 100644 --- a/tensorflow/compiler/mlir/hlo/tests/legalize-to-std.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/legalize-to-std.mlir @@ -51,38 +51,38 @@ func @unary_ops_float(%arg0: tensor<4xf32>) -> tensor<4xf32> { return %0 : tensor<4xf32> } -// CHECK-LABEL: func @compare_int(%arg0: tensor<4xi32>) -> (tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>) { -func @compare_int(%arg0: tensor<4xi32>) -> (tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>) { - // CHECK-NEXT: %0 = cmpi "eq", %arg0, %arg0 : tensor<4xi32> - %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> - // CHECK-NEXT: %1 = cmpi "ne", %arg0, %arg0 : tensor<4xi32> - %1 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> - // CHECK-NEXT: %2 = cmpi "slt", %arg0, %arg0 : tensor<4xi32> - %2 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> - // CHECK-NEXT: %3 = cmpi "sle", %arg0, %arg0 : tensor<4xi32> - %3 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> - // CHECK-NEXT: %4 = cmpi "sgt", %arg0, %arg0 : tensor<4xi32> - %4 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> - // CHECK-NEXT: %5 = cmpi "sge", %arg0, %arg0 : tensor<4xi32> - %5 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> +// CHECK-LABEL: func @compare_int +func @compare_int(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> (tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>) { + // CHECK-NEXT: %0 = cmpi "eq", %arg0, %arg1 : tensor<4xi32> + %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> + // CHECK-NEXT: %1 = cmpi "ne", %arg0, %arg1 : tensor<4xi32> + %1 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "NE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> + // CHECK-NEXT: %2 = cmpi "slt", %arg0, %arg1 : tensor<4xi32> + %2 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "LT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> + // CHECK-NEXT: %3 = cmpi "sle", %arg0, %arg1 : tensor<4xi32> + %3 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "LE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> + // CHECK-NEXT: %4 = cmpi "sgt", %arg0, %arg1 : tensor<4xi32> + %4 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> + // CHECK-NEXT: %5 = cmpi "sge", %arg0, %arg1 : tensor<4xi32> + %5 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> // CHECK-NEXT: return %0, %1, %2, %3, %4, %5 : tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1> return %0, %1, %2, %3, %4, %5 : tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1> } // CHECK-LABEL: func @compare_float -func @compare_float(%arg0: tensor<4xf32>) -> (tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>) { - // CHECK-NEXT: %0 = cmpf "oeq", %arg0, %arg0 : tensor<4xf32> - %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> - // CHECK-NEXT: %1 = cmpf "une", %arg0, %arg0 : tensor<4xf32> - %1 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> - // CHECK-NEXT: %2 = cmpf "olt", %arg0, %arg0 : tensor<4xf32> - %2 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> - // CHECK-NEXT: %3 = cmpf "ole", %arg0, %arg0 : tensor<4xf32> - %3 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> - // CHECK-NEXT: %4 = cmpf "ogt", %arg0, %arg0 : tensor<4xf32> - %4 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> - // CHECK-NEXT: %5 = cmpf "oge", %arg0, %arg0 : tensor<4xf32> - %5 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> +func @compare_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> (tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>) { + // CHECK-NEXT: %0 = cmpf "oeq", %arg0, %arg1 : tensor<4xf32> + %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> + // CHECK-NEXT: %1 = cmpf "une", %arg0, %arg1 : tensor<4xf32> + %1 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "NE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> + // CHECK-NEXT: %2 = cmpf "olt", %arg0, %arg1 : tensor<4xf32> + %2 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "LT"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> + // CHECK-NEXT: %3 = cmpf "ole", %arg0, %arg1 : tensor<4xf32> + %3 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "LE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> + // CHECK-NEXT: %4 = cmpf "ogt", %arg0, %arg1 : tensor<4xf32> + %4 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> + // CHECK-NEXT: %5 = cmpf "oge", %arg0, %arg1 : tensor<4xf32> + %5 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> return %0, %1, %2, %3, %4, %5: tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1> } diff --git a/tensorflow/compiler/mlir/hlo/tests/legalize-trigonometric-to-approximation.mlir b/tensorflow/compiler/mlir/hlo/tests/legalize-trigonometric-to-approximation.mlir new file mode 100644 index 00000000000..c25545ca2bd --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/tests/legalize-trigonometric-to-approximation.mlir @@ -0,0 +1,380 @@ +// RUN: mlir-hlo-opt --mhlo-legalize-trigonometric-to-approximation --split-input-file %s | FileCheck %s + +func @tanh_f64(%arg0 : f64) -> f64 { + %res = tanh %arg0 : f64 + return %res : f64 +} + +// CHECK-LABEL: @tanh_f64 +// CHECK: tanh + +// ----- + +func @tanh_f32(%arg0 : f32) -> f32 { + %res = tanh %arg0 : f32 + return %res : f32 +} + +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py + +// CHECK-LABEL: func @tanh_f32 +// CHECK-SAME: (%[[VAL_0:.*]]: f32) -> f32 +// CHECK: %[[VAL_1:.*]] = constant 4.000000e-04 : f32 +// CHECK: %[[VAL_2:.*]] = constant 7.90531111 : f32 +// CHECK: %[[VAL_3:.*]] = constant -7.90531111 : f32 +// CHECK: %[[VAL_4:.*]] = constant -2.76076837E-16 : f32 +// CHECK: %[[VAL_5:.*]] = constant 2.00018794E-13 : f32 +// CHECK: %[[VAL_6:.*]] = constant -8.60467184E-11 : f32 +// CHECK: %[[VAL_7:.*]] = constant 5.12229725E-8 : f32 +// CHECK: %[[VAL_8:.*]] = constant 1.48572235E-5 : f32 +// CHECK: %[[VAL_9:.*]] = constant 6.37261954E-4 : f32 +// CHECK: %[[VAL_10:.*]] = constant 0.00489352457 : f32 +// CHECK: %[[VAL_11:.*]] = constant 1.19825836E-6 : f32 +// CHECK: %[[VAL_12:.*]] = constant 1.18534706E-4 : f32 +// CHECK: %[[VAL_13:.*]] = constant 0.00226843474 : f32 +// CHECK: %[[VAL_14:.*]] = constant 0.00489352504 : f32 +// CHECK: %[[VAL_15:.*]] = absf %[[VAL_0]] : f32 +// CHECK: %[[VAL_16:.*]] = cmpf "olt", %[[VAL_15]], %[[VAL_1]] : f32 +// CHECK: %[[VAL_17:.*]] = cmpf "ule", %[[VAL_0]], %[[VAL_2]] : f32 +// CHECK: %[[VAL_18:.*]] = select %[[VAL_17]], %[[VAL_0]], %[[VAL_2]] : f32 +// CHECK: %[[VAL_19:.*]] = cmpf "uge", %[[VAL_18]], %[[VAL_3]] : f32 +// CHECK: %[[VAL_20:.*]] = select %[[VAL_19]], %[[VAL_18]], %[[VAL_3]] : f32 +// CHECK: %[[VAL_21:.*]] = mulf %[[VAL_20]], %[[VAL_20]] : f32 +// CHECK: %[[VAL_22:.*]] = mulf %[[VAL_21]], %[[VAL_4]] : f32 +// CHECK: %[[VAL_23:.*]] = addf %[[VAL_22]], %[[VAL_5]] : f32 +// CHECK: %[[VAL_24:.*]] = mulf %[[VAL_21]], %[[VAL_23]] : f32 +// CHECK: %[[VAL_25:.*]] = addf %[[VAL_24]], %[[VAL_6]] : f32 +// CHECK: %[[VAL_26:.*]] = mulf %[[VAL_21]], %[[VAL_25]] : f32 +// CHECK: %[[VAL_27:.*]] = addf %[[VAL_26]], %[[VAL_7]] : f32 +// CHECK: %[[VAL_28:.*]] = mulf %[[VAL_21]], %[[VAL_27]] : f32 +// CHECK: %[[VAL_29:.*]] = addf %[[VAL_28]], %[[VAL_8]] : f32 +// CHECK: %[[VAL_30:.*]] = mulf %[[VAL_21]], %[[VAL_29]] : f32 +// CHECK: %[[VAL_31:.*]] = addf %[[VAL_30]], %[[VAL_9]] : f32 +// CHECK: %[[VAL_32:.*]] = mulf %[[VAL_21]], %[[VAL_31]] : f32 +// CHECK: %[[VAL_33:.*]] = addf %[[VAL_32]], %[[VAL_10]] : f32 +// CHECK: %[[VAL_34:.*]] = mulf %[[VAL_20]], %[[VAL_33]] : f32 +// CHECK: %[[VAL_35:.*]] = mulf %[[VAL_21]], %[[VAL_11]] : f32 +// CHECK: %[[VAL_36:.*]] = addf %[[VAL_35]], %[[VAL_12]] : f32 +// CHECK: %[[VAL_37:.*]] = mulf %[[VAL_21]], %[[VAL_36]] : f32 +// CHECK: %[[VAL_38:.*]] = addf %[[VAL_37]], %[[VAL_13]] : f32 +// CHECK: %[[VAL_39:.*]] = mulf %[[VAL_21]], %[[VAL_38]] : f32 +// CHECK: %[[VAL_40:.*]] = addf %[[VAL_39]], %[[VAL_14]] : f32 +// CHECK: %[[VAL_41:.*]] = divf %[[VAL_34]], %[[VAL_40]] : f32 +// CHECK: %[[VAL_42:.*]] = select %[[VAL_16]], %[[VAL_0]], %[[VAL_41]] : f32 +// CHECK: return %[[VAL_42]] : f32 + +// ----- + +func @tanh_f16(%arg0 : f16) -> f16 { + %res = tanh %arg0 : f16 + return %res : f16 +} + +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py + +// CHECK-LABEL: func @tanh_f16 +// CHECK-SAME: (%[[VAL_0:.*]]: f16) -> f16 +// CHECK: %[[VAL_1:.*]] = constant 4.000000e-04 : f32 +// CHECK: %[[VAL_2:.*]] = constant 7.90531111 : f32 +// CHECK: %[[VAL_3:.*]] = constant -7.90531111 : f32 +// CHECK: %[[VAL_4:.*]] = constant -2.76076837E-16 : f32 +// CHECK: %[[VAL_5:.*]] = constant 2.00018794E-13 : f32 +// CHECK: %[[VAL_6:.*]] = constant -8.60467184E-11 : f32 +// CHECK: %[[VAL_7:.*]] = constant 5.12229725E-8 : f32 +// CHECK: %[[VAL_8:.*]] = constant 1.48572235E-5 : f32 +// CHECK: %[[VAL_9:.*]] = constant 6.37261954E-4 : f32 +// CHECK: %[[VAL_10:.*]] = constant 0.00489352457 : f32 +// CHECK: %[[VAL_11:.*]] = constant 1.19825836E-6 : f32 +// CHECK: %[[VAL_12:.*]] = constant 1.18534706E-4 : f32 +// CHECK: %[[VAL_13:.*]] = constant 0.00226843474 : f32 +// CHECK: %[[VAL_14:.*]] = constant 0.00489352504 : f32 +// CHECK: %[[VAL_15:.*]] = fpext %[[VAL_0]] : f16 to f32 +// CHECK: %[[VAL_16:.*]] = absf %[[VAL_15]] : f32 +// CHECK: %[[VAL_17:.*]] = cmpf "olt", %[[VAL_16]], %[[VAL_1]] : f32 +// CHECK: %[[VAL_18:.*]] = cmpf "ule", %[[VAL_15]], %[[VAL_2]] : f32 +// CHECK: %[[VAL_19:.*]] = select %[[VAL_18]], %[[VAL_15]], %[[VAL_2]] : f32 +// CHECK: %[[VAL_20:.*]] = cmpf "uge", %[[VAL_19]], %[[VAL_3]] : f32 +// CHECK: %[[VAL_21:.*]] = select %[[VAL_20]], %[[VAL_19]], %[[VAL_3]] : f32 +// CHECK: %[[VAL_22:.*]] = mulf %[[VAL_21]], %[[VAL_21]] : f32 +// CHECK: %[[VAL_23:.*]] = mulf %[[VAL_22]], %[[VAL_4]] : f32 +// CHECK: %[[VAL_24:.*]] = addf %[[VAL_23]], %[[VAL_5]] : f32 +// CHECK: %[[VAL_25:.*]] = mulf %[[VAL_22]], %[[VAL_24]] : f32 +// CHECK: %[[VAL_26:.*]] = addf %[[VAL_25]], %[[VAL_6]] : f32 +// CHECK: %[[VAL_27:.*]] = mulf %[[VAL_22]], %[[VAL_26]] : f32 +// CHECK: %[[VAL_28:.*]] = addf %[[VAL_27]], %[[VAL_7]] : f32 +// CHECK: %[[VAL_29:.*]] = mulf %[[VAL_22]], %[[VAL_28]] : f32 +// CHECK: %[[VAL_30:.*]] = addf %[[VAL_29]], %[[VAL_8]] : f32 +// CHECK: %[[VAL_31:.*]] = mulf %[[VAL_22]], %[[VAL_30]] : f32 +// CHECK: %[[VAL_32:.*]] = addf %[[VAL_31]], %[[VAL_9]] : f32 +// CHECK: %[[VAL_33:.*]] = mulf %[[VAL_22]], %[[VAL_32]] : f32 +// CHECK: %[[VAL_34:.*]] = addf %[[VAL_33]], %[[VAL_10]] : f32 +// CHECK: %[[VAL_35:.*]] = mulf %[[VAL_21]], %[[VAL_34]] : f32 +// CHECK: %[[VAL_36:.*]] = mulf %[[VAL_22]], %[[VAL_11]] : f32 +// CHECK: %[[VAL_37:.*]] = addf %[[VAL_36]], %[[VAL_12]] : f32 +// CHECK: %[[VAL_38:.*]] = mulf %[[VAL_22]], %[[VAL_37]] : f32 +// CHECK: %[[VAL_39:.*]] = addf %[[VAL_38]], %[[VAL_13]] : f32 +// CHECK: %[[VAL_40:.*]] = mulf %[[VAL_22]], %[[VAL_39]] : f32 +// CHECK: %[[VAL_41:.*]] = addf %[[VAL_40]], %[[VAL_14]] : f32 +// CHECK: %[[VAL_42:.*]] = divf %[[VAL_35]], %[[VAL_41]] : f32 +// CHECK: %[[VAL_43:.*]] = select %[[VAL_17]], %[[VAL_15]], %[[VAL_42]] : f32 +// CHECK: %[[VAL_44:.*]] = fptrunc %[[VAL_43]] : f32 to f16 +// CHECK: return %[[VAL_44]] : f16 + +// ----- + +// CHECK-LABEL: @atan2_f64 +func @atan2_f64(%arg0 : f64, %arg1 : f64) -> f64 { + // CHECK: atan2 + %res = atan2 %arg0, %arg1 : f64 + return %res : f64 +} + +// ----- + +// CHECK-LABEL: func @atan2_f32 +// CHECK-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32) -> f32 +func @atan2_f32(%arg0 : f32, %arg1 : f32) -> f32 { + // CHECK: %[[CST:.*]] = constant 0.0027856871 : f32 + // CHECK: %[[CST_0:.*]] = constant -1.586600e-02 : f32 + // CHECK: %[[CST_1:.*]] = constant 0.042472221 : f32 + // CHECK: %[[CST_2:.*]] = constant -0.0749753043 : f32 + // CHECK: %[[CST_3:.*]] = constant 0.106448799 : f32 + // CHECK: %[[CST_4:.*]] = constant -0.142070308 : f32 + // CHECK: %[[CST_5:.*]] = constant 0.199934542 : f32 + // CHECK: %[[CST_6:.*]] = constant -0.333331466 : f32 + // CHECK: %[[CST_7:.*]] = constant 1.57079637 : f32 + // CHECK: %[[CST_8:.*]] = constant 0.000000e+00 : f32 + // CHECK: %[[CST_9:.*]] = constant 3.14159274 : f32 + // CHECK: %[[CST_10:.*]] = constant 0x7FC00000 : f32 + // CHECK: %[[CST_11:.*]] = constant 2.3561945 : f32 + // CHECK: %[[CST_12:.*]] = constant 0.785398185 : f32 + // CHECK: %[[CST_13:.*]] = constant 0x7F800000 : f32 + // CHECK: %[[VAL_0:.*]] = absf %[[ARG1]] : f32 + // CHECK: %[[VAL_1:.*]] = absf %[[ARG0]] : f32 + // CHECK: %[[VAL_2:.*]] = cmpf "ole", %[[VAL_0]], %[[VAL_1]] : f32 + // CHECK: %[[VAL_3:.*]] = select %[[VAL_2]], %[[VAL_0]], %[[VAL_1]] : f32 + // CHECK: %[[VAL_4:.*]] = select %[[VAL_2]], %[[VAL_1]], %[[VAL_0]] : f32 + // CHECK: %[[VAL_5:.*]] = divf %[[VAL_3]], %[[VAL_4]] : f32 + // CHECK: %[[VAL_6:.*]] = mulf %[[VAL_5]], %[[VAL_5]] : f32 + // CHECK: %[[VAL_7:.*]] = mulf %[[CST]], %[[VAL_6]] : f32 + // CHECK: %[[VAL_8:.*]] = addf %[[VAL_7]], %[[CST_0]] : f32 + // CHECK: %[[VAL_9:.*]] = mulf %[[VAL_8]], %[[VAL_6]] : f32 + // CHECK: %[[VAL_10:.*]] = addf %[[VAL_9]], %[[CST_1]] : f32 + // CHECK: %[[VAL_11:.*]] = mulf %[[VAL_10]], %[[VAL_6]] : f32 + // CHECK: %[[VAL_12:.*]] = addf %[[VAL_11]], %[[CST_2]] : f32 + // CHECK: %[[VAL_13:.*]] = mulf %[[VAL_12]], %[[VAL_6]] : f32 + // CHECK: %[[VAL_14:.*]] = addf %[[VAL_13]], %[[CST_3]] : f32 + // CHECK: %[[VAL_15:.*]] = mulf %[[VAL_14]], %[[VAL_6]] : f32 + // CHECK: %[[VAL_16:.*]] = addf %[[VAL_15]], %[[CST_4]] : f32 + // CHECK: %[[VAL_17:.*]] = mulf %[[VAL_16]], %[[VAL_6]] : f32 + // CHECK: %[[VAL_18:.*]] = addf %[[VAL_17]], %[[CST_5]] : f32 + // CHECK: %[[VAL_19:.*]] = mulf %[[VAL_18]], %[[VAL_6]] : f32 + // CHECK: %[[VAL_20:.*]] = addf %[[VAL_19]], %[[CST_6]] : f32 + // CHECK: %[[VAL_21:.*]] = mulf %[[VAL_20]], %[[VAL_6]] : f32 + // CHECK: %[[VAL_22:.*]] = mulf %[[VAL_21]], %[[VAL_5]] : f32 + // CHECK: %[[VAL_23:.*]] = addf %[[VAL_22]], %[[VAL_5]] : f32 + // CHECK: %[[VAL_24:.*]] = subf %[[CST_7]], %[[VAL_23]] : f32 + // CHECK: %[[VAL_25:.*]] = select %[[VAL_2]], %[[VAL_24]], %[[VAL_23]] : f32 + // CHECK: %[[VAL_26:.*]] = cmpf "olt", %[[ARG1]], %[[CST_8]] : f32 + // CHECK: %[[VAL_27:.*]] = subf %[[CST_9]], %[[VAL_25]] : f32 + // CHECK: %[[VAL_28:.*]] = select %[[VAL_26]], %[[VAL_27]], %[[VAL_25]] : f32 + // CHECK: %[[VAL_29:.*]] = select %[[VAL_26]], %[[CST_9]], %[[CST_8]] : f32 + // CHECK: %[[VAL_30:.*]] = cmpf "oeq", %[[ARG0]], %[[CST_8]] : f32 + // CHECK: %[[VAL_31:.*]] = select %[[VAL_30]], %[[VAL_29]], %[[VAL_28]] : f32 + // CHECK: %[[VAL_32:.*]] = cmpf "uno", %[[ARG0]], %[[ARG1]] : f32 + // CHECK: %[[VAL_35:.*]] = select %[[VAL_32]], %[[CST_10]], %[[VAL_31]] : f32 + // CHECK: %[[VAL_36:.*]] = select %[[VAL_26]], %[[CST_11]], %[[CST_12]] : f32 + // CHECK: %[[VAL_37:.*]] = cmpf "oeq", %[[ARG1]], %[[CST_13]] : f32 + // CHECK: %[[VAL_38:.*]] = cmpf "oeq", %[[ARG0]], %[[CST_13]] : f32 + // CHECK: %[[VAL_39:.*]] = and %[[VAL_37]], %[[VAL_38]] : i1 + // CHECK: %[[VAL_40:.*]] = select %[[VAL_39]], %[[VAL_36]], %[[VAL_35]] : f32 + // CHECK: %[[VAL_41:.*]] = copysign %[[VAL_40]], %[[ARG0]] : f32 + // CHECK: return %[[VAL_41]] : f32 + %res = atan2 %arg0, %arg1 : f32 + return %res : f32 +} + +// ----- + +// CHECK-LABEL: @atan2_f16 +// CHECK-SAME: (%[[ARG0:.*]]: f16, %[[ARG1:.*]]: f16) -> f16 +func @atan2_f16(%arg0 : f16, %arg1 : f16) -> f16 { + // CHECK: %[[CST:.*]] = constant 0.0027856871 : f32 + // CHECK: %[[CST_0:.*]] = constant -1.586600e-02 : f32 + // CHECK: %[[CST_1:.*]] = constant 0.042472221 : f32 + // CHECK: %[[CST_2:.*]] = constant -0.0749753043 : f32 + // CHECK: %[[CST_3:.*]] = constant 0.106448799 : f32 + // CHECK: %[[CST_4:.*]] = constant -0.142070308 : f32 + // CHECK: %[[CST_5:.*]] = constant 0.199934542 : f32 + // CHECK: %[[CST_6:.*]] = constant -0.333331466 : f32 + // CHECK: %[[CST_7:.*]] = constant 1.57079637 : f32 + // CHECK: %[[CST_8:.*]] = constant 0.000000e+00 : f32 + // CHECK: %[[CST_9:.*]] = constant 3.14159274 : f32 + // CHECK: %[[CST_10:.*]] = constant 0x7FC00000 : f32 + // CHECK: %[[CST_11:.*]] = constant 2.3561945 : f32 + // CHECK: %[[CST_12:.*]] = constant 0.785398185 : f32 + // CHECK: %[[CST_13:.*]] = constant 0x7F800000 : f32 + // CHECK: %[[VAL_0:.*]] = fpext %[[ARG0]] : f16 to f32 + // CHECK: %[[VAL_1:.*]] = fpext %[[ARG1]] : f16 to f32 + // CHECK: %[[VAL_2:.*]] = absf %[[VAL_1]] : f32 + // CHECK: %[[VAL_3:.*]] = absf %[[VAL_0]] : f32 + // CHECK: %[[VAL_4:.*]] = cmpf "ole", %[[VAL_2]], %[[VAL_3]] : f32 + // CHECK: %[[VAL_5:.*]] = select %[[VAL_4]], %[[VAL_2]], %[[VAL_3]] : f32 + // CHECK: %[[VAL_6:.*]] = select %[[VAL_4]], %[[VAL_3]], %[[VAL_2]] : f32 + // CHECK: %[[VAL_7:.*]] = divf %[[VAL_5]], %[[VAL_6]] : f32 + // CHECK: %[[VAL_8:.*]] = mulf %[[VAL_7]], %[[VAL_7]] : f32 + // CHECK: %[[VAL_9:.*]] = mulf %[[CST]], %[[VAL_8]] : f32 + // CHECK: %[[VAL_10:.*]] = addf %[[VAL_9]], %[[CST_0]] : f32 + // CHECK: %[[VAL_11:.*]] = mulf %[[VAL_10]], %[[VAL_8]] : f32 + // CHECK: %[[VAL_12:.*]] = addf %[[VAL_11]], %[[CST_1]] : f32 + // CHECK: %[[VAL_13:.*]] = mulf %[[VAL_12]], %[[VAL_8]] : f32 + // CHECK: %[[VAL_14:.*]] = addf %[[VAL_13]], %[[CST_2]] : f32 + // CHECK: %[[VAL_15:.*]] = mulf %[[VAL_14]], %[[VAL_8]] : f32 + // CHECK: %[[VAL_16:.*]] = addf %[[VAL_15]], %[[CST_3]] : f32 + // CHECK: %[[VAL_17:.*]] = mulf %[[VAL_16]], %[[VAL_8]] : f32 + // CHECK: %[[VAL_18:.*]] = addf %[[VAL_17]], %[[CST_4]] : f32 + // CHECK: %[[VAL_19:.*]] = mulf %[[VAL_18]], %[[VAL_8]] : f32 + // CHECK: %[[VAL_20:.*]] = addf %[[VAL_19]], %[[CST_5]] : f32 + // CHECK: %[[VAL_21:.*]] = mulf %[[VAL_20]], %[[VAL_8]] : f32 + // CHECK: %[[VAL_22:.*]] = addf %[[VAL_21]], %[[CST_6]] : f32 + // CHECK: %[[VAL_23:.*]] = mulf %[[VAL_22]], %[[VAL_8]] : f32 + // CHECK: %[[VAL_24:.*]] = mulf %[[VAL_23]], %[[VAL_7]] : f32 + // CHECK: %[[VAL_25:.*]] = addf %[[VAL_24]], %[[VAL_7]] : f32 + // CHECK: %[[VAL_26:.*]] = subf %[[CST_7]], %[[VAL_25]] : f32 + // CHECK: %[[VAL_27:.*]] = select %[[VAL_4]], %[[VAL_26]], %[[VAL_25]] : f32 + // CHECK: %[[VAL_28:.*]] = cmpf "olt", %[[VAL_1]], %[[CST_8]] : f32 + // CHECK: %[[VAL_29:.*]] = subf %[[CST_9]], %[[VAL_27]] : f32 + // CHECK: %[[VAL_30:.*]] = select %[[VAL_28]], %[[VAL_29]], %[[VAL_27]] : f32 + // CHECK: %[[VAL_31:.*]] = select %[[VAL_28]], %[[CST_9]], %[[CST_8]] : f32 + // CHECK: %[[VAL_32:.*]] = cmpf "oeq", %[[VAL_0]], %[[CST_8]] : f32 + // CHECK: %[[VAL_33:.*]] = select %[[VAL_32]], %[[VAL_31]], %[[VAL_30]] : f32 + // CHECK: %[[VAL_34:.*]] = cmpf "uno", %[[VAL_0]], %[[VAL_1]] : f32 + // CHECK: %[[VAL_37:.*]] = select %[[VAL_34]], %[[CST_10]], %[[VAL_33]] : f32 + // CHECK: %[[VAL_38:.*]] = select %[[VAL_28]], %[[CST_11]], %[[CST_12]] : f32 + // CHECK: %[[VAL_39:.*]] = cmpf "oeq", %[[VAL_1]], %[[CST_13]] : f32 + // CHECK: %[[VAL_40:.*]] = cmpf "oeq", %[[VAL_0]], %[[CST_13]] : f32 + // CHECK: %[[VAL_41:.*]] = and %[[VAL_39]], %[[VAL_40]] : i1 + // CHECK: %[[VAL_42:.*]] = select %[[VAL_41]], %[[VAL_38]], %[[VAL_37]] : f32 + // CHECK: %[[VAL_43:.*]] = copysign %[[VAL_42]], %[[VAL_0]] : f32 + // CHECK: %[[VAL_44:.*]] = fptrunc %[[VAL_43]] : f32 to f16 + // CHECK: return %[[VAL_44]] : f16 + %res = atan2 %arg0, %arg1 : f16 + return %res : f16 +} + +// ----- + +// CHECK-LABEL: @atan_f64 +func @atan_f64(%arg : f64) -> f64 { + // CHECK: atan + %res = atan %arg : f64 + return %res : f64 +} + +// ----- + +// CHECK-LABEL: func @atan_f32 +// CHECK-SAME: (%[[ARG:.*]]: f32) -> f32 +func @atan_f32(%arg : f32) -> f32 { + // CHECK: %[[CST:.*]] = constant 1.000000e+00 : f32 + // CHECK: %[[CST_0:.*]] = constant 0.0027856871 : f32 + // CHECK: %[[CST_1:.*]] = constant -1.586600e-02 : f32 + // CHECK: %[[CST_2:.*]] = constant 0.042472221 : f32 + // CHECK: %[[CST_3:.*]] = constant -0.0749753043 : f32 + // CHECK: %[[CST_4:.*]] = constant 0.106448799 : f32 + // CHECK: %[[CST_5:.*]] = constant -0.142070308 : f32 + // CHECK: %[[CST_6:.*]] = constant 0.199934542 : f32 + // CHECK: %[[CST_7:.*]] = constant -0.333331466 : f32 + // CHECK: %[[CST_8:.*]] = constant 1.57079637 : f32 + // CHECK: %[[CST_9:.*]] = constant 0.000000e+00 : f32 + // CHECK: %[[CST_10:.*]] = constant 0x7FC00000 : f32 + // CHECK: %[[VAL_0:.*]] = absf %[[CST]] : f32 + // CHECK: %[[VAL_1:.*]] = absf %arg0 : f32 + // CHECK: %[[VAL_2:.*]] = cmpf "ole", %[[VAL_0]], %[[VAL_1]] : f32 + // CHECK: %[[VAL_3:.*]] = select %[[VAL_2]], %[[VAL_0]], %[[VAL_1]] : f32 + // CHECK: %[[VAL_4:.*]] = select %[[VAL_2]], %[[VAL_1]], %[[VAL_0]] : f32 + // CHECK: %[[VAL_5:.*]] = divf %[[VAL_3]], %[[VAL_4]] : f32 + // CHECK: %[[VAL_6:.*]] = mulf %[[VAL_5]], %[[VAL_5]] : f32 + // CHECK: %[[VAL_7:.*]] = mulf %[[CST_0]], %[[VAL_6]] : f32 + // CHECK: %[[VAL_8:.*]] = addf %[[VAL_7]], %[[CST_1]] : f32 + // CHECK: %[[VAL_9:.*]] = mulf %[[VAL_8]], %[[VAL_6]] : f32 + // CHECK: %[[VAL_10:.*]] = addf %[[VAL_9]], %[[CST_2]] : f32 + // CHECK: %[[VAL_11:.*]] = mulf %[[VAL_10]], %[[VAL_6]] : f32 + // CHECK: %[[VAL_12:.*]] = addf %[[VAL_11]], %[[CST_3]] : f32 + // CHECK: %[[VAL_13:.*]] = mulf %[[VAL_12]], %[[VAL_6]] : f32 + // CHECK: %[[VAL_14:.*]] = addf %[[VAL_13]], %[[CST_4]] : f32 + // CHECK: %[[VAL_15:.*]] = mulf %[[VAL_14]], %[[VAL_6]] : f32 + // CHECK: %[[VAL_16:.*]] = addf %[[VAL_15]], %[[CST_5]] : f32 + // CHECK: %[[VAL_17:.*]] = mulf %[[VAL_16]], %[[VAL_6]] : f32 + // CHECK: %[[VAL_18:.*]] = addf %[[VAL_17]], %[[CST_6]] : f32 + // CHECK: %[[VAL_19:.*]] = mulf %[[VAL_18]], %[[VAL_6]] : f32 + // CHECK: %[[VAL_20:.*]] = addf %[[VAL_19]], %[[CST_7]] : f32 + // CHECK: %[[VAL_21:.*]] = mulf %[[VAL_20]], %[[VAL_6]] : f32 + // CHECK: %[[VAL_22:.*]] = mulf %[[VAL_21]], %[[VAL_5]] : f32 + // CHECK: %[[VAL_23:.*]] = addf %[[VAL_22]], %[[VAL_5]] : f32 + // CHECK: %[[VAL_24:.*]] = subf %[[CST_8]], %[[VAL_23]] : f32 + // CHECK: %[[VAL_25:.*]] = select %[[VAL_2]], %[[VAL_24]], %[[VAL_23]] : f32 + // CHECK: %[[VAL_26:.*]] = cmpf "oeq", %arg0, %[[CST_9]] : f32 + // CHECK: %[[VAL_27:.*]] = select %[[VAL_26]], %[[CST_9]], %[[VAL_25]] : f32 + // CHECK: %[[VAL_28:.*]] = cmpf "uno", %arg0, %[[CST]] : f32 + // CHECK: %[[VAL_29:.*]] = select %[[VAL_28]], %[[CST_10]], %[[VAL_27]] : f32 + // CHECK: %[[VAL_30:.*]] = copysign %[[VAL_29]], %arg0 : f32 + // CHECK: return %[[VAL_30]] : f32 + %res = atan %arg : f32 + return %res : f32 +} + +// ----- + +// CHECK-LABEL: @atan_f16 +// CHECK-SAME: (%[[ARG:.*]]: f16) -> f16 +func @atan_f16(%arg : f16) -> f16 { + // CHECK: %[[CST:.*]] = constant 1.000000e+00 : f32 + // CHECK: %[[CST_0:.*]] = constant 0.0027856871 : f32 + // CHECK: %[[CST_1:.*]] = constant -1.586600e-02 : f32 + // CHECK: %[[CST_2:.*]] = constant 0.042472221 : f32 + // CHECK: %[[CST_3:.*]] = constant -0.0749753043 : f32 + // CHECK: %[[CST_4:.*]] = constant 0.106448799 : f32 + // CHECK: %[[CST_5:.*]] = constant -0.142070308 : f32 + // CHECK: %[[CST_6:.*]] = constant 0.199934542 : f32 + // CHECK: %[[CST_7:.*]] = constant -0.333331466 : f32 + // CHECK: %[[CST_8:.*]] = constant 1.57079637 : f32 + // CHECK: %[[CST_9:.*]] = constant 0.000000e+00 : f32 + // CHECK: %[[CST_10:.*]] = constant 0x7FC00000 : f32 + // CHECK: %[[VAL_0:.*]] = fpext %arg0 : f16 to f32 + // CHECK: %[[VAL_1:.*]] = absf %[[CST]] : f32 + // CHECK: %[[VAL_2:.*]] = absf %[[VAL_0]] : f32 + // CHECK: %[[VAL_3:.*]] = cmpf "ole", %[[VAL_1]], %[[VAL_2]] : f32 + // CHECK: %[[VAL_4:.*]] = select %[[VAL_3]], %[[VAL_1]], %[[VAL_2]] : f32 + // CHECK: %[[VAL_5:.*]] = select %[[VAL_3]], %[[VAL_2]], %[[VAL_1]] : f32 + // CHECK: %[[VAL_6:.*]] = divf %[[VAL_4]], %[[VAL_5]] : f32 + // CHECK: %[[VAL_7:.*]] = mulf %[[VAL_6]], %[[VAL_6]] : f32 + // CHECK: %[[VAL_8:.*]] = mulf %[[CST_0]], %[[VAL_7]] : f32 + // CHECK: %[[VAL_9:.*]] = addf %[[VAL_8]], %[[CST_1]] : f32 + // CHECK: %[[VAL_10:.*]] = mulf %[[VAL_9]], %[[VAL_7]] : f32 + // CHECK: %[[VAL_11:.*]] = addf %[[VAL_10]], %[[CST_2]] : f32 + // CHECK: %[[VAL_12:.*]] = mulf %[[VAL_11]], %[[VAL_7]] : f32 + // CHECK: %[[VAL_13:.*]] = addf %[[VAL_12]], %[[CST_3]] : f32 + // CHECK: %[[VAL_14:.*]] = mulf %[[VAL_13]], %[[VAL_7]] : f32 + // CHECK: %[[VAL_15:.*]] = addf %[[VAL_14]], %[[CST_4]] : f32 + // CHECK: %[[VAL_16:.*]] = mulf %[[VAL_15]], %[[VAL_7]] : f32 + // CHECK: %[[VAL_17:.*]] = addf %[[VAL_16]], %[[CST_5]] : f32 + // CHECK: %[[VAL_18:.*]] = mulf %[[VAL_17]], %[[VAL_7]] : f32 + // CHECK: %[[VAL_19:.*]] = addf %[[VAL_18]], %[[CST_6]] : f32 + // CHECK: %[[VAL_20:.*]] = mulf %[[VAL_19]], %[[VAL_7]] : f32 + // CHECK: %[[VAL_21:.*]] = addf %[[VAL_20]], %[[CST_7]] : f32 + // CHECK: %[[VAL_22:.*]] = mulf %[[VAL_21]], %[[VAL_7]] : f32 + // CHECK: %[[VAL_23:.*]] = mulf %[[VAL_22]], %[[VAL_6]] : f32 + // CHECK: %[[VAL_24:.*]] = addf %[[VAL_23]], %[[VAL_6]] : f32 + // CHECK: %[[VAL_25:.*]] = subf %[[CST_8]], %[[VAL_24]] : f32 + // CHECK: %[[VAL_26:.*]] = select %[[VAL_3]], %[[VAL_25]], %[[VAL_24]] : f32 + // CHECK: %[[VAL_27:.*]] = cmpf "oeq", %[[VAL_0]], %[[CST_9]] : f32 + // CHECK: %[[VAL_28:.*]] = select %[[VAL_27]], %[[CST_9]], %[[VAL_26]] : f32 + // CHECK: %[[VAL_29:.*]] = cmpf "uno", %[[VAL_0]], %[[CST]] : f32 + // CHECK: %[[VAL_30:.*]] = select %[[VAL_29]], %[[CST_10]], %[[VAL_28]] : f32 + // CHECK: %[[VAL_31:.*]] = copysign %[[VAL_30]], %[[VAL_0]] : f32 + // CHECK: %[[VAL_32:.*]] = fptrunc %[[VAL_31]] : f32 to f16 + // CHECK: return %[[VAL_32]] : f16 + %res = atan %arg : f16 + return %res : f16 +} diff --git a/tensorflow/compiler/mlir/hlo/tests/legalize_tanh_to_approximation.mlir b/tensorflow/compiler/mlir/hlo/tests/legalize_tanh_to_approximation.mlir deleted file mode 100644 index aa834d36ac4..00000000000 --- a/tensorflow/compiler/mlir/hlo/tests/legalize_tanh_to_approximation.mlir +++ /dev/null @@ -1,125 +0,0 @@ -// RUN: mlir-hlo-opt -mhlo-legalize-tanh-to-approximation -split-input-file %s | FileCheck %s - -func @tanh_f64(%arg0 : f64) -> f64 { - %res = tanh %arg0 : f64 - return %res : f64 -} - -// CHECK-LABEL: @tanh_f64 -// CHECK: tanh - -// ----- - -func @tanh_f32(%arg0 : f32) -> f32 { - %res = tanh %arg0 : f32 - return %res : f32 -} - -// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py - -// CHECK-LABEL: func @tanh_f32( -// CHECK-SAME: %[[VAL_0:.*]]: f32) -> f32 { -// CHECK: %[[VAL_1:.*]] = constant 4.000000e-04 : f32 -// CHECK: %[[VAL_2:.*]] = constant 7.90531111 : f32 -// CHECK: %[[VAL_3:.*]] = constant -7.90531111 : f32 -// CHECK: %[[VAL_4:.*]] = constant -2.76076837E-16 : f32 -// CHECK: %[[VAL_5:.*]] = constant 2.00018794E-13 : f32 -// CHECK: %[[VAL_6:.*]] = constant -8.60467184E-11 : f32 -// CHECK: %[[VAL_7:.*]] = constant 5.12229725E-8 : f32 -// CHECK: %[[VAL_8:.*]] = constant 1.48572235E-5 : f32 -// CHECK: %[[VAL_9:.*]] = constant 6.37261954E-4 : f32 -// CHECK: %[[VAL_10:.*]] = constant 0.00489352457 : f32 -// CHECK: %[[VAL_11:.*]] = constant 1.19825836E-6 : f32 -// CHECK: %[[VAL_12:.*]] = constant 1.18534706E-4 : f32 -// CHECK: %[[VAL_13:.*]] = constant 0.00226843474 : f32 -// CHECK: %[[VAL_14:.*]] = constant 0.00489352504 : f32 -// CHECK: %[[VAL_15:.*]] = absf %[[VAL_0]] : f32 -// CHECK: %[[VAL_16:.*]] = cmpf "olt", %[[VAL_15]], %[[VAL_1]] : f32 -// CHECK: %[[VAL_17:.*]] = cmpf "ule", %[[VAL_0]], %[[VAL_2]] : f32 -// CHECK: %[[VAL_18:.*]] = select %[[VAL_17]], %[[VAL_0]], %[[VAL_2]] : f32 -// CHECK: %[[VAL_19:.*]] = cmpf "uge", %[[VAL_18]], %[[VAL_3]] : f32 -// CHECK: %[[VAL_20:.*]] = select %[[VAL_19]], %[[VAL_18]], %[[VAL_3]] : f32 -// CHECK: %[[VAL_21:.*]] = mulf %[[VAL_20]], %[[VAL_20]] : f32 -// CHECK: %[[VAL_22:.*]] = mulf %[[VAL_21]], %[[VAL_4]] : f32 -// CHECK: %[[VAL_23:.*]] = addf %[[VAL_22]], %[[VAL_5]] : f32 -// CHECK: %[[VAL_24:.*]] = mulf %[[VAL_21]], %[[VAL_23]] : f32 -// CHECK: %[[VAL_25:.*]] = addf %[[VAL_24]], %[[VAL_6]] : f32 -// CHECK: %[[VAL_26:.*]] = mulf %[[VAL_21]], %[[VAL_25]] : f32 -// CHECK: %[[VAL_27:.*]] = addf %[[VAL_26]], %[[VAL_7]] : f32 -// CHECK: %[[VAL_28:.*]] = mulf %[[VAL_21]], %[[VAL_27]] : f32 -// CHECK: %[[VAL_29:.*]] = addf %[[VAL_28]], %[[VAL_8]] : f32 -// CHECK: %[[VAL_30:.*]] = mulf %[[VAL_21]], %[[VAL_29]] : f32 -// CHECK: %[[VAL_31:.*]] = addf %[[VAL_30]], %[[VAL_9]] : f32 -// CHECK: %[[VAL_32:.*]] = mulf %[[VAL_21]], %[[VAL_31]] : f32 -// CHECK: %[[VAL_33:.*]] = addf %[[VAL_32]], %[[VAL_10]] : f32 -// CHECK: %[[VAL_34:.*]] = mulf %[[VAL_20]], %[[VAL_33]] : f32 -// CHECK: %[[VAL_35:.*]] = mulf %[[VAL_21]], %[[VAL_11]] : f32 -// CHECK: %[[VAL_36:.*]] = addf %[[VAL_35]], %[[VAL_12]] : f32 -// CHECK: %[[VAL_37:.*]] = mulf %[[VAL_21]], %[[VAL_36]] : f32 -// CHECK: %[[VAL_38:.*]] = addf %[[VAL_37]], %[[VAL_13]] : f32 -// CHECK: %[[VAL_39:.*]] = mulf %[[VAL_21]], %[[VAL_38]] : f32 -// CHECK: %[[VAL_40:.*]] = addf %[[VAL_39]], %[[VAL_14]] : f32 -// CHECK: %[[VAL_41:.*]] = divf %[[VAL_34]], %[[VAL_40]] : f32 -// CHECK: %[[VAL_42:.*]] = select %[[VAL_16]], %[[VAL_0]], %[[VAL_41]] : f32 -// CHECK: return %[[VAL_42]] : f32 -// CHECK: } - -// ----- - -func @tanh_f16(%arg0 : f16) -> f16 { - %res = tanh %arg0 : f16 - return %res : f16 -} - -// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py - -// CHECK-LABEL: func @tanh_f16( -// CHECK-SAME: %[[VAL_0:.*]]: f16) -> f16 { -// CHECK: %[[VAL_1:.*]] = constant 4.000000e-04 : f32 -// CHECK: %[[VAL_2:.*]] = constant 7.90531111 : f32 -// CHECK: %[[VAL_3:.*]] = constant -7.90531111 : f32 -// CHECK: %[[VAL_4:.*]] = constant -2.76076837E-16 : f32 -// CHECK: %[[VAL_5:.*]] = constant 2.00018794E-13 : f32 -// CHECK: %[[VAL_6:.*]] = constant -8.60467184E-11 : f32 -// CHECK: %[[VAL_7:.*]] = constant 5.12229725E-8 : f32 -// CHECK: %[[VAL_8:.*]] = constant 1.48572235E-5 : f32 -// CHECK: %[[VAL_9:.*]] = constant 6.37261954E-4 : f32 -// CHECK: %[[VAL_10:.*]] = constant 0.00489352457 : f32 -// CHECK: %[[VAL_11:.*]] = constant 1.19825836E-6 : f32 -// CHECK: %[[VAL_12:.*]] = constant 1.18534706E-4 : f32 -// CHECK: %[[VAL_13:.*]] = constant 0.00226843474 : f32 -// CHECK: %[[VAL_14:.*]] = constant 0.00489352504 : f32 -// CHECK: %[[VAL_15:.*]] = fpext %[[VAL_0]] : f16 to f32 -// CHECK: %[[VAL_16:.*]] = absf %[[VAL_15]] : f32 -// CHECK: %[[VAL_17:.*]] = cmpf "olt", %[[VAL_16]], %[[VAL_1]] : f32 -// CHECK: %[[VAL_18:.*]] = cmpf "ule", %[[VAL_15]], %[[VAL_2]] : f32 -// CHECK: %[[VAL_19:.*]] = select %[[VAL_18]], %[[VAL_15]], %[[VAL_2]] : f32 -// CHECK: %[[VAL_20:.*]] = cmpf "uge", %[[VAL_19]], %[[VAL_3]] : f32 -// CHECK: %[[VAL_21:.*]] = select %[[VAL_20]], %[[VAL_19]], %[[VAL_3]] : f32 -// CHECK: %[[VAL_22:.*]] = mulf %[[VAL_21]], %[[VAL_21]] : f32 -// CHECK: %[[VAL_23:.*]] = mulf %[[VAL_22]], %[[VAL_4]] : f32 -// CHECK: %[[VAL_24:.*]] = addf %[[VAL_23]], %[[VAL_5]] : f32 -// CHECK: %[[VAL_25:.*]] = mulf %[[VAL_22]], %[[VAL_24]] : f32 -// CHECK: %[[VAL_26:.*]] = addf %[[VAL_25]], %[[VAL_6]] : f32 -// CHECK: %[[VAL_27:.*]] = mulf %[[VAL_22]], %[[VAL_26]] : f32 -// CHECK: %[[VAL_28:.*]] = addf %[[VAL_27]], %[[VAL_7]] : f32 -// CHECK: %[[VAL_29:.*]] = mulf %[[VAL_22]], %[[VAL_28]] : f32 -// CHECK: %[[VAL_30:.*]] = addf %[[VAL_29]], %[[VAL_8]] : f32 -// CHECK: %[[VAL_31:.*]] = mulf %[[VAL_22]], %[[VAL_30]] : f32 -// CHECK: %[[VAL_32:.*]] = addf %[[VAL_31]], %[[VAL_9]] : f32 -// CHECK: %[[VAL_33:.*]] = mulf %[[VAL_22]], %[[VAL_32]] : f32 -// CHECK: %[[VAL_34:.*]] = addf %[[VAL_33]], %[[VAL_10]] : f32 -// CHECK: %[[VAL_35:.*]] = mulf %[[VAL_21]], %[[VAL_34]] : f32 -// CHECK: %[[VAL_36:.*]] = mulf %[[VAL_22]], %[[VAL_11]] : f32 -// CHECK: %[[VAL_37:.*]] = addf %[[VAL_36]], %[[VAL_12]] : f32 -// CHECK: %[[VAL_38:.*]] = mulf %[[VAL_22]], %[[VAL_37]] : f32 -// CHECK: %[[VAL_39:.*]] = addf %[[VAL_38]], %[[VAL_13]] : f32 -// CHECK: %[[VAL_40:.*]] = mulf %[[VAL_22]], %[[VAL_39]] : f32 -// CHECK: %[[VAL_41:.*]] = addf %[[VAL_40]], %[[VAL_14]] : f32 -// CHECK: %[[VAL_42:.*]] = divf %[[VAL_35]], %[[VAL_41]] : f32 -// CHECK: %[[VAL_43:.*]] = select %[[VAL_17]], %[[VAL_15]], %[[VAL_42]] : f32 -// CHECK: %[[VAL_44:.*]] = fptrunc %[[VAL_43]] : f32 to f16 -// CHECK: return %[[VAL_44]] : f16 -// CHECK: } - - diff --git a/tensorflow/compiler/mlir/hlo/tests/legalize_to_scf.mlir b/tensorflow/compiler/mlir/hlo/tests/legalize_to_scf.mlir new file mode 100644 index 00000000000..9c887a73a0f --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/tests/legalize_to_scf.mlir @@ -0,0 +1,38 @@ +// RUN: mlir-hlo-opt --mhlo-control-flow-to-scf %s | FileCheck %s + +func @lt_loop(%arg0: tensor<4xf32>, %arg1: tensor, %arg2: tensor, %arg3: tensor<4xf32>, %arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor, %arg8: tensor) -> (tuple, tensor, tensor>) { + %cst = constant dense<-1> : tensor + %cst_0 = constant dense<1> : tensor + %cst_1 = constant dense<0> : tensor + %cst_2 = constant dense<1000> : tensor + %0 = "mhlo.tuple"(%cst_1, %cst, %cst_2) : (tensor, tensor, tensor) -> tuple, tensor, tensor> + %1 = "mhlo.while"(%0) ( { + ^bb0(%arg9: tuple, tensor, tensor>): // no predecessors + %2 = "mhlo.get_tuple_element"(%arg9) {index = 0 : i32} : (tuple, tensor, tensor>) -> tensor + %3 = "mhlo.get_tuple_element"(%arg9) {index = 2 : i32} : (tuple, tensor, tensor>) -> tensor + %4 = "mhlo.compare"(%2, %3) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + "mhlo.return"(%4) : (tensor) -> () + }, { + ^bb0(%arg9: tuple, tensor, tensor>): // no predecessors + %2 = "mhlo.get_tuple_element"(%arg9) {index = 0 : i32} : (tuple, tensor, tensor>) -> tensor + %3 = mhlo.add %2, %cst_0 : tensor + %4 = "mhlo.get_tuple_element"(%arg9) {index = 1 : i32} : (tuple, tensor, tensor>) -> tensor + %5 = "mhlo.get_tuple_element"(%arg9) {index = 2 : i32} : (tuple, tensor, tensor>) -> tensor + %6 = "mhlo.tuple"(%3, %4, %5) : (tensor, tensor, tensor) -> tuple, tensor, tensor> + "mhlo.return"(%6) : (tuple, tensor, tensor>) -> () + }) : (tuple, tensor, tensor>) -> tuple, tensor, tensor> + return %1 : tuple, tensor, tensor> +} + +// CHECK-LABEL: func @lt_loop( +// CHECK: %[[VAL_9:.*]] = constant dense<-1> : tensor +// CHECK: %[[VAL_10:.*]] = constant dense<1> : tensor +// CHECK: %[[VAL_11:.*]] = constant dense<0> : tensor +// CHECK: %[[VAL_12:.*]] = constant dense<1000> : tensor +// CHECK: %[[VAL_14:.*]] = index_cast %[[VAL_11]] : tensor to tensor +// CHECK: %[[VAL_15:.*]] = extract_element %[[VAL_14]][] : tensor +// CHECK: %[[VAL_16:.*]] = index_cast %[[VAL_12]] : tensor to tensor +// CHECK: %[[VAL_17:.*]] = extract_element %[[VAL_16]][] : tensor +// CHECK: %[[VAL_18:.*]] = index_cast %[[VAL_10]] : tensor to tensor +// CHECK: %[[VAL_19:.*]] = extract_element %[[VAL_18]][] : tensor +// CHECK: scf.for %[[VAL_21:.*]] = %[[VAL_15]] to %[[VAL_17]] step %[[VAL_19]] iter_args(%[[VAL_22:.*]] = %[[VAL_9]], %[[VAL_23:.*]] = %[[VAL_12]]) diff --git a/tensorflow/compiler/mlir/hlo/tests/lhlo-copy-removal.mlir b/tensorflow/compiler/mlir/hlo/tests/lhlo-copy-removal.mlir deleted file mode 100644 index 3271595900d..00000000000 --- a/tensorflow/compiler/mlir/hlo/tests/lhlo-copy-removal.mlir +++ /dev/null @@ -1,115 +0,0 @@ -// RUN: mlir-hlo-opt -lhlo-copy-removal %s -o - | FileCheck %s - -// CHECK-LABEL: func @remove_simple -func @remove_simple(%arg0: memref<2x2xf32>) { - %0 = alloc() {temp = true} : memref<2x2xf32> - "lmhlo.copy"(%0, %arg0) : (memref<2x2xf32>, memref<2x2xf32>) -> () - dealloc %0 : memref<2x2xf32> - // CHECK-NEXT: "lmhlo.terminator"() : () -> () - "lmhlo.terminator"() : () -> () -} - -// ----- - -// CHECK-LABEL: func @remove_without_dealloc -func @remove_without_dealloc(%arg0: memref<2x2xf32>) { - %0 = alloc() {temp = true} : memref<2x2xf32> - "lmhlo.copy"(%0, %arg0) : (memref<2x2xf32>, memref<2x2xf32>) -> () - // CHECK-NEXT: "lmhlo.terminator"() : () -> () - "lmhlo.terminator"() : () -> () -} - -// ----- - -// CHECK-LABEL: func @replace_dependency -func @replace_dependency(%arg0: memref<2x2xf32>, %arg1: memref<2x2xf32>) { - %0 = alloc() {temp = true} : memref<2x2xf32> - "lmhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> () - // CHECK-NEXT: "lmhlo.exponential"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> () - "lmhlo.copy"(%0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> () - dealloc %0 : memref<2x2xf32> - // CHECK-NEXT: "lmhlo.terminator"() : () -> () - "lmhlo.terminator"() : () -> () -} - -// ----- - -// CHECK-LABEL: func @keep_copies -func @keep_copies(%arg0: memref<2x2xf32>, %arg1: memref<2x2xf32>) { - // CHECK-NEXT: "lmhlo.copy"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> () - "lmhlo.copy"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> () - // CHECK-NEXT: "lmhlo.terminator"() : () -> () - "lmhlo.terminator"() : () -> () -} - -// ----- - -// CHECK-LABEL: func @must_not_be_removed -func @must_not_be_removed(%arg0: memref<2x2xf32>, - %arg1: memref<2x2xf32>, - %arg2: memref<2x2xf32>) { - // CHECK-NEXT: %[[ALLOC:.*]] = alloc() {temp = true} : memref<2x2xf32> - %0 = alloc() {temp = true} : memref<2x2xf32> - // CHECK-NEXT: "lmhlo.exponential"(%arg0, %[[ALLOC]]) : (memref<2x2xf32>, memref<2x2xf32>) -> () - "lmhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> () - // CHECK-NEXT: "lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () - "lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () - // CHECK-NEXT: "lmhlo.copy"(%[[ALLOC]], %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () - "lmhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () - dealloc %0 : memref<2x2xf32> - "lmhlo.terminator"() : () -> () -} - -// ----- - -// CHECK-LABEL: func @must_be_removed_first -func @must_be_removed_first(%arg0: memref<2x2xf32>, - %arg1: memref<2x2xf32>, - %arg2: memref<2x2xf32>) { - %0 = alloc() {temp = true} : memref<2x2xf32> - // CHECK-NEXT: "lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () - "lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () - // CHECK-NEXT: "lmhlo.exponential"(%arg0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () - "lmhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> () - "lmhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () - dealloc %0 : memref<2x2xf32> - "lmhlo.terminator"() : () -> () -} - -// ----- - -// CHECK-LABEL: func @must_be_removed_second -func @must_be_removed_second(%arg0: memref<2x2xf32>, - %arg1: memref<2x2xf32>, - %arg2: memref<2x2xf32>) { - %0 = alloc() {temp = true} : memref<2x2xf32> - // CHECK-NEXT: "lmhlo.exponential"(%arg0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () - "lmhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> () - "lmhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () - // CHECK-NEXT: "lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () - "lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () - dealloc %0 : memref<2x2xf32> - "lmhlo.terminator"() : () -> () -} - -// ----- - -// CHECK-LABEL: func @reduce -func @reduce(%arg0: memref<1x8xf32>, %arg1: memref, %arg2: memref<1xf32>) { - %0 = alloc() : memref<1xf32> - "lmhlo.reduce"(%arg0, %arg1, %0) ( { - // CHECK: ^bb0(%[[ARG0:.*]]: memref, %[[ARG1:.*]]: memref, - // CHECK-SAME: %[[ARG2:.*]]: memref) - ^bb0(%arg3: memref, %arg4: memref, %arg5: memref): - %1 = alloc() : memref - // CHECK: "lmhlo.add"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) - "lmhlo.add"(%arg3, %arg4, %1) - : (memref, memref, memref) -> () - // CHECK-NOT; lmhlo.copy - "lmhlo.copy"(%1, %arg5) : (memref, memref) -> () - "lmhlo.terminator"() : () -> () - }) {dimensions = dense<1> : tensor<1xi64>} - : (memref<1x8xf32>, memref, memref<1xf32>) -> () - "lmhlo.copy"(%0, %arg2) : (memref<1xf32>, memref<1xf32>) -> () - return -} diff --git a/tensorflow/compiler/mlir/hlo/tests/lhlo-fuse-linalg.mlir b/tensorflow/compiler/mlir/hlo/tests/lhlo-fuse-linalg.mlir index 6a674664a36..e51bdfec6f7 100644 --- a/tensorflow/compiler/mlir/hlo/tests/lhlo-fuse-linalg.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/lhlo-fuse-linalg.mlir @@ -3,20 +3,25 @@ // RUN: mlir-hlo-opt -lhlo-fuse-linalg=use-parallel-loops %s -split-input-file | FileCheck %s -check-prefix=PLOOP #map0 = affine_map<(d0, d1) -> (d0, d1)> -#pointwise_2d_trait = {args_in = 2, args_out = 1, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]} +#pointwise_2d_trait = {indexing_maps = [#map0, #map0, #map0], + iterator_types = ["parallel", "parallel"]} func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>, %summand_2: memref<6x6xf32>, %result: memref<6x6xf32>) { %temp_result = alloc() : memref<6x6xf32> - linalg.generic #pointwise_2d_trait %summand_1, %summand_2, %temp_result { + linalg.generic #pointwise_2d_trait + ins(%summand_1, %summand_2 : memref<6x6xf32>, memref<6x6xf32>) + outs(%temp_result : memref<6x6xf32>) { ^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32): %out = addf %summand_1_in, %summand_2_in : f32 linalg.yield %out : f32 - } : memref<6x6xf32>, memref<6x6xf32>, memref<6x6xf32> - linalg.generic #pointwise_2d_trait %temp_result, %multiplier, %result { + } + linalg.generic #pointwise_2d_trait + ins(%temp_result, %multiplier : memref<6x6xf32>, memref<6x6xf32>) + outs(%result : memref<6x6xf32>) { ^bb0(%temp_result_in: f32, %multiplier_in: f32, %result_in: f32): %out = mulf %temp_result_in, %multiplier_in : f32 linalg.yield %out : f32 - } : memref<6x6xf32>, memref<6x6xf32>, memref<6x6xf32> + } dealloc %temp_result : memref<6x6xf32> return } @@ -59,36 +64,37 @@ func @fusion_of_three(%arg0: memref<100x10xf32>, %arg2: memref<100x10xf32>) { %0 = alloc() : memref<100x10xf32> linalg.generic { - args_in = 1 : i64, - args_out = 1 : i64, - indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = ["parallel", "parallel"] - } %arg1, %0 { - ^bb0(%arg3: f32, %arg4: f32): // no predecessors - linalg.yield %arg3 : f32 - }: memref<100xf32>, memref<100x10xf32> + indexing_maps = [affine_map<(d0, d1) -> (d0)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%arg1 : memref<100xf32>) + outs(%0 : memref<100x10xf32>) { + ^bb0(%arg3: f32, %arg4: f32): // no predecessors + linalg.yield %arg3 : f32 + } %1 = alloc() : memref<100x10xf32> linalg.generic { - args_in = 2 : i64, - args_out = 1 : i64, - indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = ["parallel", "parallel"] - } %arg0, %0, %1 { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%arg0, %0 : memref<100x10xf32>, memref<100x10xf32>) + outs(%1 : memref<100x10xf32>) { ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors %2 = subf %arg3, %arg4 : f32 linalg.yield %2 : f32 - }: memref<100x10xf32>, memref<100x10xf32>, memref<100x10xf32> + } dealloc %0 : memref<100x10xf32> linalg.generic { - args_in = 1 : i64, - args_out = 1 : i64, - indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = ["parallel", "parallel"] - } %1, %arg2 { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%1 : memref<100x10xf32>) + outs(%arg2 : memref<100x10xf32>) { ^bb0(%arg3: f32, %arg4: f32): // no predecessors %2 = exp %arg3 : f32 linalg.yield %2 : f32 - }: memref<100x10xf32>, memref<100x10xf32> + } dealloc %1 : memref<100x10xf32> return } @@ -130,20 +136,26 @@ func @fusion_of_three(%arg0: memref<100x10xf32>, // ----- #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -#pointwise_4d_trait = {args_in = 2, args_out = 1, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} +#pointwise_4d_trait = {indexing_maps = [#map0, #map0, #map0], + iterator_types = ["parallel", "parallel", "parallel", + "parallel"]} func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32>, %summand_2: memref<6x6x6x6xf32>, %result: memref<6x6x6x6xf32>) { %temp_result = alloc() : memref<6x6x6x6xf32> - linalg.generic #pointwise_4d_trait %summand_1, %summand_2, %temp_result { + linalg.generic #pointwise_4d_trait + ins(%summand_1, %summand_2 : memref<6x6x6x6xf32>, memref<6x6x6x6xf32>) + outs(%temp_result : memref<6x6x6x6xf32>) { ^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32): %out = addf %summand_1_in, %summand_2_in : f32 linalg.yield %out : f32 - } : memref<6x6x6x6xf32>, memref<6x6x6x6xf32>, memref<6x6x6x6xf32> - linalg.generic #pointwise_4d_trait %temp_result, %multiplier, %result { + } + linalg.generic #pointwise_4d_trait + ins(%temp_result, %multiplier : memref<6x6x6x6xf32>, memref<6x6x6x6xf32>) + outs(%result : memref<6x6x6x6xf32>) { ^bb0(%temp_result_in: f32, %multiplier_in: f32, %result_in: f32): %out = mulf %temp_result_in, %multiplier_in : f32 linalg.yield %out : f32 - } : memref<6x6x6x6xf32>, memref<6x6x6x6xf32>, memref<6x6x6x6xf32> + } dealloc %temp_result : memref<6x6x6x6xf32> return } @@ -184,21 +196,26 @@ func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32 // ----- #map0 = affine_map<(d0, d1) -> (d0, d1)> -#pointwise_2d_trait = {args_in = 2, args_out = 1, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]} +#pointwise_2d_trait = {indexing_maps = [#map0, #map0, #map0], + iterator_types = ["parallel", "parallel"]} func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>, %summand_2: memref<6x6xf32>) -> memref<6x6xf32> { %temp_result = alloc() : memref<6x6xf32> - linalg.generic #pointwise_2d_trait %summand_1, %summand_2, %temp_result { + linalg.generic #pointwise_2d_trait + ins(%summand_1, %summand_2 : memref<6x6xf32>, memref<6x6xf32>) + outs(%temp_result : memref<6x6xf32>) { ^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32): %out = addf %summand_1_in, %summand_2_in : f32 linalg.yield %out : f32 - } : memref<6x6xf32>, memref<6x6xf32>, memref<6x6xf32> + } %result = alloc() : memref<6x6xf32> - linalg.generic #pointwise_2d_trait %temp_result, %multiplier, %result { + linalg.generic #pointwise_2d_trait + ins(%temp_result, %multiplier : memref<6x6xf32>, memref<6x6xf32>) + outs(%result : memref<6x6xf32>) { ^bb0(%temp_result_in: f32, %multiplier_in: f32, %result_in: f32): %out = mulf %temp_result_in, %multiplier_in : f32 linalg.yield %out : f32 - } : memref<6x6xf32>, memref<6x6xf32>, memref<6x6xf32> + } dealloc %temp_result : memref<6x6xf32> return %result : memref<6x6xf32> } @@ -234,3 +251,51 @@ func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>, // PLOOP: addf // PLOOP: linalg.generic // PLOOP: mulf + +// ----- + +func @view_result(%arg0: memref, %arg1: memref, %arg2: index) + -> memref<*xf32> { + %c1 = constant 1 : index + %c0 = constant 0 : index + %1 = alloc(%arg2) : memref + linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, + affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"]} + ins(%arg0 : memref) outs(%1 : memref) { + ^bb0(%arg3: f32, %arg4: f32): // no predecessors + %13 = absf %arg3 : f32 + linalg.yield %13 : f32 + } + %2 = lmhlo.reshape_memref_cast %1(%arg1) + : (memref, memref) -> memref<*xf32> + return %2 : memref<*xf32> +} + +// CHECK-LABEL: func @view_result +// CHECK: %[[C1:.*]] = constant 1 +// CHECK-NOT: linalg.generic +// CHECK: scf.for {{.*}} step %[[C1]] +// CHECK-NOT: scf.for +// CHECK: linalg.generic +// CHECK: absf +// CHECK: reshape_memref_cast + +// TILED-LABEL: func @view_result +// TILED-DAG: %[[C2:.*]] = constant 2 +// TILED-NOT: linalg.generic +// TILED: scf.for {{.*}} step %[[C2]] +// TILED-NOT: scf.for +// TILED: linalg.generic +// TILED: absf +// TILED: reshape_memref_cast + + +// PLOOP-LABEL: func @view_result +// PLOOP-NOT: linalg.generic +// PLOOP: scf.parallel +// PLOOP-NOT: scf.parallel +// PLOOP: linalg.generic +// PLOOP: absf +// PLOOP: reshape_memref_cast + diff --git a/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-affine.mlir b/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-affine.mlir index 87818045993..d020f7a083b 100644 --- a/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-affine.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-affine.mlir @@ -158,7 +158,14 @@ func @float_dot_op(%lhs: memref<7x3xf32>, %rhs: // CHECK-NEXT: %[[ADD:.*]] = addf %[[MULT]], %[[RESULT]] : f32 // CHECK-NEXT: affine.store %[[ADD]], %{{.*}}[%[[I]], %[[J]]] : memref<7x4xf32> // CHECK: return - "lmhlo.dot"(%lhs, %rhs, %result) : + "lmhlo.dot"(%lhs, %rhs, %result) { + dot_dimension_numbers = { + lhs_batching_dimensions = dense<> : tensor<0xi64>, + rhs_batching_dimensions = dense<> : tensor<0xi64>, + lhs_contracting_dimensions = dense<1> : tensor<1xi64>, + rhs_contracting_dimensions = dense<0> : tensor<1xi64> + } + } : (memref<7x3xf32>, memref<3x4xf32>, memref<7x4xf32>) -> () return } @@ -175,7 +182,14 @@ func @int_dot_op(%lhs: memref<7x3xi32>, %rhs: // CHECK-NEXT: %[[ADD:.*]] = addi %[[MULT]], %[[RESULT]] : i32 // CHECK-NEXT: affine.store %[[ADD]], %{{.*}}[%[[I]], %[[J]]] : memref<7x4xi32> // CHECK: return - "lmhlo.dot"(%lhs, %rhs, %result) : + "lmhlo.dot"(%lhs, %rhs, %result) { + dot_dimension_numbers = { + lhs_batching_dimensions = dense<> : tensor<0xi64>, + rhs_batching_dimensions = dense<> : tensor<0xi64>, + lhs_contracting_dimensions = dense<1> : tensor<1xi64>, + rhs_contracting_dimensions = dense<0> : tensor<1xi64> + } + } : (memref<7x3xi32>, memref<3x4xi32>, memref<7x4xi32>) -> () return } diff --git a/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-linalg.mlir index f174b005a8d..47151089ccb 100644 --- a/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-linalg.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-linalg.mlir @@ -125,6 +125,20 @@ func @copy(%in: memref<2x4x8xf32>, %out: memref<2x4x8xf32>) { // ----- +// CHECK-LABEL: func @is_finte +func @is_finte(%input: memref<2x2xf32>, %result: memref<2x2xi1>) { + "lmhlo.is_finite"(%input, %result) : (memref<2x2xf32>, memref<2x2xi1>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]): +// CHECK-NEXT: %[[POS_INF:.+]] = constant 0x7F800000 : f32 +// CHECK-NEXT: %[[ABS_X:.+]] = absf %[[OPERAND_IN]] : f32 +// CHECK-NEXT: %[[RESULT:.+]] = cmpf "one", %[[ABS_X]], %[[POS_INF]] : f32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : i1 + +// ----- + // CHECK-LABEL: func @float_cmp func @float_cmp(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xi1>) { @@ -263,7 +277,8 @@ func @static_broadcast_in_dim_expansion(%operand: memref<1x5xf32>, // CHECK: %[[RESHAPED_ARG:.*]] = linalg.reshape %{{.*}}#[[REASSOCIATION]]] // CHECK-SAME: memref<1x5xf32> into memref<5xf32> // CHECK: linalg.generic {{{.*}}indexing_maps = -// CHECK-SAME: [#[[OPERAND_MAP]], #[[RESULT_MAP]]]{{.*}} %[[RESHAPED_ARG]] +// CHECK-SAME: [#[[OPERAND_MAP]], #[[RESULT_MAP]]] +// CHECK-SAME: ins(%[[RESHAPED_ARG]] : // CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %[[RESULT:.*]]: f32): // CHECK-NEXT: linalg.yield %[[OPERAND]] : f32 @@ -496,6 +511,18 @@ func @sin(%input: memref<2x2xf32>, // ----- +// CHECK-LABEL: func @floor +func @floor(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { + "lmhlo.floor"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]): +// CHECK-NEXT: %[[RESULT:.*]] = floorf %[[OPERAND_IN]] : f32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : f32 + +// ----- + // CHECK-LABEL: func @negf func @negf(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { "lmhlo.negate"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () @@ -521,6 +548,19 @@ func @negi(%input: memref<2x2xi32>, %result: memref<2x2xi32>) { // ----- +// CHECK-LABEL: func @not +func @not(%input: memref<2x2xi64>, %result: memref<2x2xi64>) { + "lmhlo.not"(%input, %result) : (memref<2x2xi64>, memref<2x2xi64>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i64, %[[RESULT_OUT:.*]]): +// CHECK-NEXT: %[[N1:.*]] = constant -1 : i64 +// CHECK-NEXT: %[[RESULT:.*]] = xor %[[N1]], %[[OPERAND_IN]] : i64 +// CHECK-NEXT: linalg.yield %[[RESULT]] : i64 + +// ----- + // CHECK-LABEL: func @rem func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) { @@ -560,6 +600,37 @@ func @sign(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { // ----- +// CHECK-LABEL: func @sign_bf16 +func @sign_bf16(%input: memref<2x2xbf16>, %result: memref<2x2xbf16>) { + "lmhlo.sign"(%input, %result) : (memref<2x2xbf16>, memref<2x2xbf16>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: bf16, %[[RESULT_OUT:.*]]): +// CHECK-NEXT: %[[CST:.*]] = constant 1.000000e+00 : bf16 +// CHECK-NEXT: %[[RESULT:.*]] = copysign %[[CST]], %[[OPERAND_IN]] : bf16 +// CHECK-NEXT: linalg.yield %[[RESULT]] : bf16 + +// ----- + +// CHECK-LABEL: func @sign_i16 +func @sign_i16(%input: memref<2x2xi16>, %result: memref<2x2xi16>) { + "lmhlo.sign"(%input, %result) : (memref<2x2xi16>, memref<2x2xi16>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i16, %[[RESULT_OUT:.*]]): +// CHECK-NEXT: %[[C0:.*]] = constant 0 : i16 +// CHECK-NEXT: %[[C15:.*]] = constant 15 : i16 +// CHECK-NEXT: %[[C1:.*]] = constant 1 : i16 +// CHECK-NEXT: %[[CMP:.*]] = cmpi "eq", %[[OPERAND_IN]], %[[C0]] : i16 +// CHECK-NEXT: %[[ASHR:.*]] = shift_right_signed %[[OPERAND_IN]], %[[C15]] : i16 +// CHECK-NEXT: %[[OR:.*]] = or %[[ASHR]], %[[C1]] : i16 +// CHECK-NEXT: %[[RESULT:.*]] = select %[[CMP]], %[[C0]], %[[OR]] : i16 +// CHECK-NEXT: linalg.yield %[[RESULT]] : i16 + +// ----- + // CHECK-LABEL: func @sqrt func @sqrt(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { "lmhlo.sqrt"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () @@ -702,6 +773,32 @@ func @reshape_3D_4D(%arg0: memref<1x49x16xf32>, %arg1: memref<1x784x1x1xf32>) { // ----- +// CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-LABEL: func @reshape1_4D_4D +func @reshape1_4D_4D(%arg0: memref<4x512x1x1xi32>, + %arg1: memref<1x4x1x512xi32>) { + "lmhlo.reshape"(%arg0, %arg1) + : (memref<4x512x1x1xi32>, memref<1x4x1x512xi32>) -> () + return +} +// CHECK: linalg.reshape %{{.*}} [#[[MAP]]] +// CHECK: linalg.reshape %{{.*}} [#[[MAP]]] + +// ----- + +// CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-LABEL: func @reshape2_4D_4D +func @reshape2_4D_4D(%arg0: memref<4x1x1x1024xi32>, + %arg1: memref<4x1024x1x1xi32>) { + "lmhlo.reshape"(%arg0, %arg1) + : (memref<4x1x1x1024xi32>, memref<4x1024x1x1xi32>) -> () + return +} +// CHECK: linalg.reshape %{{.*}} [#[[MAP]]] +// CHECK: linalg.reshape %{{.*}} [#[[MAP]]] + +// ----- + // CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 2)> // CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: func @reverse @@ -736,3 +833,16 @@ func @conv(%input: memref<3x5x5x3xf32>, %filter: memref<2x2x3x4xf32>, %output: m "lmhlo.copy"(%0, %output) : (memref<3x5x5x4xf32>, memref<3x5x5x4xf32>) -> () "lmhlo.terminator"() : () -> () } + +// ----- + +// CHECK-DAG: #[[TRANSPOSE_INPUT_MAP:.*]] = affine_map<(d0, d1) -> (d1, d0)> +// CHECK-DAG: #[[TRANSPOSE_OUTPUT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func @transpose +func @transpose(%arg0: memref<2x2xf32>, %arg1: memref<2x2xf32>) { + "lmhlo.transpose"(%arg0, %arg1) { + permutation = dense<[1, 0]> : tensor<2xi64> + } : (memref<2x2xf32>, memref<2x2xf32>) -> () + return +} +// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[TRANSPOSE_INPUT_MAP]], #[[TRANSPOSE_OUTPUT_MAP]]] diff --git a/tensorflow/compiler/mlir/hlo/tests/lhlo_gpu_ops.mlir b/tensorflow/compiler/mlir/hlo/tests/lhlo_gpu_ops.mlir new file mode 100644 index 00000000000..9e5ce67f39a --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/tests/lhlo_gpu_ops.mlir @@ -0,0 +1,99 @@ +// RUN: mlir-hlo-opt %s -verify-diagnostics -split-input-file | mlir-hlo-opt | FileCheck %s + +// CHECK-LABEL: func @batch_norm_grad_memrefs +func @batch_norm_grad_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, %arg2: memref<8xf32>, + %arg3: memref<8xf32>, %arg4: memref<8x8x8x8xf32>, + %grad_operand: memref<8x8x8x8xf32>, %grad_scale: memref<8xf32>, + %grad_offset: memref<8xf32>) -> () { + "lmhlo_gpu.batch_norm_grad"(%arg0, %arg1, %arg2, %arg3, %arg4, %grad_operand, %grad_scale, %grad_offset) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} + : (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8x8x8x8xf32>, + memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>) -> () + return +} + +// CHECK-LABEL: func @batch_norm_inference_memrefs +func @batch_norm_inference_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, %arg2: memref<8xf32>, + %arg3: memref<8xf32>, %arg4: memref<8xf32>, %arg_out: memref<8x8x8x8xf32>) -> () { + "lmhlo_gpu.batch_norm_inference"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg_out) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} + : (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8x8x8x8xf32>) -> () + return +} + +// CHECK-LABEL: func @batch_norm_training_memrefs +func @batch_norm_training_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, %arg2: memref<8xf32>, + %output: memref<8x8x8x8xf32>, %batch_mean: memref<8xf32>, + %batch_var: memref<8xf32>) -> () { + "lmhlo_gpu.batch_norm_training"(%arg0, %arg1, %arg2, %output, %batch_mean, %batch_var) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} + : (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>) -> () + return +} + +// CHECK-LABEL: func @conv_forward +func @conv_forward(%input : memref<1x1x8x8xf16>, %filter: memref<1x1x2x2xf16>, %output: memref<1x1x7x7xf16>) { + %scratch = alloc() : memref<32xi8> + // This defined a 2D convolution over a 8x8 single channel input using a 2x2 + // filter and with an output of 7x7xf16. The 1x1x8x8 is (N, C, H, W) + "lmhlo_gpu.conv_forward"(%input, %filter, %output, %scratch) + { dimension_numbers = {input_batch_dimension = 0 : i64, + input_feature_dimension = 1 : i64, + input_spatial_dimensions = dense<[2,3]> : tensor<2xi64>, + kernel_input_feature_dimension = 0 : i64, + kernel_output_feature_dimension = 1 : i64, + kernel_spatial_dimensions = dense<[2,3]> : tensor<2xi64>, + output_batch_dimension = 0 : i64, + output_feature_dimension = 1 : i64, + output_spatial_dimensions = dense<[2,3]> : tensor<2xi64>}, + window_strides = dense<[1, 1]> : tensor<2xi64>, + padding = dense<[0,0]> : tensor<2xi64>, + lhs_dilation = dense<[1,1]> : tensor<2xi64>, + rhs_dilation = dense<[1,1]> : tensor<2xi64>, + feature_group_count = 1, + batch_group_count = 1, + result_scale = 1.0, + backend_config = {algorithm=0, tensor_ops_enabled = true } + } + : (memref<1x1x8x8xf16>, memref<1x1x2x2xf16>, memref<1x1x7x7xf16>, memref<32xi8>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @gemm +func @gemm(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>, %output:memref<5x5xf32>) { + "lmhlo_gpu.gemm"(%lhs, %rhs, %output) { dot_dimension_numbers = { + lhs_batching_dimensions = dense<[1,1]> : tensor<2xi64>, + rhs_batching_dimensions = dense<[1,1]> : tensor<2xi64>, + lhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>, + rhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>}, + alpha = 0.5, + batch_size = 1, + algorithm = 0} + : (memref<5x4xf32>, memref<4x5xf32>, memref<5x5xf32>) -> () + return +} + + +// CHECK-LABEL: func @gemm_bias +func @gemm_bias(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>, + %bias: memref<5x5xf32>, %output:memref<5x5xf32>) { + "lmhlo_gpu.gemm_bias"(%lhs, %rhs, %bias, %output) { dot_dimension_numbers = { + lhs_batching_dimensions = dense<[1,1]> : tensor<2xi64>, + rhs_batching_dimensions = dense<[1,1]> : tensor<2xi64>, + lhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>, + rhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>}, + alpha = 0.5, + beta = 1.0, + batch_size = 1, + algorithm = 0} + : (memref<5x4xf32>, memref<4x5xf32>, memref<5x5xf32>, memref<5x5xf32>) -> () + return +} + +// CHECK-LABEL: func @cholesky +func @cholesky(%arg : memref<10x10xf32>, %out: memref<10x10xf32>) { + %scratch = alloc() : memref<32xi8> + %info = alloc() : memref<32xi32> + "lmhlo_gpu.cholesky"(%arg, %out, %scratch, %info) { is_upper = true } + : (memref<10x10xf32>, memref<10x10xf32>, memref<32xi8>, memref<32xi32>) -> () + return +} diff --git a/tensorflow/compiler/mlir/hlo/tests/lower-complex.mlir b/tensorflow/compiler/mlir/hlo/tests/lower-complex.mlir index a7bd21257a6..b9c91d61377 100644 --- a/tensorflow/compiler/mlir/hlo/tests/lower-complex.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/lower-complex.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt %s -mhlo-test-chlo-legalize-to-hlo -mhlo-test-lower-complex | FileCheck %s +// RUN: mlir-hlo-opt %s -chlo-legalize-to-hlo -mhlo-test-lower-complex | FileCheck %s // CHECK-LABEL: @add func @add(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) { diff --git a/tensorflow/compiler/mlir/hlo/tests/mhlo_infer_shape_type_methods.mlir b/tensorflow/compiler/mlir/hlo/tests/mhlo_infer_shape_type_methods.mlir new file mode 100644 index 00000000000..d626f520824 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/tests/mhlo_infer_shape_type_methods.mlir @@ -0,0 +1,37 @@ +// RUN: mlir-hlo-opt --mhlo-test-infer-shaped-type-methods --allow-unregistered-dialect --split-input-file %s | FileCheck %s + +// ----- +// CHECK-LABEL: @select +// CHECK-SAME: (%[[PRED:.*]]: tensor<2x?xi1>, +func @select(%pred : tensor<2x?xi1>, %a : tensor<2x?xf32>, %b : tensor<2x?xf32>) + -> tensor<2xi64> { + // CHECK: %[[C2:.*]] = constant 2 : i64 + // CHECK: %[[C1:.*]] = constant 1 : index + // CHECK: %[[DIM_AS_INDEX:.*]] = dim %[[PRED]], %[[C1]] : tensor<2x?xi1> + // CHECK: %[[DIM:.*]] = index_cast %[[DIM_AS_INDEX]] : index to i64 + // CHECK: %[[SHAPE:.*]] = tensor_from_elements %[[C2]], %[[DIM]] : tensor<2xi64> + // CHECK: return %[[SHAPE]] : tensor<2xi64> + %0 = "mhlo.select"(%pred, %a, %b) + : (tensor<2x?xi1>, tensor<2x?xf32>, tensor<2x?xf32>) -> tensor<2x?xf32> + %1 = "mhlo_test.reify_return_type_shapes"(%0) + : (tensor<2x?xf32>) -> tensor<2xi64> + return %1 : tensor<2xi64> +} + +// ----- +// CHECK-LABEL: @compare +// CHECK-SAME: (%[[A:.*]]: tensor<2x?xf32>, +func @compare(%a : tensor<2x?xf32>, %b : tensor<2x?xf32>) -> tensor<2xi64> { + // CHECK: %[[C2:.*]] = constant 2 : i64 + // CHECK: %[[C1:.*]] = constant 1 : index + // CHECK: %[[DIM_AS_INDEX:.*]] = dim %[[A]], %[[C1]] : tensor<2x?xf32> + // CHECK: %[[DIM:.*]] = index_cast %[[DIM_AS_INDEX]] : index to i64 + // CHECK: %[[SHAPE:.*]] = tensor_from_elements %[[C2]], %[[DIM]] : tensor<2xi64> + // CHECK: return %[[SHAPE]] : tensor<2xi64> + %0 = "mhlo.compare"(%a, %b) { comparison_direction = "NE" } + : (tensor<2x?xf32>, tensor<2x?xf32>) -> tensor<2x?xi1> + %1 = "mhlo_test.reify_return_type_shapes"(%0) + : (tensor<2x?xi1>) -> tensor<2xi64> + return %1 : tensor<2xi64> +} + diff --git a/tensorflow/compiler/mlir/hlo/tests/ops.mlir b/tensorflow/compiler/mlir/hlo/tests/ops.mlir index a8f16c403ae..fb4ab62371f 100644 --- a/tensorflow/compiler/mlir/hlo/tests/ops.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/ops.mlir @@ -328,6 +328,14 @@ func @collective_permute_duplicate_sources(%arg0: tensor<128x32xf32>) -> tensor< // ----- +func @concat_0D(%arg0: tensor, %arg1: tensor) -> tensor<2xi32> { + // expected-error@+1 {{rank-0 values cannot be concatenated}} + %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor, tensor) -> tensor<2xi32> + return %0 : tensor<2xi32> +} + +// ----- + // CHECK-LABEL: @concat_1D func @concat_1D(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<3xi32> { %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<2xi32>) -> tensor<3xi32> @@ -600,6 +608,14 @@ func @recv_non_token_second_result(%token: !mhlo.token) -> tuple // ----- +// CHECK-LABEL: func @replica_id +func @replica_id() -> tensor { + %0 = "mhlo.replica_id"() : () -> tensor + return %0 : tensor +} + +// ----- + func @rng_uniform_invalid_type(%mu: tensor>, %sigma: tensor) -> tensor<2x3x5xf32> { %shape = mhlo.constant dense<[2, 3, 5]> : tensor<3xi64> // expected-error@+1 {{but got 'tensor>'}} @@ -731,7 +747,7 @@ func @dynamic_slice_mismatch_element_types(%arg0: tensor<3x4xi32>, %arg1: tensor // ----- func @dynamic_slice_invalid_start(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { - // expected-error@+1 {{operand #1 must be a 0-dim integer tensor of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values, but got 'tensor<2xi64>'}} + // expected-error@+1 {{operand #1 must be 0D tensor of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values, but got 'tensor<2xi64>'}} %0 = "mhlo.dynamic-slice"(%arg0, %arg1) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<2xi64>) -> tensor<1x4xi32> return %0 : tensor<1x4xi32> } @@ -747,7 +763,7 @@ func @dynamic_update_slice(%input: tensor<3x4xi64>, %update: tensor<2xi64>, %sta // ----- func @dynamic_update_slice_invalid_start(%input: tensor<3x4xi64>, %update: tensor<2xi64>, %start: tensor<2xi64>) -> tensor<3x4xi64> { - // expected-error@+1 {{operand #2 must be a 0-dim integer tensor of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values, but got 'tensor<2xi64>'}} + // expected-error@+1 {{operand #2 must be 0D tensor of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values, but got 'tensor<2xi64>'}} %0 = "mhlo.dynamic-update-slice"(%input, %update, %start) : (tensor<3x4xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<3x4xi64> return %0 : tensor<3x4xi64> } @@ -1002,34 +1018,34 @@ func @constant_invalid() -> () { func @sort(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { // CHECK: mhlo.sort - %0 = "mhlo.sort"(%input0, %input1) ( { + %0:2 = "mhlo.sort"(%input0, %input1) ( { ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): %7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor "mhlo.return"(%7) : (tensor) -> () - }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> + }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>) return } // ----- func @sort_no_operands() { - // expected-error @+1 {{op requires at least one input}} - %0 = "mhlo.sort"() ( { + // expected-error @+1 {{expected named operation to have atleast 1 result}} + %0:0 = "mhlo.sort"() ( { ^bb0(%arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor): %7 = "mhlo.compare"(%arg1, %arg2) {comparison_direction = "GT"} : (tensor, tensor) -> tensor "mhlo.return"(%7) : (tensor) -> () - }) {dimension = 1 : i64, is_stable = true} : () -> tuple<> + }) {dimension = 1 : i64, is_stable = true} : () -> () return } // ----- func @sort_unknown_rank(%input0: tensor<*xf32>, %input1: tensor<16x16xi32>) { - %0 = "mhlo.sort"(%input0, %input1) ( { + %0:2 = "mhlo.sort"(%input0, %input1) ( { ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): %7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor "mhlo.return"(%7) : (tensor) -> () - }) {dimension = 1 : i64, is_stable = true} : (tensor<*xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> + }) {dimension = 1 : i64, is_stable = true} : (tensor<*xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>) return } @@ -1037,23 +1053,23 @@ func @sort_unknown_rank(%input0: tensor<*xf32>, %input1: tensor<16x16xi32>) { func @sort_unknown_rank(%input0: tensor<*xf32>, %input1: tensor<16x16xi32>) { // expected-error @+1 {{comparator block argument #0 should be of type 'tensor' but got 'tensor'}} - %0 = "mhlo.sort"(%input0, %input1) ( { + %0:2 = "mhlo.sort"(%input0, %input1) ( { ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): %7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor "mhlo.return"(%7) : (tensor) -> () - }) {dimension = 1 : i64, is_stable = true} : (tensor<*xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> + }) {dimension = 1 : i64, is_stable = true} : (tensor<*xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>) return } // ----- func @sort_different_dims(%input0: tensor<16x8xf32>, %input1: tensor<16x16xi32>) { - // expected-error @+1 {{op requires all inputs to have the same dimensions}} - %0 = "mhlo.sort"(%input0, %input1) ( { + // expected-error @+1 {{op requires the same shape for all operands and results}} + %0:2 = "mhlo.sort"(%input0, %input1) ( { ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): %7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor "mhlo.return"(%7) : (tensor) -> () - }) {dimension = 1 : i64, is_stable = true} : (tensor<16x8xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> + }) {dimension = 1 : i64, is_stable = true} : (tensor<16x8xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>) return } @@ -1061,11 +1077,11 @@ func @sort_different_dims(%input0: tensor<16x8xf32>, %input1: tensor<16x16xi32>) func @sort_dim_out_of_range(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { // expected-error @+1 {{dimension attribute value must be in range [-2, 2), but found 10}} - %0 = "mhlo.sort"(%input0, %input1) ( { + %0:2 = "mhlo.sort"(%input0, %input1) ( { ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): %7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor "mhlo.return"(%7) : (tensor) -> () - }) {dimension = 10 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> + }) {dimension = 10 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>) return } @@ -1073,11 +1089,11 @@ func @sort_dim_out_of_range(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi3 func @sort_dim_out_of_range(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { // expected-error @+1 {{dimension attribute value must be in range [-2, 2), but found -3}} - %0 = "mhlo.sort"(%input0, %input1) ( { + %0:2 = "mhlo.sort"(%input0, %input1) ( { ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): %7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor "mhlo.return"(%7) : (tensor) -> () - }) {dimension = -3 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> + }) {dimension = -3 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>) return } @@ -1085,11 +1101,11 @@ func @sort_dim_out_of_range(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi3 func @sort_wrong_block_arg_count(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { // expected-error @+1 {{op comparator block should have 4 arguments}} - %0 = "mhlo.sort"(%input0, %input1) ( { + %0:2 = "mhlo.sort"(%input0, %input1) ( { ^bb0(%arg0: tensor, %arg1: tensor): %7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor "mhlo.return"(%7) : (tensor) -> () - }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> + }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>) return } @@ -1097,11 +1113,11 @@ func @sort_wrong_block_arg_count(%input0: tensor<16x16xf32>, %input1: tensor<16x func @sort_wrong_block_arg_type(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { // expected-error @+1 {{op comparator block argument #3 should be of type 'tensor' but got 'tensor'}} - %0 = "mhlo.sort"(%input0, %input1) ( { + %0:2 = "mhlo.sort"(%input0, %input1) ( { ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): %7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor "mhlo.return"(%7) : (tensor) -> () - }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> + }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>) return } @@ -1185,3 +1201,24 @@ func @incompatible_shapes(%arg0: tensor, %shape: tensor<2xindex>) -> tens %0 = "mhlo.dynamic_reshape"(%arg0, %shape) : (tensor, tensor<2xindex>) -> tensor return %0 : tensor } + +// ----- + +func @cbrt(%arg: tensor<2x4xf32>) -> tensor<2x4xf32> { + %0 = "mhlo.cbrt"(%arg) : (tensor<2x4xf32>) -> tensor<2x4xf32> + return %0 : tensor<2x4xf32> +} + +// ----- + +func @bitcast(%arg: tensor<2x4xf32>) -> tensor<2x4xf32> { + %0 = "mhlo.bitcast"(%arg) : (tensor<2x4xf32>) -> tensor<2x4xf32> + return %0 : tensor<2x4xf32> +} + +// ----- + +func @bitcast(%arg: tensor<2x4xf32>) -> tensor<2x4xf32> { + %0 = "mhlo.reduce_precision"(%arg) {exponent_bits=2 : i32, mantissa_bits=3 : i32} : (tensor<2x4xf32>) -> tensor<2x4xf32> + return %0 : tensor<2x4xf32> +} diff --git a/tensorflow/compiler/mlir/hlo/tests/unfuse_batch_norm.mlir b/tensorflow/compiler/mlir/hlo/tests/unfuse_batch_norm.mlir index f903dbb7080..53ee94f8d1a 100644 --- a/tensorflow/compiler/mlir/hlo/tests/unfuse_batch_norm.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/unfuse_batch_norm.mlir @@ -109,7 +109,7 @@ func @batchNormInference_dynamic_shape( // CHECK-DAG: %[[C3:.*]] = constant 3 : index // CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.000000e-03> : tensor // CHECK-DAG: %[[DIM:.+]] = dim %[[VARIANCE]], %[[C0]] : tensor - // CHECK-DAG: %[[TO_DIM_TENSOR:.+]] = tensor_from_elements(%[[DIM]]) : tensor<1xindex> + // CHECK-DAG: %[[TO_DIM_TENSOR:.+]] = tensor_from_elements %[[DIM]] : tensor<1xindex> // CHECK-DAG: %[[EPS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[EPS]], %[[TO_DIM_TENSOR]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<1xindex>) -> tensor // CHECK-DAG: %[[VARIANCE_EPS:.+]] = mhlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor // CHECK-DAG: %[[STDDEV:.+]] = "mhlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor) -> tensor @@ -117,7 +117,7 @@ func @batchNormInference_dynamic_shape( // CHECK-DAG: %[[INPUT_DIM_1:.+]] = dim %[[X]], %[[C1]] : tensor // CHECK-DAG: %[[INPUT_DIM_2:.+]] = dim %[[X]], %[[C2]] : tensor // CHECK-DAG: %[[INPUT_DIM_3:.+]] = dim %[[X]], %[[C3]] : tensor - // CHECK-DAG: %[[TO_INPUT_DIM_TENSOR:.+]] = tensor_from_elements(%[[INPUT_DIM_0]], %[[INPUT_DIM_1]], %[[INPUT_DIM_2]], %[[INPUT_DIM_3]]) : tensor<4xindex> + // CHECK-DAG: %[[TO_INPUT_DIM_TENSOR:.+]] = tensor_from_elements %[[INPUT_DIM_0]], %[[INPUT_DIM_1]], %[[INPUT_DIM_2]], %[[INPUT_DIM_3]] : tensor<4xindex> // CHECK-DAG: %[[STDDEV_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[STDDEV]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xindex>) -> tensor // CHECK-DAG: %[[SCALE_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[SCALE]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xindex>) -> tensor // CHECK-DAG: %[[OFFSET_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[OFFSET]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xindex>) -> tensor diff --git a/tensorflow/compiler/mlir/hlo/tools/mlir-hlo-opt/mlir-hlo-opt.cpp b/tensorflow/compiler/mlir/hlo/tools/mlir-hlo-opt/mlir-hlo-opt.cpp index d0c0e3c51e1..ed96dd5ffd8 100644 --- a/tensorflow/compiler/mlir/hlo/tools/mlir-hlo-opt/mlir-hlo-opt.cpp +++ b/tensorflow/compiler/mlir/hlo/tools/mlir-hlo-opt/mlir-hlo-opt.cpp @@ -15,6 +15,7 @@ limitations under the License. #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "mlir-hlo/Dialect/mhlo/transforms/register_passes.h" #include "mlir/InitAllDialects.h" @@ -31,6 +32,7 @@ int main(int argc, char **argv) { registry.insert(); registry.insert(); registry.insert(); + registry.insert(); return failed( mlir::MlirOptMain(argc, argv, "MLIR HLO pass driver\n", registry)); diff --git a/tensorflow/compiler/mlir/init_mlir.cc b/tensorflow/compiler/mlir/init_mlir.cc index 54f8a57d8a6..fac9f51d8ba 100644 --- a/tensorflow/compiler/mlir/init_mlir.cc +++ b/tensorflow/compiler/mlir/init_mlir.cc @@ -20,6 +20,11 @@ limitations under the License. namespace tensorflow { InitMlir::InitMlir(int *argc, char ***argv) : init_llvm_(*argc, *argv) { + llvm::setBugReportMsg( + "TensorFlow crashed, please file a bug on " + "https://github.com/tensorflow/tensorflow/issues with the trace " + "below.\n"); + constexpr char kSeparator[] = "--"; // Find index of separator between two sets of flags. diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 2d3a58b5b9d..eff591895e1 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -1,3 +1,9 @@ +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "filegroup") + +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "get_compatible_with_cloud") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test", "tf_native_cc_binary") load( "//third_party/mlir:tblgen.bzl", @@ -37,6 +43,7 @@ filegroup( gentbl( name = "tensorflow_lite_ops_inc_gen", + compatible_with = get_compatible_with_cloud(), tbl_outs = [ ( "-gen-op-decls", @@ -68,6 +75,7 @@ gentbl( gentbl( name = "tensorflow_lite_op_interfaces_inc_gen", + compatible_with = get_compatible_with_cloud(), tbl_outs = [ ( "-gen-op-interface-decls", @@ -87,6 +95,7 @@ gentbl( gentbl( name = "tensorflow_lite_prepare_tf_inc_gen", + compatible_with = get_compatible_with_cloud(), tbl_outs = [ ( "-gen-rewriters", @@ -105,6 +114,7 @@ gentbl( gentbl( name = "tensorflow_lite_lower_static_tensor_list_inc_gen", + compatible_with = get_compatible_with_cloud(), tbl_outs = [ ( "-gen-rewriters", @@ -122,6 +132,7 @@ gentbl( gentbl( name = "tensorflow_lite_legalize_tf_inc_gen", + compatible_with = get_compatible_with_cloud(), tbl_outs = [ ( "-gen-rewriters", @@ -139,6 +150,7 @@ gentbl( gentbl( name = "tensorflow_lite_optimize_inc_gen", + compatible_with = get_compatible_with_cloud(), tbl_outs = [ ( "-gen-rewriters", @@ -157,6 +169,7 @@ gentbl( gentbl( name = "tensorflow_lite_quantize_inc_gen", + compatible_with = get_compatible_with_cloud(), tbl_outs = [ ( "-gen-rewriters", @@ -173,6 +186,7 @@ gentbl( gentbl( name = "tensorflow_lite_post_quantize_inc_gen", + compatible_with = get_compatible_with_cloud(), tbl_outs = [ ( "-gen-rewriters", @@ -280,6 +294,28 @@ cc_library( ], ) +cc_library( + name = "nms_utils", + srcs = [ + "utils/nms_utils.cc", + ], + hdrs = [ + "utils/nms_utils.h", + ], + copts = ["-std=c++14"], + deps = [ + ":tensorflow_lite", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_attributes", + "//tensorflow/core:framework", + "@flatbuffers", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", + ], +) + cc_library( name = "tftext_utils", srcs = [ @@ -373,6 +409,7 @@ cc_library( deps = [ ":constant_utils", ":lstm_utils", + ":nms_utils", ":stateful_ops_utils", ":tensorflow_lite", ":tftext_utils", @@ -384,6 +421,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:convert_tensor", "//tensorflow/compiler/mlir/tensorflow:mangling_util", "//tensorflow/compiler/mlir/tensorflow:tensorflow_attributes", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/compiler/mlir/tensorflow:tf_legalize_hlo", "//tensorflow/compiler/mlir/tensorflow:unroll_batch_matmul_pass", @@ -439,7 +477,6 @@ cc_library( "transforms/default_quant_params.cc", "transforms/generated_post_quantize.inc", "transforms/generated_quantize.inc", - "transforms/load_quantization_recipe.cc", "transforms/post_quantize.cc", "transforms/prepare_quantize.cc", "transforms/quantize.cc", @@ -498,6 +535,7 @@ filegroup( gentbl( name = "op_quant_spec_getters_inc", + compatible_with = get_compatible_with_cloud(), tbl_outs = [("", "utils/generated_op_quant_spec_getters.inc")], tblgen = "//tensorflow/compiler/mlir/lite/quantization:op_quant_spec_getters_gen", td_file = "ir/tfl_ops.td", @@ -509,19 +547,6 @@ gentbl( ], ) -# Library with tensorflow Lite dialect static initialization. -cc_library( - name = "tensorflow_lite_dialect_registration", - srcs = [ - "ir/dialect_registration.cc", - ], - deps = [ - ":tensorflow_lite", - "@llvm-project//mlir:IR", - ], - alwayslink = 1, -) - tf_native_cc_binary( name = "converter-gen", srcs = [ @@ -536,6 +561,7 @@ tf_native_cc_binary( gentbl( name = "converter_inc", + compatible_with = get_compatible_with_cloud(), tbl_outs = [ ( "--gen-operator-converters", @@ -628,12 +654,10 @@ cc_library( ":flatbuffer_tflite_operator_lib", ":stateful_ops_utils", ":tensorflow_lite", - ":tensorflow_lite_dialect_registration", "//tensorflow/compiler/mlir:op_or_arg_name_mapper", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:convert_tensor", "//tensorflow/compiler/mlir/tensorflow:export_tf_dialect_op", - "//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/compiler/xla:statusor", "//tensorflow/core:protos_all_cc", @@ -645,6 +669,7 @@ cc_library( "//tensorflow/lite/delegates/flex:allowlisted_flex_ops_lib", "//tensorflow/lite/kernels/internal:kernel_utils", "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/schema:schema_utils", "//tensorflow/lite/tools/versioning", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", @@ -672,7 +697,6 @@ cc_library( ":convert_type", ":flatbuffer_tflite_operator_lib", ":tensorflow_lite", - ":tensorflow_lite_dialect_registration", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:mangling_util", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", @@ -682,6 +706,7 @@ cc_library( "//tensorflow/core/platform:status", "//tensorflow/lite:framework", "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/schema:schema_utils", "@com_google_absl//absl/base", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -867,7 +892,6 @@ cc_library( "//tensorflow/compiler/mlir/lite/quantization:quantization_passes", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:decode_constant_pass", - "//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration", "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow:tf_graph_optimization_pass", "//tensorflow/compiler/mlir/tensorflow:tf_saved_model_passes", diff --git a/tensorflow/compiler/mlir/lite/converter_gen.cc b/tensorflow/compiler/mlir/lite/converter_gen.cc index edead2037a3..44eba0d5e6f 100644 --- a/tensorflow/compiler/mlir/lite/converter_gen.cc +++ b/tensorflow/compiler/mlir/lite/converter_gen.cc @@ -513,7 +513,7 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) { continue; } if (trait.getDef().getValueAsString("trait") != - "OpTrait::TFLRuntimeOpTrait") { + "::mlir::OpTrait::TFLRuntimeOpTrait") { continue; } diff --git a/tensorflow/compiler/mlir/lite/experimental/estimators/BUILD b/tensorflow/compiler/mlir/lite/experimental/estimators/BUILD index 373c95f6bf5..3b80b871790 100644 --- a/tensorflow/compiler/mlir/lite/experimental/estimators/BUILD +++ b/tensorflow/compiler/mlir/lite/experimental/estimators/BUILD @@ -1,3 +1,5 @@ +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") + package( default_visibility = [ "//visibility:public", diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc index 34200fb88b6..a98e83b7e1e 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc @@ -75,6 +75,7 @@ limitations under the License. #include "tensorflow/lite/delegates/flex/allowlisted_flex_ops.h" #include "tensorflow/lite/kernels/internal/kernel_utils.h" #include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/schema/schema_utils.h" #include "tensorflow/lite/string_util.h" #include "tensorflow/lite/tools/versioning/op_version.h" #include "tensorflow/lite/tools/versioning/runtime_version.h" @@ -325,6 +326,21 @@ static Optional GetTflitePoolParams(Operation* inst, namespace { +// Helper struct that wraps inputs/outputs of a single SignatureDef. +struct SignatureDefData { + // Note, we are using maps here to make order deterministic + // for easily testing only. + + // Inputs defined in the signature def mapped to tensor names. + std::map inputs; + // Outputs defined in the signature def mapped to tensor names. + std::map outputs; + // Method name exported by the signature def. + std::string method_name; + // SignatureDef key. + std::string signature_def_key; +}; + // Translates an MLIR module in TFLite dialect to TFLite FlatBuffer. class Translator { public: @@ -333,16 +349,19 @@ class Translator { // internal error. static Optional Translate( ModuleOp module, bool emit_builtin_tflite_ops, bool emit_select_tf_ops, - bool emit_custom_ops, OpOrArgNameMapper* op_or_arg_name_mapper); + bool emit_custom_ops, const std::unordered_set& tags, + OpOrArgNameMapper* op_or_arg_name_mapper); private: enum class OpType : char { kTfliteBuiltin, kSelectTf, kCustomOp }; explicit Translator(ModuleOp module, bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops, + const std::unordered_set& saved_model_tags, OpOrArgNameMapper* op_or_arg_name_mapper) : module_(module), name_mapper_(*op_or_arg_name_mapper), - builder_(kInitialBufferSize) { + builder_(kInitialBufferSize), + saved_model_tags_(saved_model_tags) { // The first buffer must be empty according to the schema definition. empty_buffer_ = tflite::CreateBuffer(builder_); buffers_.push_back(empty_buffer_); @@ -449,6 +468,17 @@ class Translator { Optional>> CreateMetadataVector(); + // Builds and returns list of tfl.SignatureDef sections in the model. + Optional>> + CreateSignatureDefs(const std::vector& signature_defs); + + // Returns list of offsets for the passed 'items' in TensorMap structure + // inside the flatbuffer. + // 'items' is a map from tensor name in signatureDef to tensor name in + // the model. + std::vector> GetList( + const std::map& items); + // Uses the tf.entry_function attribute (if set) to initialize the op to name // mapping. void InitializeNamesFromAttribute(FuncOp fn, bool* has_input_attr); @@ -471,6 +501,8 @@ class Translator { BufferOffset empty_buffer_; std::vector> buffers_; + // Maps tensor name in the graph to the tensor index. + absl::flat_hash_map tensor_index_map_; // Maps op name to index of the corresponding OperatorCode in opcodes_ vector. absl::flat_hash_map opcode_index_map_; @@ -489,6 +521,9 @@ class Translator { // The failed ops during legalization. std::set failed_flex_ops_; std::set failed_custom_ops_; + + // Set of saved model tags, if any. + const std::unordered_set saved_model_tags_; }; std::string Translator::UniqueName(mlir::Value val) { @@ -1130,6 +1165,7 @@ Optional> Translator::BuildSubGraph( } tensor_index_map.insert({value, tensors.size()}); + tensor_index_map_[name] = tensors.size(); auto tensor_or = BuildTensor(value, name, buffers_.size()); if (!tensor_or) return false; tensors.push_back(*tensor_or); @@ -1285,6 +1321,149 @@ Translator::CreateMetadataVector() { return builder_.CreateVector(metadata); } +// Helper method that returns list of all strings in a StringAttr identified +// by 'attr_key' and values are separated by a comma. +llvm::SmallVector GetStringsFromAttrWithSeparator( + mlir::DictionaryAttr attr, const std::string& attr_key) { + llvm::SmallVector result; + if (auto str = attr.get(attr_key).dyn_cast_or_null()) { + str.getValue().split(result, ',', /*MaxSplit=*/-1, + /*KeepEmpty=*/false); + } + return result; +} + +// Helper method that return list of string for all the StringAttr in the +// Attribute identified by 'attr_name'. +std::vector GetStringsFromDictionaryAttr( + const llvm::SmallVector& dict_attrs, + const std::string& attr_name) { + std::vector result; + for (const auto& arg_attr : dict_attrs) { + auto attrs = arg_attr.getAttrs(); + for (const auto attr : attrs) { + if (attr.first.str() == attr_name) { + auto array_attr = attr.second.dyn_cast_or_null(); + if (!array_attr || array_attr.empty()) continue; + auto string_attr = array_attr[0].dyn_cast_or_null(); + if (!string_attr) continue; + result.push_back(string_attr.getValue().str()); + } + } + } + return result; +} + +std::vector BuildSignaturedef( + FuncOp main_op, const std::string& saved_model_tag) { + static const char kSignatureDefIndexPath[] = "tf_saved_model.index_path"; + static const char kEntryFunctionAttributes[] = "tf.entry_function"; + + // Fetch inputs and outputs from the signature. + llvm::SmallVector arg_attrs, res_attrs; + main_op.getAllArgAttrs(arg_attrs); + main_op.getAllResultAttrs(res_attrs); + std::vector sig_def_inputs = + GetStringsFromDictionaryAttr(arg_attrs, kSignatureDefIndexPath); + std::vector sig_def_outputs = + GetStringsFromDictionaryAttr(res_attrs, kSignatureDefIndexPath); + + // If no defined saved model signature, then return empty list. + // This can happen when we are converting model not from SavedModel. + if (sig_def_inputs.empty() || sig_def_outputs.empty()) return {}; + + // Fetch function inputs and outputs tensor names. + auto dict_attr = + main_op.getAttrOfType(kEntryFunctionAttributes); + if (!dict_attr) return {}; + + // Get Input and output tensor names from attribute. + llvm::SmallVector input_names = + GetStringsFromAttrWithSeparator(dict_attr, /*attr_key=*/"inputs"); + llvm::SmallVector output_names = + GetStringsFromAttrWithSeparator(dict_attr, /*attr_key=*/"outputs"); + + // Verify input size match the number of arguments. + if (input_names.size() != main_op.getNumArguments()) { + main_op.emitWarning() << "invalid entry function specification"; + return {}; + } + // Verify output size match the number of arguments. + auto term = main_op.back().getTerminator(); + if (output_names.size() != term->getNumOperands()) { + main_op.emitWarning() << "output names (" << output_names.size() + << ") != terminator operands (" + << term->getNumOperands() << ")"; + return {}; + } + // Verify number of tensors for inputs and outputs matches size + // of the list in the signature def. + if (input_names.size() != sig_def_inputs.size() || + output_names.size() != sig_def_outputs.size()) { + main_op.emitWarning( + "Mismatch between signature def inputs/outputs and main function " + "arguments."); + return {}; + } + // Exported method name. + auto exported_name = + main_op.getAttrOfType("tf_saved_model.exported_names"); + if (exported_name.empty()) { + main_op.emitError("Empty exported names for main Function"); + return {}; + } + // Fill the SignatureDefData container. + // We create vector of size 1 as TFLite now supports only 1 signatureDef. + std::vector result(1); + for (int i = 0; i < input_names.size(); ++i) { + result[0].inputs[sig_def_inputs[i]] = input_names[i].str(); + } + for (int i = 0; i < output_names.size(); ++i) { + result[0].outputs[sig_def_outputs[i]] = output_names[i].str(); + } + if (auto name_attr = exported_name[0].dyn_cast_or_null()) + result[0].method_name = name_attr.getValue().str(); + result[0].signature_def_key = saved_model_tag; + return result; +} + +std::vector> Translator::GetList( + const std::map& items) { + std::vector> result; + for (const auto& item : items) { + auto name_buf = builder_.CreateString(item.first); + tflite::TensorMapBuilder tensor_map_builder(builder_); + tensor_map_builder.add_name(name_buf); + tensor_map_builder.add_tensor_index(tensor_index_map_[item.second]); + result.push_back(tensor_map_builder.Finish()); + } + return result; +} + +Optional>> +Translator::CreateSignatureDefs( + const std::vector& signature_defs) { + std::vector> signature_defs_buffer; + for (const auto& signature_def_data : signature_defs) { + auto inputs = GetList(signature_def_data.inputs); + auto outputs = GetList(signature_def_data.outputs); + auto inputs_buf = builder_.CreateVector(inputs); + auto outputs_buf = builder_.CreateVector(outputs); + auto method_name_buf = + builder_.CreateString(signature_def_data.method_name); + auto signature_def_key_buf = + builder_.CreateString(signature_def_data.signature_def_key); + tflite::SignatureDefBuilder sig_def_builder(builder_); + sig_def_builder.add_inputs(inputs_buf); + sig_def_builder.add_outputs(outputs_buf); + sig_def_builder.add_method_name(method_name_buf); + sig_def_builder.add_key(signature_def_key_buf); + signature_defs_buffer.push_back(sig_def_builder.Finish()); + } + + return builder_.CreateVector(signature_defs_buffer); +} + bool UpdateEntryFunction(ModuleOp module) { if (module.lookupSymbol("main") != nullptr) { // We already have an entry function. @@ -1311,11 +1490,12 @@ bool UpdateEntryFunction(ModuleOp module) { Optional Translator::Translate( ModuleOp module, bool emit_builtin_tflite_ops, bool emit_select_tf_ops, - bool emit_custom_ops, OpOrArgNameMapper* op_or_arg_name_mapper) { + bool emit_custom_ops, const std::unordered_set& tags, + OpOrArgNameMapper* op_or_arg_name_mapper) { if (!UpdateEntryFunction(module)) return llvm::None; if (!IsValidTFLiteMlirModule(module)) return llvm::None; Translator translator(module, emit_builtin_tflite_ops, emit_select_tf_ops, - emit_custom_ops, op_or_arg_name_mapper); + emit_custom_ops, tags, op_or_arg_name_mapper); return translator.TranslateInternal(); } @@ -1391,10 +1571,17 @@ Optional Translator::TranslateInternal() { auto metadata = CreateMetadataVector(); if (!metadata) return llvm::None; - auto model = tflite::CreateModel( - builder_, TFLITE_SCHEMA_VERSION, builder_.CreateVector(opcodes_), - builder_.CreateVector(subgraphs), description, - builder_.CreateVector(buffers_), metadata_buffer, *metadata); + // Build SignatureDef + // We only have 1 entry point 'main' function, so build only 1 signature def. + auto main_fn_signature_def = BuildSignaturedef( + main_fn, saved_model_tags_.empty() ? "" : *saved_model_tags_.begin()); + auto signature_defs = CreateSignatureDefs(main_fn_signature_def); + + auto model = tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION, + builder_.CreateVector(opcodes_), + builder_.CreateVector(subgraphs), + description, builder_.CreateVector(buffers_), + metadata_buffer, *metadata, *signature_defs); tflite::FinishModelBuffer(builder_, model); tflite::UpdateOpVersion(builder_.GetBufferPointer()); tflite::UpdateMinimumRuntimeVersionForModel(builder_.GetBufferPointer()); @@ -1518,12 +1705,10 @@ bool tflite::MlirToFlatBufferTranslateFunction( ModuleOp module, std::string* serialized_flatbuffer, bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops, OpOrArgNameMapper* op_or_arg_name_mapper) { - auto maybe_translated = - Translator::Translate(module, emit_builtin_tflite_ops, emit_select_tf_ops, - emit_custom_ops, op_or_arg_name_mapper); - if (!maybe_translated) return true; - *serialized_flatbuffer = std::move(*maybe_translated); - return false; + return MlirToFlatBufferTranslateFunction( + module, serialized_flatbuffer, emit_builtin_tflite_ops, + emit_select_tf_ops, emit_custom_ops, /*saved_model_tags=*/{}, + op_or_arg_name_mapper); } bool tflite::MlirToFlatBufferTranslateFunction( @@ -1533,5 +1718,30 @@ bool tflite::MlirToFlatBufferTranslateFunction( OpOrArgLocNameMapper op_or_arg_name_mapper; return MlirToFlatBufferTranslateFunction( module, serialized_flatbuffer, emit_builtin_tflite_ops, - emit_select_tf_ops, emit_custom_ops, &op_or_arg_name_mapper); + emit_select_tf_ops, emit_custom_ops, /*saved_model_tags=*/{}, + &op_or_arg_name_mapper); +} + +bool tflite::MlirToFlatBufferTranslateFunction( + mlir::ModuleOp module, std::string* serialized_flatbuffer, + bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops, + const std::unordered_set& saved_model_tags) { + OpOrArgLocNameMapper op_or_arg_name_mapper; + return MlirToFlatBufferTranslateFunction( + module, serialized_flatbuffer, emit_builtin_tflite_ops, + emit_select_tf_ops, emit_custom_ops, saved_model_tags, + &op_or_arg_name_mapper); +} + +bool tflite::MlirToFlatBufferTranslateFunction( + mlir::ModuleOp module, std::string* serialized_flatbuffer, + bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops, + const std::unordered_set& saved_model_tags, + OpOrArgNameMapper* op_or_arg_name_mapper) { + auto maybe_translated = Translator::Translate( + module, emit_builtin_tflite_ops, emit_select_tf_ops, emit_custom_ops, + saved_model_tags, op_or_arg_name_mapper); + if (!maybe_translated) return true; + *serialized_flatbuffer = std::move(*maybe_translated); + return false; } diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.h b/tensorflow/compiler/mlir/lite/flatbuffer_export.h index 0fbf2f07dfb..0888d2a4a41 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_export.h +++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_H_ #include +#include #include "mlir/IR/Module.h" // from @llvm-project #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" @@ -33,11 +34,24 @@ bool MlirToFlatBufferTranslateFunction(mlir::ModuleOp module, bool emit_select_tf_ops, bool emit_custom_ops); +// Same as above but takes SavedModel tags of the model. +bool MlirToFlatBufferTranslateFunction( + mlir::ModuleOp module, std::string* serialized_flatbuffer, + bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops, + const std::unordered_set& saved_model_tags); + // Same as the above but with a custom op name mapper. bool MlirToFlatBufferTranslateFunction( mlir::ModuleOp module, std::string* serialized_flatbuffer, bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops, tensorflow::OpOrArgNameMapper* op_or_arg_name_mapper); + +// Same as above but takes SavedModel tags of the model. +bool MlirToFlatBufferTranslateFunction( + mlir::ModuleOp module, std::string* serialized_flatbuffer, + bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops, + const std::unordered_set& saved_model_tags, + tensorflow::OpOrArgNameMapper* op_or_arg_name_mapper); } // namespace tflite #endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_H_ diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc index 230383729c4..7d64e268063 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc @@ -75,6 +75,7 @@ limitations under the License. #include "tensorflow/core/platform/status.h" #include "tensorflow/lite/model.h" #include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/schema/schema_utils.h" using llvm::ArrayRef; using mlir::Builder; @@ -271,18 +272,18 @@ StatusOr GetMlirOpName(const tflite::OperatorT& op, return std::string("tfl.basic_lstm"); } - if (op_code.builtin_code == tflite::BuiltinOperator_CUSTOM) { + auto builtin_code = tflite::GetBuiltinCode(&op_code); + if (builtin_code == tflite::BuiltinOperator_CUSTOM) { return std::string("tfl.custom"); } - if (op_code.builtin_code == tflite::BuiltinOperator_IF) { + if (builtin_code == tflite::BuiltinOperator_IF) { return std::string("tf.If"); } - if (op_code.builtin_code == tflite::BuiltinOperator_WHILE) { + if (builtin_code == tflite::BuiltinOperator_WHILE) { return std::string("tf.While"); } - llvm::StringRef op_name( - tflite::EnumNameBuiltinOperator(op_code.builtin_code)); + llvm::StringRef op_name(tflite::EnumNameBuiltinOperator(builtin_code)); return llvm::Twine("tfl.", op_name.lower()).str(); } @@ -637,7 +638,8 @@ StatusOr ConvertOp( } llvm::SmallVector attrs; - if (op_code.builtin_code == tflite::BuiltinOperator_CUSTOM) { + auto builtin_code = tflite::GetBuiltinCode(&op_code); + if (builtin_code == tflite::BuiltinOperator_CUSTOM) { auto status = mlir::CustomOptionsToAttributes( op_code.custom_code, op.custom_options, builder, loc, &attrs); if (!status.ok()) { @@ -784,7 +786,7 @@ static StatusOr PostProcessFuncOp(FuncOp func) { auto new_output_type = new_qtype.castFromExpressedType( mlir::quant::UniformQuantizedType::castToExpressedType( value.getType())); - builder.setInsertionPointAfter(cst); + builder.setInsertionPointAfter(cst.getOperation()); auto new_op = builder.create( cst.getLoc(), new_output_type, mlir::TypeAttr::get(new_output_type), cst.valueAttr()); diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc b/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc index 5accb419e83..60fd1160be2 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc @@ -127,12 +127,12 @@ static tflite::TensorType ConvertDerivedTypeAttrForOptionWriter( // I32Attr already returns an int as required by flatbuffer builders. static int ConvertI32AttrForOptionWriter( - llvm::APInt i, flatbuffers::FlatBufferBuilder* builder) { - return i.getSExtValue(); + int i, flatbuffers::FlatBufferBuilder* builder) { + return i; } static int ConvertPositiveI32AttrForOptionWriter( - llvm::APInt i, flatbuffers::FlatBufferBuilder* builder) { + int i, flatbuffers::FlatBufferBuilder* builder) { return ConvertI32AttrForOptionWriter(i, builder); } diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc index 403b3dd18ad..2894af9b97e 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc @@ -41,10 +41,10 @@ limitations under the License. #include "mlir/Transforms/FoldUtils.h" // from @llvm-project #include "mlir/Transforms/InliningUtils.h" // from @llvm-project #include "mlir/Transforms/RegionUtils.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/ir/tfl_structs.cc.inc" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" namespace mlir { -#include "tensorflow/compiler/mlir/lite/ir/tfl_structs.cc.inc" namespace TFL { // Returns true when the given operand arguments have the same shape or @@ -569,7 +569,7 @@ namespace { int64_t GetConcatenationOpAxis(ConcatenationOp op) { auto output_type = op.output().getType().cast(); - int64_t axis = op.axis().getSExtValue(); + int32_t axis = op.axis(); if (axis < 0) axis += output_type.getRank(); return axis; } @@ -1027,13 +1027,13 @@ static LogicalResult Verify(PackOp op) { // Check axis bounds. if (input_type.hasRank()) { - int64_t axis_value = op.axis().getSExtValue(); + int32_t axis_value = op.axis(); if (axis_value < 0) axis_value += input_type.getRank() + 1; if (axis_value < 0 || axis_value >= input_type.getRank() + 1) return op.emitOpError() << "op attribute 'axis' should be in range [-rank - 1, rank + 1), " << "got rank = " << input_type.getRank() - << ", and axis = " << op.axis().getSExtValue(); + << ", and axis = " << op.axis(); } // Make sure all inputs have the same shape and element type. @@ -1545,7 +1545,7 @@ static LogicalResult VerifySplitOpOutputTypes( } static LogicalResult Verify(SplitOp op) { - int64_t num_splits = op.num_splits().getSExtValue(); + int64_t num_splits = op.num_splits(); if (op.getNumResults() != num_splits) return op.emitOpError("output count should match 'num_splits' attribute"); @@ -1581,7 +1581,7 @@ static LogicalResult Verify(SplitOp op) { } static LogicalResult Verify(SplitVOp op) { - int64_t num_splits = op.num_splits().getSExtValue(); + int64_t num_splits = op.num_splits(); if (op.getNumResults() != num_splits) return op.emitOpError("output count should match 'num_splits' attribute"); @@ -2377,8 +2377,16 @@ LogicalResult WhileOp::moveOutOfLoop(llvm::ArrayRef ops) { //===----------------------------------------------------------------------===// #include "tensorflow/compiler/mlir/lite/ir/tfl_ops_interface.cc.inc" + +} // namespace TFL +} // namespace mlir + #define GET_OP_CLASSES #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.cc.inc" + +namespace mlir { +namespace TFL { + #include "tensorflow/compiler/mlir/lite/runtime_verifiers.inc" Operation *TensorFlowLiteDialect::materializeConstant(OpBuilder &builder, diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.h b/tensorflow/compiler/mlir/lite/ir/tfl_ops.h index d2d8442155b..589f18d789d 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.h +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.h @@ -30,11 +30,11 @@ limitations under the License. #include "mlir/Interfaces/LoopLikeInterface.h" // from @llvm-project #include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/ir/tfl_structs.h.inc" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/lite/schema/schema_generated.h" namespace mlir { -#include "tensorflow/compiler/mlir/lite/ir/tfl_structs.h.inc" namespace TFL { class TensorFlowLiteDialect : public Dialect { @@ -50,10 +50,11 @@ class TensorFlowLiteDialect : public Dialect { }; #include "tensorflow/compiler/mlir/lite/ir/tfl_ops_interface.h.inc" -#define GET_OP_CLASSES -#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h.inc" } // end namespace TFL } // end namespace mlir +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h.inc" + #endif // TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_OPS_H_ diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index f1cdfec631d..f7ee323957d 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -39,7 +39,7 @@ def TFL_Dialect : Dialect { represented using zero-dimensional tensors); }]; - let cppNamespace = "TFL"; + let cppNamespace = "::mlir::TFL"; } //===----------------------------------------------------------------------===// @@ -385,28 +385,27 @@ def BinaryOpSameElementTypeConstraint : //===----------------------------------------------------------------------===// def TFL_BroadcastableBinaryBuilder : OpBuilder< - "OpBuilder &builder, OperationState &result, Value lhs, Value rhs", + "Value lhs, Value rhs", [{ auto resultType = OpTrait::util::getBroadcastedType(lhs.getType(), rhs.getType()); if (!resultType) - mlir::emitError(result.location, "non-broadcastable operands"); - result.addOperands({lhs, rhs}); - result.types.push_back(resultType); + mlir::emitError($_state.location, "non-broadcastable operands"); + $_state.addOperands({lhs, rhs}); + $_state.types.push_back(resultType); }]>; def TFL_FusedBroadcastableBinaryBuilder : OpBuilder< - "OpBuilder &builder, OperationState &result, Value lhs, Value rhs, " - "StringAttr fusedActivationFunction", + "Value lhs, Value rhs, StringAttr fusedActivationFunction", [{ buildFusedBroadcastableBinOp( - &builder, result, lhs, rhs, fusedActivationFunction); + &$_builder, $_state, lhs, rhs, fusedActivationFunction); }]>; def TFL_ComparisonBinaryBuilder : OpBuilder< - "OpBuilder &builder, OperationState &result, Value lhs, Value rhs", + "Value lhs, Value rhs", [{ - buildComparisonBinOp(&builder, result, lhs, rhs); + buildComparisonBinOp(&$_builder, $_state, lhs, rhs); }]>; //===----------------------------------------------------------------------===// @@ -520,7 +519,11 @@ def TFL_AddOp : TFL_Op<"add", [ let hasOptions = 1; } -def TFL_AddNOp : TFL_Op<"add_n", [Commutative, NoSideEffect, SameOperandsAndResultsScale]> { +def TFL_AddNOp : TFL_Op<"add_n", [ + Commutative, + NoSideEffect, + SameOperandsAndResultsScale, + NoQuantizableResult]> { let summary = "add_n operator"; let description = [{ @@ -536,7 +539,9 @@ def TFL_AddNOp : TFL_Op<"add_n", [Commutative, NoSideEffect, SameOperandsAndResu ); } -def TFL_ReduceAnyOp : TFL_Op<"reduce_any", [NoSideEffect]> { +def TFL_ReduceAnyOp : TFL_Op<"reduce_any", [ + NoSideEffect, + NoQuantizableResult]> { let summary = [{ Computes the "logical or" of elements across dimensions of a tensor. }]; @@ -693,7 +698,8 @@ def TFL_ArgMinOp : TFL_Op<"arg_min", [NoSideEffect]> { def TFL_CeilOp: TFL_Op<"ceil", [ NoSideEffect, SameOperandsAndResultShape, - SameOperandsAndResultType]> { + SameOperandsAndResultType, + NoQuantizableResult]> { let summary = "Ceil operator"; let description = [{ @@ -720,14 +726,14 @@ def TFL_ConcatenationOp : TFL_Op<"concatenation", let arguments = ( ins TFL_VariadicTensorOf< - [F32, I64, I32, I16, I8, QI8, QUI8, UI8]>:$values, + [F32, I64, I32, I16, I8, QI8, QUI8, UI8, I1]>:$values, I32Attr:$axis, TFL_AFAttr:$fused_activation_function ); let results = (outs TFL_TensorOf< - [F32, I64, I32, I16, I8, QI8, QUI8, UI8]>:$output + [F32, I64, I32, I16, I8, QI8, QUI8, UI8, I1]>:$output ); let hasOptions = 1; @@ -765,10 +771,10 @@ def TFL_ConstOp : Op ]; } @@ -817,13 +823,12 @@ def TFL_SparseConstOp : Op ]; } @@ -889,7 +894,7 @@ def TFL_DepthwiseConv2DOp : let extraClassDeclaration = [{ // AffineQuantizedOpInterface: int GetChannelDimIndex() { return 3; } - int GetQuantizationDimIndex() { return 3; } + int GetQuantizationDimIndex() { return 3; } // SparseOpInterface: std::vector GetSparseOperands() { return {1}; } std::vector> GetFloatBlockSize() { return {}; } @@ -1002,9 +1007,8 @@ def TFL_GatherOp : TFL_Op<"gather", [ let builders = [ - OpBuilder<"OpBuilder &builder, OperationState &result, " - "Value params, Value indices, IntegerAttr axis", - [{ BuildGatherOp(&builder, result, params, indices, axis); }]> + OpBuilder<"Value params, Value indices, IntegerAttr axis", + [{ BuildGatherOp(&$_builder, $_state, params, indices, axis); }]> ]; let results = (outs @@ -1093,7 +1097,8 @@ def TFL_LocalResponseNormalizationOp : TFL_Op<"local_response_normalization", [ TFL_OperandHasRank<0, 4>, SameOperandsAndResultShape, SameOperandsAndResultType, - NoSideEffect]> { + NoSideEffect, + NoQuantizableResult]> { let summary = "Local Response Normalization."; let description = [{ @@ -1220,7 +1225,8 @@ def TFL_NonMaxSuppressionV4Op : TFL_Op<"non_max_suppression_v4", [ TFL_OperandHasRank<1, 1>, // Other operands are scalar params. TFL_OperandHasRank<2, 0>, TFL_OperandHasRank<3, 0>, - TFL_OperandHasRank<4, 0>]> { + TFL_OperandHasRank<4, 0>, + NoQuantizableResult]> { let summary = [{ Greedily selects a subset of bounding boxes in descending order of score, }]; @@ -1269,7 +1275,8 @@ def TFL_NonMaxSuppressionV5Op : TFL_Op<"non_max_suppression_v5", [ TFL_OperandHasRank<1, 1>, // Other operands are scalar params. TFL_OperandHasRank<2, 0>, TFL_OperandHasRank<3, 0>, - TFL_OperandHasRank<4, 0>, TFL_OperandHasRank<5, 0>]> { + TFL_OperandHasRank<4, 0>, TFL_OperandHasRank<5, 0>, + NoQuantizableResult]> { let summary = [{ Greedily selects a subset of bounding boxes in descending order of score, }]; @@ -1336,10 +1343,9 @@ def TFL_NotEqualOp : TFL_Op<"not_equal", [ let builders = [ - OpBuilder< - "OpBuilder &builder, OperationState &result, Value lhs, Value rhs", + OpBuilder<"Value lhs, Value rhs", [{ - buildComparisonBinOp(&builder, result, lhs, rhs); + buildComparisonBinOp(&$_builder, $_state, lhs, rhs); }]> ]; @@ -1383,7 +1389,8 @@ def TFL_DivOp : TFL_Op<"div", [ def TFL_EluOp: TFL_Op<"elu", [ NoSideEffect, SameOperandsAndResultShape, - SameOperandsAndResultType]> { + SameOperandsAndResultType, + NoQuantizableResult]> { let summary = "Exponential Linear Unit operator"; let description = [{ Computes the exponential linear @@ -1441,8 +1448,10 @@ def TFL_EqualOp: TFL_Op<"equal", [ let builders = [TFL_ComparisonBinaryBuilder]; } -def TFL_ExpOp: TFL_Op<"exp", [NoSideEffect, - SameOperandsAndResultType]> { +def TFL_ExpOp: TFL_Op<"exp", [ + NoSideEffect, + SameOperandsAndResultType, + NoQuantizableResult]> { let summary = "Natural exponentiation operator"; let description = [{ @@ -1546,7 +1555,8 @@ shape(squeeze(t, [2, 4])) ==> [1, 2, 3, 1] def TFL_FillOp: TFL_Op<"fill", [ NoSideEffect, PredOpTrait<"input and result must have same element type", - TFL_TCresVTEtIsSameAsOp<0, 1>>]> { + TFL_TCresVTEtIsSameAsOp<0, 1>>, + NoQuantizableResult]> { let summary = "Fill the tensor with given value."; let description = [{ Fill the tensor with given value. @@ -1563,7 +1573,8 @@ def TFL_FillOp: TFL_Op<"fill", [ def TFL_FloorOp: TFL_Op<"floor", [ NoSideEffect, SameOperandsAndResultShape, - SameOperandsAndResultType]> { + SameOperandsAndResultType, + NoQuantizableResult]> { let summary = "Floor operator"; let description = [{ @@ -1581,7 +1592,8 @@ def TFL_FloorDivOp : TFL_Op<"floor_div", [ BinaryOpSameElementTypeConstraint, PredOpTrait<"lhs and output must have same element type", TFL_TCresVTEtIsSameAsOp<0, 0>>, - TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>]> { + TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>, + NoQuantizableResult]> { let summary = "Floor div operator"; let description = [{ @@ -1606,7 +1618,8 @@ def TFL_FloorModOp : TFL_Op<"floor_mod", [ BinaryOpSameElementTypeConstraint, PredOpTrait<"lhs and output must have same element type", TFL_TCresVTEtIsSameAsOp<0, 0>>, - TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>]> { + TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>, + NoQuantizableResult]> { let summary = "Division reminder"; let description = [{ @@ -1745,7 +1758,9 @@ def TFL_LessOp : TFL_Op<"less", [ let printer = [{ return mlir::impl::printOneResultOp(getOperation(), p); }]; } -def TFL_LogicalAndOp : TFL_Op<"logical_and", [NoSideEffect]> { +def TFL_LogicalAndOp : TFL_Op<"logical_and", [ + NoSideEffect, + NoQuantizableResult]> { let summary = "Logical AND operator"; let description = [{ @@ -1778,7 +1793,9 @@ def TFL_LogicalNotOp : TFL_Op<"logical_not", [ let results = (outs TFL_BoolTensor:$output); } -def TFL_LogicalOrOp : TFL_Op<"logical_or", [NoSideEffect]> { +def TFL_LogicalOrOp : TFL_Op<"logical_or", [ + NoSideEffect, + NoQuantizableResult]> { let summary = "Logical OR operator"; let description = [{ @@ -2005,7 +2022,8 @@ def TFL_OneHotOp : TFL_Op<"one_hot", [NoSideEffect]> { def TFL_RoundOp: TFL_Op<"round", [ NoSideEffect, SameOperandsAndResultShape, - SameOperandsAndResultType]> { + SameOperandsAndResultType, + NoQuantizableResult]> { let summary = "Round operator"; let description = [{ @@ -2213,7 +2231,8 @@ def TFL_MulOp : TFL_Op<"mul", [ def TFL_NegOp: TFL_Op<"neg", [ NoSideEffect, SameOperandsAndResultShape, - SameOperandsAndResultType]> { + SameOperandsAndResultType, + NoQuantizableResult]> { let summary = "Negation operator"; let description = [{ @@ -2447,8 +2466,7 @@ def TFL_ReluOp: TFL_Op<"relu", [ PredOpTrait<"x and y must have same element type", TFL_TCresVTEtIsSameAsOp<0, 0>>, NoSideEffect, - SameOperandsAndResultShape, - SameOperandsAndResultsScale]> { + SameOperandsAndResultShape]> { let summary = "Relu operator"; let description = [{ @@ -2463,11 +2481,10 @@ def TFL_ReluOp: TFL_Op<"relu", [ // This builder doesn't work with quantized type, so it can only be used by // non-quantization tablegen patterns. Currently, it is used by the // elementwise-move reordering pattern in the optimize_patterns.td - let builders = [OpBuilder< - "OpBuilder &, OperationState &state, Value input", + let builders = [OpBuilder<"Value input", [{ - state.addOperands({input}); - state.addTypes(input.getType()); + $_state.addOperands({input}); + $_state.addTypes(input.getType()); }]> ]; } @@ -2476,8 +2493,7 @@ def TFL_Relu6Op: TFL_Op<"relu6", [ PredOpTrait<"x and y must have same element type", TFL_TCresVTEtIsSameAsOp<0, 0>>, NoSideEffect, - SameOperandsAndResultShape, - SameOperandsAndResultsScale]> { + SameOperandsAndResultShape]> { let summary = "Relu6 operator"; let description = [{ @@ -2492,11 +2508,10 @@ def TFL_Relu6Op: TFL_Op<"relu6", [ // This builder doesn't work with quantized type, so it can only be used by // non-quantization tablegen patterns. Currently, it is used by the // elementwise-move reordering pattern in the optimize_patterns.td - let builders = [OpBuilder< - "OpBuilder &, OperationState &state, Value input", + let builders = [OpBuilder<"Value input", [{ - state.addOperands({input}); - state.addTypes(input.getType()); + $_state.addOperands({input}); + $_state.addTypes(input.getType()); }]> ]; } @@ -2505,8 +2520,7 @@ def TFL_Relu1Op: TFL_Op<"relu_n1_to_1", [ PredOpTrait<"x and y must have same element type", TFL_TCresVTEtIsSameAsOp<0, 0>>, NoSideEffect, - SameOperandsAndResultShape, - SameOperandsAndResultsScale]> { + SameOperandsAndResultShape]> { let summary = "Relu1 operator"; let description = [{ @@ -2522,10 +2536,10 @@ def TFL_Relu1Op: TFL_Op<"relu_n1_to_1", [ // non-quantization tablegen patterns. Currently, it is used by the // elementwise-move reordering pattern in the optimize_patterns.td let builders = [OpBuilder< - "OpBuilder &, OperationState &state, Value input", + "Value input", [{ - state.addOperands({input}); - state.addTypes(input.getType()); + $_state.addOperands({input}); + $_state.addTypes(input.getType()); }]> ]; } @@ -2625,7 +2639,8 @@ def TFL_RangeOp: TFL_Op<"range", [ TFL_OperandHasRank<2, 0>, PredOpTrait<"operands and output must have same element type", And<[TCresVTEtIsSameAsOp<0, 0>, TCresVTEtIsSameAsOp<0, 1>, - TCresVTEtIsSameAsOp<0, 2>]>>]> { + TCresVTEtIsSameAsOp<0, 2>]>>, + NoQuantizableResult]> { let summary = "Range operator"; let description = [{ @@ -2704,12 +2719,11 @@ def TFL_SelectOp : TFL_Op<"select", [ TFL_TensorOf<[F32, I1, I8, I16, I32, I64, QI8, QUI8, QI16, TFL_Quint8]>:$output); // TODO(jpienaar): autogenerate this. - let builders = [OpBuilder<"OpBuilder &builder, OperationState &result, " - "Value condition, Value x, Value y", + let builders = [OpBuilder<"Value condition, Value x, Value y", [{ auto resultType = x.getType(); - result.addOperands({condition, x, y}); - result.types.push_back(resultType); + $_state.addOperands({condition, x, y}); + $_state.types.push_back(resultType); }]>]; let hasOptions = 1; @@ -2740,10 +2754,9 @@ def TFL_SelectV2Op : TFL_Op<"select_v2", [ let results = (outs TFL_TensorOf<[F32, I1, I8, I16, I32, I64, QI8, QUI8, QI16, TFL_Quint8]>:$output); - let builders = [OpBuilder<"OpBuilder &builder, OperationState &result, " - "Value cond, Value x, Value y", + let builders = [OpBuilder<"Value cond, Value x, Value y", [{ - BuildSelectV2Op(&builder, result, cond, x, y); + BuildSelectV2Op(&$_builder, $_state, cond, x, y); }]>]; let hasOptions = 1; @@ -2918,11 +2931,10 @@ def TFL_TanhOp: TFL_Op<"tanh", [ // This builder doesn't work with quantized type, so it can only be used by // non-quantization tablegen patterns. Currently, it is used by the // elementwise-move reordering pattern in the optimize_patterns.td - let builders = [OpBuilder< - "OpBuilder &, OperationState &state, Value input", + let builders = [OpBuilder<"Value input", [{ - state.addOperands({input}); - state.addTypes(input.getType()); + $_state.addOperands({input}); + $_state.addTypes(input.getType()); }]> ]; @@ -2992,9 +3004,8 @@ def TFL_TopKV2Op: TFL_Op<"topk_v2", [ TFL_TensorOf<[F32, I8, I32, I64, UI8, QI8, QUI8]>:$values, TFL_I32Tensor:$indices); - let builders = [OpBuilder<"OpBuilder &builder, OperationState &result, " - "Value input, Value k", - [{ BuildTopKOp(&builder, result, input, k); }]>]; + let builders = [OpBuilder<"Value input, Value k", + [{ BuildTopKOp(&$_builder, $_state, input, k); }]>]; let hasOptions = 1; } @@ -3069,7 +3080,8 @@ def TFL_ZerosLikeOp: TFL_Op<"zeros_like", [ TFL_TCresVTEtIsSameAsOp<0, 0>>, SameOperandsAndResultType, SameOperandsAndResultShape, - NoSideEffect]> { + NoSideEffect, + NoQuantizableResult]> { let summary = "ZerosLike operator"; let description = [{ @@ -3526,11 +3538,11 @@ def TFL_QConstOp : Op:$output); let builders = [OpBuilder< - "OpBuilder &, OperationState &state, TypeAttr qtype, Attribute value", + "TypeAttr qtype, Attribute value", [{ - state.addAttribute("qtype", qtype); - state.addAttribute("value", value); - state.addTypes(qtype.getValue()); + $_state.addAttribute("qtype", qtype); + $_state.addAttribute("value", value); + $_state.addTypes(qtype.getValue()); }]> ]; } @@ -3555,14 +3567,14 @@ def TFL_SparseQConstOp : Op:$output); let builders = [OpBuilder< - "OpBuilder &, OperationState &state, TypeAttr qtype, " - "Attribute value, SparsityParameterAttr s_param, Attribute compressed_data", + "TypeAttr qtype, Attribute value, SparsityParameterAttr s_param, " + "Attribute compressed_data", [{ - state.addTypes(qtype.getValue()); - state.addAttribute("qtype", qtype); - state.addAttribute("value", value); - state.addAttribute("s_param", s_param); - state.addAttribute("compressed_data", compressed_data); + $_state.addTypes(qtype.getValue()); + $_state.addAttribute("qtype", qtype); + $_state.addAttribute("value", value); + $_state.addAttribute("s_param", s_param); + $_state.addAttribute("compressed_data", compressed_data); }]> ]; } @@ -4243,7 +4255,8 @@ def TFL_SVDFOp : def TFL_SegmentSumOp: TFL_Op<"segment_sum", [ NoSideEffect, PredOpTrait<"input and output must have same element type", - TFL_TCresVTEtIsSameAsOp<0, 0>>]> { + TFL_TCresVTEtIsSameAsOp<0, 0>>, + NoQuantizableResult]> { let summary = "SegmentSum operator"; let description = [{ diff --git a/tensorflow/compiler/mlir/lite/python/BUILD b/tensorflow/compiler/mlir/lite/python/BUILD index ceca156e07e..caa5605b00b 100644 --- a/tensorflow/compiler/mlir/lite/python/BUILD +++ b/tensorflow/compiler/mlir/lite/python/BUILD @@ -1,3 +1,5 @@ +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") + licenses(["notice"]) # Apache 2.0 package(default_visibility = [":friends"]) @@ -86,6 +88,7 @@ cc_library( "//tensorflow/lite/toco:toco_flags_proto_cc", "//tensorflow/lite/toco:types_proto_cc", "//tensorflow/stream_executor/lib", + "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", diff --git a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc index e786bedc86d..005c5123906 100644 --- a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc @@ -90,9 +90,10 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags, pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops; pass_config.lower_tensor_list_ops = true; - return internal::ConvertMLIRToTFLiteFlatBuffer(toco_flags, std::move(module), - pass_config, result, - /*session=*/llvm::None); + return internal::ConvertMLIRToTFLiteFlatBuffer( + toco_flags, std::move(module), pass_config, /*saved_model_tags=*/{}, + result, + /*session=*/llvm::None); } } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc index 529c9ee9238..7bbd3209dfe 100644 --- a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc @@ -16,6 +16,7 @@ limitations under the License. #include +#include "absl/types/span.h" #include "llvm/ADT/None.h" #include "llvm/ADT/StringSet.h" #include "llvm/Support/ToolOutputFile.h" @@ -118,9 +119,9 @@ Status HandleInputOutputArraysWithModule(const toco::ModelFlags& model_flags, return Status::OK(); } -Status ConvertSavedModelToTFLiteFlatBuffer( - const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags, - string* result) { +Status ConvertSavedModelToTFLiteFlatBuffer(const toco::ModelFlags& model_flags, + const toco::TocoFlags& toco_flags, + string* result) { mlir::MLIRContext context; mlir::TFL::QuantizationSpecs quant_specs; @@ -156,9 +157,12 @@ Status ConvertSavedModelToTFLiteFlatBuffer( tensorflow::GraphImportConfig specs; specs.upgrade_legacy = true; + std::vector custom_opdefs(toco_flags.custom_opdefs().begin(), + toco_flags.custom_opdefs().end()); TF_ASSIGN_OR_RETURN(auto module, ImportSavedModel(model_flags.saved_model_dir(), model_flags.saved_model_version(), tags, + absl::MakeSpan(custom_opdefs), exported_names, specs, &context)); if (!model_flags.input_arrays().empty() || @@ -173,7 +177,7 @@ Status ConvertSavedModelToTFLiteFlatBuffer( // TODO(b/153507667): Pass the session object when importing logic is removed. auto status = internal::ConvertMLIRToTFLiteFlatBuffer( - toco_flags, std::move(module), pass_config, result, + toco_flags, std::move(module), pass_config, tags, result, /*session=*/llvm::None); return status; } diff --git a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc index a4e58123e05..ae2454dcf1e 100644 --- a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc +++ b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc @@ -273,7 +273,8 @@ Status DumpOpGraphToFile(mlir::ModuleOp module, const std::string& filename) { Status ConvertMLIRToTFLiteFlatBuffer( const toco::TocoFlags& toco_flags, mlir::OwningModuleRef module, - const mlir::TFL::PassConfig& pass_config, string* result, + const mlir::TFL::PassConfig& pass_config, + const std::unordered_set& saved_model_tags, string* result, llvm::Optional session) { bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops(); bool emit_select_tf_ops = toco_flags.enable_select_tf_ops(); @@ -297,8 +298,8 @@ Status ConvertMLIRToTFLiteFlatBuffer( auto status = ConvertTFExecutorToTFLOrFlatbuffer( module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops, - emit_select_tf_ops, emit_custom_ops, pass_config.quant_specs, result, - &pm); + emit_select_tf_ops, emit_custom_ops, pass_config.quant_specs, + saved_model_tags, result, &pm); if (toco_flags.has_dump_graphviz_dir()) { TF_RETURN_IF_ERROR(DumpOpGraphToFile( // rename once we enable the new converter feature flag. diff --git a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h index d79bdc6df67..d4f9e739121 100644 --- a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h +++ b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h @@ -16,6 +16,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_TF_TFL_FLATBUFFER_HELPERS_H_ #include +#include #include #include "llvm/ADT/Optional.h" @@ -48,7 +49,8 @@ Status PopulateQuantizationSpecs( // This will also run relevant passes as well. Status ConvertMLIRToTFLiteFlatBuffer( const toco::TocoFlags& toco_flags, mlir::OwningModuleRef module, - const mlir::TFL::PassConfig& pass_config, string* result, + const mlir::TFL::PassConfig& pass_config, + const std::unordered_set& saved_model_tags, string* result, llvm::Optional session); // Give a warning for any unused flags that have been specified. diff --git a/tensorflow/compiler/mlir/lite/quantization/BUILD b/tensorflow/compiler/mlir/lite/quantization/BUILD index aec0d8da34f..7e7020997ef 100644 --- a/tensorflow/compiler/mlir/lite/quantization/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/BUILD @@ -1,3 +1,9 @@ +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "filegroup") + +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "get_compatible_with_cloud") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_native_cc_binary") load( "//tensorflow/core/platform:build_config.bzl", @@ -41,6 +47,7 @@ filegroup( gentbl( name = "quantization_interfaces_inc_gen", + compatible_with = get_compatible_with_cloud(), tbl_outs = [ ( "-gen-op-interface-decls", diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD index 38c7ad86e05..905426ab952 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD @@ -1,3 +1,4 @@ +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow:tensorflow.bzl", "tf_cc_binary") package( diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h index 6e356acbbdf..eb9843f6e4a 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h @@ -106,9 +106,9 @@ struct ConvertStatsToQDQs : public OpRewritePattern { mins.push_back(FloatAttr::getValueAsDouble(*it++)); maxs.push_back(FloatAttr::getValueAsDouble(*it)); } - quant_type = quant::fakeQuantAttrsToType( - op.getLoc(), num_bits, op.axis()->getSExtValue(), mins, maxs, - narrow_range, expressed, is_signed); + quant_type = + quant::fakeQuantAttrsToType(op.getLoc(), num_bits, *op.axis(), mins, + maxs, narrow_range, expressed, is_signed); } else if (auto stats = op.layerStats().dyn_cast()) { double rmin = FloatAttr::getValueAsDouble(stats.getValue({0})); double rmax = FloatAttr::getValueAsDouble(stats.getValue({1})); @@ -119,7 +119,7 @@ struct ConvertStatsToQDQs : public OpRewritePattern { return failure(); } - rewriter.setInsertionPointAfter(op); + rewriter.setInsertionPointAfter(op.getOperation()); Type result_type = quant_type.castFromExpressedType(op.getType()); auto q = rewriter.create(op.getLoc(), result_type, op.arg()); auto dq = rewriter.create(op.getLoc(), op.getType(), q); diff --git a/tensorflow/compiler/mlir/lite/quantization/tensorflow/BUILD b/tensorflow/compiler/mlir/lite/quantization/tensorflow/BUILD index 38ea69c51d6..76fd75e18ea 100644 --- a/tensorflow/compiler/mlir/lite/quantization/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/tensorflow/BUILD @@ -1,3 +1,5 @@ +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") + package( default_visibility = [ ":friends", diff --git a/tensorflow/compiler/mlir/lite/quantization/tensorflow/tests/BUILD b/tensorflow/compiler/mlir/lite/quantization/tensorflow/tests/BUILD index 4faa8d2efe8..d7d01eb59a3 100644 --- a/tensorflow/compiler/mlir/lite/quantization/tensorflow/tests/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/tensorflow/tests/BUILD @@ -1,3 +1,4 @@ +load("//tensorflow:tensorflow.bzl", "filegroup") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") package(licenses = ["notice"]) diff --git a/tensorflow/compiler/mlir/lite/quantization/tensorflow/tf_to_quant.cc b/tensorflow/compiler/mlir/lite/quantization/tensorflow/tf_to_quant.cc index 0826b3265f6..b043834188c 100644 --- a/tensorflow/compiler/mlir/lite/quantization/tensorflow/tf_to_quant.cc +++ b/tensorflow/compiler/mlir/lite/quantization/tensorflow/tf_to_quant.cc @@ -106,9 +106,8 @@ struct InsertQuantOpsAfterTFFakeQuantOp } // Use the min/max from the operands and the num_bits and narrow_range // attribute to create the quantization parameter for the new quantize op. - rewriter.setInsertionPointAfter(tf_op); - IntegerAttr num_bits = - rewriter.getI64IntegerAttr(tf_op.num_bits().getSExtValue()); + rewriter.setInsertionPointAfter(tf_op.getOperation()); + IntegerAttr num_bits = rewriter.getI64IntegerAttr(tf_op.num_bits()); BoolAttr narrow_range = rewriter.getBoolAttr(tf_op.narrow_range()); Type res_type = tf_op.getType(); TypeAttr qtype = quant::GetQuantizedTypeAttr( diff --git a/tensorflow/compiler/mlir/lite/quantization/tests/BUILD b/tensorflow/compiler/mlir/lite/quantization/tests/BUILD index 4faa8d2efe8..d7d01eb59a3 100644 --- a/tensorflow/compiler/mlir/lite/quantization/tests/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/tests/BUILD @@ -1,3 +1,4 @@ +load("//tensorflow:tensorflow.bzl", "filegroup") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") package(licenses = ["notice"]) diff --git a/tensorflow/compiler/mlir/lite/quantization/tools/op_quant_spec_getters_gen.cc b/tensorflow/compiler/mlir/lite/quantization/tools/op_quant_spec_getters_gen.cc index 208fb4c8a56..fc56ad05535 100644 --- a/tensorflow/compiler/mlir/lite/quantization/tools/op_quant_spec_getters_gen.cc +++ b/tensorflow/compiler/mlir/lite/quantization/tools/op_quant_spec_getters_gen.cc @@ -55,7 +55,7 @@ static bool OpQuantSpecWriter(raw_ostream &os, RecordKeeper &records) { for (const auto t : op.getTraits()) { if (auto opTrait = llvm::dyn_cast(&t)) { auto trait = opTrait->getTrait(); - if (!trait.consume_front("OpTrait::quant::")) continue; + if (!trait.consume_front("::mlir::OpTrait::quant::")) continue; OUT(2) << "if (auto tfl = llvm::dyn_cast<" << op.getQualCppClassName() << ">(op)) {\n"; @@ -65,7 +65,7 @@ static bool OpQuantSpecWriter(raw_ostream &os, RecordKeeper &records) { OUT(4) << "for (int i = 0, e = op->getNumResults(); i != e; ++i)\n"; OUT(6) << "spec->restricted_output_params[std::make_pair(" << matches[1] << ", " << matches[2] - << ")].push_back(tfl.OpTrait::quant::" << trait << "<" + << ")].push_back(tfl.::mlir::OpTrait::quant::" << trait << "<" << op.getQualCppClassName() << ">::GetResultQuantizedType(i));\n"; matches.clear(); diff --git a/tensorflow/compiler/mlir/lite/sparsity/BUILD b/tensorflow/compiler/mlir/lite/sparsity/BUILD index 9ced3220c9b..7f9f06455cb 100644 --- a/tensorflow/compiler/mlir/lite/sparsity/BUILD +++ b/tensorflow/compiler/mlir/lite/sparsity/BUILD @@ -1,3 +1,5 @@ +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") + package( default_visibility = [ ":friends", diff --git a/tensorflow/compiler/mlir/lite/tests/BUILD b/tensorflow/compiler/mlir/lite/tests/BUILD index 58d5afb5864..d34fb991b71 100644 --- a/tensorflow/compiler/mlir/lite/tests/BUILD +++ b/tensorflow/compiler/mlir/lite/tests/BUILD @@ -1,3 +1,4 @@ +load("//tensorflow:tensorflow.bzl", "filegroup") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") package(licenses = ["notice"]) diff --git a/tensorflow/compiler/mlir/lite/tests/debuginfo/BUILD b/tensorflow/compiler/mlir/lite/tests/debuginfo/BUILD index 1f746c528d6..6ea272745bd 100644 --- a/tensorflow/compiler/mlir/lite/tests/debuginfo/BUILD +++ b/tensorflow/compiler/mlir/lite/tests/debuginfo/BUILD @@ -1,3 +1,4 @@ +load("//tensorflow:tensorflow.bzl", "filegroup") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") licenses(["notice"]) diff --git a/tensorflow/compiler/mlir/lite/tests/dilated-conv.mlir b/tensorflow/compiler/mlir/lite/tests/dilated-conv.mlir index 6b3c5b04aa4..d92bdc3f460 100644 --- a/tensorflow/compiler/mlir/lite/tests/dilated-conv.mlir +++ b/tensorflow/compiler/mlir/lite/tests/dilated-conv.mlir @@ -2,7 +2,7 @@ func @testDilatedConv(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> { %cst = constant dense<[2, 2]> : tensor<2xi32> - %cst_0 = constant dense<2> : tensor<2x2xi32> + %cst_0 = constant dense<4> : tensor<2x2xi32> %0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_0) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32> %1 = "tf.Conv2D"(%0, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32> %2 = "tf.BatchToSpaceND"(%1, %cst, %cst_0) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32> @@ -29,8 +29,8 @@ func @testDilatedConvWithNonConstantPadAndCrops(%arg0: tensor<1x128x128x3xf32>, func @testDilatedConvWithNonZeroBasePadding(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> { %cst = constant dense<[2, 2]> : tensor<2xi32> - %cst_0 = constant dense<2> : tensor<2x2xi32> - %cst_1 = constant dense<1> : tensor<2x2xi32> + %cst_0 = constant dense<4> : tensor<2x2xi32> + %cst_1 = constant dense<0> : tensor<2x2xi32> %0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_0) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32> %1 = "tf.Conv2D"(%0, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32> %2 = "tf.BatchToSpaceND"(%1, %cst, %cst_1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32> @@ -44,10 +44,11 @@ func @testDilatedConvWithNonZeroBasePadding(%arg0: tensor<1x128x128x3xf32>, %arg func @testDilatedConvWithNonTrivialDilations(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> { %cst = constant dense<[2, 2]> : tensor<2xi32> - %cst_0 = constant dense<2> : tensor<2x2xi32> + %cst_0 = constant dense<4> : tensor<2x2xi32> + %cst_1 = constant dense<0> : tensor<2x2xi32> %0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_0) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32> %1 = "tf.Conv2D"(%0, %arg1) {padding = "VALID", dilations = [1, 2, 2, 1], strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32> - %2 = "tf.BatchToSpaceND"(%1, %cst, %cst_0) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32> + %2 = "tf.BatchToSpaceND"(%1, %cst, %cst_1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32> return %2 : tensor<1x128x128x8xf32> // CHECK-LABEL: testDilatedConvWithNonTrivialDilations @@ -59,80 +60,85 @@ func @testDilatedConvWithNonTrivialDilations(%arg0: tensor<1x128x128x3xf32>, %ar func @testDilatedDepthWiseConv(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> { %cst = constant dense<[2, 2]> : tensor<2xi32> - %cst_0 = constant dense<2> : tensor<2x2xi32> + %cst_0 = constant dense<4> : tensor<2x2xi32> + %cst_1 = constant dense<0> : tensor<2x2xi32> %0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_0) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32> %1 = "tf.DepthwiseConv2dNative"(%0, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32> - %2 = "tf.BatchToSpaceND"(%1, %cst, %cst_0) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32> + %2 = "tf.BatchToSpaceND"(%1, %cst, %cst_1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32> return %2 : tensor<1x128x128x8xf32> // CHECK-LABEL: testDilatedDepthWiseConv // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>) - // CHECK-NEXT: [[RESULT:%.*]] = "tf.DepthwiseConv2dNative"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> + // CHECK-NEXT: [[RESULT:%.*]] = "tf.DepthwiseConv2dNative"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> // CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32> } func @testDilatedConvWithPad(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>, %arg3: tensor<8xf32>) -> tensor<1x128x128x8xf32> { %cst = constant dense<[2, 2]> : tensor<2xi32> - %cst_0 = constant dense<2> : tensor<2x2xi32> + %cst_0 = constant dense<4> : tensor<2x2xi32> + %cst_1 = constant dense<0> : tensor<2x2xi32> %0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_0) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32> %1 = "tf.Conv2D"(%0, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32> %2 = "tf.Pad"(%1, %arg1) : (tensor<4x64x64x8xf32>, tensor<2x2xi32>) -> tensor<4x64x64x8xf32> - %3 = "tf.BatchToSpaceND"(%2, %cst, %cst_0) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32> + %3 = "tf.BatchToSpaceND"(%2, %cst, %cst_1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32> %4 = "tf.BiasAdd"(%3, %arg3) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32> return %4 : tensor<1x128x128x8xf32> // CHECK-LABEL: testDilatedConvWithPad // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>, [[BIAS:%.*]]: tensor<8xf32>) - // CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> + // CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> // CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[CONV]], [[BIAS]]) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32> // CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32> } func @testDilatedDepthWiseConvWithPad(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>, %arg3: tensor<8xf32>) -> tensor<1x128x128x8xf32> { %cst = constant dense<[2, 2]> : tensor<2xi32> - %cst_0 = constant dense<2> : tensor<2x2xi32> + %cst_0 = constant dense<4> : tensor<2x2xi32> + %cst_1 = constant dense<0> : tensor<2x2xi32> %0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_0) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32> %1 = "tf.DepthwiseConv2dNative"(%0, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32> %2 = "tf.Pad"(%1, %arg1) : (tensor<4x64x64x8xf32>, tensor<2x2xi32>) -> tensor<4x64x64x8xf32> - %3 = "tf.BatchToSpaceND"(%2, %cst, %cst_0) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32> + %3 = "tf.BatchToSpaceND"(%2, %cst, %cst_1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32> %4 = "tf.BiasAdd"(%3, %arg3) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32> return %4 : tensor<1x128x128x8xf32> // CHECK-LABEL: testDilatedDepthWiseConvWithPad // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>, [[BIAS:%.*]]: tensor<8xf32>) - // CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> + // CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> // CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[CONV]], [[BIAS]]) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32> // CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32> } func @testDilatedConvWithBiasAdd(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<5x5x3x8xf32>, %arg2: tensor<8xf32>) -> tensor<1x128x128x8xf32> { %cst = constant dense<[2, 2]> : tensor<2xi32> - %cst_0 = constant dense<2> : tensor<2x2xi32> + %cst_0 = constant dense<4> : tensor<2x2xi32> + %cst_1 = constant dense<0> : tensor<2x2xi32> %0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_0) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32> %1 = "tf.Conv2D"(%0, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32> - %2 = "tf.BatchToSpaceND"(%1, %cst, %cst_0) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32> + %2 = "tf.BatchToSpaceND"(%1, %cst, %cst_1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32> %3 = "tf.BiasAdd"(%2, %arg2) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32> return %3 : tensor<1x128x128x8xf32> // CHECK-LABEL: testDilatedConvWithBiasAdd // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>, [[BIAS:%.*]]: tensor<8xf32>) - // CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> + // CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> // CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[CONV]], [[BIAS]]) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32> // CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32> } func @testDilatedDepthWiseConvWithBiasAdd(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<5x5x3x8xf32>, %arg2: tensor<8xf32>) -> tensor<1x128x128x8xf32> { %cst = constant dense<[2, 2]> : tensor<2xi32> - %cst_0 = constant dense<2> : tensor<2x2xi32> + %cst_0 = constant dense<4> : tensor<2x2xi32> + %cst_1 = constant dense<0> : tensor<2x2xi32> %0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_0) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32> %1 = "tf.DepthwiseConv2dNative"(%0, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32> - %2 = "tf.BatchToSpaceND"(%1, %cst, %cst_0) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32> + %2 = "tf.BatchToSpaceND"(%1, %cst, %cst_1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32> %3 = "tf.BiasAdd"(%2, %arg2) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32> return %3 : tensor<1x128x128x8xf32> // CHECK-LABEL: testDilatedDepthWiseConvWithBiasAdd // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>, [[BIAS:%.*]]: tensor<8xf32>) - // CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> + // CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> // CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[CONV]], [[BIAS]]) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32> // CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32> } @@ -140,12 +146,13 @@ func @testDilatedDepthWiseConvWithBiasAdd(%arg0: tensor<1x128x128x3xf32>, %arg1: func @testDilatedConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %arg1: tensor<5x5x1x1xf32>, %arg2: tensor<128xf32>) -> tensor<1x128x128xf32> { %cst = constant dense<[2, 2]> : tensor<2xi32> %cst_0 = "tf.Const"() { value = dense<3> : tensor } : () -> tensor - %cst_1 = constant dense<2> : tensor<2x2xi32> + %cst_1 = constant dense<4> : tensor<2x2xi32> + %cst_2 = constant dense<0> : tensor<2x2xi32> %0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32> %1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor) -> tensor<4x68x68x1xf32> %2 = "tf.Conv2D"(%1, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32> %3 = "tf.Squeeze"(%2) {squeeze_dims = [3]} : (tensor<4x64x64x1xf32>) -> tensor<4x64x64xf32> - %4 = "tf.BatchToSpaceND"(%3, %cst, %cst_1) : (tensor<4x64x64xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32> + %4 = "tf.BatchToSpaceND"(%3, %cst, %cst_2) : (tensor<4x64x64xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32> %5 = "tf.BiasAdd"(%4, %arg2) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32> return %5 : tensor<1x128x128xf32> @@ -153,7 +160,7 @@ func @testDilatedConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %arg1: ten // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>) // CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor} : () -> tensor // CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor) -> tensor<1x128x128x1xf32> - // CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32> + // CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32> // CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32> // CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[SQUEEZE]], [[BIAS]]) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32> // CHECK-NEXT: return [[RESULT]] : tensor<1x128x128xf32> @@ -162,12 +169,13 @@ func @testDilatedConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %arg1: ten func @testDilatedDepthWiseConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %arg1: tensor<5x5x1x1xf32>, %arg2: tensor<128xf32>) -> tensor<1x128x128xf32> { %cst = constant dense<[2, 2]> : tensor<2xi32> %cst_0 = "tf.Const"() { value = dense<3> : tensor } : () -> tensor - %cst_1 = constant dense<2> : tensor<2x2xi32> + %cst_1 = constant dense<4> : tensor<2x2xi32> + %cst_2 = constant dense<0> : tensor<2x2xi32> %0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32> %1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor) -> tensor<4x68x68x1xf32> %2 = "tf.DepthwiseConv2dNative"(%1, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32> %3 = "tf.Squeeze"(%2) {squeeze_dims = [3]} : (tensor<4x64x64x1xf32>) -> tensor<4x64x64xf32> - %4 = "tf.BatchToSpaceND"(%3, %cst, %cst_1) : (tensor<4x64x64xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32> + %4 = "tf.BatchToSpaceND"(%3, %cst, %cst_2) : (tensor<4x64x64xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32> %5 = "tf.BiasAdd"(%4, %arg2) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32> return %5 : tensor<1x128x128xf32> @@ -175,7 +183,7 @@ func @testDilatedDepthWiseConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, % // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>) // CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor} : () -> tensor // CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor) -> tensor<1x128x128x1xf32> - // CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32> + // CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32> // CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32> // CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[SQUEEZE]], [[BIAS]]) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32> // CHECK-NEXT: return [[RESULT]] : tensor<1x128x128xf32> @@ -184,20 +192,21 @@ func @testDilatedDepthWiseConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, % func @testDilatedConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %arg1: tensor<5x5x1x1xf32>, %arg2: tensor) -> tensor<1x128x128xf32> { %cst = constant dense<[2, 2]> : tensor<2xi32> %cst_0 = "tf.Const"() { value = dense<3> : tensor } : () -> tensor - %cst_1 = constant dense<2> : tensor<2x2xi32> + %cst_1 = constant dense<4> : tensor<2x2xi32> + %cst_2 = constant dense<0> : tensor<2x2xi32> %0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x?x?xf32> %1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x?x?xf32>, tensor) -> tensor<4x?x?x1xf32> %2 = "tf.Conv2D"(%1, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x?x?x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x?x?x1xf32> %3 = "tf.Squeeze"(%2) {squeeze_dims = [3]} : (tensor<4x?x?x1xf32>) -> tensor<4x?x?xf32> %4 = "tf.BiasAdd"(%3, %arg2) : (tensor<4x?x?xf32>, tensor) -> tensor<4x?x?xf32> - %5 = "tf.BatchToSpaceND"(%4, %cst, %cst_1) : (tensor<4x?x?xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32> + %5 = "tf.BatchToSpaceND"(%4, %cst, %cst_2) : (tensor<4x?x?xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32> return %5 : tensor<1x128x128xf32> // CHECK-LABEL: testDilatedConvWithExpandSqueeze2 // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor) // CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor} : () -> tensor // CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor) -> tensor<1x128x128x1xf32> - // CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32> + // CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32> // CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32> // CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[SQUEEZE]], [[BIAS]]) : (tensor<1x128x128xf32>, tensor) -> tensor<1x128x128xf32> // CHECK-NEXT: return [[RESULT]] : tensor<1x128x128xf32> @@ -206,20 +215,21 @@ func @testDilatedConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %arg1: ten func @testDilatedDepthWiseConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %arg1: tensor<5x5x1x1xf32>, %arg2: tensor) -> tensor<1x128x128xf32> { %cst = constant dense<[2, 2]> : tensor<2xi32> %cst_0 = "tf.Const"() { value = dense<3> : tensor } : () -> tensor - %cst_1 = constant dense<2> : tensor<2x2xi32> + %cst_1 = constant dense<4> : tensor<2x2xi32> + %cst_2 = constant dense<0> : tensor<2x2xi32> %0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x?x?xf32> %1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x?x?xf32>, tensor) -> tensor<4x?x?x1xf32> %2 = "tf.DepthwiseConv2dNative"(%1, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x?x?x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x?x?x1xf32> %3 = "tf.Squeeze"(%2) {squeeze_dims = [3]} : (tensor<4x?x?x1xf32>) -> tensor<4x?x?xf32> %4 = "tf.BiasAdd"(%3, %arg2) : (tensor<4x?x?xf32>, tensor) -> tensor<4x?x?xf32> - %5 = "tf.BatchToSpaceND"(%4, %cst, %cst_1) : (tensor<4x?x?xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32> + %5 = "tf.BatchToSpaceND"(%4, %cst, %cst_2) : (tensor<4x?x?xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32> return %5 : tensor<1x128x128xf32> // CHECK-LABEL: testDilatedDepthWiseConvWithExpandSqueeze2 // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor) // CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor} : () -> tensor // CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor) -> tensor<1x128x128x1xf32> - // CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32> + // CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32> // CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32> // CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[SQUEEZE]], [[BIAS]]) : (tensor<1x128x128xf32>, tensor) -> tensor<1x128x128xf32> // CHECK-NEXT: return [[RESULT]] : tensor<1x128x128xf32> @@ -228,7 +238,8 @@ func @testDilatedDepthWiseConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, % func @testDilatedConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> { %cst = constant dense<[2, 2]> : tensor<2xi32> %cst_0 = "tf.Const"() { value = dense<3> : tensor } : () -> tensor - %cst_1 = constant dense<2> : tensor<2x2xi32> + %cst_1 = constant dense<4> : tensor<2x2xi32> + %cst_2 = constant dense<0> : tensor<2x2xi32> %0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32> %1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor) -> tensor<4x68x68x1xf32> %2 = "tf.Conv2D"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32> @@ -251,13 +262,14 @@ func @testDilatedConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: ten func @testDilatedDepthWiseConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> { %cst = constant dense<[2, 2]> : tensor<2xi32> %cst_0 = "tf.Const"() { value = dense<3> : tensor } : () -> tensor - %cst_1 = constant dense<2> : tensor<2x2xi32> + %cst_1 = constant dense<4> : tensor<2x2xi32> + %cst_2 = constant dense<0> : tensor<2x2xi32> %0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32> %1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor) -> tensor<4x68x68x1xf32> %2 = "tf.DepthwiseConv2dNative"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32> %3 = "tf.Squeeze"(%2) {squeeze_dims = [3]} : (tensor<4x64x64x1xf32>) -> tensor<4x64x64xf32> %4 = "tf.Pad"(%3, %arg1) : (tensor<4x64x64xf32>, tensor<2x2xi32>) -> tensor<4x64x64xf32> - %5 = "tf.BatchToSpaceND"(%4, %cst, %cst_1) : (tensor<4x64x64xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32> + %5 = "tf.BatchToSpaceND"(%4, %cst, %cst_2) : (tensor<4x64x64xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32> %6 = "tf.BiasAdd"(%5, %arg3) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32> return %6 : tensor<1x128x128xf32> @@ -265,7 +277,7 @@ func @testDilatedDepthWiseConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, % // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>) // CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor} : () -> tensor // CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor) -> tensor<1x128x128x1xf32> - // CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32> + // CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32> // CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32> // CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[SQUEEZE]], [[BIAS]]) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32> // CHECK-NEXT: return [[RESULT]] : tensor<1x128x128xf32> @@ -274,12 +286,13 @@ func @testDilatedDepthWiseConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, % func @testDilatedConvWithDifferentExpandSqueezeAxis(%arg0: tensor<1x128x128xf32>, %arg1: tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32> { %cst = constant dense<[2, 2]> : tensor<2xi32> %cst_0 = "tf.Const"() { value = dense<3> : tensor } : () -> tensor - %cst_1 = constant dense<2> : tensor<2x2xi32> + %cst_1 = constant dense<4> : tensor<2x2xi32> + %cst_2 = constant dense<0> : tensor<2x2xi32> %0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32> %1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor) -> tensor<4x68x68x1xf32> %2 = "tf.Conv2D"(%1, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32> %3 = "tf.Squeeze"(%2) {squeeze_dims = [2]} : (tensor<4x64x64x1xf32>) -> tensor<4x64x64x1xf32> - %4 = "tf.BatchToSpaceND"(%3, %cst, %cst_1) : (tensor<4x64x64x1xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x1xf32> + %4 = "tf.BatchToSpaceND"(%3, %cst, %cst_2) : (tensor<4x64x64x1xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x1xf32> return %4 : tensor<1x128x128x1xf32> // CHECK-LABEL: testDilatedConvWithDifferentExpandSqueezeAxis diff --git a/tensorflow/compiler/mlir/lite/tests/end2end/BUILD b/tensorflow/compiler/mlir/lite/tests/end2end/BUILD index 25bd761f99e..b0b794034ea 100644 --- a/tensorflow/compiler/mlir/lite/tests/end2end/BUILD +++ b/tensorflow/compiler/mlir/lite/tests/end2end/BUILD @@ -1,3 +1,4 @@ +load("//tensorflow:tensorflow.bzl", "filegroup") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") licenses(["notice"]) diff --git a/tensorflow/compiler/mlir/lite/tests/end2end/add.pbtxt b/tensorflow/compiler/mlir/lite/tests/end2end/add.pbtxt index facd6005e7d..9f8d82eb184 100644 --- a/tensorflow/compiler/mlir/lite/tests/end2end/add.pbtxt +++ b/tensorflow/compiler/mlir/lite/tests/end2end/add.pbtxt @@ -96,5 +96,6 @@ versions { # CHECK-NEXT: metadata: [ { # CHECK-NEXT: name: "min_runtime_version", # CHECK-NEXT: buffer: 4 -# CHECK-NEXT: } ] +# CHECK-NEXT: } ], +# CHECK-NEXT: signature_defs: [ ] # CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/tests/end2end/conv_2d_nchw.pbtxt b/tensorflow/compiler/mlir/lite/tests/end2end/conv_2d_nchw.pbtxt new file mode 100644 index 00000000000..5f498a404a9 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/end2end/conv_2d_nchw.pbtxt @@ -0,0 +1,232 @@ +# RUN: tf_tfl_translate -tf-input-arrays=input -tf-input-shapes=1,8,8,2 -tf-input-data-types=DT_FLOAT -tf-output-arrays=output_0 -print-function-result-mapping %s -o - 2>&1 | FileCheck %s + +node { + name: "input" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 1 + } + dim { + size: 8 + } + dim { + size: 8 + } + dim { + size: 2 + } + } + } + } +} +node { + name: "conv_net_2d/conv_2d_0/w" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 3 + } + dim { + size: 3 + } + dim { + size: 2 + } + dim { + size: 2 + } + } + tensor_content: ";;\177<5\241i\275\312f\211>#\346j>\033W\325\275\253>\210=Vr\r\276\304\222\313\276\374\346\214>\016e\211>)\253\000>\3241\337\275\235g-\276*(\216\276\326#\367\274\023\213\300\276\227\031\206>PUF=\253\330\263<\337IL\276\334\320\215>\377\306v\276\372C\302\273baM>H\314\270<2\221\352=J\026{\276\221\243\245\276?\314\240=UW2\2755\207\253\274\256\207\333\273\335\372\227>\246\232;\276%\r\374" + } + } + } +} +node { + name: "conv_net_2d/conv_2d_0/w/read" + op: "Identity" + input: "conv_net_2d/conv_2d_0/w" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@conv_net_2d/conv_2d_0/w" + } + } + } +} +node { + name: "conv_net_2d_1/conv_2d_0/convolution" + op: "Conv2D" + input: "input" + input: "conv_net_2d/conv_2d_0/w/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NCHW" + } + } + attr { + key: "dilations" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "explicit_paddings" + value { + list { + } + } + } + attr { + key: "padding" + value { + s: "SAME" + } + } + attr { + key: "strides" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "use_cudnn_on_gpu" + value { + b: true + } + } +} +node { + name: "conv_net_2d/conv_2d_0/b" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\315\314\314=\315\314\314=" + } + } + } +} +node { + name: "conv_net_2d/conv_2d_0/b/read" + op: "Identity" + input: "conv_net_2d/conv_2d_0/b" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@conv_net_2d/conv_2d_0/b" + } + } + } +} +node { + name: "conv_net_2d_1/conv_2d_0/BiasAdd" + op: "BiasAdd" + input: "conv_net_2d_1/conv_2d_0/convolution" + input: "conv_net_2d/conv_2d_0/b/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } +} +node { + name: "conv_net_2d_1/Relu" + op: "Relu" + input: "conv_net_2d_1/conv_2d_0/BiasAdd" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "output_0" + op: "Identity" + input: "conv_net_2d_1/Relu" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +library { +} + +# CHECK: 'main' inputs: +# CHECK-NEXT: name: 'input' +# CHECK-NEXT: 'main' outputs: +# CHECK-NEXT: name: 'output_0' diff --git a/tensorflow/compiler/mlir/lite/tests/end2end/fake_quant_per_channel.pbtxt b/tensorflow/compiler/mlir/lite/tests/end2end/fake_quant_per_channel.pbtxt index adfcd93b4bc..117edd02beb 100644 --- a/tensorflow/compiler/mlir/lite/tests/end2end/fake_quant_per_channel.pbtxt +++ b/tensorflow/compiler/mlir/lite/tests/end2end/fake_quant_per_channel.pbtxt @@ -459,11 +459,13 @@ node { # CHECK-LABEL: { # CHECK: version: 3, # CHECK: operator_codes: [ { -# CHECK: builtin_code: CONV_2D, -# CHECK: version: 3 +# CHECK: deprecated_builtin_code: 3, +# CHECK: version: 3, +# CHECK: builtin_code: CONV_2D # CHECK: }, { -# CHECK: builtin_code: RESHAPE, +# CHECK: deprecated_builtin_code: 22, # CHECK: version: 1 +# CHECK: builtin_code: RESHAPE # CHECK: } ], # CHECK: subgraphs: [ { # CHECK: tensors: [ { diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/BUILD b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/BUILD index e21b268279c..41fbbbcb9c5 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/BUILD +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/BUILD @@ -1,3 +1,4 @@ +load("//tensorflow:tensorflow.bzl", "filegroup") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") load("//tensorflow:tensorflow.bzl", "tf_native_cc_binary") diff --git a/tensorflow/compiler/mlir/lite/tests/fuse-tftext.mlir b/tensorflow/compiler/mlir/lite/tests/fuse-tftext.mlir index 138614d81e6..d56c2cc221a 100644 --- a/tensorflow/compiler/mlir/lite/tests/fuse-tftext.mlir +++ b/tensorflow/compiler/mlir/lite/tests/fuse-tftext.mlir @@ -3442,8 +3442,8 @@ func @sgnn_projection(%arg0: tensor {tf._user_specified_name = "va %0 = "tf.Const"() {value = dense<[[1902835825], [-1475704015], [473120514], [1254202069], [1558833093], [1756181982], [1906603252], [-1034142694], [542842690], [535515822]]> : tensor<10x1xi64>} : () -> tensor<10x1xi64> %1 = "tf.StringToHashBucketFast"(%arg0) {device = "", num_buckets = 2147483647 : i64} : (tensor) -> tensor %2 = "tf.Sgnn"(%1, %0) {device = ""} : (tensor, tensor<10x1xi64>) -> tensor<10x?xf64> - %3 = "tf.Const"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64> - %4 = "tf.Reshape"(%2, %3) : (tensor<10x?xf64>, tensor<1xi64>) -> tensor + %3 = "tf.Const"() {value = dense<[-1, 10]> : tensor<2xi64>} : () -> tensor<2xi64> + %4 = "tf.Reshape"(%2, %3) : (tensor<10x?xf64>, tensor<2xi64>) -> tensor return %4 : tensor } diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir index d02e4e705f4..4de278ee324 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt %s -tfl-legalize-tf | FileCheck %s +// RUN: tf-opt %s -tfl-legalize-tf --cse | FileCheck %s func @add(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> { %0 = "tf.Add"(%arg0, %arg1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> @@ -196,7 +196,6 @@ func @shape(%arg0: tensor) -> tensor<2xi32> { // CHECK-LABEL: shape // CHECK: "tfl.shape"(%arg0) : (tensor) -> tensor<2xi32> -// CHECK: %1 = "tfl.shape"(%arg0) : (tensor) -> tensor<2xi32> } func @fill(%arg0: tensor<3xi32>, %arg1: tensor) -> tensor { @@ -719,9 +718,8 @@ func @matrix_diag_v2_no_match(%arg0: tensor<8x16xf32>) -> tensor<8x16x16xf32> { // CHECK-SAME: [[VAL_0:%.*]]: tensor<8x16xf32>) -> tensor<8x16x16xf32> { // CHECK: [[VAL_1:%.*]] = constant dense<1> : tensor<1xi32> // CHECK: [[VAL_2:%.*]] = constant dense<-1> : tensor<1xi32> -// CHECK: [[VAL_5:%.*]] = constant dense<-1> : tensor<1xi32> // CHECK: [[VAL_3:%.*]] = constant dense<0> : tensor<2xi32> -// CHECK: [[VAL_4:%.*]] = "tf.MatrixDiagV2"([[VAL_0]], [[VAL_1]], [[VAL_2]], [[VAL_5]], [[VAL_3]]) : (tensor<8x16xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<2xi32>) -> tensor<8x16x16xf32> +// CHECK: [[VAL_4:%.*]] = "tf.MatrixDiagV2"([[VAL_0]], [[VAL_1]], [[VAL_2]], [[VAL_2]], [[VAL_3]]) : (tensor<8x16xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<2xi32>) -> tensor<8x16x16xf32> // CHECK: return [[VAL_4]] : tensor<8x16x16xf32> } @@ -753,9 +751,8 @@ func @matrix_diag_v3_no_match(%arg0: tensor<8x16xf32>) -> tensor<8x16x16xf32> { // CHECK-SAME: [[VAL_0:%.*]]: tensor<8x16xf32>) -> tensor<8x16x16xf32> { // CHECK: [[VAL_1:%.*]] = constant dense<1> : tensor<1xi32> // CHECK: [[VAL_2:%.*]] = constant dense<-1> : tensor<1xi32> -// CHECK: [[VAL_5:%.*]] = constant dense<-1> : tensor<1xi32> // CHECK: [[VAL_3:%.*]] = constant dense<0> : tensor<2xi32> -// CHECK: [[VAL_4:%.*]] = "tf.MatrixDiagV3"([[VAL_0]], [[VAL_1]], [[VAL_2]], [[VAL_5]], [[VAL_3]]) : (tensor<8x16xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<2xi32>) -> tensor<8x16x16xf32> +// CHECK: [[VAL_4:%.*]] = "tf.MatrixDiagV3"([[VAL_0]], [[VAL_1]], [[VAL_2]], [[VAL_2]], [[VAL_3]]) : (tensor<8x16xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<2xi32>) -> tensor<8x16x16xf32> // CHECK: return [[VAL_4]] : tensor<8x16x16xf32> } @@ -1006,11 +1003,11 @@ func @batch_to_space_nd_unsupported(%arg0: tensor, %arg1: tensor< // CHECK: "tf.BatchToSpaceND" } -func @space_to_batch_nd(%arg0: tensor<1x4x4x3xf32>, %arg1: tensor<2xi32>, %arg2: tensor<2x2xi32>) -> tensor { - %0 = "tf.SpaceToBatchND"(%arg0, %arg1, %arg2) : (tensor<1x4x4x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor - return %0 : tensor +func @space_to_batch_nd(%arg0: tensor<1x4x4x3xf32>, %arg1: tensor<2xi32>, %arg2: tensor<2x2xi32>) -> tensor<*xf32> { + %0 = "tf.SpaceToBatchND"(%arg0, %arg1, %arg2) : (tensor<1x4x4x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<*xf32> + return %0 : tensor<*xf32> // CHECK-LABEL: space_to_batch_nd - // CHECK: "tfl.space_to_batch_nd"(%arg0, %arg1, %arg2) : (tensor<1x4x4x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor + // CHECK: "tfl.space_to_batch_nd"(%arg0, %arg1, %arg2) : (tensor<1x4x4x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<*xf32> } func @split(%arg0: tensor, %arg1: tensor<1x4x3x3xf32>) -> tensor<1x4x3xf32> { @@ -1029,32 +1026,75 @@ func @splitv(%arg0: tensor<1x4x3x3xf32>, %arg1: tensor<2xi32>, %arg2: tensor, tensor<2xi32>, tensor) -> (tensor<1x4x2x3xf32>, tensor<1x4x1x3xf32>) } -func @matmul_transposed(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> { +func @matmul(%arg0: tensor<40x37xf32>, %arg1: tensor<37x40xf32>) -> tensor<40x40xf32> { + %0 = "tf.MatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", transpose_a = false, transpose_b = false} : +(tensor<40x37xf32>, tensor<37x40xf32>) -> tensor<40x40xf32> + return %0 : tensor<40x40xf32> +// CHECK-LABEL: matmul +// CHECK: %[[CST:.*]] = constant dense<[1, 0]> : tensor<2xi32> +// CHECK: %[[ARG:.*]] = "tfl.transpose"(%arg1, %[[CST]]) : (tensor<37x40xf32>, tensor<2xi32>) -> tensor<40x37xf32> +// CHECK: %[[CST_0:.*]] = constant unit +// CHECK: "tfl.fully_connected"(%arg0, %[[ARG]], %[[CST_0]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> tensor<40x40xf32> +} + +func @matmul_transposed_a(%arg0: tensor<37x40xf32>, %arg1: tensor<37x40xf32>) -> tensor<40x40xf32> { + %0 = "tf.MatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", transpose_a = true, transpose_b = false} : +(tensor<37x40xf32>, tensor<37x40xf32>) -> tensor<40x40xf32> + return %0 : tensor<40x40xf32> +// CHECK-LABEL: matmul_transposed_a +// CHECK: %[[CST_0:.*]] = constant dense<[1, 0]> : tensor<2xi32> +// CHECK: %[[ARG_0:.*]] = "tfl.transpose"(%arg0, %[[CST_0]]) : (tensor<37x40xf32>, tensor<2xi32>) -> tensor<40x37xf32> +// CHECK: %[[ARG_1:.*]] = "tfl.transpose"(%arg1, %[[CST_0]]) : (tensor<37x40xf32>, tensor<2xi32>) -> tensor<40x37xf32> +// CHECK: %[[CST_2:.*]] = constant unit +// CHECK: "tfl.fully_connected"(%[[ARG_0]], %[[ARG_1]], %[[CST_2]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> tensor<40x40xf32> +} + +func @matmul_transposed_b(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> { %0 = "tf.MatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", transpose_a = false, transpose_b = true} : (tensor<40x37xf32>, tensor<40x37xf32>) -> tensor<40x40xf32> return %0 : tensor<40x40xf32> -// CHECK-LABEL: matmul_transposed +// CHECK-LABEL: matmul_transposed_b // CHECK: "tfl.fully_connected"(%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> tensor<40x40xf32> } -func @concatv2With3Tensors(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2: tensor<2x1xi32>) -> tensor<2x3xi32> { +func @matmul_transposed_ab(%arg0: tensor<37x40xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> { + %0 = "tf.MatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", transpose_a = true, transpose_b = true} : +(tensor<37x40xf32>, tensor<40x37xf32>) -> tensor<40x40xf32> + return %0 : tensor<40x40xf32> +// CHECK-LABEL: matmul_transposed_ab +// CHECK: %[[CST_0:.*]] = constant dense<[1, 0]> : tensor<2xi32> +// CHECK: %[[ARG_0:.*]] = "tfl.transpose"(%arg0, %[[CST_0]]) : (tensor<37x40xf32>, tensor<2xi32>) -> tensor<40x37xf32> +// CHECK: %[[CST_1:.*]] = constant unit +// CHECK: "tfl.fully_connected"(%[[ARG_0]], %arg1, %[[CST_1]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> tensor<40x40xf32> +} + +func @concat_v2_with_3_tensors(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2: tensor<2x1xi32>) -> tensor<2x3xi32> { %0 = "tf.Const"() { value = dense<-1> : tensor } : () -> tensor %1 = "tf.ConcatV2"(%arg0, %arg1, %arg2, %0) : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>, tensor) -> tensor<2x3xi32> return %1 : tensor<2x3xi32> -// CHECK-LABEL: concatv2With3Tensors +// CHECK-LABEL: concat_v2_with_3_tensors // CHECK: "tfl.concatenation"(%arg0, %arg1, %arg2) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x3xi32> } -func @concatv2I64Axis(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2: tensor<2x1xi32>) -> tensor<2x3xi32> { +func @concat_v2_i64_axis(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2: tensor<2x1xi32>) -> tensor<2x3xi32> { %0 = "tf.Const"() { value = dense<-1> : tensor } : () -> tensor %1 = "tf.ConcatV2"(%arg0, %arg1, %arg2, %0) : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>, tensor) -> tensor<2x3xi32> return %1 : tensor<2x3xi32> -// CHECK-LABEL: concatv2I64Axis +// CHECK-LABEL: concat_v2_i64_axis // CHECK: "tfl.concatenation"(%arg0, %arg1, %arg2) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x3xi32> } +func @concat_v2_with_bool_type(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "tf.Const"() { value = dense<-1> : tensor } : () -> tensor + %1 = "tf.ConcatV2"(%arg0, %arg1, %0) : (tensor, tensor, tensor) -> tensor + return %1 : tensor + +// CHECK-LABEL: concat_v2_with_bool_type +// CHECK: "tfl.concatenation"(%arg0, %arg1) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor, tensor) -> tensor +} + func @resize_with_bilinear(%arg0: tensor<1x100x100x3xf32>, %arg1: tensor<4xi32>) -> tensor { %0 = "tf.ResizeBilinear"(%arg0, %arg1) {align_corners = true} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor return %0 : tensor @@ -1324,10 +1364,7 @@ func @conv2d_backprop_input(%arg0: tensor<4xi32>, %arg1: tensor<3x3x1x32xf32>, % // CHECK: %[[ARG0:.*]] = "tfl.transpose"(%arg1, %[[CST]]) : (tensor<3x3x1x32xf32>, tensor<4xi32>) -> tensor<1x3x3x32xf32> // CHECK: %[[CST_0:.*]] = constant unit // CHECK: %[[ARG1:.*]] = "tfl.transpose_conv"(%arg0, %[[ARG0]], %arg2, %[[CST_0]]) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x3x3x32xf32>, tensor<15x14x14x32xf32>, none) -> tensor<15x28x28x1xf32> - // CHECK: %[[CST_1:.*]] = constant dense<[2, 0, 1, 3]> : tensor<4xi32> - // CHECK: %[[ARG2:.*]] = "tfl.transpose"(%arg1, %[[CST_1]]) : (tensor<3x3x1x32xf32>, tensor<4xi32>) -> tensor<1x3x3x32xf32> - // CHECK: %[[CST_2:.*]] = constant unit - // CHECK: %[[ARG3:.*]] = "tfl.transpose_conv"(%arg0, %[[ARG2]], %arg2, %[[CST_2]]) {padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x3x3x32xf32>, tensor<15x14x14x32xf32>, none) -> tensor<15x28x28x1xf32> + // CHECK: %[[ARG3:.*]] = "tfl.transpose_conv"(%arg0, %[[ARG0]], %arg2, %[[CST_0]]) {padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x3x3x32xf32>, tensor<15x14x14x32xf32>, none) -> tensor<15x28x28x1xf32> // CHECK: %[[RESULT:.*]] = tfl.add %[[ARG1]], %[[ARG3]] {fused_activation_function = "NONE"} : tensor<15x28x28x1xf32> // CHECK: return %[[RESULT]] : tensor<15x28x28x1xf32> } @@ -1533,3 +1570,27 @@ func @add_with_int32_5d_inputs(%arg0: tensor<1x1x1x3x1xi32>, %arg1 : tensor<1x1x // CHECK-LABEL: add_with_int32_5d_inputs // CHECK: "tf.Add"(%arg0, %arg1) } + +func @tranpose_int32_perm(%arg0: tensor<2x3xf32>) -> tensor<3x2xf32> { + %cst = "tf.Const"() { value = dense<[1, 0]> : tensor<2xi32> } : () -> tensor<2xi32> + %0 = "tf.Transpose"(%arg0, %cst): (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32> + return %0 : tensor<3x2xf32> + // CHECK-LABEL: tranpose_int32_perm + // CHECK: "tfl.transpose" +} + +func @tranpose_int64_perm(%arg0: tensor<2x3xf32>) -> tensor<3x2xf32> { + %cst = "tf.Const"() { value = dense<[1, 0]> : tensor<2xi64> } : () -> tensor<2xi64> + %0 = "tf.Transpose"(%arg0, %cst): (tensor<2x3xf32>, tensor<2xi64>) -> tensor<3x2xf32> + return %0 : tensor<3x2xf32> + // CHECK-LABEL: tranpose_int64_perm + // CHECK: "tfl.transpose" +} + +func @tranpose_arg(%arg0: tensor<2x3xf32>, %arg1: tensor<2xi32>) -> tensor<3x2xf32> { + %0 = "tf.Transpose"(%arg0, %arg1): (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32> + return %0 : tensor<3x2xf32> + // CHECK-LABEL: tranpose_arg + // CHECK: "tfl.transpose" +} + diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2exec/BUILD b/tensorflow/compiler/mlir/lite/tests/mlir2exec/BUILD index 745d9eacf15..35e0a376384 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2exec/BUILD +++ b/tensorflow/compiler/mlir/lite/tests/mlir2exec/BUILD @@ -6,6 +6,7 @@ # runtime behavior, but the majority of runtime tests should be TFLite side and # invariants only verified in the converter/compiler. +load("//tensorflow:tensorflow.bzl", "filegroup") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") licenses(["notice"]) diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/BUILD b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/BUILD index c0ae9570225..e77b8d8fbd5 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/BUILD +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/BUILD @@ -1,3 +1,4 @@ +load("//tensorflow:tensorflow.bzl", "filegroup") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") licenses(["notice"]) diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/basic_lstm.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/basic_lstm.mlir index 8389045fc57..b2f684e6be8 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/basic_lstm.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/basic_lstm.mlir @@ -4,8 +4,9 @@ func @main(tensor<1x384xf32>, tensor<1x96xf32>, tensor<384x480xf32>, tensor<384x // CHECK: { // CHECK-NEXT: version: 3, // CHECK-NEXT: operator_codes: [ { -// CHECK-NEXT: builtin_code: LSTM, +// CHECK-NEXT: deprecated_builtin_code: 16, // CHECK-NEXT: version: 2 +// CHECK-NEXT: builtin_code: LSTM // CHECK-NEXT: } ], // CHECK-NEXT: subgraphs: [ { // CHECK-NEXT: tensors: [ { @@ -115,6 +116,7 @@ func @main(tensor<1x384xf32>, tensor<1x96xf32>, tensor<384x480xf32>, tensor<384x // CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: buffer: 10 // CHECK-NEXT: } ] +// CHECK-NEXT: signature_defs: [ ] // CHECK-NEXT:} ^bb0(%arg0: tensor<1x384xf32>, %arg1: tensor<1x96xf32>, %arg2: tensor<384x480xf32>, %arg3: tensor<384xf32>, %arg4: tensor<1x96xf32>): diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/custom_op_with_tflite_op.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/custom_op_with_tflite_op.mlir index 2d906d6901e..a067826f86d 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/custom_op_with_tflite_op.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/custom_op_with_tflite_op.mlir @@ -6,14 +6,17 @@ func @main(tensor<4xf32>) -> tensor<4xf32> { // CHECK: { // CHECK-NEXT: version: 3, // CHECK-NEXT: operator_codes: [ { -// CHECK-NEXT: builtin_code: MUL, +// CHECK-NEXT: deprecated_builtin_code: 18, // CHECK-NEXT: version: 1 +// CHECK-NEXT: builtin_code: MUL // CHECK-NEXT: }, { -// CHECK-NEXT: builtin_code: CUSTOM, -// CHECK-NEXT: custom_code: "MyCustomOp" +// CHECK-NEXT: deprecated_builtin_code: 32, +// CHECK-NEXT: custom_code: "MyCustomOp", +// CHECK-NEXT: builtin_code: CUSTOM // CHECK-NEXT: }, { -// CHECK-NEXT: builtin_code: EXP, -// CHECK-NEXT: version: 1 +// CHECK-NEXT: deprecated_builtin_code: 47, +// CHECK-NEXT: version: 1, +// CHECK-NEXT: builtin_code: EXP // CHECK-NEXT: } ], // CHECK-NEXT: subgraphs: [ { // CHECK-NEXT: tensors: [ { @@ -97,6 +100,7 @@ func @main(tensor<4xf32>) -> tensor<4xf32> { // CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: buffer: 6 // CHECK-NEXT: } ] +// CHECK-NEXT: signature_defs: [ ] // CHECK-NEXT:} %0 = "tfl.pseudo_const" () {value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const") diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/depthwise_conv2d.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/depthwise_conv2d.mlir index 98c3eb154e1..ef82175a47d 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/depthwise_conv2d.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/depthwise_conv2d.mlir @@ -5,11 +5,13 @@ func @main(tensor<1x224x224x3xf32>) -> tensor<1x112x112x32xf32> { // CHECK: { // CHECK-NEXT: version: 3, // CHECK-NEXT: operator_codes: [ { - // CHECK-NEXT: builtin_code: DEQUANTIZE, + // CHECK-NEXT: deprecated_builtin_code: 6, // CHECK-NEXT: version: 1 + // CHECK-NEXT: builtin_code: DEQUANTIZE // CHECK-NEXT: }, { - // CHECK-NEXT: builtin_code: DEPTHWISE_CONV_2D, + // CHECK-NEXT: deprecated_builtin_code: 4, // CHECK-NEXT: version: 1 + // CHECK-NEXT: builtin_code: DEPTHWISE_CONV_2D // CHECK-NEXT: } ], // CHECK-NEXT: subgraphs: [ { // CHECK-NEXT: tensors: [ { @@ -89,6 +91,7 @@ func @main(tensor<1x224x224x3xf32>) -> tensor<1x112x112x32xf32> { // CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: buffer: 6 // CHECK-NEXT: } ] + // CHECK-NEXT: signature_defs: [ ] // CHECK-NEXT:} %0 = "tfl.pseudo_const" () {value = dense<-1.23697901> : tensor<32xf32>} : () -> tensor<32xf32> loc("Const") diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/depthwise_conv2d_v2.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/depthwise_conv2d_v2.mlir index 86f27936946..f4bc10b2fe2 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/depthwise_conv2d_v2.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/depthwise_conv2d_v2.mlir @@ -5,11 +5,13 @@ func @main(tensor<1x224x224x3xf32>) -> tensor<1x112x112x32xf32> { // CHECK: { // CHECK-NEXT: version: 3, // CHECK-NEXT: operator_codes: [ { - // CHECK-NEXT: builtin_code: DEQUANTIZE, - // CHECK-NEXT: version: 1 + // CHECK-NEXT: deprecated_builtin_code: 6, + // CHECK-NEXT: version: 1, + // CHECK-NEXT: builtin_code: DEQUANTIZE // CHECK-NEXT: }, { - // CHECK-NEXT: builtin_code: DEPTHWISE_CONV_2D, - // CHECK-NEXT: version: 2 + // CHECK-NEXT: deprecated_builtin_code: 4, + // CHECK-NEXT: version: 2, + // CHECK-NEXT: builtin_code: DEPTHWISE_CONV_2D // CHECK-NEXT: } ], // CHECK-NEXT: subgraphs: [ { // CHECK-NEXT: tensors: [ { @@ -91,6 +93,7 @@ func @main(tensor<1x224x224x3xf32>) -> tensor<1x112x112x32xf32> { // CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: buffer: 6 // CHECK-NEXT: } ] + // CHECK-NEXT: signature_defs: [ ] // CHECK-NEXT:} %0 = "tfl.pseudo_const" () {value = dense<-1.23697901> : tensor<32xf32>} : () -> tensor<32xf32> loc("Const") diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/disable_flex_enable_builtin.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/disable_flex_enable_builtin.mlir index c034fa7e462..f7ff99b117d 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/disable_flex_enable_builtin.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/disable_flex_enable_builtin.mlir @@ -5,11 +5,13 @@ func @main(tensor<4xf32>) -> tensor<4xf32> { // CHECK: { // CHECK-NEXT: version: 3, // CHECK-NEXT: operator_codes: [ { -// CHECK-NEXT: builtin_code: MUL, -// CHECK-NEXT: version: 1 +// CHECK-NEXT: deprecated_builtin_code: 18, +// CHECK-NEXT: version: 1, +// CHECK-NEXT: builtin_code: MUL // CHECK-NEXT: }, { -// CHECK-NEXT: builtin_code: EXP, -// CHECK-NEXT: version: 1 +// CHECK-NEXT: deprecated_builtin_code: 47, +// CHECK-NEXT: version: 1, +// CHECK-NEXT: builtin_code: EXP // CHECK-NEXT: } ], // CHECK-NEXT: subgraphs: [ { // CHECK-NEXT: tensors: [ { @@ -95,6 +97,7 @@ func @main(tensor<4xf32>) -> tensor<4xf32> { // CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: buffer: 6 // CHECK-NEXT: } ] +// CHECK-NEXT: signature_defs: [ ] // CHECK-NEXT:} %0 = "tfl.pseudo_const" () {value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const") diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/fake_quant.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/fake_quant.mlir index 6d8c54b783a..9aca1ecb47d 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/fake_quant.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/fake_quant.mlir @@ -6,8 +6,9 @@ func @main(tensor<4xf32>) -> tensor<4xf32> { // CHECK: { // CHECK-NEXT: version: 3, // CHECK-NEXT: operator_codes: [ { -// CHECK-NEXT: builtin_code: FAKE_QUANT, -// CHECK-NEXT: version: 1 +// CHECK-NEXT: deprecated_builtin_code: 80, +// CHECK-NEXT: version: 1, +// CHECK-NEXT: builtin_code: FAKE_QUANT // CHECK-NEXT: } ], // CHECK-NEXT: subgraphs: [ { // CHECK-NEXT: tensors: [ { @@ -53,6 +54,7 @@ func @main(tensor<4xf32>) -> tensor<4xf32> { // CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: buffer: 3 // CHECK-NEXT: } ] +// CHECK-NEXT: signature_defs: [ ] // CHECK-NEXT: } // IMPORT: "tfl.fake_quant"(%arg0) {max = 1.400000e+00 : f32, min = 3.000000e-01 : f32, narrow_range = false, num_bits = 6 : i32} diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/flex_exclusively.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/flex_exclusively.mlir index 018d99fc74d..b2d7f611ede 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/flex_exclusively.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/flex_exclusively.mlir @@ -4,8 +4,9 @@ func @main(%arg0: tensor<3x2xf32>) -> tensor<3x2xf32> { // CHECK: { // CHECK-NEXT: version: 3, // CHECK-NEXT: operator_codes: [ { -// CHECK-NEXT: builtin_code: CUSTOM, +// CHECK-NEXT: deprecated_builtin_code: 32, // CHECK-NEXT: custom_code: "FlexAddV2" +// CHECK-NEXT: builtin_code: CUSTOM // CHECK-NEXT: } ], // CHECK-NEXT: subgraphs: [ { // CHECK-NEXT: tensors: [ { @@ -46,6 +47,7 @@ func @main(%arg0: tensor<3x2xf32>) -> tensor<3x2xf32> { // CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: buffer: 3 // CHECK-NEXT: } ] +// CHECK-NEXT: signature_defs: [ ] // CHECK-NEXT: } %0 = "tf.AddV2"(%arg0, %arg0) : (tensor<3x2xf32>, tensor<3x2xf32>) -> tensor<3x2xf32> diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/flex_op_with_complex128.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/flex_op_with_complex128.mlir index a5e6d4aabb5..b8749b4b76c 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/flex_op_with_complex128.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/flex_op_with_complex128.mlir @@ -5,8 +5,9 @@ func @main(tensor<4xcomplex>, tensor<4xcomplex>) -> tensor<4xcomplex>, tensor<4xcomplex>) -> tensor<4xcomplex>, tensor<4xcomplex>) -> tensor<4xcomplex> loc("add") diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/flex_op_with_f64.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/flex_op_with_f64.mlir index 4b75d3e8ff4..c8f3949500e 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/flex_op_with_f64.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/flex_op_with_f64.mlir @@ -5,8 +5,9 @@ func @main(tensor<4xf64>, tensor<4xf64>) -> tensor<4xf64> { // CHECK: { // CHECK-NEXT: version: 3, // CHECK-NEXT: operator_codes: [ { -// CHECK-NEXT: builtin_code: CUSTOM, -// CHECK-NEXT: custom_code: "FlexAdd" +// CHECK-NEXT: deprecated_builtin_code: 32, +// CHECK-NEXT: custom_code: "FlexAdd", +// CHECK-NEXT: builtin_code: CUSTOM // CHECK-NEXT: } ], // CHECK-NEXT: subgraphs: [ { // CHECK-NEXT: tensors: [ { @@ -59,6 +60,7 @@ func @main(tensor<4xf64>, tensor<4xf64>) -> tensor<4xf64> { // CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: buffer: 4 // CHECK-NEXT: } ] +// CHECK-NEXT: signature_defs: [ ] // CHECK-NEXT:} %0 = "tf.Add"(%arg0, %arg1) : (tensor<4xf64>, tensor<4xf64>) -> tensor<4xf64> loc("add") diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/flex_op_with_tflite_op.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/flex_op_with_tflite_op.mlir index 8a9175b5c59..059cfc0d54e 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/flex_op_with_tflite_op.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/flex_op_with_tflite_op.mlir @@ -5,14 +5,17 @@ func @main(tensor<4xf32>) -> tensor<4xf32> { // CHECK: { // CHECK-NEXT: version: 3, // CHECK-NEXT: operator_codes: [ { +// CHECK-NEXT: deprecated_builtin_code: 18, +// CHECK-NEXT: version: 1, // CHECK-NEXT: builtin_code: MUL -// CHECK-NEXT: version: 1 // CHECK-NEXT: }, { -// CHECK-NEXT: builtin_code: CUSTOM, -// CHECK-NEXT: custom_code: "FlexDiv" +// CHECK-NEXT: deprecated_builtin_code: 32, +// CHECK-NEXT: custom_code: "FlexDiv", +// CHECK-NEXT: builtin_code: CUSTOM // CHECK-NEXT: }, { +// CHECK-NEXT: deprecated_builtin_code: 47, +// CHECK-NEXT: version: 1, // CHECK-NEXT: builtin_code: EXP -// CHECK-NEXT: version: 1 // CHECK-NEXT: } ], // CHECK-NEXT: subgraphs: [ { // CHECK-NEXT: tensors: [ { @@ -96,6 +99,7 @@ func @main(tensor<4xf32>) -> tensor<4xf32> { // CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: buffer: 6 // CHECK-NEXT: } ] +// CHECK-NEXT: signature_defs: [ ] // CHECK-NEXT:} %0 = "tfl.pseudo_const" () {value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const") diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/fully_connected.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/fully_connected.mlir index bbe4fdb8337..b01bafe4ea7 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/fully_connected.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/fully_connected.mlir @@ -5,8 +5,9 @@ func @main(tensor<40x37xf32>, tensor<40x37xf32>) -> tensor<40x40xf32> { // CHECK: { // CHECK-NEXT: version: 3, // CHECK-NEXT: operator_codes: [ { - // CHECK-NEXT: builtin_code: FULLY_CONNECTED, - // CHECK-NEXT: version: 1 + // CHECK-NEXT: deprecated_builtin_code: 9, + // CHECK-NEXT: version: 1, + // CHECK-NEXT: builtin_code: FULLY_CONNECTED // CHECK-NEXT: } ], // CHECK-NEXT: subgraphs: [ { // CHECK-NEXT: tensors: [ { @@ -68,6 +69,7 @@ func @main(tensor<40x37xf32>, tensor<40x37xf32>) -> tensor<40x40xf32> { // CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: buffer: 5 // CHECK-NEXT: } ] + // CHECK-NEXT: signature_defs: [ ] // CHECK-NEXT:} %cst = constant unit diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/fully_connected_v2.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/fully_connected_v2.mlir index 0abe720ccba..95bcc1547f7 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/fully_connected_v2.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/fully_connected_v2.mlir @@ -5,8 +5,9 @@ func @main(tensor<40x37xf32>, tensor<40x37xf32>) -> tensor<40x40xf32> { // CHECK: { // CHECK-NEXT: version: 3, // CHECK-NEXT: operator_codes: [ { - // CHECK-NEXT: builtin_code: FULLY_CONNECTED, - // CHECK-NEXT: version: 2 + // CHECK-NEXT: deprecated_builtin_code: 9, + // CHECK-NEXT: version: 2, + // CHECK-NEXT: builtin_code: FULLY_CONNECTED // CHECK-NEXT: } ], // CHECK-NEXT: subgraphs: [ { // CHECK-NEXT: tensors: [ { @@ -68,6 +69,7 @@ func @main(tensor<40x37xf32>, tensor<40x37xf32>) -> tensor<40x40xf32> { // CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: buffer: 5 // CHECK-NEXT: } ] + // CHECK-NEXT: signature_defs: [ ] // CHECK-NEXT:} %cst = constant unit diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/hashtable_resource.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/hashtable_resource.mlir index 3adee1dec77..2d5852dd83d 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/hashtable_resource.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/hashtable_resource.mlir @@ -3,8 +3,9 @@ // CHECK: { // CHECK: version: 3, // CHECK: operator_codes: [ { -// CHECK: builtin_code: CUSTOM, -// CHECK: custom_code: "HashTableV2" +// CHECK: deprecated_builtin_code: 32, +// CHECK: custom_code: "HashTableV2", +// CHECK: builtin_code: CUSTOM // CHECK: } ], // CHECK: subgraphs: [ { // CHECK: tensors: [ { diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/if_op.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/if_op.mlir index 7290209cc4a..c89239c2e6f 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/if_op.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/if_op.mlir @@ -4,16 +4,19 @@ // CHECK: { // CHECK-NEXT: version: 3, // CHECK-NEXT: operator_codes: [ { -// CHECK-NEXT: builtin_code: LESS, -// CHECK-NEXT: version: 1 +// CHECK-NEXT: deprecated_builtin_code: 58, +// CHECK-NEXT: version: 1, +// CHECK-NEXT: builtin_code: LESS // CHECK-NEXT: }, { -// CHECK-NEXT: builtin_code: IF, -// CHECK-NEXT: version: 1 +// CHECK-NEXT: deprecated_builtin_code: 118, +// CHECK-NEXT: version: 1, +// CHECK-NEXT: builtin_code: IF // CHECK-NEXT: }, { // CHECK-NEXT: version: 1 // CHECK-NEXT: }, { -// CHECK-NEXT: builtin_code: MUL, -// CHECK-NEXT: version: 1 +// CHECK-NEXT: deprecated_builtin_code: 18, +// CHECK-NEXT: version: 1, +// CHECK-NEXT: builtin_code: MUL // CHECK-NEXT: } ], // CHECK-NEXT: subgraphs: [ { // CHECK-NEXT: tensors: [ { @@ -163,6 +166,7 @@ // CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: buffer: 11 // CHECK-NEXT: } ] +// CHECK-NEXT: signature_defs: [ ] // CHECK-NEXT: } func @main(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> { diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/logical.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/logical.mlir index 84cbf48c099..f32fe880121 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/logical.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/logical.mlir @@ -5,11 +5,13 @@ func @main(tensor<4xi1>) -> tensor<4xi1> { // CHECK: { // CHECK-NEXT: version: 3, // CHECK-NEXT: operator_codes: [ { - // CHECK-NEXT: builtin_code: LOGICAL_OR, - // CHECK-NEXT: version: 1 + // CHECK-NEXT: deprecated_builtin_code: 84, + // CHECK-NEXT: version: 1, + // CHECK-NEXT: builtin_code: LOGICAL_OR // CHECK-NEXT: }, { - // CHECK-NEXT: builtin_code: LOGICAL_AND, - // CHECK-NEXT: version: 1 + // CHECK-NEXT: deprecated_builtin_code: 86, + // CHECK-NEXT: version: 1, + // CHECK-NEXT: builtin_code: LOGICAL_AND // CHECK-NEXT: } ], // CHECK-NEXT: subgraphs: [ { // CHECK-NEXT: tensors: [ { @@ -85,6 +87,7 @@ func @main(tensor<4xi1>) -> tensor<4xi1> { // CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: buffer: 6 // CHECK-NEXT: } ] + // CHECK-NEXT: signature_defs: [ ] // CHECK-NEXT: } // CHECK-EMPTY: diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/lstm.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/lstm.mlir index 707bc926870..017870ca334 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/lstm.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/lstm.mlir @@ -4,8 +4,9 @@ func @main(tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, t // CHECK: { // CHECK-NEXT: version: 3, // CHECK-NEXT: operator_codes: [ { -// CHECK-NEXT: builtin_code: LSTM, -// CHECK-NEXT: version: 1 +// CHECK-NEXT: deprecated_builtin_code: 16, +// CHECK-NEXT: version: 1, +// CHECK-NEXT: builtin_code: LSTM // CHECK-NEXT: } ], // CHECK-NEXT: subgraphs: [ { // CHECK-NEXT: tensors: [ { @@ -257,6 +258,7 @@ func @main(tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, t // CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: buffer: 26 // CHECK-NEXT: } ] +// CHECK-NEXT: signature_defs: [ ] // CHECK-NEXT: } // CHECK-EMPTY: diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/lstm_quantized.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/lstm_quantized.mlir index 5985ffaa446..10332e45bec 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/lstm_quantized.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/lstm_quantized.mlir @@ -7,8 +7,9 @@ func @main(%arg0: tensor<1x528x!quant.uniform> // CHECK: { // CHECK-NEXT: version: 3, // CHECK-NEXT: operator_codes: [ { -// CHECK-NEXT: builtin_code: LSTM, -// CHECK-NEXT: version: 1 +// CHECK-NEXT: deprecated_builtin_code: 16, +// CHECK-NEXT: version: 1, +// CHECK-NEXT: builtin_code: LSTM // CHECK-NEXT: } ], // CHECK-NEXT: subgraphs: [ { // CHECK-NEXT: tensors: [ { @@ -319,5 +320,6 @@ func @main(%arg0: tensor<1x528x!quant.uniform> // CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: buffer: 23 // CHECK-NEXT: } ] +// CHECK-NEXT: signature_defs: [ ] // CHECK-NEXT: } } diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/math.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/math.mlir index 297a8b8cb59..eeca4267524 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/math.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/math.mlir @@ -5,20 +5,25 @@ func @main(tensor<4xf32>) -> tensor<4xf32> { // CHECK: { // CHECK-NEXT: version: 3, // CHECK-NEXT: operator_codes: [ { - // CHECK-NEXT: builtin_code: SQUARED_DIFFERENCE, - // CHECK-NEXT: version: 1 + // CHECK-NEXT: deprecated_builtin_code: 99, + // CHECK-NEXT: version: 1, + // CHECK-NEXT: builtin_code: SQUARED_DIFFERENCE // CHECK-NEXT: }, { - // CHECK-NEXT: builtin_code: MUL, - // CHECK-NEXT: version: 1 + // CHECK-NEXT: deprecated_builtin_code: 18, + // CHECK-NEXT: version: 1, + // CHECK-NEXT: builtin_code: MUL // CHECK-NEXT: }, { - // CHECK-NEXT: builtin_code: DIV, - // CHECK-NEXT: version: 1 + // CHECK-NEXT: deprecated_builtin_code: 42, + // CHECK-NEXT: version: 1, + // CHECK-NEXT: builtin_code: DIV // CHECK-NEXT: }, { - // CHECK-NEXT: builtin_code: EXP, - // CHECK-NEXT: version: 1 + // CHECK-NEXT: deprecated_builtin_code: 47, + // CHECK-NEXT: version: 1, + // CHECK-NEXT: builtin_code: EXP // CHECK-NEXT: }, { - // CHECK-NEXT: builtin_code: NEG, - // CHECK-NEXT: version: 1 + // CHECK-NEXT: deprecated_builtin_code: 59, + // CHECK-NEXT: version: 1, + // CHECK-NEXT: builtin_code: NEG // CHECK-NEXT: } ], // CHECK-NEXT: subgraphs: [ { // CHECK-NEXT: tensors: [ { @@ -135,6 +140,7 @@ func @main(tensor<4xf32>) -> tensor<4xf32> { // CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: buffer: 8 // CHECK-NEXT: } ] + // CHECK-NEXT: signature_defs: [ ] // CHECK-NEXT: } %0 = "tfl.pseudo_const" () {value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const") diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/metadata.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/metadata.mlir index 49d71f24d2d..3fb00cf6024 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/metadata.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/metadata.mlir @@ -33,4 +33,5 @@ module attributes { // CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: buffer: 6 // CHECK-NEXT: } ] +// CHECK-NEXT: signature_defs: [ ] // CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/mul_v2.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/mul_v2.mlir index 15fce806a70..c8af68a190d 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/mul_v2.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/mul_v2.mlir @@ -5,8 +5,9 @@ func @main(tensor<3x!quant.uniform>) -> tensor<3x!quant.uniform>) -> tensor<3x!quant.uniform>, value = dense<2> : tensor<3xi8>} : () -> tensor<3x!quant.uniform> diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/mul_v3.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/mul_v3.mlir index 2e0d76b511a..441dbd8f925 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/mul_v3.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/mul_v3.mlir @@ -5,8 +5,9 @@ func @main(tensor<3x!quant.uniform>) -> tensor<3x!quant.uniform>) -> tensor<3x!quant.uniform>, value = dense<2> : tensor<3xi8>} : () -> tensor<3x!quant.uniform> diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/nn.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/nn.mlir index ffa13532679..ec0fd07c25a 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/nn.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/nn.mlir @@ -5,8 +5,9 @@ func @main(tensor<1x6x6x16xf32>) -> tensor<1x1x1x16xf32> { // CHECK: { // CHECK-NEXT: version: 3, // CHECK-NEXT: operator_codes: [ { - // CHECK-NEXT: builtin_code: AVERAGE_POOL_2D, - // CHECK-NEXT: version: 1 + // CHECK-NEXT: deprecated_builtin_code: 1, + // CHECK-NEXT: version: 1, + // CHECK-NEXT: builtin_code: AVERAGE_POOL_2D // CHECK-NEXT: } ], // CHECK-NEXT: subgraphs: [ { // CHECK-NEXT: tensors: [ { @@ -54,6 +55,7 @@ func @main(tensor<1x6x6x16xf32>) -> tensor<1x1x1x16xf32> { // CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: buffer: 3 // CHECK-NEXT: } ] + // CHECK-NEXT: signature_defs: [ ] // CHECK-NEXT: } %0 = "tfl.average_pool_2d"(%arg0) {filter_height = 3 : i32, filter_width = 6 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 3 : i32, stride_w = 1 : i32} : (tensor<1x6x6x16xf32>) -> tensor<1x1x1x16xf32> loc("avgpool") diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/numeric_verify.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/numeric_verify.mlir index 4f28ad327df..60360c7ded6 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/numeric_verify.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/numeric_verify.mlir @@ -3,8 +3,9 @@ // CHECK: { // CHECK-NEXT: version: 3, // CHECK-NEXT: operator_codes: [ { -// CHECK-NEXT: builtin_code: CUSTOM, -// CHECK-NEXT: custom_code: "NumericVerify" +// CHECK-NEXT: deprecated_builtin_code: 32, +// CHECK-NEXT: custom_code: "NumericVerify", +// CHECK-NEXT: builtin_code: CUSTOM // CHECK-NEXT: } ], // CHECK-NEXT: subgraphs: [ { // CHECK-NEXT: tensors: [ { @@ -47,6 +48,7 @@ // CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: buffer: 3 // CHECK-NEXT: } ] +// CHECK-NEXT: signature_defs: [ ] // CHECK-NEXT:} func @main(%arg0: tensor<4xf32>, %arg1: tensor<4x!quant.uniform>) -> tensor<4xf32> { diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/quantization.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/quantization.mlir index dbe10a3f90c..93581e54f10 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/quantization.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/quantization.mlir @@ -4,20 +4,25 @@ func @main(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x1001xf32> { // CHECK: { // CHECK-NEXT: version: 3, // CHECK-NEXT: operator_codes: [ { -// CHECK-NEXT: builtin_code: QUANTIZE, -// CHECK-NEXT: version: 1 +// CHECK-NEXT: deprecated_builtin_code: 114, +// CHECK-NEXT: version: 1, +// CHECK-NEXT: builtin_code: QUANTIZE // CHECK-NEXT: }, { -// CHECK-NEXT: builtin_code: CONV_2D, -// CHECK-NEXT: version: 1 +// CHECK-NEXT: deprecated_builtin_code: 3, +// CHECK-NEXT: version: 1, +// CHECK-NEXT: builtin_code: CONV_2D // CHECK-NEXT: }, { -// CHECK-NEXT: builtin_code: RESHAPE, -// CHECK-NEXT: version: 1 +// CHECK-NEXT: deprecated_builtin_code: 22, +// CHECK-NEXT: version: 1, +// CHECK-NEXT: builtin_code: RESHAPE // CHECK-NEXT: }, { -// CHECK-NEXT: builtin_code: SOFTMAX, -// CHECK-NEXT: version: 1 +// CHECK-NEXT: deprecated_builtin_code: 25, +// CHECK-NEXT: version: 1, +// CHECK-NEXT: builtin_code: SOFTMAX // CHECK-NEXT: }, { -// CHECK-NEXT: builtin_code: DEQUANTIZE, -// CHECK-NEXT: version: 1 +// CHECK-NEXT: deprecated_builtin_code: 6, +// CHECK-NEXT: version: 1, +// CHECK-NEXT: builtin_code: DEQUANTIZE // CHECK-NEXT: } ], // CHECK-NEXT: subgraphs: [ { // CHECK-NEXT: tensors: [ { @@ -160,6 +165,7 @@ func @main(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x1001xf32> { // CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: buffer: 10 // CHECK-NEXT: } ] +// CHECK-NEXT: signature_defs: [ ] // CHECK-NEXT:} %0 = "tfl.pseudo_const" () {value = dense<[1, 1001]> : tensor<2xi32>} : () -> tensor<2xi32> loc("Const") diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/reshape.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/reshape.mlir index 15defbc3957..af59475f6a1 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/reshape.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/reshape.mlir @@ -5,8 +5,9 @@ func @main(tensor<3x2xi32>) -> tensor<6xi32> { // CHECK: { // CHECK-NEXT: version: 3, // CHECK-NEXT: operator_codes: [ { -// CHECK-NEXT: builtin_code: RESHAPE, -// CHECK-NEXT: version: 1 +// CHECK-NEXT: deprecated_builtin_code: 22, +// CHECK-NEXT: version: 1, +// CHECK-NEXT: builtin_code: RESHAPE // CHECK-NEXT: } ], // CHECK-NEXT: subgraphs: [ { // CHECK-NEXT: tensors: [ { @@ -58,6 +59,7 @@ func @main(tensor<3x2xi32>) -> tensor<6xi32> { // CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: buffer: 4 // CHECK-NEXT: } ] +// CHECK-NEXT: signature_defs: [ ] // CHECK-NEXT: } %0 = "tfl.pseudo_const" () {value = dense<[6]> : tensor<1xi32>} : () -> tensor<1xi32> loc("Const") diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/signature_def.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/signature_def.mlir new file mode 100644 index 00000000000..b9866b4696d --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/signature_def.mlir @@ -0,0 +1,117 @@ +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s + +// CHECK: { +// CHECK-NEXT: version: 3, +// CHECK-NEXT: operator_codes: [ { +// CHECK-NEXT: deprecated_builtin_code: 9, +// CHECK-NEXT: version: 1, +// CHECK-NEXT: builtin_code: FULLY_CONNECTED +// CHECK-NEXT: } ], +// CHECK-NEXT: subgraphs: [ { +// CHECK-NEXT: tensors: [ { +// CHECK-NEXT: shape: [ 1, 384 ], +// CHECK-NEXT: buffer: 1, +// CHECK-NEXT: name: "serving_default_input2:0", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: }, +// CHECK-NEXT: shape_signature: [ -1, 384 ] +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 1, 384 ], +// CHECK-NEXT: buffer: 2, +// CHECK-NEXT: name: "serving_default_input1:0", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: }, +// CHECK-NEXT: shape_signature: [ -1, 384 ] +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 5 ], +// CHECK-NEXT: buffer: 3, +// CHECK-NEXT: name: "std.constant", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 5, 384 ], +// CHECK-NEXT: buffer: 4, +// CHECK-NEXT: name: "std.constant1", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 5, 384 ], +// CHECK-NEXT: buffer: 5, +// CHECK-NEXT: name: "std.constant2", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 1, 5 ], +// CHECK-NEXT: buffer: 6, +// CHECK-NEXT: name: "StatefulPartitionedCall:0", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: }, +// CHECK-NEXT: shape_signature: [ -1, 5 ] +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 1, 5 ], +// CHECK-NEXT: buffer: 7, +// CHECK-NEXT: name: "StatefulPartitionedCall:1", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: }, +// CHECK-NEXT: shape_signature: [ -1, 5 ] +// CHECK-NEXT: } ], +// CHECK-NEXT: inputs: [ 0, 1 ], +// CHECK-NEXT: outputs: [ 6, 5 ], +// CHECK-NEXT: operators: [ { +// CHECK-NEXT: inputs: [ 0, 3, 2 ], +// CHECK-NEXT: outputs: [ 5 ], +// CHECK-NEXT: builtin_options_type: FullyConnectedOptions, +// CHECK-NEXT: builtin_options: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: inputs: [ 0, 4, 2 ], +// CHECK-NEXT: outputs: [ 6 ], +// CHECK-NEXT: builtin_options_type: FullyConnectedOptions, +// CHECK-NEXT: builtin_options: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: } ], +// CHECK-NEXT: name: "main" +// CHECK-NEXT: } ], +// CHECK-NEXT: description: "MLIR Converted.", + +// CHECK: metadata: [ { +// CHECK-NEXT: name: "min_runtime_version", +// CHECK-NEXT: buffer: 8 +// CHECK-NEXT: } ], +// CHECK-NEXT: signature_defs: [ { +// CHECK-NEXT: inputs: [ { +// CHECK-NEXT: name: "input1", +// CHECK-NEXT: tensor_index: 1 +// CHECK-NEXT: }, { +// CHECK-NEXT: name: "input2" +// CHECK-NEXT: } ], +// CHECK-NEXT: outputs: [ { +// CHECK-NEXT: name: "end_logits", +// CHECK-NEXT: tensor_index: 5 +// CHECK-NEXT: }, { +// CHECK-NEXT: name: "start_logits", +// CHECK-NEXT: tensor_index: 6 +// CHECK-NEXT: } ], +// CHECK-NEXT: method_name: "serving_default", +// CHECK-NEXT: key: "" +// CHECK-NEXT: } ] +// CHECK-NEXT:} +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 554 : i32}, tf_saved_model.semantics} { + func @main(%arg0: tensor {tf_saved_model.index_path = ["input2"]}, %arg1: tensor {tf_saved_model.index_path = ["input1"]}) -> (tensor {tf_saved_model.index_path = ["start_logits"]}, tensor {tf_saved_model.index_path = ["end_logits"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input2:0,serving_default_input1:0", outputs = "StatefulPartitionedCall:1,StatefulPartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %cst = constant dense<0.000000e+00> : tensor<5xf32> + %cst_0 = constant dense<1.0> : tensor<5x384xf32> + %cst_1 = constant dense<1.0> : tensor<5x384xf32> + %0 = "tfl.fully_connected"(%arg0, %cst_0, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor, tensor<5x384xf32>, tensor<5xf32>) -> tensor + %1 = "tfl.fully_connected"(%arg0, %cst_1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor, tensor<5x384xf32>, tensor<5xf32>) -> tensor + return %1, %0 : tensor, tensor + } +} diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/simple.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/simple.mlir index 2182db1d39e..fd0e0386c46 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/simple.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/simple.mlir @@ -7,8 +7,9 @@ func @main(tensor<3x2xi32>) -> tensor<3x2xi32> // CHECK: { // CHECK-NEXT: version: 3, // CHECK-NEXT: operator_codes: [ { -// CHECK-NEXT: builtin_code: SUB, -// CHECK-NEXT: version: 1 +// CHECK-NEXT: deprecated_builtin_code: 41, +// CHECK-NEXT: version: 1, +// CHECK-NEXT: builtin_code: SUB // CHECK-NEXT: }, { // CHECK-NEXT: version: 1 // CHECK-NEXT: } ], @@ -104,6 +105,7 @@ func @main(tensor<3x2xi32>) -> tensor<3x2xi32> // CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: buffer: 6 // CHECK-NEXT: } ] +// CHECK-NEXT: signature_defs: [ ] // CHECK-NEXT: } %0 = "tfl.pseudo_const" () {value = dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32> loc("Const") diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/svdf.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/svdf.mlir index 3d29823c93c..3f48facd122 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/svdf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/svdf.mlir @@ -4,8 +4,9 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>) - // CHECK: { // CHECK-NEXT: version: 3, // CHECK-NEXT: operator_codes: [ { -// CHECK-NEXT: builtin_code: SVDF, -// CHECK-NEXT: version: 1 +// CHECK-NEXT: deprecated_builtin_code: 27, +// CHECK-NEXT: version: 1, +// CHECK-NEXT: builtin_code: SVDF // CHECK-NEXT: } ], // CHECK-NEXT: subgraphs: [ { // CHECK-NEXT: tensors: [ { @@ -86,6 +87,7 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>) - // CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: buffer: 7 // CHECK-NEXT: } ] +// CHECK-NEXT: signature_defs: [ ] // CHECK-NEXT: } // CHECK-EMPTY: diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/svdf_v2.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/svdf_v2.mlir index 8dfa68798b8..1b6b66ed087 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/svdf_v2.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/svdf_v2.mlir @@ -4,8 +4,9 @@ func @main(tensor<4 x f32>, tensor<4 x i8>, tensor<4 x f32>, tensor<4 x f32>) -> // CHECK: { // CHECK-NEXT: version: 3, // CHECK-NEXT: operator_codes: [ { -// CHECK-NEXT: builtin_code: SVDF, -// CHECK-NEXT: version: 2 +// CHECK-NEXT: deprecated_builtin_code: 27, +// CHECK-NEXT: version: 2, +// CHECK-NEXT: builtin_code: SVDF // CHECK-NEXT: } ], // CHECK-NEXT: subgraphs: [ { // CHECK-NEXT: tensors: [ { @@ -87,6 +88,7 @@ func @main(tensor<4 x f32>, tensor<4 x i8>, tensor<4 x f32>, tensor<4 x f32>) -> // CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: buffer: 7 // CHECK-NEXT: } ] +// CHECK-NEXT: signature_defs: [ ] // CHECK-NEXT: } // CHECK-EMPTY: diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/tfl_while_op.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/tfl_while_op.mlir index 996543cc9c7..68b21765717 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/tfl_while_op.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/tfl_while_op.mlir @@ -3,14 +3,17 @@ // CHECK: { // CHECK-NEXT: version: 3, // CHECK-NEXT: operator_codes: [ { -// CHECK-NEXT: builtin_code: WHILE, -// CHECK-NEXT: version: 1 +// CHECK-NEXT: deprecated_builtin_code: 119, +// CHECK-NEXT: version: 1, +// CHECK-NEXT: builtin_code: WHILE // CHECK-NEXT: }, { -// CHECK-NEXT: builtin_code: GREATER, -// CHECK-NEXT: version: 1 +// CHECK-NEXT: deprecated_builtin_code: 61, +// CHECK-NEXT: version: 1, +// CHECK-NEXT: builtin_code: GREATER // CHECK-NEXT: }, { -// CHECK-NEXT: builtin_code: SUB, -// CHECK-NEXT: version: 1 +// CHECK-NEXT: deprecated_builtin_code: 41, +// CHECK-NEXT: version: 1, +// CHECK-NEXT: builtin_code: SUB // CHECK-NEXT: }, { // CHECK-NEXT: version: 1 // CHECK-NEXT: } ], @@ -196,6 +199,7 @@ // CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: buffer: 14 // CHECK-NEXT: } ] +// CHECK-NEXT: signature_defs: [ ] // CHECK-NEXT: } func @WhileOp_cond(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> tensor { diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/transpose_conv_optional.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/transpose_conv_optional.mlir index ca335ebd000..1256dd19036 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/transpose_conv_optional.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/transpose_conv_optional.mlir @@ -4,8 +4,9 @@ func @main(%arg0: tensor<4xi32>, %arg1: tensor<32x4x4x128xf32>, %arg2: tensor<1x // CHECK: { // CHECK-NEXT: version: 3, // CHECK-NEXT: operator_codes: [ { -// CHECK-NEXT: builtin_code: TRANSPOSE_CONV, -// CHECK-NEXT: version: 1 +// CHECK-NEXT: deprecated_builtin_code: 67, +// CHECK-NEXT: version: 1, +// CHECK-NEXT: builtin_code: TRANSPOSE_CONV // CHECK-NEXT: } ], // CHECK-NEXT: subgraphs: [ { // CHECK-NEXT: tensors: [ { @@ -69,6 +70,7 @@ func @main(%arg0: tensor<4xi32>, %arg1: tensor<32x4x4x128xf32>, %arg2: tensor<1x // CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: buffer: 5 // CHECK-NEXT: } ] +// CHECK-NEXT: signature_defs: [ ] // CHECK-NEXT:} %cst = constant unit diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/type_attr.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/type_attr.mlir index 01410d370d4..690331dec84 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/type_attr.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/type_attr.mlir @@ -3,8 +3,9 @@ // CHECK: { // CHECK: version: 3, // CHECK: operator_codes: [ { -// CHECK: builtin_code: CUSTOM, -// CHECK: custom_code: "SomeOperation" +// CHECK: deprecated_builtin_code: 32, +// CHECK: custom_code: "SomeOperation", +// CHECK: builtin_code: CUSTOM // CHECK: } ], // CHECK: subgraphs: [ { // CHECK: tensors: [ { diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unidirectional_sequence_lstm.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unidirectional_sequence_lstm.mlir index 9b0315e1e20..ffb5b2c781b 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unidirectional_sequence_lstm.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unidirectional_sequence_lstm.mlir @@ -4,8 +4,9 @@ func @main(tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, t // CHECK: { // CHECK-NEXT: version: 3, // CHECK-NEXT: operator_codes: [ { -// CHECK-NEXT: builtin_code: UNIDIRECTIONAL_SEQUENCE_LSTM, -// CHECK-NEXT: version: 1 +// CHECK-NEXT: deprecated_builtin_code: 44, +// CHECK-NEXT: version: 1, +// CHECK-NEXT: builtin_code: UNIDIRECTIONAL_SEQUENCE_LSTM // CHECK-NEXT: } ], // CHECK-NEXT: subgraphs: [ { // CHECK-NEXT: tensors: [ { @@ -256,6 +257,7 @@ func @main(tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, t // CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: buffer: 26 // CHECK-NEXT: } ] +// CHECK-NEXT: signature_defs: [ ] // CHECK-NEXT: } // CHECK-EMPTY: diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unidirectional_sequence_rnn.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unidirectional_sequence_rnn.mlir index 67349b857f7..5b29c1ff050 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unidirectional_sequence_rnn.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unidirectional_sequence_rnn.mlir @@ -4,8 +4,9 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>) - // CHECK: { // CHECK-NEXT: version: 3, // CHECK-NEXT: operator_codes: [ { -// CHECK-NEXT: builtin_code: UNIDIRECTIONAL_SEQUENCE_RNN, -// CHECK-NEXT: version: 1 +// CHECK-NEXT: deprecated_builtin_code: 35, +// CHECK-NEXT: version: 1, +// CHECK-NEXT: builtin_code: UNIDIRECTIONAL_SEQUENCE_RNN // CHECK-NEXT: } ], // CHECK-NEXT: subgraphs: [ { // CHECK-NEXT: tensors: [ { @@ -86,6 +87,7 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>) - // CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: buffer: 7 // CHECK-NEXT: } ] +// CHECK-NEXT: signature_defs: [ ] // CHECK-NEXT: } // CHECK-EMPTY: diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/while_op.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/while_op.mlir index d69e8f40311..51935676eed 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/while_op.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/while_op.mlir @@ -3,14 +3,17 @@ // CHECK: { // CHECK-NEXT: version: 3, // CHECK-NEXT: operator_codes: [ { -// CHECK-NEXT: builtin_code: WHILE, -// CHECK-NEXT: version: 1 +// CHECK-NEXT: deprecated_builtin_code: 119, +// CHECK-NEXT: version: 1, +// CHECK-NEXT: builtin_code: WHILE // CHECK-NEXT: }, { -// CHECK-NEXT: builtin_code: GREATER, -// CHECK-NEXT: version: 1 +// CHECK-NEXT: deprecated_builtin_code: 61, +// CHECK-NEXT: version: 1, +// CHECK-NEXT: builtin_code: GREATER // CHECK-NEXT: }, { -// CHECK-NEXT: builtin_code: SUB, -// CHECK-NEXT: version: 1 +// CHECK-NEXT: deprecated_builtin_code: 41, +// CHECK-NEXT: version: 1, +// CHECK-NEXT: builtin_code: SUB // CHECK-NEXT: }, { // CHECK-NEXT: version: 1 // CHECK-NEXT: } ], @@ -196,6 +199,7 @@ // CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: buffer: 14 // CHECK-NEXT: } ] +// CHECK-NEXT: signature_defs: [ ] // CHECK-NEXT: } func @main(%arg0: tensor, %arg1: tensor<1xf32>) -> tensor<1xf32> { diff --git a/tensorflow/compiler/mlir/lite/tests/ops.mlir b/tensorflow/compiler/mlir/lite/tests/ops.mlir index cbb562c2e03..b62f5655183 100644 --- a/tensorflow/compiler/mlir/lite/tests/ops.mlir +++ b/tensorflow/compiler/mlir/lite/tests/ops.mlir @@ -1700,6 +1700,15 @@ func @testRelu6WithQuantizedTypes(%arg0 : tensor<10x!quant.uniform> // ----- +func @testReluWithDifferentScales(%arg0 : tensor<10x!quant.uniform>) -> tensor<10x!quant.uniform> { + %0 = "tfl.relu"(%arg0) : (tensor<10x!quant.uniform>) -> tensor<10x!quant.uniform> + %1 = "tfl.relu_n1_to_1"(%0) : (tensor<10x!quant.uniform>) -> tensor<10x!quant.uniform> + %2 = "tfl.relu6"(%1) : (tensor<10x!quant.uniform>) -> tensor<10x!quant.uniform> + return %2 : tensor<10x!quant.uniform> +} + +// ----- + func @testEmbeddingLookup(%arg0 : tensor, %arg1 : tensor) -> tensor { %0 = "tfl.embedding_lookup"(%arg0, %arg1) : (tensor,tensor) -> tensor return %0 : tensor diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir index edbcef3d321..bedf77f726a 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir @@ -26,6 +26,26 @@ func @fusedDepthwiseConv2dRelu6(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16 // CHECK: return %0 } +// CHECK-LABEL: fusedMaxPool2dRelu +func @fusedMaxPool2dRelu(%arg0: tensor<1x147x147x16xf32>) -> tensor<1x73x73x16xf32> { + %0 = "tfl.max_pool_2d"(%arg0) {filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x147x147x16xf32>) -> tensor<1x73x73x16xf32> + %1 = "tfl.relu"(%0) : (tensor<1x73x73x16xf32>) -> tensor<1x73x73x16xf32> + return %1 : tensor<1x73x73x16xf32> + + // CHECK: %0 = "tfl.max_pool_2d"(%arg0) {filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "RELU", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x147x147x16xf32>) -> tensor<1x73x73x16xf32> + // CHECK: return %0 +} + +// CHECK-LABEL: fusedAvgPool2dRelu1 +func @fusedAvgPool2dRelu1(%arg0: tensor<1x147x147x16xf32>) -> tensor<1x73x73x16xf32> { + %0 = "tfl.average_pool_2d"(%arg0) {filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x147x147x16xf32>) -> tensor<1x73x73x16xf32> + %1 = "tfl.relu_n1_to_1"(%0) : (tensor<1x73x73x16xf32>) -> tensor<1x73x73x16xf32> + return %1 : tensor<1x73x73x16xf32> + + // CHECK: %0 = "tfl.average_pool_2d"(%arg0) {filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "RELU_N1_TO_1", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x147x147x16xf32>) -> tensor<1x73x73x16xf32> + // CHECK: return %0 +} + // CHECK-LABEL: fuseAddIntoConv2d func @fuseAddIntoConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>) -> tensor<256x30x30x16xf32> { %cst = constant dense<1.5> : tensor<16xf32> @@ -50,6 +70,96 @@ func @fuseSubIntoConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf // CHECK: %0 = "tfl.conv_2d"(%arg0, %arg1, %cst) } +// CHECK-LABEL: fuseAddIntoTransposeConv +func @fuseAddIntoTransposeConv(%arg0: tensor<1x32x42x128xf32>) -> tensor<1x64x84x32xf32> { + %cst = constant dense<1.5> : tensor<32xf32> + %cst_0 = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32> + %cst_1 = constant dense<[1, 64, 84, 32]> : tensor<4xi32> + %cst_2 = constant dense<1.0> : tensor<32x4x4x128xf32> + %cst_3 = constant dense<[1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0]> : tensor<32xf32> + %0 = "tfl.transpose_conv"(%cst_1, %cst_2, %arg0, %cst_3) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, tensor<32xf32>) -> tensor<1x64x84x32xf32> + %1 = "tfl.add"(%0, %cst) {fused_activation_function = "NONE"} : (tensor<1x64x84x32xf32>, tensor<32xf32>) -> tensor<1x64x84x32xf32> + return %1 : tensor<1x64x84x32xf32> + + // CHECK: %[[SHAPE:.*]] = constant dense<[1, 64, 84, 32]> : tensor<4xi32> + // CHECK: %[[WEIGHTS:.*]] = constant dense<1.000000e+00> : tensor<32x4x4x128xf32> + // CHECK: %[[BIAS:.*]] = constant dense<[2.500000e+00, 3.500000e+00, 2.500000e+00, 3.500000e+00, 2.500000e+00, 3.500000e+00, 2.500000e+00, 3.500000e+00, 2.500000e+00, 3.500000e+00, 2.500000e+00, 3.500000e+00, 2.500000e+00, 3.500000e+00, 2.500000e+00, 3.500000e+00, 2.500000e+00, 3.500000e+00, 2.500000e+00, 3.500000e+00, 2.500000e+00, 3.500000e+00, 2.500000e+00, 3.500000e+00, 2.500000e+00, 3.500000e+00, 2.500000e+00, 3.500000e+00, 2.500000e+00, 3.500000e+00, 2.500000e+00, 3.500000e+00]> : tensor<32xf32> + // CHECK: %[[RESULT:.*]] = "tfl.transpose_conv"(%[[SHAPE]], %[[WEIGHTS]], %arg0, %[[BIAS]]) + // CHECK: return %[[RESULT]] +} + +// CHECK-LABEL: fuseSubIntoTransposeConv +func @fuseSubIntoTransposeConv(%arg0: tensor<1x32x42x128xf32>) -> tensor<1x64x84x32xf32> { + %cst = constant dense<1.5> : tensor<32xf32> + %cst_0 = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32> + %cst_1 = constant dense<[1, 64, 84, 32]> : tensor<4xi32> + %cst_2 = constant dense<1.0> : tensor<32x4x4x128xf32> + %cst_3 = constant dense<[1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0]> : tensor<32xf32> + %0 = "tfl.transpose_conv"(%cst_1, %cst_2, %arg0, %cst_3) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, tensor<32xf32>) -> tensor<1x64x84x32xf32> + %1 = "tfl.sub"(%0, %cst) {fused_activation_function = "NONE"} : (tensor<1x64x84x32xf32>, tensor<32xf32>) -> tensor<1x64x84x32xf32> + return %1 : tensor<1x64x84x32xf32> + + // CHECK: %[[SHAPE:.*]] = constant dense<[1, 64, 84, 32]> : tensor<4xi32> + // CHECK: %[[WEIGHTS:.*]] = constant dense<1.000000e+00> : tensor<32x4x4x128xf32> + // CHECK: %[[BIAS:.*]] = constant dense<[-5.000000e-01, 5.000000e-01, -5.000000e-01, 5.000000e-01, -5.000000e-01, 5.000000e-01, -5.000000e-01, 5.000000e-01, -5.000000e-01, 5.000000e-01, -5.000000e-01, 5.000000e-01, -5.000000e-01, 5.000000e-01, -5.000000e-01, 5.000000e-01, -5.000000e-01, 5.000000e-01, -5.000000e-01, 5.000000e-01, -5.000000e-01, 5.000000e-01, -5.000000e-01, 5.000000e-01, -5.000000e-01, 5.000000e-01, -5.000000e-01, 5.000000e-01, -5.000000e-01, 5.000000e-01, -5.000000e-01, 5.000000e-01]> : tensor<32xf32> + // CHECK: %[[RESULT:.*]] = "tfl.transpose_conv"(%[[SHAPE]], %[[WEIGHTS]], %arg0, %[[BIAS]]) + // CHECK: return %[[RESULT]] +} + +// CHECK-LABEL: fuseAddIntoTransposeConvNoBias +func @fuseAddIntoTransposeConvNoBias(%arg0: tensor<1x32x42x128xf32>) -> tensor<1x64x84x32xf32> { + %cst = constant dense<1.5> : tensor<32xf32> + %cst_0 = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32> + %cst_1 = constant dense<[1, 64, 84, 32]> : tensor<4xi32> + %cst_2 = constant dense<1.0> : tensor<32x4x4x128xf32> + %cst_3 = constant unit + %0 = "tfl.transpose_conv"(%cst_1, %cst_2, %arg0, %cst_3) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, none) -> tensor<1x64x84x32xf32> + %1 = "tfl.add"(%0, %cst) {fused_activation_function = "NONE"} : (tensor<1x64x84x32xf32>, tensor<32xf32>) -> tensor<1x64x84x32xf32> + return %1 : tensor<1x64x84x32xf32> + + // CHECK: %[[SHAPE:.*]] = constant dense<[1, 64, 84, 32]> : tensor<4xi32> + // CHECK: %[[WEIGHTS:.*]] = constant dense<1.000000e+00> : tensor<32x4x4x128xf32> + // CHECK: %[[BIAS:.*]] = constant dense<1.500000e+00> : tensor<32xf32> + // CHECK: %[[RESULT:.*]] = "tfl.transpose_conv"(%[[SHAPE]], %[[WEIGHTS]], %arg0, %[[BIAS]]) + // CHECK: return %[[RESULT]] +} + +// CHECK-LABEL: fuseMulIntoTransposeConv +func @fuseMulIntoTransposeConv(%arg0: tensor<1x32x42x128xf32>) -> tensor<1x64x84x32xf32> { + %cst = constant dense<1.5> : tensor<32xf32> + %cst_0 = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32> + %cst_1 = constant dense<[1, 64, 84, 32]> : tensor<4xi32> + %cst_2 = constant dense<1.0> : tensor<32x4x4x128xf32> + %cst_3 = constant dense<[1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0]> : tensor<32xf32> + %0 = "tfl.transpose_conv"(%cst_1, %cst_2, %arg0, %cst_3) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, tensor<32xf32>) -> tensor<1x64x84x32xf32> + %1 = "tfl.mul"(%0, %cst) {fused_activation_function = "NONE"} : (tensor<1x64x84x32xf32>, tensor<32xf32>) -> tensor<1x64x84x32xf32> + return %1 : tensor<1x64x84x32xf32> + + // CHECK: %[[SHAPE:.*]] = constant dense<[1, 64, 84, 32]> : tensor<4xi32> + // CHECK: %[[WEIGHTS:.*]] = constant dense<1.500000e+00> : tensor<32x4x4x128xf32> + // CHECK: %[[BIAS:.*]] = constant dense<[1.500000e+00, 3.000000e+00, 1.500000e+00, 3.000000e+00, 1.500000e+00, 3.000000e+00, 1.500000e+00, 3.000000e+00, 1.500000e+00, 3.000000e+00, 1.500000e+00, 3.000000e+00, 1.500000e+00, 3.000000e+00, 1.500000e+00, 3.000000e+00, 1.500000e+00, 3.000000e+00, 1.500000e+00, 3.000000e+00, 1.500000e+00, 3.000000e+00, 1.500000e+00, 3.000000e+00, 1.500000e+00, 3.000000e+00, 1.500000e+00, 3.000000e+00, 1.500000e+00, 3.000000e+00, 1.500000e+00, 3.000000e+00]> : tensor<32xf32> + // CHECK: %[[RESULT:.*]] = "tfl.transpose_conv"(%[[SHAPE]], %[[WEIGHTS]], %arg0, %[[BIAS]]) + // CHECK: return %[[RESULT]] +} + +// CHECK-LABEL: fuseMulIntoTransposeConvNoBias +func @fuseMulIntoTransposeConvNoBias(%arg0: tensor<1x32x42x128xf32>) -> tensor<1x64x84x32xf32> { + %cst = constant dense<1.5> : tensor<32xf32> + %cst_0 = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32> + %cst_1 = constant dense<[1, 64, 84, 32]> : tensor<4xi32> + %cst_2 = constant dense<1.0> : tensor<32x4x4x128xf32> + %cst_3 = constant unit + %0 = "tfl.transpose_conv"(%cst_1, %cst_2, %arg0, %cst_3) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, none) -> tensor<1x64x84x32xf32> + %1 = "tfl.mul"(%0, %cst) {fused_activation_function = "NONE"} : (tensor<1x64x84x32xf32>, tensor<32xf32>) -> tensor<1x64x84x32xf32> + return %1 : tensor<1x64x84x32xf32> + + // CHECK: %[[SHAPE:.*]] = constant dense<[1, 64, 84, 32]> : tensor<4xi32> + // CHECK: %[[WEIGHTS:.*]] = constant dense<1.500000e+00> : tensor<32x4x4x128xf32> + // CHECK: %[[BIAS:.*]] = constant unit + // CHECK: %[[RESULT:.*]] = "tfl.transpose_conv"(%[[SHAPE]], %[[WEIGHTS]], %arg0, %[[BIAS]]) + // CHECK: return %[[RESULT]] +} + // CHECK-LABEL: fuseAddIntoFollowingConv2d func @fuseAddIntoFollowingConv2d(%arg0: tensor<256x32x32x3xf32>) -> tensor<256x30x30x16xf32> { %cst = constant dense<1.5> : tensor @@ -182,6 +292,22 @@ func @fuseMulIntoFullyConnected(%arg0: tensor<4x2xf32>) -> tensor<4x2xf32> { // CHECK: return %[[RES]] : tensor<4x2xf32> } +// CHECK-LABEL: @fuseBroadcastMulIntoFullyConnected +func @fuseBroadcastMulIntoFullyConnected(%arg0: tensor<1x10368xbf16>) -> tensor<32x1x256xbf16> { + %cst_0 = constant dense<2.0> : tensor<256x10368xbf16> + %cst_1 = constant unit + %cst_2 = constant dense<3.0> : tensor<32x1x256xbf16> + %0 = "tfl.fully_connected"(%arg0, %cst_0, %cst_1) { + fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT" + } : (tensor<1x10368xbf16>, tensor<256x10368xbf16>, none) -> tensor<1x256xbf16> + %1 = "tfl.mul"(%0, %cst_2) {fused_activation_function = "NONE"} : (tensor<1x256xbf16>, tensor<32x1x256xbf16>) -> tensor<32x1x256xbf16> + return %1 : tensor<32x1x256xbf16> + +// CHECK: %[[V0:.*]] = "tfl.fully_connected"(%arg0, {{.*}}) {{{.*}}} : (tensor<1x10368xbf16>, tensor<256x10368xbf16>, none) -> tensor<1x256xbf16> +// CHECK: %[[V1:.*]] = "tfl.mul"(%[[V0]], {{.*}}) {{{.*}}} : (tensor<1x256xbf16>, tensor<32x1x256xbf16>) -> tensor<32x1x256xbf16> +// CHECK: return %[[V1]] : tensor<32x1x256xbf16> +} + // CHECK-LABEL: @fuseAddIntoFollowingFullyConnectedWithQDQs func @fuseAddIntoFollowingFullyConnectedWithQDQs(%arg0: tensor<4x2xf32>) -> tensor<4x2xf32> { @@ -865,6 +991,16 @@ func @Relu(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { // CHECK: return %[[RESULT]] } +// CHECK-LABEL: Relu_bf16 +func @Relu_bf16(%arg0: tensor<2x3xbf16>) -> tensor<2x3xbf16> { + %cst = constant dense<0.0> : tensor<2x3xbf16> + %0 = "tfl.maximum"(%arg0, %cst) : (tensor<2x3xbf16>, tensor<2x3xbf16>) -> tensor<2x3xbf16> + return %0 : tensor<2x3xbf16> + + // CHECK: %[[RESULT:.*]] = "tfl.relu"(%arg0) + // CHECK: return %[[RESULT]] +} + // CHECK-LABEL: Relu1 func @Relu1(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { %cst = constant dense<-1.0> : tensor @@ -1175,3 +1311,29 @@ func @FoldReduceProdKeepDim(%arg0: tensor<8x128xf32>) -> tensor<1x1xf32> { // CHECK: %[[RESULT:.*]] = "tfl.reduce_prod"(%arg0, %cst) {keep_dims = true} : (tensor<8x128xf32>, tensor<2xi32>) -> tensor<1x1xf32> // CHECK: return %[[RESULT]] : tensor<1x1xf32> } + +func @SoftMaxWithNormalization(%arg0: tensor<8x128xf32>) -> tensor<8x128xf32> { + %cst = constant dense<1> : tensor<1xi32> + %0 = "tfl.reduce_max"(%arg0, %cst) {keep_dims = true} : (tensor<8x128xf32>, tensor<1xi32>) -> tensor<8x1xf32> + %1 = "tfl.sub"(%arg0, %0) {fused_activation_function = "NONE"} : (tensor<8x128xf32>, tensor<8x1xf32>) -> tensor<8x128xf32> + %2 = "tfl.exp"(%1) : (tensor<8x128xf32>) -> tensor<8x128xf32> + %3 = "tfl.sum"(%2, %cst) {keep_dims = true} : (tensor<8x128xf32>, tensor<1xi32>) -> tensor<8x1xf32> + %4 = "tfl.div"(%2, %3) {fused_activation_function = "NONE"} : (tensor<8x128xf32>, tensor<8x1xf32>) -> tensor<8x128xf32> + return %4 : tensor<8x128xf32> + +// CHECK-LABEL: SoftMaxWithNormalization +// CHECK: %[[RESULT:.*]] = "tfl.softmax"(%arg0) {beta = 1.000000e+00 : f32} : (tensor<8x128xf32>) -> tensor<8x128xf32> +// CHECK: return %[[RESULT]] : tensor<8x128xf32> +} + +func @SoftMaxWithoutNormalization(%arg0: tensor<8x128xf32>) -> tensor<8x128xf32> { + %cst = constant dense<1> : tensor<1xi32> + %0 = "tfl.exp"(%arg0) : (tensor<8x128xf32>) -> tensor<8x128xf32> + %1 = "tfl.sum"(%0, %cst) {keep_dims = true} : (tensor<8x128xf32>, tensor<1xi32>) -> tensor<8x1xf32> + %2 = "tfl.div"(%0, %1) {fused_activation_function = "NONE"} : (tensor<8x128xf32>, tensor<8x1xf32>) -> tensor<8x128xf32> + return %2 : tensor<8x128xf32> + +// CHECK-LABEL: SoftMaxWithoutNormalization +// CHECK: %[[RESULT:.*]] = "tfl.softmax"(%arg0) {beta = 1.000000e+00 : f32} : (tensor<8x128xf32>) -> tensor<8x128xf32> +// CHECK: return %[[RESULT]] : tensor<8x128xf32> +} diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-composite-functions-tf.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-composite-functions-tf.mlir index 9e8a957b34c..2b871769c81 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-composite-functions-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-composite-functions-tf.mlir @@ -520,3 +520,42 @@ func @func_with_call(%arg0: tensor<100xf32>) -> tensor<100xf32> { return %0 : tensor<100xf32> } } + +// ----- + +module { +func @tflite_custom_nms(%arg0: tensor<1x100x4xf32>, %arg1: tensor<1x100x91xf32>, %arg2: tensor<100x4xf32>) -> (tensor, tensor, tensor, tensor) attributes {tf._implements = #tf.func<@"TFLite_Detection_PostProcess", {max_detections = 10 : i64, max_classes_per_detection = 1 : i64, num_classes = 91 : i64, nms_score_threshold = 0.5 : f32, nms_iou_threshold = 0.6 : f32, y_scale = 5.0 : f32, x_scale = 10.0 : f32, h_scale = 1.0 : f32, w_scale = 2.0 : f32, use_regular_nms = 0 : i1}>, tf._reference = "mlir"} { + %0 = "tf.Const"() {value = dense<0.0> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<0.0> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<0.0> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<0.0> : tensor} : () -> tensor + return %0, %1, %2, %3 : tensor, tensor, tensor, tensor +} + +// CHECK-LABEL: func @tflite_custom_nms( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x100x4xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x100x91xf32>, +// CHECK-SAME: %[[VAL_2:.*]]: tensor<100x4xf32>) -> (tensor, tensor, tensor, tensor) attributes {tf._implements = "TFLite_Detection_PostProcess", tf._reference = "mlir"} { +// CHECK: %[[VAL_3:.*]]:4 = "tfl.custom"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) {custom_code = "TFLite_Detection_PostProcess", custom_option = opaque<"tfl", "0x6D61785F646574656374696F6E73006D61785F636C61737365735F7065725F646574656374696F6E006E756D5F636C6173736573006E6D735F73636F72655F7468726573686F6C64006E6D735F696F755F7468726573686F6C6400795F7363616C6500785F7363616C6500685F7363616C6500775F7363616C65007573655F726567756C61725F6E6D73000A217E8E465B681720313A00000C000000010000000A0000000000803F010000000A0000009A99193F0000003F5B0000000000000000000040000020410000A0400E06060E0E06060E0E0E322601"> : tensor<217xi8>} : (tensor<1x100x4xf32>, tensor<1x100x91xf32>, tensor<100x4xf32>) -> (tensor, tensor, tensor, tensor) +// CHECK: return %[[VAL_3]]#0, %[[VAL_3]]#1, %[[VAL_3]]#2, %[[VAL_3]]#3 : tensor, tensor, tensor, tensor +// CHECK: } +} + +// ----- + +module { +// expected-error @+1 {{Invalid number of results from TFLite_Detection_PostProcess}} +func @tflite_custom_nms_invalid_results(%arg0: tensor<1x100x4xf32>, %arg1: tensor<1x100x91xf32>, %arg2: tensor<100x4xf32>) -> (tensor, tensor, tensor) attributes {tf._implements = #tf.func<@"TFLite_Detection_PostProcess", {max_detections = 10 : i64, max_classes_per_detection = 1 : i64, num_classes = 91 : i64, nms_score_threshold = 0.5 : f32, nms_iou_threshold = 0.6 : f32, y_scale = 5.0 : f32, x_scale = 10.0 : f32, h_scale = 1.0 : f32, w_scale = 2.0 : f32, use_regular_nms = 0 : i1}>, tf._reference = "mlir"} + +// expected-error @+1 {{Invalid number of arguments to TFLite_Detection_PostProcess}} +func @tflite_custom_nms_invalid_args(%arg0: tensor<1x100x4xf32>, %arg1: tensor<1x100x91xf32>) -> (tensor, tensor, tensor, tensor) attributes {tf._implements = #tf.func<@"TFLite_Detection_PostProcess", {max_detections = 10 : i64, max_classes_per_detection = 1 : i64, num_classes = 91 : i64, nms_score_threshold = 0.5 : f32, nms_iou_threshold = 0.6 : f32, y_scale = 5.0 : f32, x_scale = 10.0 : f32, h_scale = 1.0 : f32, w_scale = 2.0 : f32, use_regular_nms = 0 : i1}>, tf._reference = "mlir"} + +// expected-error @+1 {{max_classes_per_detection attribute is not set or not an integer}} +func @tflite_custom_nms_missing_func_args(%arg0: tensor<1x100x4xf32>, %arg1: tensor<1x100x91xf32>, %arg2: tensor<100x4xf32>) -> (tensor, tensor, tensor, tensor) attributes {tf._implements = #tf.func<@"TFLite_Detection_PostProcess", {max_detections = 10 : i64, num_classes = 91 : i64, nms_score_threshold = 0.5 : f32, nms_iou_threshold = 0.6 : f32, y_scale = 5.0 : f32, x_scale = 10.0 : f32, h_scale = 1.0 : f32, w_scale = 2.0 : f32, use_regular_nms = 0 : i1}>, tf._reference = "mlir"} { + %0 = "tf.Const"() {value = dense<0.0> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<0.0> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<0.0> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<0.0> : tensor} : () -> tensor + return %0, %1, %2, %3 : tensor, tensor, tensor, tensor +} +} diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir index 6a992d6dfe4..186c8631e56 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir @@ -1,4 +1,5 @@ // RUN: tf-opt -tfl-prepare-tf %s | FileCheck %s +// RUN: tf-opt %s -tf-layout-optimization=force-data-format=NHWC -tfl-prepare-tf | FileCheck --check-prefix=LAYOUT --dump-input=always %s module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} { @@ -53,37 +54,12 @@ func @depthwiseConv2D(tensor<256x32x32x3xf32>, tensor<3x3x3x4xf32>, tensor<256x3 // CHECK: %5 = "tf.DepthwiseConv2dNative" } -func @fusedBatchNorm(tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>) { -^bb0(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>): - // OK - %0:5 = "tf.FusedBatchNorm"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) - // Unsupported training - %1:5 = "tf.FusedBatchNorm"( %0#0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) - // Use other output - %2:5 = "tf.FusedBatchNorm"( %1#0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) +func @Conv2dNCHW(%arg0: tensor<256x3x32x32xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<256x16x30x30xf32> { + %0 = "tf.Conv2D"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", dilations = [1, 1, 1, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<256x3x32x32xf32>, tensor<3x3x3x16xf32>) -> tensor<256x16x30x30xf32> + return %0 : tensor<256x16x30x30xf32> - return %2, %2#1 : tensor<8x8x8x8xf32>, tensor<8xf32> - -// CHECK-LABEL: fusedBatchNorm -// CHECK: %[[CONSTANT:.*]] = constant dense<1.000000e-03> -// variance + epsilon -// CHECK: %[[ADD1:.*]] = "tf.Add"(%[[ARG4:.*]], %[[CONSTANT]]) -// rsqrt(variance + epsilon) -// CHECK: %[[RSQRT:.*]] = "tf.Rsqrt"(%[[ADD1]]) -// scale * rsqrt(variance + epsilon) -// CHECK: %[[MUL1:.*]] = "tf.Mul"(%[[ARG1:.*]], %[[RSQRT]]) -// x * scale * rsqrt(variance + epsilon) -// CHECK: %[[MUL2:.*]] = "tf.Mul"(%[[ARG0:.*]], %[[MUL1]]) -// mean * scale * rsqrt(variance + epsilon) -// CHECK: %[[MUL3:.*]] = "tf.Mul"(%[[ARG3:.*]], %[[MUL1]]) -// offset - mean * scale * rsqrt(variance + epsilon) -// CHECK: %[[SUB:.*]] = "tf.Sub"(%[[ARG2:.*]], %[[MUL3]]) -// x * scale * rsqrt(variance + epsilon) + -// offset - mean * scale * rsqrt(variance + epsilon) -// CHECK: %[[ADD2:.*]] = "tf.Add"(%[[MUL2]], %[[SUB]]) - -// CHECK: %[[BATCHNORM1_a:[^,]+]], {{.*}} = "tf.FusedBatchNorm"(%[[ADD2]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]]) -// CHECK: "tf.FusedBatchNorm"(%[[BATCHNORM1_a]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]]) + // LAYOUT-LABEL: Conv2dNCHW + // LAYOUT: "tfl.conv_2d" } func @fusedBatchNormV3(tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>) { @@ -483,6 +459,20 @@ func @StridedSliceEllipsisMaskBefore(%arg0: tensor<21x15x7xf32>) -> tensor<21x15 // CHECK: %[[STRIDED_SLICE:.*]] = "tf.StridedSlice"(%arg0, %[[CST]], %[[CST]], %[[CST_0]]) {begin_mask = 3 : i64, ellipsis_mask = 0 : i64, end_mask = 3 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<21x15x7xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<21x15x2xf32> } +// CHECK-LABEL: @StridedSliceEllipsisMaskBeforeWithBeginAndEndMask +func @StridedSliceEllipsisMaskBeforeWithBeginAndEndMask(%arg0: tensor<4x5x4xf32>) -> tensor<4x4x4xf32> { + %cst = constant dense<[0, 1, 0]> : tensor<3xi32> + %cst_0 = constant dense<0> : tensor<3xi32> + %cst_1 = constant dense<1> : tensor<3xi32> + %0 = "tf.StridedSlice"(%arg0, %cst, %cst_0, %cst_1) {begin_mask = 6 : i64, ellipsis_mask = 1 : i64, end_mask = 4 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4x5x4xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<4x4x4xf32> + return %0 : tensor<4x4x4xf32> + + // CHECK: %[[CST:.*]] = constant dense<[0, 1, 0]> : tensor<3xi32> + // CHECK: %[[CST_0:.*]] = constant dense<0> : tensor<3xi32> + // CHECK: %[[CST_1:.*]] = constant dense<1> : tensor<3xi32> + // CHECK: %[[STRIDED_SLICE:.*]] = "tf.StridedSlice"(%arg0, %[[CST]], %[[CST_0]], %[[CST_1]]) {begin_mask = 7 : i64, ellipsis_mask = 0 : i64, end_mask = 5 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4x5x4xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<4x4x4xf32> +} + // CHECK-LABEL: @StridedSliceEllipsisMaskAfter func @StridedSliceEllipsisMaskAfter(%arg0: tensor<21x15x7xf32>) -> tensor<5x15x7xf32> { %cst = constant dense<0> : tensor<2xi32> @@ -629,4 +619,24 @@ func @lower_rfft_to_rfft2d(%input: tensor<10x20x30xf32>, %fft_len: tensor<1xi32> // CHECK: %[[SQE:.*]] = "tf.Squeeze"(%[[RFF]]) {squeeze_dims = [-2]} : (tensor<10x20x1x30xcomplex>) -> tensor<10x20x30xcomplex> } +// CHECK-LABEL: xla_gather_to_slice +func @xla_gather_to_slice(%arg0 : tensor<1x9x104x768xf32>) -> tensor<*xf32> { + %0 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> + %1 = "tf.Const"() {value = dense<[1, 9, 23, 768]> : tensor<4xi32>} : () -> tensor<4xi32> + %2 = "tf.XlaGather"(%arg0, %0, %1) {device = "", dimension_numbers = "\0A\04\00\01\02\03\1A\01\02", indices_are_sorted = false} : (tensor<1x9x104x768xf32>, tensor<1xi32>, tensor<4xi32>) -> tensor<*xf32> + return %2 : tensor<*xf32> + +// CHECK: %[[CST:.*]] = constant dense<0> : tensor<4xi64> +// CHECK: %[[CST0:.*]] = constant dense<[1, 9, 23, 768]> : tensor<4xi64> +// CHECK: %[[V0:.*]] = "tf.Slice"(%arg0, %[[CST]], %[[CST0]]) : (tensor<1x9x104x768xf32>, tensor<4xi64>, tensor<4xi64>) -> tensor<*xf32> +// CHECK: return %[[V0]] : tensor<*xf32> +} + +// CHECK-LABEL: DontMatchFusedBatchNormV3 +func @DontMatchFusedBatchNormV3(%arg0 :tensor, %arg1 : tensor<576xf32>, %arg2 : tensor<576xf32>, %arg3 : tensor<576xf32>,%arg4 : tensor<576xf32>) -> (tensor) { + %result:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {data_format = "NHWC", device = "", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = false} : (tensor, tensor<576xf32>, tensor<576xf32>, tensor<576xf32>, tensor<576xf32>) -> (tensor, tensor<576xf32>, tensor<576xf32>, tensor<576xf32>, tensor<576xf32>, tensor<*xf32>) + return %result : tensor + // CHECK: "tf.FusedBatchNormV3" +} + } diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc index d63eb481376..2feb7fedb81 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc @@ -38,6 +38,10 @@ CreateTFExecutorToControlDialectConversion(); } // namespace mlir namespace tensorflow { +namespace { +// Data layout supported by TFLite. +const char kTFLiteDataLayout[] = "NHWC"; +} // namespace void AddQuantizationPasses(const mlir::TFL::QuantizationSpecs& quant_specs, mlir::OpPassManager* pass_manager) { @@ -170,6 +174,12 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config, if (pass_config.shape_inference) { pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass()); } + // Force layout supported by TFLite, this will transpose the data + // to match 'kTFLiteDataLayout' + mlir::TF::LayoutOptimizationPipelineOptions layout_optimization_options; + layout_optimization_options.force_data_format = kTFLiteDataLayout; + mlir::TF::CreateLayoutOptimizationPipeline(*pass_manager, + layout_optimization_options); // Prepare for TFLite dialect, rerun canonicalization, and then legalize to // the TFLite dialect. pass_manager->addPass( diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc index 046c7bbbcf0..aa3545d9beb 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc @@ -143,6 +143,7 @@ int main(int argc, char **argv) { mlir::SourceMgrDiagnosticHandler sourceMgrHandler(source_mgr, &context); StatusOr module; + std::unordered_set tags; tensorflow::GraphImportConfig specs; specs.upgrade_legacy = upgrade_legacy; @@ -161,8 +162,7 @@ int main(int argc, char **argv) { module = tensorflow::errors::InvalidArgument( "Importing saved model should not have input_mlir set"); - std::unordered_set tags = - absl::StrSplit(saved_model_tags, ','); + tags = absl::StrSplit(saved_model_tags, ','); std::vector exported_names_vector = absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty()); absl::Span exported_names(exported_names_vector); @@ -171,10 +171,11 @@ int main(int argc, char **argv) { llvm::errs() << "There should be only one exported name"; return kTrFailure; } - - module = - tensorflow::ImportSavedModel(input_file_name, saved_model_version, tags, - exported_names, specs, &context); + std::vector extra_opdefs(custom_opdefs.begin(), + custom_opdefs.end()); + module = tensorflow::ImportSavedModel(input_file_name, saved_model_version, + tags, extra_opdefs, exported_names, + specs, &context); } else { module = tensorflow::LoadFromGraphdefOrMlirSource( input_file_name, input_mlir, use_splatted_constant, custom_opdefs, @@ -240,7 +241,7 @@ int main(int argc, char **argv) { std::string result; auto status = tensorflow::ConvertTFExecutorToTFLOrFlatbuffer( module.ValueOrDie().get(), output_mlir, emit_builtin_tflite_ops, - emit_select_tf_ops, emit_custom_ops, quant_specs, &result, &pm); + emit_select_tf_ops, emit_custom_ops, quant_specs, tags, &result, &pm); if (!status.ok()) return kTrFailure; std::string error_msg; diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc index c158f3a8e21..622e32c2766 100644 --- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "llvm/Support/raw_ostream.h" #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Module.h" // from @llvm-project @@ -70,6 +71,27 @@ mlir::LogicalResult IsValidGraph(mlir::ModuleOp module) { } return mlir::success(); } + +// Util that registers 'extra_tf_opdefs' to the TF global registry. +// Return OK on success, failure if registering failed. +Status RegisterExtraTfOpDefs(absl::Span extra_tf_opdefs) { + for (const auto& tf_opdefs_string : extra_tf_opdefs) { + tensorflow::OpDef opdef; + if (!tensorflow::protobuf::TextFormat::ParseFromString(tf_opdefs_string, + &opdef)) { + LOG(ERROR) << "OpDef parsing failed for: " << tf_opdefs_string; + return errors::InvalidArgument("fail to parse extra OpDef"); + } + // Register extra opdefs. + // TODO(b/133770952): Support shape functions. + tensorflow::OpRegistry::Global()->Register( + [opdef](tensorflow::OpRegistrationData* op_reg_data) -> Status { + *op_reg_data = tensorflow::OpRegistrationData(opdef); + return Status::OK(); + }); + } + return Status::OK(); +} } // namespace StatusOr LoadFromGraphdefOrMlirSource( @@ -92,21 +114,9 @@ StatusOr LoadFromGraphdefOrMlirSource( return OwningModuleRef(mlir::parseSourceFile(*source_mgr, context)); } - for (const auto& tf_opdefs_string : extra_tf_opdefs) { - tensorflow::OpDef opdef; - if (!tensorflow::protobuf::TextFormat::ParseFromString(tf_opdefs_string, - &opdef)) { - LOG(ERROR) << "OpDef parsing failed for: " << tf_opdefs_string; - return errors::InvalidArgument("fail to parse extra OpDef"); - } - // Register extra opdefs. - // TODO(b/133770952): Support shape functions. - tensorflow::OpRegistry::Global()->Register( - [opdef](tensorflow::OpRegistrationData* op_reg_data) -> Status { - *op_reg_data = tensorflow::OpRegistrationData(opdef); - return Status::OK(); - }); - } + // Register extra TF ops passed as OpDef. + auto extra_opdefs_status = RegisterExtraTfOpDefs(extra_tf_opdefs); + if (!extra_opdefs_status.ok()) return extra_opdefs_status; if (use_splatted_constant) { return tensorflow::GraphdefToSplattedMlirTranslateFunction( @@ -127,8 +137,12 @@ StatusOr LoadFromGraphdefOrMlirSource( Status ConvertTFExecutorToTFLOrFlatbuffer( mlir::ModuleOp module, bool export_to_mlir, bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops, - const mlir::TFL::QuantizationSpecs& quant_specs, std::string* result, - mlir::PassManager* pass_manager) { + const mlir::TFL::QuantizationSpecs& quant_specs, + const std::unordered_set& saved_model_tags, + std::string* result, mlir::PassManager* pass_manager) { + // Explicitly disable dumping Op details on failures. + module.getContext()->printOpOnDiagnostic(false); + // Register a warning handler only log to std out. mlir::ScopedDiagnosticHandler s( module.getContext(), [](mlir::Diagnostic& diag) { @@ -158,7 +172,7 @@ Status ConvertTFExecutorToTFLOrFlatbuffer( if (!quant_specs.RunWeightQuantization()) { if (tflite::MlirToFlatBufferTranslateFunction( module, result, emit_builtin_tflite_ops, emit_select_tf_ops, - emit_custom_ops)) { + emit_custom_ops, saved_model_tags)) { return statusHandler.ConsumeStatus(); } } else { @@ -167,7 +181,7 @@ Status ConvertTFExecutorToTFLOrFlatbuffer( std::string pre_quantized_result; if (tflite::MlirToFlatBufferTranslateFunction( module, &pre_quantized_result, emit_builtin_tflite_ops, - emit_select_tf_ops, emit_custom_ops)) { + emit_select_tf_ops, emit_custom_ops, saved_model_tags)) { return statusHandler.ConsumeStatus(); } flatbuffers::FlatBufferBuilder q_builder(/*initial_size=*/10240); @@ -198,8 +212,13 @@ Status ConvertTFExecutorToTFLOrFlatbuffer( StatusOr ImportSavedModel( const std::string& input_filename, const int saved_model_version, const std::unordered_set& tags, + absl::Span extra_tf_opdefs, absl::Span exported_names, const GraphImportConfig& specs, mlir::MLIRContext* context) { + // Register extra TF ops passed as OpDef. + auto extra_opdefs_status = RegisterExtraTfOpDefs(extra_tf_opdefs); + if (!extra_opdefs_status.ok()) return extra_opdefs_status; + if (saved_model_version == 2) { auto module_or = tensorflow::SavedModelObjectGraphToMlirImport( input_filename, tags, exported_names, context); diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h index 8f1edec8879..95b6097e1eb 100644 --- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h +++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h @@ -48,6 +48,7 @@ LoadFromGraphdefOrMlirSource( stream_executor::port::StatusOr ImportSavedModel( const std::string& input_filename, const int saved_model_version, const std::unordered_set& tags, + absl::Span extra_tf_opdefs, absl::Span exported_names, const GraphImportConfig& specs, mlir::MLIRContext* context); @@ -62,8 +63,9 @@ stream_executor::port::StatusOr ImportSavedModel( Status ConvertTFExecutorToTFLOrFlatbuffer( mlir::ModuleOp module, bool export_to_mlir, bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops, - const mlir::TFL::QuantizationSpecs& quant_specs, std::string* result, - mlir::PassManager* pass_manager); + const mlir::TFL::QuantizationSpecs& quant_specs, + const std::unordered_set& saved_model_tags, + std::string* result, mlir::PassManager* pass_manager); } // namespace tensorflow #endif // TENSORFLOW_COMPILER_MLIR_LITE_TF_TO_TFL_FLATBUFFER_H_ diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td index 47cfaecd3fb..322da815a47 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td @@ -27,6 +27,9 @@ def NonOpaqueElementsAttr : ElementsAttrBase< def F32ElementsAttr : ElementsAttrBase< CPred<"$_self.cast().getType().getElementType().isF32()">, "float constant tensor">; +def Int64ElementsAttr : ElementsAttrBase< + CPred<"$_self.cast().getType().getElementType().isInteger(64)">, "Int 64 constant tensor">; + // Extract the ith int element from an ArrayAttr $0 as an 32-bit IntegerAttr // with builder. class ExtractI32At : NativeCodeCall< @@ -50,6 +53,10 @@ def ExtractSingleElementAsInteger : NativeCodeCall< def ExtractSingleElementAsInt32 : NativeCodeCall< "$_builder.getI32IntegerAttr(ExtractSingleElementAsInteger($_self.cast()).getInt())">; +// Converts tensor with int64 to int32. +def CreateCastToInt32 : NativeCodeCall< + "CreateCastToInt32($0, $_loc, $_builder)">; + // Checks whether the given operation has static shapes and same shapes of all inputs. def HasSameStaticShapesPred : CPred<"HasSameStaticShapes($0.getDefiningOp())">; def HasSameStaticShapes : Constraint; @@ -149,6 +156,7 @@ def LegalizeMaxPool2D : Pat< IsIntList1XY1:$ksize, IsIntList1XY1:$strides, $padding, + $explicit_paddings, IsDataFormatNHWC:$format), (TFL_MaxPool2DOp $value, /*padding=*/$padding, @@ -207,8 +215,14 @@ def LegalizeSoftPlus : Pat<(TF_SoftplusOp F32Tensor:$arg0), def LegalizeSqueeze : Pat<(TF_SqueezeOp $arg, $squeeze_dims), (TFL_SqueezeOp $arg, $squeeze_dims)>; def LegalizeTanh : Pat<(TF_TanhOp $arg), (TFL_TanhOp $arg)>; + +def LegalizeTransposeInt64 : Pat< + (TF_TransposeOp $arg, (ConstantOp Int64ElementsAttr:$perm)), + (TFL_TransposeOp $arg, (CreateCastToInt32 $perm))>; + def LegalizeTranspose : Pat<(TF_TransposeOp $arg, $perm), (TFL_TransposeOp $arg, $perm)>; + def LegalizeWhere : Pat<(TF_WhereOp $arg), (TFL_WhereOp $arg)>; def LegalizeZerosLike : Pat<(TF_ZerosLikeOp $arg), (TFL_ZerosLikeOp $arg)>; diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc index 297b1459fc5..13c7a08a094 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc @@ -30,6 +30,7 @@ limitations under the License. #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/Threading.h" #include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project +#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project #include "mlir/Dialect/Quant/UniformSupport.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Diagnostics.h" // from @llvm-project @@ -48,6 +49,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/utils/constant_utils.h" #include "tensorflow/compiler/mlir/lite/utils/validators.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h" #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/statusor.h" @@ -65,7 +67,6 @@ namespace TFL { // The actual LegalizeTF Pass. namespace { -using xla::Status; using xla::StatusOr; constexpr char kUnidirectionalSequenceLstm[] = "tf.UnidirectionalSequenceLstm"; @@ -74,6 +75,10 @@ constexpr char kTfLiteInputIndices[] = "_tflite_input_indices"; // Legalize operations in functions. class LegalizeTF : public PassWrapper { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + public: LegalizeTF() = default; LegalizeTF(const LegalizeTF&) {} @@ -112,6 +117,17 @@ bool HasSameStaticShapes(Operation* op) { return true; } +// Util that casts 'val' to Int32 by adding a cast Op. +Value CreateCastToInt32(Attribute val, Location loc, + PatternRewriter& rewriter) { + auto shape = val.getType().dyn_cast().getShape(); + IntegerType new_ele_type = rewriter.getIntegerType(32); + ShapedType new_type = RankedTensorType::get(shape, new_ele_type); + return rewriter.create(loc, new_type, + rewriter.create(loc, val), + rewriter.getBoolAttr(false)); +} + #include "tensorflow/compiler/mlir/lite/transforms/generated_legalize_tf.inc" #define DECL_CONVERT_OP(tf_op) \ @@ -154,9 +170,8 @@ LogicalResult ConvertTFRandomUniformOp::matchAndRewrite( tensorflow::random::PhiloxRandom, float> Distribution; - tensorflow::random::PhiloxRandom generator( - random_uniform_op.seed().getSExtValue(), - random_uniform_op.seed2().getSExtValue()); + tensorflow::random::PhiloxRandom generator(random_uniform_op.seed(), + random_uniform_op.seed2()); Distribution dist; size_t num_elements = 0; if (auto output_type = @@ -227,26 +242,47 @@ LogicalResult ConvertTFConcatV2Op::matchAndRewrite( return success(); } -// The following is effectively: -// def : Pat< -// (TF_MatMulOp $a, $b, ConstBoolAttrFalse:$transpose_a, -// ConstBoolAttrTrue:$transpose_b), -// (TFL_FullyConnectedOp:$__0 $a, $b, -// NoInput.pattern, TFL_AF_None, TFL_FCWO_Default, ConstBoolAttrFalse)>; LogicalResult ConvertTFMatMulOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tf_matmul_op = cast(op); - if (tf_matmul_op.transpose_a()) return failure(); - if (!tf_matmul_op.transpose_b()) return failure(); + auto lhs = op->getOperand(0); + auto rhs = op->getOperand(1); + auto transpose = [&](Value input) -> std::pair { + RankedTensorType type = + input.getType().dyn_cast_or_null(); + if (!type || type.getRank() != 2) return {failure(), nullptr}; + + auto permute_attr = DenseIntElementsAttr::get( + RankedTensorType::get({2}, rewriter.getI32Type()), {1, 0}); + auto permute = rewriter.create( + op->getLoc(), permute_attr.getType(), permute_attr); + llvm::SmallVector new_shape{type.getShape()[1], + type.getShape()[0]}; + auto output = rewriter.create( + op->getLoc(), RankedTensorType::get(new_shape, type.getElementType()), + input, permute); + return {success(), output}; + }; + + // TODO(jpienaar): Remove once handled via dailect conversion. + if (tf_matmul_op.transpose_a()) { + LogicalResult result = success(); + std::tie(result, lhs) = transpose(lhs); + if (failed(result)) return failure(); + } + if (!tf_matmul_op.transpose_b()) { + LogicalResult result = success(); + std::tie(result, rhs) = transpose(rhs); + if (failed(result)) return failure(); + } Type output_type = tf_matmul_op.getResult().getType(); - // TODO(jpienaar): Follow up post shuffle discussion. auto no_input = rewriter.create( op->getLoc(), rewriter.getNoneType(), rewriter.getUnitAttr()); auto fc_op = rewriter.create( - op->getLoc(), ArrayRef{output_type}, op->getOperand(0), - op->getOperand(1), no_input, rewriter.getStringAttr("NONE"), - rewriter.getStringAttr("DEFAULT"), rewriter.getBoolAttr(false)); + op->getLoc(), ArrayRef{output_type}, lhs, rhs, no_input, + rewriter.getStringAttr("NONE"), rewriter.getStringAttr("DEFAULT"), + rewriter.getBoolAttr(false)); rewriter.replaceOp(op, {fc_op.getResult(0)}); return success(); } @@ -259,7 +295,7 @@ LogicalResult ConvertTFPackOp::matchAndRewrite( auto output_type = tf_pack_op.output().getType(); auto values_count = rewriter.getI32IntegerAttr(tf_pack_op.N()); // Axis can be negative. - auto axis = rewriter.getI32IntegerAttr(tf_pack_op.axis().getSExtValue()); + auto axis = rewriter.getI32IntegerAttr(tf_pack_op.axis()); rewriter.replaceOpWithNewOp(op, output_type, values, values_count, axis); @@ -356,27 +392,22 @@ LogicalResult ConvertTFStridedSliceOp::matchAndRewrite( op, tf_strided_slice_op.output().getType(), tf_strided_slice_op.input(), tf_strided_slice_op.begin(), tf_strided_slice_op.end(), tf_strided_slice_op.strides(), - rewriter.getI32IntegerAttr( - tf_strided_slice_op.begin_mask().getSExtValue()), - rewriter.getI32IntegerAttr( - tf_strided_slice_op.end_mask().getSExtValue()), - rewriter.getI32IntegerAttr( - tf_strided_slice_op.ellipsis_mask().getSExtValue()), - rewriter.getI32IntegerAttr( - tf_strided_slice_op.new_axis_mask().getSExtValue()), - rewriter.getI32IntegerAttr( - tf_strided_slice_op.shrink_axis_mask().getSExtValue())); + rewriter.getI32IntegerAttr(tf_strided_slice_op.begin_mask()), + rewriter.getI32IntegerAttr(tf_strided_slice_op.end_mask()), + rewriter.getI32IntegerAttr(tf_strided_slice_op.ellipsis_mask()), + rewriter.getI32IntegerAttr(tf_strided_slice_op.new_axis_mask()), + rewriter.getI32IntegerAttr(tf_strided_slice_op.shrink_axis_mask())); return success(); } int num_input_dims = ranked_input_type.getRank(); // Pad `begin` array with zero values and update the `begin_mask`. SmallVector begin_pad_val(num_input_dims, 0); - int begin_mask = tf_strided_slice_op.begin_mask().getSExtValue(); + int begin_mask = tf_strided_slice_op.begin_mask(); Value padded_begin = PadStridedSliceAttributeArray( op, rewriter, tf_strided_slice_op.begin(), begin_pad_val, &begin_mask); // Pad `end` array with `input_shape` and update the `end_mask`. - int end_mask = tf_strided_slice_op.end_mask().getSExtValue(); + int end_mask = tf_strided_slice_op.end_mask(); auto input_shape = ranked_input_type.getShape(); SmallVector end_pad_val(input_shape.begin(), input_shape.end()); Value padded_end = PadStridedSliceAttributeArray( @@ -390,12 +421,9 @@ LogicalResult ConvertTFStridedSliceOp::matchAndRewrite( padded_begin, padded_end, padded_strides, rewriter.getI32IntegerAttr(begin_mask), rewriter.getI32IntegerAttr(end_mask), - rewriter.getI32IntegerAttr( - tf_strided_slice_op.ellipsis_mask().getSExtValue()), - rewriter.getI32IntegerAttr( - tf_strided_slice_op.new_axis_mask().getSExtValue()), - rewriter.getI32IntegerAttr( - tf_strided_slice_op.shrink_axis_mask().getSExtValue())); + rewriter.getI32IntegerAttr(tf_strided_slice_op.ellipsis_mask()), + rewriter.getI32IntegerAttr(tf_strided_slice_op.new_axis_mask()), + rewriter.getI32IntegerAttr(tf_strided_slice_op.shrink_axis_mask())); return success(); } @@ -406,7 +434,7 @@ LogicalResult ConvertTFUnpackOp::matchAndRewrite( auto input = tf_unpack_op.value(); auto num = rewriter.getI32IntegerAttr(tf_unpack_op.num()); // Axis can be negative. - auto axis = rewriter.getI32IntegerAttr(tf_unpack_op.axis().getSExtValue()); + auto axis = rewriter.getI32IntegerAttr(tf_unpack_op.axis()); rewriter.replaceOpWithNewOp(op, tf_unpack_op.output().getTypes(), input, num, axis); @@ -637,7 +665,7 @@ void LegalizeTF::runOnFunction() { auto func = getFunction(); // Add the generated patterns to the list. - populateWithGenerated(context, &patterns); + populateWithGenerated(context, patterns); patterns .insert> { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + void RunOnFunction(FuncOp func); void runOnOperation() override { @@ -60,8 +64,8 @@ void RunOnWhile(TF::WhileOp while_op) { // Mark old function as private so that it can be DCE'd if not called. func.setVisibility(SymbolTable::Visibility::Private); }; - create_region_with_call(while_op.cond_func(), new_op.cond()); - create_region_with_call(while_op.body_func(), new_op.body()); + create_region_with_call(while_op.cond_function(), new_op.cond()); + create_region_with_call(while_op.body_function(), new_op.body()); op->replaceAllUsesWith(new_op.getResults()); op->erase(); diff --git a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc index edddc7751ab..c0a7ea9337b 100644 --- a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc +++ b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc @@ -714,7 +714,7 @@ struct ConvertTensorListStack RankedTensorType shape_type = RankedTensorType::get({-1}, rewriter.getIntegerType(32)); auto new_shape = rewriter.create(loc, shape_type, input); - SmallVector output_shape = {op.num_elements().getSExtValue()}; + SmallVector output_shape(/*Size=*/1, op.num_elements()); for (const auto &dim : dense_elem_attr.getIntValues()) output_shape.push_back(dim.getSExtValue()); RankedTensorType result_type = @@ -749,7 +749,7 @@ Type VariantToUnrankedTensorType(Type type, Value value) { // Changes the function type of `cond_func` and `body_func` for the given While // op. LogicalResult UpdateFunctionTypes(TF::WhileOp op) { - for (FuncOp func : {op.cond_func(), op.body_func()}) { + for (FuncOp func : {op.cond_function(), op.body_function()}) { if (!func) continue; FunctionType func_type = func.getType(); @@ -892,7 +892,7 @@ LogicalResult LowerStaticTensorListPass::RewriteFunction( target.addLegalOp(); OwningRewritePatternList patterns; - populateWithGenerated(context, &patterns); + populateWithGenerated(context, patterns); patterns.insert { LogicalResult matchAndRewrite(TFL::MulOp mul_op, PatternRewriter &rewriter) const override { + // If we are broadcasting on the lhs then don't fold the multiply as it + // would increase the amount of compute done by the fully connected op. + if (mul_op.lhs().getType() != mul_op.getType()) return failure(); + // Mul. DenseElementsAttr cst; Value constant_val = mul_op.rhs(); @@ -794,7 +798,7 @@ void Optimize::runOnFunction() { // Potentially the binary ops might be fused together, like hard_swish, thus // we explore these potentially first and then fuse the binary ops with the // following ops in a second pattern match. - TFL::populateWithGenerated(ctx, &patterns); + TFL::populateWithGenerated(ctx, patterns); patterns.insert, FuseFullyConnectedAndReluX, diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc b/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc index 2311ae0668c..f1ea837446b 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc @@ -83,8 +83,8 @@ class FoldIfOp : public OpRewritePattern { if (!llvm::hasSingleElement(parent_op)) return failure(); // Find the then and else branch functions. - FuncOp then_func = op.then_func(); - FuncOp else_func = op.else_func(); + FuncOp then_func = op.then_function(); + FuncOp else_func = op.else_function(); // If the If has no uses and its functions are side-effect free, then // remove. diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td index 559d22dcf47..653c33ea9df 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td @@ -21,8 +21,13 @@ include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td" include "tensorflow/compiler/mlir/lite/utils/utils.td" include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" +// Checks if the param passed is a F32 ElementsAttr. def F32ElementsAttr : ElementsAttrBase< - CPred<"$_self.cast().getType().getElementType().isF32()">, "float constant tensor">; + CPred<"$_self.isa() && $_self.cast().getType().getElementType().isF32()">, + "float constant tensor">; + +// Checks if the param passed is of NoneType. +def IsNoneType : Constraint()">>; def ExtractSingleElementAsFloat : NativeCodeCall< "ExtractSingleElementAsFloat($_self.cast())">; @@ -52,15 +57,31 @@ multiclass FuseActFnIntoConvOpPat { [(HasOneUse $conv_out)]>; } +multiclass FuseActFnIntoPoolOpPat { + def FuseActivationFuncWithAvgPool#ActFnOp#ActFnAttr : Pat< + (ActFnOp (TFL_AveragePool2DOp:$pool_out $input, $filter_height, + $filter_width, $padding, $stride_h, $stride_w, TFL_AF_None)), + (TFL_AveragePool2DOp $input, $filter_height, $filter_width, $padding, + $stride_h, $stride_w, ActFnAttr), + [(HasOneUse $pool_out)]>; + def FuseActivationFuncWithMaxPool#ActFnOp#ActFnAttr : Pat< + (ActFnOp (TFL_MaxPool2DOp:$pool_out $input, $padding, $stride_w, $stride_h, + $filter_width, $filter_height, TFL_AF_None)), + (TFL_MaxPool2DOp $input, $padding, $stride_w, $stride_h, + $filter_width, $filter_height, ActFnAttr), + [(HasOneUse $pool_out)]>; +} + // TODO(hinsu): Also fuse ops corresponding to SIGN_BIT fused // activation functions. // Currently we're not fusing tanh, sigmoid, hard_swish and other activations // those cannot be simply translated into clamping. foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu], [TFL_Relu6Op, TFL_AF_Relu6], - [TFL_Relu1Op, TFL_AF_Relu1]] in + [TFL_Relu1Op, TFL_AF_Relu1]] in { defm : FuseActFnIntoConvOpPat; - + defm : FuseActFnIntoPoolOpPat; +} class CanFuseConvOrDepthwiseConv : Constraint< CPred<"TFL::CanFuseConvOrDepthwiseConv($0, $1, " # is_depthwise # ")">>; @@ -93,6 +114,29 @@ multiclass FuseBinaryOpToPrecedingAffine { $multiplier), [(CanFuseConvOrDepthwiseConv<"true"> $filter, $value), (HasOneUse $output)]>; + def FuseBinaryOpWithTransposeConv#binaryOp : Pat< + (binaryOp (TFL_TransposeConvOp:$output $output_shape, $weights, $inputs, + (ConstantOp F32ElementsAttr:$bias), $padding, + $stride_h, $stride_w), + (ConstantOp F32ElementsAttr:$value), TFL_AF_None), + (TFL_TransposeConvOp $output_shape, $weights, $inputs, + (binaryOp (ConstantOp $bias), + (ConstantOp $value), TFL_AF_None), + $padding, $stride_h, $stride_w), + [(CanFuseConvOrDepthwiseConv<"false"> $weights, $value), + (HasOneUse $output)]>; + // Fuse for TransposeConv with no bias + def FuseBinaryOpWithTransposeConvNoneBias#binaryOp : Pat< + (binaryOp (TFL_TransposeConvOp:$output $output_shape, $weights, $inputs, + (ConstantOp $bias), $padding, + $stride_h, $stride_w), + (ConstantOp F32ElementsAttr:$value), TFL_AF_None), + (TFL_TransposeConvOp $output_shape, $weights, $inputs, + (ConstantOp $value), + $padding, $stride_h, $stride_w), + [(CanFuseConvOrDepthwiseConv<"false"> $weights, $value), + (IsNoneType $bias), + (HasOneUse $output)]>; } foreach binaryOp = [TFL_AddOp, TFL_SubOp] in defm : FuseBinaryOpToPrecedingAffine; @@ -146,6 +190,39 @@ multiclass FuseMulOrDivWithConv2dOrDepthwiseConv2d { $h_factor, $w_factor, $act_fn, $padding, $stride_h, $stride_w), [(CanFuseConvOrDepthwiseConv<"false"> $filter, $value), (HasOneUse $conv_output)]>; + def FuseMulOrDivWithTransposeConv#BinaryOp : Pat< + (BinaryOp (TFL_TransposeConvOp:$output $output_shape, + (ConstantOp F32ElementsAttr:$weights), $input, + (ConstantOp F32ElementsAttr:$bias), + $padding, $stride_h, $stride_w), + (ConstantOp $value), TFL_AF_None), + (TFL_TransposeConvOp $output_shape, + (BinaryOp (ConstantOp $weights), + (ConstantOp (ExpandTo4DForConv $value)), + TFL_AF_None), + $input, + (BinaryOp (ConstantOp $bias), + (ConstantOp $value), + TFL_AF_None), + $padding, $stride_h, $stride_w), + [(CanFuseConvOrDepthwiseConv<"false"> $weights, $value), + (HasOneUse $output)]>; + def FuseMulOrDivWithTransposeConvWithNoneBias#BinaryOp : Pat< + (BinaryOp (TFL_TransposeConvOp:$output $output_shape, + (ConstantOp F32ElementsAttr:$weights), $input, + (ConstantOp $bias), + $padding, $stride_h, $stride_w), + (ConstantOp $value), TFL_AF_None), + (TFL_TransposeConvOp $output_shape, + (BinaryOp (ConstantOp $weights), + (ConstantOp (ExpandTo4DForConv $value)), + TFL_AF_None), + $input, + (ConstantOp $bias), + $padding, $stride_h, $stride_w), + [(CanFuseConvOrDepthwiseConv<"false"> $weights, $value), + (IsNoneType $bias), + (HasOneUse $output)]>; } foreach BinaryOp = [TFL_DivOp, TFL_MulOp] in @@ -420,9 +497,9 @@ def ConvertExpandDimsToReshape : Pat< [(AnyStaticShapeTensor $expand_dims_op)]>; class FloatValueEquals : Constraint().getNumElements() == 1 &&" - "$0.isa() &&" - "*$0.cast().getValues().begin() == " # val>>; + "$0.isa() && " + "llvm::all_of($0.cast().getFloatValues(), " + "[](const APFloat& f) { return f.isExactlyValue(" # val # "); })">>; // ReLU patterns def MatchReluPattern : Pat< @@ -552,3 +629,37 @@ foreach ReduceOp = [TFL_ReduceMaxOp, TFL_ReduceMinOp, TFL_ReduceProdOp, (HasOneUse $reduce)]>; } + +def IsSame : Constraint>; +def HasTwoUse : Constraint>; +def AxesIsLastDimension : Constraint().getNumElements() == 1 && " + "$0.cast().getValue({0}) == " + "$1.getType().cast().getRank() - 1">>; + +// Convert exp(x)/sum(exp(x)) into softmax. +def OptimizeToSoftmax : Pat< + (TFL_DivOp (TFL_ExpOp:$exp $input), + (TFL_SumOp:$sum $sum_input, (ConstantOp I32ElementsAttr: $axes), + ConstBoolAttrTrue), TFL_AF_None), + (TFL_SoftmaxOp $input, ConstF32Attr<"1.0">), + [(IsSame $exp, $sum_input), + (AxesIsLastDimension $axes, $sum_input), + (HasTwoUse $exp), + (HasOneUse $sum)]>; + +// Convert softmax(x-max(x)) into softmax(x) as the softmax op already deals +// with the max normalization. +def FoldNormalizationIntoSoftmax : Pat< + (TFL_SoftmaxOp + (TFL_SubOp:$sub $input, + (TFL_ReduceMaxOp:$max $max_input, (ConstantOp I32ElementsAttr: $axes), + ConstBoolAttrTrue), + TFL_AF_None), + $beta), + (TFL_SoftmaxOp $input, $beta), + [(IsSame $input, $max_input), + (AxesIsLastDimension $axes, $max_input), + (HasOneUse $sub), + (HasOneUse $max)]>; diff --git a/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc b/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc index 1c6550bc902..ca30b2f1fcf 100644 --- a/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc @@ -144,7 +144,7 @@ void PostQuantizePass::runOnFunction() { OwningRewritePatternList patterns; auto func = getFunction(); auto* ctx = func.getContext(); - TFL::populateWithGenerated(ctx, &patterns); + TFL::populateWithGenerated(ctx, patterns); patterns.insert>(ctx); applyPatternsAndFoldGreedily(func, patterns); diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc index 0efd7187e16..172ce59ddd4 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc @@ -42,6 +42,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/utils/lstm_utils.h" +#include "tensorflow/compiler/mlir/lite/utils/nms_utils.h" #include "tensorflow/compiler/mlir/lite/utils/tftext_utils.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -59,6 +60,7 @@ namespace { constexpr char kTFAPIImplements[] = "tf.api_implements"; constexpr char kTFTextAPIPrefix[] = "tftext:"; +constexpr char kCustomSSDPostprocessing[] = "TFLite_Detection_PostProcess"; constexpr char kTfNMSPadded[] = "non_max_suppression_padded_v2"; using mlir::TF::FuncAttr; @@ -99,59 +101,6 @@ class ConvertEmbeddedLookupFunc { FuncOp func_; }; -// Abstracts the conversion of the padded NMS composite function. -class ConvertNMSPaddedFunc { - public: - explicit ConvertNMSPaddedFunc(FuncOp func) : func_(func) {} - - void RewriteFunc() { - func_.setAttr(kTFImplements, - StringAttr::get(kTfNMSPadded, func_.getContext())); - Value boxes = func_.getArgument(0); - Value scores = func_.getArgument(1); - Value max_output_size = func_.getArgument(2); - Value iou_threshold = func_.getArgument(3); - Value score_threshold = func_.getArgument(4); - auto output_type0 = func_.getType().getResult(0); - auto output_type1 = func_.getType().getResult(1); - - OpBuilder builder(func_.getBody()); - auto op = builder.create( - func_.getLoc(), output_type0, output_type1, boxes, scores, - max_output_size, iou_threshold, score_threshold); - - builder.create(func_.getLoc(), op.getResults()); - } - - LogicalResult VerifySignature() { - // Verify high-level function signature. - // Relevant argument characteristics are checked by the TFL op definition. - if (func_.getNumArguments() < 5) { - return func_.emitError() - << "Invalid number of arguments to " - "non_max_suppression_padded_v2 (need atleast 5): " - << func_.getNumArguments(); - } - if (func_.getType().getNumResults() != 2) { - return func_.emitError() << "Invalid number of results from " - "non_max_suppression_padded_v2 (need 2): " - << func_.getType().getNumResults(); - } - // The TFLite fused op does not support batching yet. - // TODO(b/158709815): Add support for batches with padded NMS. - auto boxes_type = - func_.getArgument(0).getType().dyn_cast(); - if (!boxes_type.hasRank() || boxes_type.getRank() != 2) { - return func_.emitError() << "TFLite does not support batched input for " - "non_max_suppression_padded"; - } - return success(); - } - - private: - FuncOp func_; -}; - // This pass uses mechanisms listed in RFC: // https://github.com/tensorflow/community/pull/113 // It prepares composite functions that are attributed to indicate @@ -161,6 +110,10 @@ class ConvertNMSPaddedFunc { class PrepareCompositeFunctionsPass : public PassWrapper> { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + public: explicit PrepareCompositeFunctionsPass() {} @@ -219,6 +172,12 @@ void PrepareCompositeFunctionsPass::ConvertTFImplementsWithAttributes( if (failed(ConvertTFTextAPI(func, api_name, attr))) { return signalPassFailure(); } + } else if (api_name == kCustomSSDPostprocessing) { + ConvertSSDPostProcessFunc convert_ssd_postprocess(func, attr); + if (failed(convert_ssd_postprocess.VerifySignature()) || + failed(convert_ssd_postprocess.RewriteFunc())) { + return signalPassFailure(); + } } } diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td b/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td index f5b252773f6..5cfdb4b982d 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td @@ -40,64 +40,6 @@ def : Pat< (TF_MulOp $t, (TF_MulOp:$mul (TF_RsqrtOp (TF_AddOp $v, (TF_ConstOp $variance_epsilon))), $gamma)), (TF_SubOp $beta, (TF_MulOp $m, $mul)))>; -// Converts tf.FusedBatchNorm & tf.FusedBatchNormV3 into a sequence of more primitive arithmetic -// operations. Specifically, performs the following calculation: -// -// (x - mean) * scale / sqrt(variance + epsilon) + offset -// -// Let multiplier = scale / sqrt(variance + epsilon), -// to compute -// (x - mean) * scale / sqrt(variance + epsilon) + offset, -// is then to compute -// (x * multiplier) + (offset - mean * multiplier). -def : Pattern< - (TF_FusedBatchNormOp:$root - $x, $scale, $offset, $mean, $variance, - F32Attr:$epsilon, $exponential_avg_factor, - $data_format, FalseBoolAttr:$is_training), - [(TF_AddOp - (TF_MulOp - $x, - (TF_MulOp:$multiplier - $scale, - (TF_RsqrtOp - (TF_AddOp $variance, - (TF_ConstOp $epsilon))))), - (TF_SubOp $offset, (TF_MulOp $mean, $multiplier))), - // We already guaranteed that the last four results has no use so it does - // not matter what value we provide here for replacement. - /*batch_mean=*/(replaceWithValue $x), - /*batch_variance=*/(replaceWithValue $x), - /*reserve_space_1=*/(replaceWithValue $x), - /*reserve_space_2=*/(replaceWithValue $x)], - [(HasNoUseOf:$root__1), (HasNoUseOf:$root__2), - (HasNoUseOf:$root__3), (HasNoUseOf:$root__4)]>; - -def : Pattern< - (TF_FusedBatchNormV3Op:$root - $x, $scale, $offset, $mean, $variance, - F32Attr:$epsilon, $exponential_avg_factor, - $data_format, FalseBoolAttr:$is_training), - [(TF_AddOp - (TF_MulOp - $x, - (TF_MulOp:$multiplier - $scale, - (TF_RsqrtOp - (TF_AddOp $variance, - (TF_ConstOp $epsilon))))), - (TF_SubOp $offset, (TF_MulOp $mean, $multiplier))), - // We already guaranteed that the last five results have no use so it does - // not matter what value we provide here for replacement. - /*batch_mean=*/(replaceWithValue $x), - /*batch_variance=*/(replaceWithValue $x), - /*reserve_space_1=*/(replaceWithValue $x), - /*reserve_space_2=*/(replaceWithValue $x), - /*reserve_space_3=*/(replaceWithValue $x)], - [(HasNoUseOf:$root__1), (HasNoUseOf:$root__2), - (HasNoUseOf:$root__3), (HasNoUseOf:$root__4), - (HasNoUseOf:$root__5)]>; - class TFi32 : ConstantAttr(v)>; // Matmul without transpose on b to matmul with explicit transpose op and diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc index 07b7aacd95d..783f21fce21 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc @@ -69,6 +69,11 @@ namespace { // training quantization simpler. class PrepareQuantizePass : public PassWrapper { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + public: // Constructor used by the PassRegistration and enforce uint8 quantization. // This is only used by test. diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc index c521ca0ed53..c4f30c22be3 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc @@ -80,9 +80,11 @@ namespace { // Prepare TF operations in functions for subsequent legalization. class PrepareTFPass : public PassWrapper { public: - explicit PrepareTFPass() : unfold_batch_matmul_(true) {} - explicit PrepareTFPass(bool unfold_batch_matmul) - : unfold_batch_matmul_(unfold_batch_matmul) {} + PrepareTFPass() = default; + PrepareTFPass(const PrepareTFPass &) {} + explicit PrepareTFPass(bool unfold_batch_matmul) { + unfold_batch_matmul_ = unfold_batch_matmul; + } void runOnFunction() override; void getDependentDialects(DialectRegistry ®istry) const override { @@ -91,7 +93,10 @@ class PrepareTFPass : public PassWrapper { } private: - bool unfold_batch_matmul_; + Option unfold_batch_matmul_{ + *this, "tfl-unfold-batch-matmul", + llvm::cl::desc("Unfold BatchMatMul into individual MatMul ops."), + llvm::cl::init(true)}; }; template @@ -210,9 +215,8 @@ struct InsertTFLQuantOpsAfterTFFakeQuantOp } // Use the min/max from the operands and the num_bits and narrow_range // attribute to create the quantization parameter for the new quantize op. - rewriter.setInsertionPointAfter(tf_op); - IntegerAttr num_bits = - rewriter.getI64IntegerAttr(tf_op.num_bits().getSExtValue()); + rewriter.setInsertionPointAfter(tf_op.getOperation()); + IntegerAttr num_bits = rewriter.getI64IntegerAttr(tf_op.num_bits()); BoolAttr narrow_range = rewriter.getBoolAttr(tf_op.narrow_range()); Type res_type = tf_op.getType(); TypeAttr qtype = quant::GetQuantizedTypeAttr( @@ -533,8 +537,8 @@ struct ConvertTFStridedSlice : public RewritePattern { loc, new_output_type, original_input, shape); // Replace the original strided_slice. - llvm::APInt new_begin_mask = strided_slice_op.begin_mask(); - llvm::APInt new_end_mask = strided_slice_op.end_mask(); + uint64_t new_begin_mask = strided_slice_op.begin_mask(); + uint64_t new_end_mask = strided_slice_op.end_mask(); // Since we expand the dims, we need to apply them to the begin_mask & // end_mask. new_begin_mask |= strided_slice_op.new_axis_mask(); @@ -597,8 +601,8 @@ struct ConvertTFStridedSlice : public RewritePattern { const int ellipsis_filled_dim_size = input_size - begin_shape[0] + 1; - int64_t begin_mask = strided_slice_op.begin_mask().getSExtValue(); - int64_t end_mask = strided_slice_op.end_mask().getSExtValue(); + int64_t begin_mask = strided_slice_op.begin_mask(); + int64_t end_mask = strided_slice_op.end_mask(); int64_t new_begin_mask = 0; int64_t new_end_mask = 0; @@ -634,13 +638,16 @@ struct ConvertTFStridedSlice : public RewritePattern { ++index; // After the ellipsis. - for (; index < begin_shape[0]; ++index) { + for (; index < begin_shape[0];) { padded_begin.push_back(begin_dense_elem_attr.getValue(index)); padded_end.push_back(end_dense_elem_attr.getValue(index)); padded_stride.push_back(stride_dense_elem_attr.getValue(index)); if ((begin_mask >> index) & 1) new_begin_mask |= (1 << new_index); if ((end_mask >> index) & 1) new_end_mask |= (1 << new_index); + + ++index; + ++new_index; } auto attribute_type = rewriter.getIntegerType(64); @@ -676,16 +683,16 @@ struct ConvertTFStridedSlice : public RewritePattern { // TODO(renjieliu): Consider expand the transformation for shrink mask as // well. - if (strided_slice_op.shrink_axis_mask().getZExtValue()) return failure(); + if (strided_slice_op.shrink_axis_mask()) return failure(); // Handle new axis mask. - uint64_t new_axis_mask = strided_slice_op.new_axis_mask().getZExtValue(); + uint64_t new_axis_mask = strided_slice_op.new_axis_mask(); if (new_axis_mask != 0) { return RewriteNewAxisMask(strided_slice_op, new_axis_mask, rewriter); } // Handle ellipsis mask. - uint64_t ellipsis_mask = strided_slice_op.ellipsis_mask().getZExtValue(); + uint64_t ellipsis_mask = strided_slice_op.ellipsis_mask(); if (ellipsis_mask != 0) { return RewriteEllipsisMask(strided_slice_op, ellipsis_mask, rewriter); } @@ -733,6 +740,278 @@ struct ConvertTFBroadcastTo : public RewritePattern { } }; +// The below pattern is equivalent to the DRR rule below +// The checks are dependent on generated values, so we can't add +// the checks on intermediate values, ideally we should find equivalent +// checks that guarantees the resultant ops are valid. +// The extra conditions are the broadcasting conditions. +// +// The pattern lower FusedBatchNormV3 to arithmetic ops. +// Specifically, performs the following calculation: +// +// (x - mean) * scale / sqrt(variance + epsilon) + offset +// +// Let multiplier = scale / sqrt(variance + epsilon), +// to compute +// (x - mean) * scale / sqrt(variance + epsilon) + offset, +// is then to compute +// (x * multiplier) + (offset - mean * multiplier). +// +// def : Pattern< +// (TF_FusedBatchNormV3Op:$root +// $x, $scale, $offset, $mean, $variance, +// F32Attr:$epsilon, $exponential_avg_factor, +// $data_format, FalseBoolAttr:$is_training), +// [(TF_AddOp +// (TF_MulOp +// $x, +// (TF_MulOp:$multiplier +// $scale, +// (TF_RsqrtOp +// (TF_AddOp $variance, +// (TF_ConstOp $epsilon))))), +// (TF_SubOp $offset, (TF_MulOp $mean, $multiplier))), +// // We already guaranteed that the last five results have no use so it does +// // not matter what value we provide here for replacement. +// /*batch_mean=*/(replaceWithValue $x), +// /*batch_variance=*/(replaceWithValue $x), +// /*reserve_space_1=*/(replaceWithValue $x), +// /*reserve_space_2=*/(replaceWithValue $x), +// /*reserve_space_3=*/(replaceWithValue $x)], +// [(HasNoUseOf:$root__1), (HasNoUseOf:$root__2), +// (HasNoUseOf:$root__3), (HasNoUseOf:$root__4), +// (HasNoUseOf:$root__5), (AreBroadcastableTypes $multiplier, $x)]>; + +struct FusedBatchNormV3Pat : public ::mlir::RewritePattern { + explicit FusedBatchNormV3Pat(::mlir::MLIRContext *context) + : ::mlir::RewritePattern( + "tf.FusedBatchNormV3", + {"tf.Add", "tf.Const", "tf.Mul", "tf.Rsqrt", "tf.Sub"}, 1, + context) {} + + ::mlir::LogicalResult matchAndRewrite( + ::mlir::Operation *fused_batch_norm, + ::mlir::PatternRewriter &rewriter) const override { + // Variables for capturing values and attributes used for creating ops + Operation::operand_range mean(fused_batch_norm->getOperands()); + ::mlir::FloatAttr exponential_avg_factor; + ::mlir::StringAttr data_format; + ::mlir::TF::FusedBatchNormV3Op root; + Operation::operand_range offset(fused_batch_norm->getOperands()); + Operation::operand_range x(fused_batch_norm->getOperands()); + Operation::operand_range scale(fused_batch_norm->getOperands()); + Operation::operand_range variance(fused_batch_norm->getOperands()); + ::mlir::FloatAttr epsilon; + ::mlir::BoolAttr is_training; + + // Match + auto fused_batch_norm_op = + dyn_cast_or_null<::mlir::TF::FusedBatchNormV3Op>(fused_batch_norm); + root = fused_batch_norm_op; + x = fused_batch_norm_op.getODSOperands(0); + scale = fused_batch_norm_op.getODSOperands(1); + offset = fused_batch_norm_op.getODSOperands(2); + mean = fused_batch_norm_op.getODSOperands(3); + variance = fused_batch_norm_op.getODSOperands(4); + { + epsilon = fused_batch_norm_op.getAttrOfType<::mlir::FloatAttr>("epsilon"); + if (!epsilon) + epsilon = rewriter.getFloatAttr(rewriter.getF32Type(), 0.0001f); + + if (!(((epsilon.isa<::mlir::FloatAttr>())) && + ((epsilon.cast<::mlir::FloatAttr>().getType().isF32())))) { + return rewriter.notifyMatchFailure( + fused_batch_norm_op, [&](::mlir::Diagnostic &diag) { + diag << "op 'tf.FusedBatchNormV3' attribute 'epsilon' failed to " + "satisfy constraint: 32-bit float attribute"; + }); + } + } + { + exponential_avg_factor = + fused_batch_norm_op.getAttrOfType<::mlir::FloatAttr>( + "exponential_avg_factor"); + if (!exponential_avg_factor) + exponential_avg_factor = + rewriter.getFloatAttr(rewriter.getF32Type(), 1.0f); + } + { + data_format = + fused_batch_norm_op.getAttrOfType<::mlir::StringAttr>("data_format"); + if (!data_format) data_format = rewriter.getStringAttr("NHWC"); + } + { + is_training = + fused_batch_norm_op.getAttrOfType<::mlir::BoolAttr>("is_training"); + if (!is_training) is_training = rewriter.getBoolAttr(true); + + if (!((!is_training.getValue()))) { + return rewriter.notifyMatchFailure( + fused_batch_norm_op, [&](::mlir::Diagnostic &diag) { + diag << "op 'tf.FusedBatchNormV3' attribute 'is_training' failed " + "to " + "satisfy constraint: FalseBoolAttr"; + }); + } + } + + if (!(((*root.getODSResults(1).begin()).use_empty()))) { + return rewriter.notifyMatchFailure( + fused_batch_norm_op, [&](::mlir::Diagnostic &diag) { + diag << "entities '' failed to satisfy constraint: has no use"; + }); + } + + if (!(((*root.getODSResults(2).begin()).use_empty()))) { + return rewriter.notifyMatchFailure( + fused_batch_norm_op, [&](::mlir::Diagnostic &diag) { + diag << "entities '' failed to satisfy constraint: has no use"; + }); + } + + if (!(((*root.getODSResults(3).begin()).use_empty()))) { + return rewriter.notifyMatchFailure( + fused_batch_norm_op, [&](::mlir::Diagnostic &diag) { + diag << "entities '' failed to satisfy constraint: has no use"; + }); + } + + if (!(((*root.getODSResults(4).begin()).use_empty()))) { + return rewriter.notifyMatchFailure( + fused_batch_norm_op, [&](::mlir::Diagnostic &diag) { + diag << "entities '' failed to satisfy constraint: has no use"; + }); + } + + if (!(((*root.getODSResults(5).begin()).use_empty()))) { + return rewriter.notifyMatchFailure( + fused_batch_norm_op, [&](::mlir::Diagnostic &diag) { + diag << "entities '' failed to satisfy constraint: has no use"; + }); + } + // Rewrite + auto odsLoc = rewriter.getFusedLoc({fused_batch_norm->getLoc()}); + ::llvm::SmallVector<::mlir::Value, 4> replace_values; + ::mlir::TF::ConstOp epsilon_const_op; + { + epsilon_const_op = + rewriter.create<::mlir::TF::ConstOp>(odsLoc, + /*value=*/epsilon); + } + ::mlir::TF::AddOp add_op_1; + { + ::mlir::Value tblgen_value_0 = (*variance.begin()); + ::mlir::Value tblgen_value_1 = + (*epsilon_const_op.getODSResults(0).begin()); + add_op_1 = rewriter.create<::mlir::TF::AddOp>(odsLoc, + /*x=*/tblgen_value_0, + /*y=*/tblgen_value_1); + // We need to make sure the Add operands are broadcastable. + if (mlir::OpTrait::impl::verifyCompatibleOperandBroadcast(add_op_1) + .value == LogicalResult::Failure) { + return failure(); + } + } + ::mlir::TF::RsqrtOp rsqrt_op; + { + ::mlir::SmallVector<::mlir::Value, 4> tblgen_values; + ::mlir::SmallVector<::mlir::NamedAttribute, 4> tblgen_attrs; + tblgen_values.push_back((*add_op_1.getODSResults(0).begin())); + rsqrt_op = rewriter.create<::mlir::TF::RsqrtOp>(odsLoc, tblgen_values, + tblgen_attrs); + } + ::mlir::TF::MulOp multiplier; + { + ::mlir::Value tblgen_value_0 = (*scale.begin()); + ::mlir::Value tblgen_value_1 = (*rsqrt_op.getODSResults(0).begin()); + // We need to make sure the Add operands are broadcastable. + multiplier = rewriter.create<::mlir::TF::MulOp>(odsLoc, + /*x=*/tblgen_value_0, + /*y=*/tblgen_value_1); + if (mlir::OpTrait::impl::verifyCompatibleOperandBroadcast(multiplier) + .value == LogicalResult::Failure) { + return failure(); + } + } + ::mlir::TF::MulOp mul_op_1; + { + ::mlir::Value tblgen_value_0 = (*x.begin()); + ::mlir::Value tblgen_value_1 = (*multiplier.getODSResults(0).begin()); + mul_op_1 = rewriter.create<::mlir::TF::MulOp>(odsLoc, + /*x=*/tblgen_value_0, + /*y=*/tblgen_value_1); + // We need to make sure the Mul operands are broadcastable. + if (mlir::OpTrait::impl::verifyCompatibleOperandBroadcast(mul_op_1) + .value == LogicalResult::Failure) { + return failure(); + } + } + ::mlir::TF::MulOp mul_op_2; + { + ::mlir::Value tblgen_value_0 = (*mean.begin()); + ::mlir::Value tblgen_value_1 = (*multiplier.getODSResults(0).begin()); + mul_op_2 = rewriter.create<::mlir::TF::MulOp>(odsLoc, + /*x=*/tblgen_value_0, + /*y=*/tblgen_value_1); + if (mlir::OpTrait::impl::verifyCompatibleOperandBroadcast(mul_op_2) + .value == LogicalResult::Failure) { + return failure(); + } + } + ::mlir::TF::SubOp sub_op; + { + ::mlir::Value tblgen_value_0 = (*offset.begin()); + ::mlir::Value tblgen_value_1 = (*mul_op_2.getODSResults(0).begin()); + sub_op = rewriter.create<::mlir::TF::SubOp>(odsLoc, + /*x=*/tblgen_value_0, + /*y=*/tblgen_value_1); + if (mlir::OpTrait::impl::verifyCompatibleOperandBroadcast(sub_op).value == + LogicalResult::Failure) { + return failure(); + } + } + ::mlir::TF::AddOp add_op_2; + { + ::mlir::SmallVector<::mlir::Value, 4> tblgen_values; + ::mlir::SmallVector<::mlir::NamedAttribute, 4> tblgen_attrs; + tblgen_values.push_back((*mul_op_1.getODSResults(0).begin())); + tblgen_values.push_back((*sub_op.getODSResults(0).begin())); + ::mlir::SmallVector<::mlir::Type, 4> tblgen_types; + for (auto v : fused_batch_norm_op.getODSResults(0)) { + tblgen_types.push_back(v.getType()); + } + add_op_2 = rewriter.create<::mlir::TF::AddOp>( + odsLoc, tblgen_types, tblgen_values, tblgen_attrs); + // We need to make sure the Add operands are broadcastable. + if (mlir::OpTrait::impl::verifyCompatibleOperandBroadcast(add_op_2) + .value == LogicalResult::Failure) { + return failure(); + } + } + for (auto v : + ::llvm::SmallVector<::mlir::Value, 4>{add_op_2.getODSResults(0)}) { + replace_values.push_back(v); + } + for (auto v : ::llvm::SmallVector<::mlir::Value, 4>{x}) { + replace_values.push_back(v); + } + for (auto v : ::llvm::SmallVector<::mlir::Value, 4>{x}) { + replace_values.push_back(v); + } + for (auto v : ::llvm::SmallVector<::mlir::Value, 4>{x}) { + replace_values.push_back(v); + } + for (auto v : ::llvm::SmallVector<::mlir::Value, 4>{x}) { + replace_values.push_back(v); + } + for (auto v : ::llvm::SmallVector<::mlir::Value, 4>{x}) { + replace_values.push_back(v); + } + rewriter.replaceOp(fused_batch_norm, replace_values); + return success(); + }; +}; + #include "tensorflow/compiler/mlir/lite/transforms/generated_prepare_tf.inc" // Returns success if all the operations in the `op`'s regions including `op` @@ -758,10 +1037,13 @@ LogicalResult ConvertTf2XlaOps(FuncOp func, MLIRContext *context) { target.addLegalOp(); target.addLegalOp(); target.addIllegalOp(); + target.addIllegalOp(); OwningRewritePatternList patterns; mhlo::PopulateLegalizeTfWithTf2XlaPatterns("XLA_CPU_JIT", patterns); + mhlo::PopulateLegalizeTfPatterns(context, &patterns); TF::PopulateLegalizeHloToTfPatterns(&patterns, context); + mhlo::GatherOp::getCanonicalizationPatterns(patterns, context); return applyPartialConversion(func, target, patterns); } @@ -892,9 +1174,10 @@ void PrepareTFPass::runOnFunction() { // This pattern will try to identify and optimize for dilated convolution. // e.g. Patterns like "SpaceToBatchND -> Conv2D -> BatchToSpaceND" will be // replaced with a single Conv op with dilation parameter. - patterns.insert, + patterns.insert, FusedBatchNormV3Pat, ConvertTFDilatedConvOp>(ctx); - TFL::populateWithGenerated(ctx, &patterns); + + TFL::populateWithGenerated(ctx, patterns); // TODO(karimnosseir): Split to separate pass probably after // deciding on long term plan for this optimization. // This will allow optimizing any TF_Mul->TF_Conv in the graph @@ -905,7 +1188,7 @@ void PrepareTFPass::runOnFunction() { // Load the generated pattern again, so new quantization pass-through // will be applied. patterns.clear(); - TFL::populateWithGenerated(ctx, &patterns); + TFL::populateWithGenerated(ctx, patterns); if (unfold_batch_matmul_) { patterns.insert, TF::ConvertTFBatchMatMulOp>(ctx); diff --git a/tensorflow/compiler/mlir/lite/transforms/quantize.cc b/tensorflow/compiler/mlir/lite/transforms/quantize.cc index ba25b5c897c..529e57780c3 100644 --- a/tensorflow/compiler/mlir/lite/transforms/quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/quantize.cc @@ -84,7 +84,7 @@ void QuantizePass::runOnFunction() { OwningRewritePatternList patterns; auto func = getFunction(); auto* ctx = func.getContext(); - TFL::populateWithGenerated(ctx, &patterns); + TFL::populateWithGenerated(ctx, patterns); patterns.insert( ctx, enable_numeric_verify, error_tolerance, enable_single_layer_verify); applyPatternsAndFoldGreedily(func, patterns); diff --git a/tensorflow/compiler/mlir/lite/utils/nms_utils.cc b/tensorflow/compiler/mlir/lite/utils/nms_utils.cc new file mode 100644 index 00000000000..e462d4f38b0 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/utils/nms_utils.cc @@ -0,0 +1,174 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/lite/utils/nms_utils.h" + +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" + +namespace mlir { +namespace TFL { + +namespace { + +// TODO(b/162842801): Consolidate all util definitions of kTFImplements. +constexpr char kTFImplements[] = "tf._implements"; +constexpr char kCustomSSDPostprocessing[] = "TFLite_Detection_PostProcess"; +constexpr char kTfNMSPadded[] = "non_max_suppression_padded_v2"; + +inline OpaqueElementsAttr CustomOption(OpBuilder* builder, + const std::string& content) { + ShapedType type = RankedTensorType::get( + {static_cast(content.size())}, builder->getIntegerType(8)); + return OpaqueElementsAttr::get(builder->getContext()->getLoadedDialect("tfl"), + type, + StringRef(content.data(), content.size())); +} + +} // namespace + +void ConvertNMSPaddedFunc::RewriteFunc() { + func_.setAttr(kTFImplements, + StringAttr::get(kTfNMSPadded, func_.getContext())); + Value boxes = func_.getArgument(0); + Value scores = func_.getArgument(1); + Value max_output_size = func_.getArgument(2); + Value iou_threshold = func_.getArgument(3); + Value score_threshold = func_.getArgument(4); + auto output_type0 = func_.getType().getResult(0); + auto output_type1 = func_.getType().getResult(1); + + OpBuilder builder(func_.getBody()); + auto op = builder.create( + func_.getLoc(), output_type0, output_type1, boxes, scores, + max_output_size, iou_threshold, score_threshold); + + builder.create(func_.getLoc(), op.getResults()); +} + +LogicalResult ConvertNMSPaddedFunc::VerifySignature() { + // Verify high-level function signature. + // Relevant argument characteristics are checked by the TFL op definition. + if (func_.getNumArguments() < 5) { + return func_.emitError() + << "Invalid number of arguments to " + "non_max_suppression_padded_v2 (need atleast 5): " + << func_.getNumArguments(); + } + if (func_.getType().getNumResults() != 2) { + return func_.emitError() << "Invalid number of results from " + "non_max_suppression_padded_v2 (need 2): " + << func_.getType().getNumResults(); + } + // The TFLite fused op does not support batching yet. + // TODO(b/158709815): Add support for batches with padded NMS. + auto boxes_type = func_.getArgument(0).getType().dyn_cast(); + if (!boxes_type.hasRank() || boxes_type.getRank() != 2) { + return func_.emitError() << "TFLite does not support batched input for " + "non_max_suppression_padded"; + } + return success(); +} + +LogicalResult ConvertSSDPostProcessFunc::RewriteFunc() { + func_.eraseBody(); + func_.addEntryBlock(); + func_.setAttr(kTFImplements, + StringAttr::get(kCustomSSDPostprocessing, func_.getContext())); + + OpBuilder builder(func_.getBody()); + std::string custom_option_buffer; + if (failed(CreateNMSCustomOptions(func_, attr_.GetAttrs(), + custom_option_buffer))) { + return failure(); + } + auto op = builder.create( + func_.getLoc(), func_.getType().getResults(), func_.getArguments(), + kCustomSSDPostprocessing, CustomOption(&builder, custom_option_buffer)); + builder.create(func_.getLoc(), op.getResults()); + + return success(); +} + +LogicalResult ConvertSSDPostProcessFunc::CreateNMSCustomOptions( + FuncOp func, DictionaryAttr attrs, std::string& custom_option_buffer) { + flexbuffers::Builder fbb; + size_t start_map = fbb.StartMap(); + + if (failed(AddIntAttr(func, attrs, "max_detections", &fbb)) || + failed(AddIntAttr(func, attrs, "max_classes_per_detection", &fbb)) || + failed(AddIntAttr(func, attrs, "num_classes", &fbb)) || + failed(AddFloatAttr(func, attrs, "nms_score_threshold", &fbb)) || + failed(AddFloatAttr(func, attrs, "nms_iou_threshold", &fbb)) || + failed(AddFloatAttr(func, attrs, "y_scale", &fbb)) || + failed(AddFloatAttr(func, attrs, "x_scale", &fbb)) || + failed(AddFloatAttr(func, attrs, "h_scale", &fbb)) || + failed(AddFloatAttr(func, attrs, "w_scale", &fbb))) + return failure(); + auto use_regular_nms = + attrs.get("use_regular_nms").dyn_cast_or_null(); + if (!use_regular_nms) { + return func.emitError() + << "use_regular_nms attribute is not set or not a bool"; + } + fbb.Int("use_regular_nms", use_regular_nms.getValue()); + + fbb.EndMap(start_map); + fbb.Finish(); + custom_option_buffer.assign(fbb.GetBuffer().begin(), fbb.GetBuffer().end()); + return success(); +} + +LogicalResult ConvertSSDPostProcessFunc::AddIntAttr( + FuncOp func, DictionaryAttr attrs, const std::string& attribute, + flexbuffers::Builder* builder) { + auto int_attr = attrs.get(attribute).dyn_cast_or_null(); + if (!int_attr) { + return func.emitError() + << attribute.c_str() << " attribute is not set or not an integer"; + } + builder->Int(attribute.c_str(), int_attr.getInt()); + return success(); +} + +LogicalResult ConvertSSDPostProcessFunc::AddFloatAttr( + FuncOp func, DictionaryAttr attrs, const std::string& attribute, + flexbuffers::Builder* builder) { + auto float_attr = attrs.get(attribute).dyn_cast_or_null(); + if (!float_attr) { + return func.emitError() + << attribute.c_str() << " attribute is not set or not a float"; + } + builder->Float(attribute.c_str(), float_attr.getValue().convertToFloat()); + return success(); +} + +LogicalResult ConvertSSDPostProcessFunc::VerifySignature() { + // Verify high-level function signature. + if (func_.getNumArguments() != 3) { + return func_.emitError() + << "Invalid number of arguments to " << kCustomSSDPostprocessing + << ": " << func_.getNumArguments(); + } + if (func_.getType().getNumResults() != 4) { + return func_.emitError() + << "Invalid number of results from " << kCustomSSDPostprocessing + << ": " << func_.getType().getNumResults(); + } + return success(); +} + +} // namespace TFL +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/utils/nms_utils.h b/tensorflow/compiler/mlir/lite/utils/nms_utils.h new file mode 100644 index 00000000000..6a9035e0c81 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/utils/nms_utils.h @@ -0,0 +1,76 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This header file defines common utils used by TFLite transformation +// passes to work with NMS ops in TFLite. + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_NMS_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_NMS_UTILS_H_ + +#include + +#include "flatbuffers/flexbuffers.h" // from @flatbuffers +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" + +namespace mlir { +namespace TFL { + +// Abstracts the conversion of the padded NMS composite function. +class ConvertNMSPaddedFunc { + public: + explicit ConvertNMSPaddedFunc(FuncOp func) : func_(func) {} + + void RewriteFunc(); + + LogicalResult VerifySignature(); + + private: + FuncOp func_; +}; + +// Abstracts the conversion of the SSD post-processing composite function to +// TFLite. +class ConvertSSDPostProcessFunc { + public: + explicit ConvertSSDPostProcessFunc(FuncOp func, mlir::TF::FuncAttr attr) + : func_(func), attr_(attr) {} + + LogicalResult RewriteFunc(); + + LogicalResult VerifySignature(); + + private: + LogicalResult CreateNMSCustomOptions(FuncOp func, DictionaryAttr attrs, + std::string& custom_option_buffer); + + LogicalResult AddIntAttr(FuncOp func, DictionaryAttr attrs, + const std::string& attribute, + flexbuffers::Builder* builder); + + LogicalResult AddFloatAttr(FuncOp func, DictionaryAttr attrs, + const std::string& attribute, + flexbuffers::Builder* builder); + + FuncOp func_; + mlir::TF::FuncAttr attr_; +}; + +} // end namespace TFL +} // end namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_TFTEXT_UTILS_H_ diff --git a/tensorflow/compiler/mlir/op_or_arg_name_mapper.cc b/tensorflow/compiler/mlir/op_or_arg_name_mapper.cc index bce0ed4a33d..6b605741355 100644 --- a/tensorflow/compiler/mlir/op_or_arg_name_mapper.cc +++ b/tensorflow/compiler/mlir/op_or_arg_name_mapper.cc @@ -28,6 +28,7 @@ limitations under the License. #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "tensorflow/compiler/mlir/utils/name_utils.h" static inline absl::string_view StringRefToView(llvm::StringRef ref) { return absl::string_view(ref.data(), ref.size()); @@ -103,62 +104,16 @@ int OpOrArgNameMapper::InitOpName(OpOrVal op_or_val, llvm::StringRef name) { bool OpOrArgNameMapper::IsUnique(llvm::StringRef name) { return true; } -namespace { -// Derives name from location. -std::string GetNameFromLoc(mlir::Location loc) { - llvm::SmallVector loc_names; - llvm::SmallVector locs; - locs.push_back(loc); - bool names_is_nonempty = false; - - while (!locs.empty()) { - mlir::Location curr_loc = locs.pop_back_val(); - - if (auto name_loc = curr_loc.dyn_cast()) { - // Add name in NameLoc. For NameLoc we also account for names due to ops - // in functions where the op's name is first. - auto name = name_loc.getName().strref().split('@').first; - loc_names.push_back(name); - if (!name.empty()) names_is_nonempty = true; - continue; - } else if (auto call_loc = curr_loc.dyn_cast()) { - // Add name if CallSiteLoc's callee has a NameLoc (as should be the - // case if imported with DebugInfo). - if (auto name_loc = call_loc.getCallee().dyn_cast()) { - auto name = name_loc.getName().strref().split('@').first; - loc_names.push_back(name); - if (!name.empty()) names_is_nonempty = true; - continue; - } - } else if (auto fused_loc = curr_loc.dyn_cast()) { - // Push all locations in FusedLoc in reverse order, so locations are - // visited based on order in FusedLoc. - auto reversed_fused_locs = llvm::reverse(fused_loc.getLocations()); - locs.append(reversed_fused_locs.begin(), reversed_fused_locs.end()); - continue; - } - - // Location is not a supported, so an empty StringRef is added. - loc_names.push_back(llvm::StringRef()); - } - - if (names_is_nonempty) - return llvm::join(loc_names.begin(), loc_names.end(), ";"); - - return ""; -} -} // anonymous namespace - std::string OpOrArgLocNameMapper::GetName(OpOrVal op_or_val) { if (auto* op = op_or_val.dyn_cast()) { - auto name_from_loc = GetNameFromLoc(op->getLoc()); + auto name_from_loc = mlir::GetNameFromLoc(op->getLoc()); if (!name_from_loc.empty()) return name_from_loc; // If the location is none of the expected types, then simply use name // generated using the op type. return std::string(op->getName().getStringRef()); } auto val = op_or_val.dyn_cast(); - auto name_from_loc = GetNameFromLoc(val.getLoc()); + auto name_from_loc = mlir::GetNameFromLoc(val.getLoc()); if (!name_from_loc.empty()) return name_from_loc; // If the location is none of the expected types, then simply use name // generated using the op type. Follow TF convention and append the result diff --git a/tensorflow/compiler/mlir/python/BUILD b/tensorflow/compiler/mlir/python/BUILD index 5bbfba773a3..502695acd40 100644 --- a/tensorflow/compiler/mlir/python/BUILD +++ b/tensorflow/compiler/mlir/python/BUILD @@ -1,3 +1,6 @@ +load("//tensorflow:tensorflow.bzl", "filegroup") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") + package( default_visibility = ["//visibility:public"], licenses = ["notice"], # Apache 2.0 @@ -10,6 +13,7 @@ cc_library( deps = [ "//tensorflow/c:tf_status", "//tensorflow/c:tf_status_helper", + "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:convert_graphdef", "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow:tf_saved_model_passes", @@ -35,6 +39,9 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", ], alwayslink = 1, ) diff --git a/tensorflow/compiler/mlir/python/mlir.cc b/tensorflow/compiler/mlir/python/mlir.cc index f1f6c43d3b3..066726593a7 100644 --- a/tensorflow/compiler/mlir/python/mlir.cc +++ b/tensorflow/compiler/mlir/python/mlir.cc @@ -16,19 +16,53 @@ limitations under the License. #include #include "llvm/Support/raw_ostream.h" +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/InitAllPasses.h" // from @llvm-project #include "mlir/Parser.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project #include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h" #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/op.h" namespace tensorflow { +namespace { + +// Runs pass pipeline `pass_pipeline` on `module` if `pass_pipeline` is not +// empty. +std::string RunPassPipelineOnModule(mlir::ModuleOp module, + const std::string &pass_pipeline, + TF_Status *status) { + if (!pass_pipeline.empty()) { + mlir::PassManager pm(module.getContext()); + std::string error; + llvm::raw_string_ostream error_stream(error); + if (failed(mlir::parsePassPipeline(pass_pipeline, pm, error_stream))) { + TF_SetStatus(status, TF_INVALID_ARGUMENT, + ("Invalid pass_pipeline: " + error_stream.str()).c_str()); + return "// error"; + } + + mlir::StatusScopedDiagnosticHandler statusHandler(module.getContext()); + if (failed(pm.run(module))) { + Set_TF_Status_from_Status(status, statusHandler.ConsumeStatus()); + return "// error"; + } + } + return MlirModuleToString(module); +} + +} // anonymous namespace + std::string ImportGraphDef(const std::string &proto, const std::string &pass_pipeline, TF_Status *status) { @@ -41,31 +75,49 @@ std::string ImportGraphDef(const std::string &proto, GraphDebugInfo debug_info; GraphImportConfig specs; mlir::MLIRContext context; - context.loadAllGloballyRegisteredDialects(); auto module = ConvertGraphdefToMlir(graphdef, debug_info, specs, &context); if (!module.ok()) { Set_TF_Status_from_Status(status, module.status()); return "// error"; } - // Run the pass_pipeline on the module if not empty. - if (!pass_pipeline.empty()) { - mlir::PassManager pm(&context); - std::string error; - llvm::raw_string_ostream error_stream(error); - if (failed(mlir::parsePassPipeline(pass_pipeline, pm, error_stream))) { - TF_SetStatus(status, TF_INVALID_ARGUMENT, - ("Invalid pass_pipeline: " + error_stream.str()).c_str()); - return "// error"; - } + return RunPassPipelineOnModule(module->get(), pass_pipeline, status); +} - mlir::StatusScopedDiagnosticHandler statusHandler(&context); - if (failed(pm.run(*module.ValueOrDie()))) { - Set_TF_Status_from_Status(status, statusHandler.ConsumeStatus()); - return "// error"; - } +std::string ImportFunction(const std::string &functiondef_proto, + const std::string &functiondef_library_proto, + const std::string &pass_pipeline, + TF_Status *status) { + FunctionDef functiondef; + auto s = tensorflow::LoadProtoFromBuffer(functiondef_proto, &functiondef); + if (!s.ok()) { + Set_TF_Status_from_Status(status, s); + return "// error"; } - return MlirModuleToString(*module.ConsumeValueOrDie()); + + FunctionDefLibrary fdef_lib; + s = tensorflow::LoadProtoFromBuffer(functiondef_library_proto, &fdef_lib); + if (!s.ok()) { + Set_TF_Status_from_Status(status, s); + return "// error"; + } + + FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib); + s = flib_def.AddFunctionDef(functiondef); + if (!s.ok()) { + Set_TF_Status_from_Status(status, s); + return "// error"; + } + + const std::string &function_name = functiondef.signature().name(); + mlir::MLIRContext context; + auto module = ConvertFunctionToMlir(function_name, flib_def, &context); + if (!module.ok()) { + Set_TF_Status_from_Status(status, module.status()); + return "// error"; + } + + return RunPassPipelineOnModule(module->get(), pass_pipeline, status); } std::string ExperimentalConvertSavedModelToMlir( @@ -86,7 +138,6 @@ std::string ExperimentalConvertSavedModelToMlir( std::vector exported_names = absl::StrSplit(exported_names_str, ',', absl::SkipEmpty()); mlir::MLIRContext context; - context.loadAllGloballyRegisteredDialects(); auto module_or = ConvertSavedModelToMlir( &bundle, &context, absl::Span(exported_names)); if (!module_or.status().ok()) { @@ -117,7 +168,6 @@ std::string ExperimentalConvertSavedModelV1ToMlir( // Convert the SavedModelBundle to an MLIR module. mlir::MLIRContext context; - context.loadAllGloballyRegisteredDialects(); auto module_or = ConvertSavedModelV1ToMlir(bundle, {}, &context, upgrade_legacy); if (!module_or.status().ok()) { @@ -153,6 +203,7 @@ std::string ExperimentalRunPassPipeline(const std::string &mlir_txt, bool show_debug_info, TF_Status *status) { mlir::MLIRContext context; + mlir::RegisterAllTensorFlowDialects(context.getDialectRegistry()); mlir::OwningModuleRef module; { mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context); @@ -167,6 +218,7 @@ std::string ExperimentalRunPassPipeline(const std::string &mlir_txt, mlir::PassManager pm(&context); std::string error; llvm::raw_string_ostream error_stream(error); + mlir::registerAllPasses(); if (failed(mlir::parsePassPipeline(pass_pipeline, pm, error_stream))) { TF_SetStatus(status, TF_INVALID_ARGUMENT, ("Invalid pass_pipeline: " + error_stream.str()).c_str()); diff --git a/tensorflow/compiler/mlir/python/mlir.h b/tensorflow/compiler/mlir/python/mlir.h index e68ac28124b..6133068a5e8 100644 --- a/tensorflow/compiler/mlir/python/mlir.h +++ b/tensorflow/compiler/mlir/python/mlir.h @@ -25,13 +25,23 @@ limitations under the License. namespace tensorflow { // Simple wrapper to support tf.mlir.experimental.convert_graph_def. -// Load a .pbptx, convert to MLIR, and (optionally) optimize the module before -// returning it as a string. +// Load a GraphDef (binary or textual proto format), convert to MLIR, and +// (optionally) optimize the module before returning it as a string. // This is an early experimental API, ideally we should return a wrapper object // around a Python binding to the MLIR module. std::string ImportGraphDef(const std::string &proto, const std::string &pass_pipeline, TF_Status *status); +// Simple wrapper to support tf.mlir.experimental.convert_function. +// Load FunctionDef and FunctionDefLibrary (binary or textual proto format), +// convert to MLIR, and (optionally) optimize the module before returning it as +// a string. +// This is an early experimental API, ideally we should return a wrapper object +// around a Python binding to the MLIR module. +std::string ImportFunction(const std::string &functiondef_proto, + const std::string &functiondef_library_proto, + const std::string &pass_pipeline, TF_Status *status); + // Load a SavedModel and return a textual MLIR string corresponding to it. // // Args: diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/BUILD b/tensorflow/compiler/mlir/python/mlir_wrapper/BUILD index 5e21dddd444..47bff366311 100644 --- a/tensorflow/compiler/mlir/python/mlir_wrapper/BUILD +++ b/tensorflow/compiler/mlir/python/mlir_wrapper/BUILD @@ -20,6 +20,7 @@ tf_python_pybind_extension( "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/python:pybind11_lib", "//tensorflow/python:pybind11_status", + "@llvm-project//llvm:FileCheckLib", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", @@ -36,6 +37,7 @@ tf_python_pybind_extension( deps = [ "//tensorflow/python:pybind11_lib", "//tensorflow/python:pybind11_status", + "@llvm-project//llvm:FileCheckLib", "@llvm-project//llvm:Support", "@pybind11", ], diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/basic_classes.cc b/tensorflow/compiler/mlir/python/mlir_wrapper/basic_classes.cc index 25adb44fe1d..5ae638851f4 100644 --- a/tensorflow/compiler/mlir/python/mlir_wrapper/basic_classes.cc +++ b/tensorflow/compiler/mlir/python/mlir_wrapper/basic_classes.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "llvm/Support/FileCheck.h" +#include "llvm/FileCheck/FileCheck.h" #include "mlir/IR/Block.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/filecheck_wrapper.cc b/tensorflow/compiler/mlir/python/mlir_wrapper/filecheck_wrapper.cc index 8a841856b72..051952ebaba 100644 --- a/tensorflow/compiler/mlir/python/mlir_wrapper/filecheck_wrapper.cc +++ b/tensorflow/compiler/mlir/python/mlir_wrapper/filecheck_wrapper.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "llvm/Support/FileCheck.h" +#include "llvm/FileCheck/FileCheck.h" #include "llvm/Support/SourceMgr.h" #include "pybind11/pybind11.h" #include "pybind11/stl.h" diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.cc b/tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.cc index 4152b576e71..6cd49cf368d 100644 --- a/tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.cc +++ b/tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.cc @@ -22,23 +22,25 @@ limitations under the License. #include "mlir/Parser.h" // from @llvm-project #include "pybind11/pybind11.h" #include "pybind11/stl.h" +#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/python/lib/core/pybind11_lib.h" #include "tensorflow/python/lib/core/pybind11_status.h" PYBIND11_MODULE(mlir_wrapper, m) { - m.def("registerDialects", []() { - mlir::registerDialect(); - mlir::registerDialect(); - mlir::registerDialect(); + m.def("preloadTensorFlowDialects", [](mlir::MLIRContext &context) { + mlir::RegisterAllTensorFlowDialects(context.getDialectRegistry()); + context.getDialectRegistry().loadAll(&context); }); + m.def("verify", [](std::string input) { llvm::SourceMgr SM = llvm::SourceMgr(); SM.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(input), llvm::SMLoc()); mlir::MLIRContext ctx; - ctx.loadAllGloballyRegisteredDialects(); + mlir::RegisterAllTensorFlowDialects(ctx.getDialectRegistry()); + ctx.getDialectRegistry().loadAll(&ctx); auto module = mlir::parseSourceFile(SM, &ctx); if (!module) { return false; diff --git a/tensorflow/compiler/mlir/runlit.cfg.py b/tensorflow/compiler/mlir/runlit.cfg.py index f9870183b88..17410b4e5b2 100644 --- a/tensorflow/compiler/mlir/runlit.cfg.py +++ b/tensorflow/compiler/mlir/runlit.cfg.py @@ -73,8 +73,9 @@ tool_names = [ 'mlir-opt', 'mlir-hlo-opt', 'mlir-translate', 'tf-opt', 'tf_tfl_translate', 'tf_tfjs_translate', 'flatbuffer_to_string', 'flatbuffer_translate', 'tf-mlir-translate', 'mlir-tflite-runner', 'tfcompile', - 'json_to_flatbuffer', 'xla-gpu-opt', 'xla-opt', 'hlo_to_llvm_ir', - 'kernel-gen-opt', 'xla-thunks-opt', 'tfjs-opt' + 'json_to_flatbuffer', 'xla-gpu-opt', 'xla-mlir-gpu-opt', 'xla-opt', + 'hlo_to_llvm_ir', 'kernel-gen-opt', 'tf_to_kernel', 'tf_to_gpu_binary', + 'xla-thunks-opt', 'tfjs-opt' ] tools = [ToolSubst(s, unresolved='ignore') for s in tool_names] llvm_config.add_tool_substitutions(tools, tool_dirs) diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index f9cdc40a901..1c740731acd 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -1,5 +1,11 @@ +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "filegroup") + +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "get_compatible_with_cloud") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//third_party/mlir:tblgen.bzl", "gentbl") -load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_gen_op_wrapper_py", "tf_native_cc_binary") +load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_gen_op_wrapper_py") package( default_visibility = [":friends"], @@ -13,6 +19,7 @@ package_group( "//learning/brain/experimental/dtensor/...", "//learning/brain/experimental/tfrt/...", "//learning/pathways/data_parallel/tf2xla/...", + "//platforms/xla/sparse_core/...", "//tensorflow/compiler/...", "//tensorflow/lite/experimental/tf_runtime/...", "//tensorflow/python/...", @@ -33,6 +40,7 @@ filegroup( "ir/tf_op_base.td", "ir/tf_op_interfaces.td", "ir/tf_ops.td", + "ir/tfrt_ops.td", "@llvm-project//mlir:OpBaseTdFiles", "@llvm-project//mlir:include/mlir/Interfaces/CallInterfaces.td", "@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td", @@ -43,6 +51,7 @@ filegroup( gentbl( name = "tensorflow_op_interfaces_inc_gen", + compatible_with = get_compatible_with_cloud(), tbl_outs = [ ( "-gen-op-interface-decls", @@ -63,6 +72,7 @@ gentbl( gentbl( name = "tensorflow_struct_doc_gen", + compatible_with = get_compatible_with_cloud(), tbl_outs = [ ( "-gen-dialect-doc", @@ -100,6 +110,8 @@ cc_library( deps = [ ":tensorflow_op_interfaces_inc_gen", ":tensorflow_structs", + "//tensorflow/core:framework", + "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", ], @@ -107,6 +119,7 @@ cc_library( gentbl( name = "tensorflow_all_ops_inc_gen", + compatible_with = get_compatible_with_cloud(), tbl_outs = [ ( "-gen-op-decls", @@ -124,6 +137,26 @@ gentbl( ], ) +gentbl( + name = "tensorflow_tfrt_ops_inc_gen", + compatible_with = get_compatible_with_cloud(), + tbl_outs = [ + ( + "-gen-op-decls", + "ir/tfrt_ops.h.inc", + ), + ( + "-gen-op-defs", + "ir/tfrt_ops.cc.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "ir/tfrt_ops.td", + td_srcs = [ + ":tensorflow_ops_td_files", + ], +) + # We only shard tf_op on name for build performance reasons. tf_ops_category_list = [ { @@ -139,6 +172,7 @@ tf_ops_category_list = [ [[ gentbl( name = "tensorflow_" + target["name"] + "_inc_gen", + compatible_with = get_compatible_with_cloud(), tbl_outs = [ ( "-gen-op-decls -op-include-regex='" + target["include"] + "'", @@ -159,6 +193,7 @@ tf_ops_category_list = [ gentbl( name = "tensorflow_remaining_ops_inc_gen", + compatible_with = get_compatible_with_cloud(), tbl_outs = [ ( "-gen-op-decls -op-exclude-regex='" + "|".join([target["include"] for target in tf_ops_category_list]) + "' ", @@ -178,6 +213,7 @@ gentbl( gentbl( name = "tf_saved_model_inc_gen", + compatible_with = get_compatible_with_cloud(), tbl_outs = [ ( "-gen-op-decls", @@ -204,6 +240,7 @@ gentbl( gentbl( name = "tensorflow_executor_inc_gen", + compatible_with = get_compatible_with_cloud(), tbl_outs = [ ( "-gen-op-decls", @@ -230,6 +267,7 @@ gentbl( gentbl( name = "tensorflow_device_ops_inc_gen", + compatible_with = get_compatible_with_cloud(), tbl_outs = [ ( "-gen-op-decls ", @@ -255,6 +293,7 @@ gentbl( gentbl( name = "tensorflow_canonicalize_inc_gen", + compatible_with = get_compatible_with_cloud(), tbl_outs = [ ( "-gen-rewriters", @@ -270,6 +309,7 @@ gentbl( gentbl( name = "hlo_legalize_tf_inc_gen", + compatible_with = get_compatible_with_cloud(), tbl_outs = [ ("-gen-rewriters", "transforms/generated_legalize_hlo.inc"), ], @@ -343,6 +383,7 @@ cc_library( name = "tensorflow_" + target["name"], srcs = [ "ir/tf_ops.h", + "ir/tfrt_ops.h", "ir/tf_remaining_ops.h", "ir/tf_" + target["name"] + ".cc", "ir/tf_" + target["name"] + ".cc.inc", @@ -352,6 +393,7 @@ cc_library( textual_hdrs = [ "ir/tf_all_ops.h.inc", "ir/tf_ops_helpers.inc", + "ir/tfrt_ops.h.inc", "ir/tf_remaining_ops.h.inc", ] + ["ir/tf_" + target["name"] + ".h.inc" for target in tf_ops_category_list], deps = [ @@ -386,6 +428,7 @@ cc_library( "ir/tf_ops.h", "ir/tf_remaining_ops.h", "ir/tf_remaining_ops.cc", + "ir/tfrt_ops.h", ] + ["ir/tf_" + target["name"] + ".h" for target in tf_ops_category_list], hdrs = [ ], @@ -393,6 +436,49 @@ cc_library( "ir/tf_all_ops.h.inc", "ir/tf_ops_helpers.inc", "ir/tf_remaining_ops.h.inc", + "ir/tfrt_ops.h.inc", + ] + ["ir/tf_" + target["name"] + ".h.inc" for target in tf_ops_category_list], + deps = [ + ":tensorflow_attributes", + ":tensorflow_canonicalize_inc_gen", + ":tensorflow_op_interfaces", + ":tensorflow_op_interfaces_inc_gen", + ":tensorflow_remaining_ops_inc_gen", + ":tensorflow_side_effects", + ":tensorflow_structs", + ":tensorflow_tfrt_ops_inc_gen", + ":tensorflow_traits", + ":tensorflow_types", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:DerivedAttributeOpInterface", + "@llvm-project//mlir:Dialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:LoopLikeInterface", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:SideEffects", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", + ], +) + +cc_library( + name = "tensorflow_tfrt_ops", + srcs = [ + "ir/tf_ops.h", + "ir/tfrt_ops.h", + "ir/tfrt_ops.cc", + "ir/tf_remaining_ops.h", + ] + ["ir/tf_" + target["name"] + ".h" for target in tf_ops_category_list], + hdrs = [ + ], + textual_hdrs = [ + "ir/tf_all_ops.h.inc", + "ir/tf_ops_helpers.inc", + "ir/tfrt_ops.h.inc", + "ir/tf_remaining_ops.h.inc", ] + ["ir/tf_" + target["name"] + ".h.inc" for target in tf_ops_category_list], deps = [ ":tensorflow_attributes", @@ -402,6 +488,7 @@ cc_library( ":tensorflow_remaining_ops_inc_gen", ":tensorflow_side_effects", ":tensorflow_structs", + ":tensorflow_tfrt_ops_inc_gen", ":tensorflow_traits", ":tensorflow_types", "//tensorflow/core:framework", @@ -428,9 +515,11 @@ cc_library( textual_hdrs = [ "ir/tf_all_ops.h.inc", "ir/tf_remaining_ops.h", + "ir/tfrt_ops.h", ] + ["ir/tf_" + target["name"] + ".h" for target in tf_ops_category_list], deps = [ ":tensorflow_all_ops_inc_gen", + ":tensorflow_tfrt_ops_inc_gen", ":tensorflow_remaining_ops_inc_gen", ":tensorflow_attributes", ":tensorflow_canonicalize_inc_gen", @@ -441,6 +530,7 @@ cc_library( ":tensorflow_traits", ":tensorflow_types", ":tensorflow_remaining_ops", + ":tensorflow_tfrt_ops", "@llvm-project//llvm:Support", "@llvm-project//mlir:DerivedAttributeOpInterface", "@llvm-project//mlir:Dialect", @@ -538,6 +628,7 @@ cc_library( ":tensorflow_ops", ":tensorflow_side_effects", ":tensorflow_structs", + ":tensorflow_tfrt_ops_inc_gen", ":tensorflow_traits", ":tensorflow_types", ":tf_saved_model_inc_gen", @@ -567,6 +658,7 @@ cc_library( gentbl( name = "decompose_resource_ops_inc_gen", + compatible_with = get_compatible_with_cloud(), tbl_outs = [ ( "-gen-rewriters", @@ -599,6 +691,7 @@ cc_library( gentbl( name = "tf_data_optimization_inc_gen", + compatible_with = get_compatible_with_cloud(), tbl_outs = [ ( "-gen-rewriters", @@ -719,14 +812,13 @@ cc_library( ], deps = [ ":tensorflow", + ":tensorflow_op_interfaces", ":tensorflow_types", - "//tensorflow/compiler/tf2xla:resource_operation_table", - "//tensorflow/core:framework", - "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SideEffectInterfaces", "@llvm-project//mlir:Support", ], ) @@ -741,6 +833,7 @@ cc_library( "transforms/cluster_formation.cc", "transforms/cluster_outlining.cc", "transforms/collection_ops_util.cc", + "transforms/contraction_fusion.cc", "transforms/decompose_resource_ops_pass.cc", "transforms/device_index_selector.cc", "transforms/einsum.cc", @@ -772,6 +865,8 @@ cc_library( "transforms/replicate_to_island.cc", "transforms/resource_device_inference.cc", "transforms/resource_op_lifting.cc", + "transforms/resource_op_lifting_cleanup.cc", + "transforms/resource_op_lifting_cleanup.h", "transforms/rewrite_tpu_embedding_ops.cc", "transforms/shape_inference.cc", "transforms/shape_inference_pass.cc", @@ -785,7 +880,9 @@ cc_library( "transforms/test_visitor_util.cc", "transforms/tf_data_optimization_pass.cc", "transforms/tf_device_assignment.cc", + "transforms/tpu_cluster_cleanup_attributes.cc", "transforms/tpu_cluster_formation.cc", + "transforms/tpu_colocate_composite_resource_ops.cc", "transforms/tpu_dynamic_layout_pass.cc", "transforms/tpu_dynamic_padding_mapper.cc", "transforms/tpu_extract_head_tail_outside_compilation.cc", @@ -794,6 +891,8 @@ cc_library( "transforms/tpu_identity_pruning.cc", "transforms/tpu_merge_variables_with_execute.cc", "transforms/tpu_outside_compilation_cluster.cc", + "transforms/tpu_parallel_execute_sink_resource_write.cc", + "transforms/tpu_resource_read_for_write.cc", "transforms/tpu_rewrite_pass.cc", "transforms/tpu_sharding_identification_pass.cc", "transforms/tpu_space_to_depth_pass.cc", @@ -804,7 +903,6 @@ cc_library( "translate/tf_functional_to_executor.cc", ], hdrs = [ - "transforms/batchmatmul_to_einsum.h", "transforms/bridge.h", "transforms/collection_ops_util.h", "transforms/einsum.h", @@ -812,6 +910,9 @@ cc_library( "transforms/shape_inference.h", ], includes = ["include"], + textual_hdrs = [ + "ir/tf_ops_helpers.inc", + ], deps = [ ":attribute_utils", ":bridge_logger", @@ -820,10 +921,13 @@ cc_library( ":decompose_resource_ops", ":decompose_resource_ops_inc_gen", ":device_util", + ":dump_mlir_util", ":error_util", ":export_tf_dialect_op", ":lower_tf_lib", ":mangling_util", + ":serialize_mlir_module_utils", + ":shape_inference_utils", ":tensorflow", ":tensorflow_analysis", ":tensorflow_optimize_inc_gen", @@ -854,6 +958,7 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", "@llvm-project//mlir:Parser", "@llvm-project//mlir:Pass", "@llvm-project//mlir:StandardOps", @@ -898,6 +1003,7 @@ cc_library( srcs = ["transforms/graph_optimization_pass.cc"], hdrs = ["transforms/graph_optimization_pass.h"], deps = [ + ":dump_mlir_util", ":error_util", ":tensorflow_passes", "//tensorflow/compiler/mlir:mlir_graph_optimization_pass", @@ -920,15 +1026,14 @@ cc_library( alwayslink = 1, ) -# Library with TensorFlow dialect static initialization. cc_library( - name = "tensorflow_dialect_registration", - srcs = ["ir/dialect_registration.cc"], + name = "upgrade_graph", + srcs = ["translate/upgrade_graph.cc"], + hdrs = ["translate/upgrade_graph.h"], deps = [ - ":tensorflow", - "@llvm-project//mlir:Shape", + "//tensorflow/core:framework", + "//tensorflow/core:graph", ], - alwayslink = 1, ) cc_library( @@ -942,8 +1047,10 @@ cc_library( "translate/import_model.h", ], deps = [ + ":convert_attr", ":convert_tensor", ":convert_type", + ":dump_mlir_util", ":error_util", ":export_tf_dialect_op", ":export_utils", @@ -955,11 +1062,13 @@ cc_library( ":tensorflow_types", ":tf_saved_model_passes", ":translate_utils", + ":upgrade_graph", "//tensorflow/cc/saved_model:bundle_v2", "//tensorflow/cc/saved_model:constants", "//tensorflow/cc/saved_model:loader_lite", "//tensorflow/cc/saved_model:loader_util", "//tensorflow/compiler/jit:shape_inference_helpers", + "//tensorflow/compiler/mlir:name_utils", "//tensorflow/compiler/mlir:op_or_arg_name_mapper", "//tensorflow/compiler/tf2xla:functionalize_control_flow", "//tensorflow/compiler/xla:status_macros", @@ -1064,7 +1173,6 @@ cc_library( cc_library( name = "export_tf_dialect_op", srcs = [ - "translate/derived_attr_populator.inc", "translate/export_tf_dialect_op.cc", ], hdrs = [ @@ -1074,13 +1182,16 @@ cc_library( ":convert_type", ":export_utils", ":tensorflow", + "//tensorflow/compiler/mlir:string_container_utils", "//tensorflow/compiler/xla:status_macros", + "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/stream_executor/lib", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", + "@llvm-project//mlir:DerivedAttributeOpInterface", "@llvm-project//mlir:IR", ], ) @@ -1154,6 +1265,24 @@ cc_library( ], ) +cc_library( + name = "convert_attr", + srcs = ["utils/convert_attr.cc"], + hdrs = ["utils/convert_attr.h"], + visibility = [ + "//visibility:public", + ], + deps = [ + ":convert_tensor", + ":convert_type", + ":tensorflow_attributes", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/platform:errors", + "//tensorflow/stream_executor/lib", + "@llvm-project//mlir:IR", + ], +) + cc_library( name = "convert_type", srcs = ["utils/convert_type.cc"], @@ -1286,6 +1415,7 @@ cc_library( ":decode_constant_pass", ":eval_util", ":tensorflow", + ":tensorflow_traits", ":tensorflow_types", "//tensorflow/c:tf_status", "//tensorflow/c/eager:c_api", @@ -1304,9 +1434,8 @@ cc_library( cc_library( name = "tf_dialect_lib", deps = [ - ":tensorflow_dialect_registration", ":tf_dialect_passes", - "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", ], ) @@ -1317,6 +1446,7 @@ cc_library( deps = [ ":convert_graphdef", ":mlir_roundtrip_flags", + ":tensorflow", "//tensorflow/compiler/tf2xla:functionalize_control_flow_pass_registration", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", @@ -1437,37 +1567,6 @@ tf_cc_test( ], ) -tf_native_cc_binary( - name = "derived_attr_populator_gen", - srcs = [ - "translate/derived_attr_populator_gen.cc", - ], - deps = [ - "@llvm-project//llvm:Support", - "@llvm-project//llvm:TableGen", - "@llvm-project//mlir:TableGen", - ], -) - -gentbl( - name = "derived_attr_populator_inc", - tbl_outs = [ - ("", "translate/derived_attr_populator.inc"), - ], - tblgen = ":derived_attr_populator_gen", - td_file = "ir/tf_ops.td", - td_srcs = [ - "@llvm-project//mlir:include/mlir/IR/OpBase.td", - "@llvm-project//mlir:include/mlir/Interfaces/CallInterfaces.td", - "@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td", - "@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td", - "@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td", - "ir/tf_generated_ops.td", - "ir/tf_op_base.td", - "ir/tf_op_interfaces.td", - ], -) - filegroup( name = "tensorflow_optimize_td_files", srcs = [ @@ -1477,6 +1576,7 @@ filegroup( gentbl( name = "tensorflow_optimize_inc_gen", + compatible_with = get_compatible_with_cloud(), tbl_outs = [ ( "-gen-rewriters", @@ -1494,19 +1594,20 @@ gentbl( COMPILE_MLIR_UTIL_DEPS = [ ":bridge_logger", ":convert_graphdef", + ":convert_tensor", ":convert_type", ":dump_mlir_util", ":error_util", ":mlir_roundtrip_flags", + ":serialize_mlir_module_utils", ":tensorflow", - ":tensorflow_dialect_registration", ":tensorflow_types", ":tensorflow_passes", ":translate_utils", "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:variant", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", - "@llvm-project//mlir:Parser", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Shape", "@llvm-project//mlir:StandardOps", @@ -1528,9 +1629,9 @@ COMPILE_MLIR_UTIL_DEPS = [ "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", "//tensorflow/stream_executor/lib", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/service:hlo", - ":convert_tensor", ] # Prefer to link 'compile_mlir_util' library that also links necessary @@ -1557,27 +1658,61 @@ cc_library( ], ) -tf_cc_test( - name = "compile_mlir_util_test", - size = "small", - srcs = ["utils/compile_mlir_util_test.cc"], +cc_library( + name = "compile_mlir_util_pass", + srcs = ["utils/compile_mlir_util_pass.cc"], deps = [ ":compile_mlir_util", - "//tensorflow/cc:function_ops", - "//tensorflow/cc:resource_variable_ops", - "//tensorflow/cc:scope", - "//tensorflow/compiler/jit", - "//tensorflow/compiler/tf2xla:common", - "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla:xla_data_proto_cc", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core:testlib", - "//tensorflow/stream_executor/lib", + "@llvm-project//mlir:Pass", ], + alwayslink = 1, +) + +cc_library( + name = "serialize_mlir_module_utils", + srcs = ["utils/serialize_mlir_module_utils.cc"], + hdrs = ["utils/serialize_mlir_module_utils.h"], + deps = [ + ":error_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/core/platform:errors", + "//tensorflow/core/platform:status", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + ], +) + +cc_library( + name = "tf_xla_mlir_translate", + srcs = ["utils/tf_xla_mlir_translate.cc"], + deps = [ + ":compile_mlir_util", + ":mlir_roundtrip_flags", + ":serialize_mlir_module_utils", + ":tensorflow", + ":translate_cl_options", + "//tensorflow/compiler/mlir:string_container_utils", + "//tensorflow/compiler/mlir/xla:translate_cl_options", + "//tensorflow/compiler/tf2xla:xla_argument", + "//tensorflow/compiler/tf2xla:xla_helpers", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/compiler/xla/service:hlo_proto_cc", + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/platform:errors", + "//tensorflow/core/platform:status", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Translation", + ], + alwayslink = 1, ) cc_library( @@ -1607,6 +1742,7 @@ cc_library( tf_gen_op_wrapper_py( name = "gen_mlir_passthrough_op_py", out = "gen_mlir_passthrough_op.py", + compatible_with = [], deps = [":mlir_passthrough_op"], ) @@ -1616,6 +1752,7 @@ tf_gen_op_wrapper_py( # without linking any of the other tensorflow passes. gentbl( name = "lower_tf_inc_gen", + compatible_with = get_compatible_with_cloud(), tbl_outs = [ ( "-gen-rewriters", @@ -1728,6 +1865,7 @@ cc_library( "//tensorflow/core/platform:logging", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", ], ) @@ -1754,14 +1892,13 @@ cc_library( ":convert_graphdef", ":error_util", ":tensorflow", - ":tensorflow_dialect_registration", ":tensorflow_passes", "//tensorflow/core:framework", "//tensorflow/core:graph", "//tensorflow/core:lib", "//tensorflow/core/platform:logging", "@llvm-project//llvm:Support", - "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", "@llvm-project//mlir:Analysis", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", @@ -1777,6 +1914,7 @@ tf_cc_test( "//tensorflow/core:framework", "//tensorflow/core:graph", "//tensorflow/core:lib", + "//tensorflow/core:ops", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core/platform:test", @@ -1837,3 +1975,28 @@ cc_library( "@llvm-project//mlir:IR", ], ) + +cc_library( + name = "shape_inference_utils", + srcs = ["utils/shape_inference_utils.cc"], + hdrs = ["utils/shape_inference_utils.h"], + deps = [ + ":convert_tensor", + ":convert_type", + ":export_tf_dialect_op", + ":export_utils", + ":tensorflow", + ":tensorflow_attributes", + ":tensorflow_types", + "//tensorflow/compiler/mlir:array_container_utils", + "//tensorflow/core:framework", + "//tensorflow/core:ops", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/platform:types", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:DerivedAttributeOpInterface", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:Support", + ], +) diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.cc b/tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.cc index 8ec7513f81f..cdc9e33e368 100644 --- a/tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.cc +++ b/tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.cc @@ -18,20 +18,17 @@ limitations under the License. #include #include -#include "absl/strings/str_cat.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/Optional.h" #include "llvm/ADT/SCCIterator.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Support/Casting.h" -#include "llvm/Support/Debug.h" #include "mlir/Analysis/CallGraph.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Block.h" // from @llvm-project -#include "mlir/IR/Builders.h" // from @llvm-project -#include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/Module.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project @@ -42,9 +39,8 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" -#include "tensorflow/compiler/tf2xla/resource_operation_table.h" -#include "tensorflow/core/framework/resource_mgr.h" namespace mlir { namespace TF { @@ -231,48 +227,16 @@ BacktrackAnalysisInfo::BacktrackAnalysisInfo( backtracked_values_.push_back(backtrack_analysis.BacktrackValue(result)); } -namespace { - -//===----------------------------------------------------------------------===// -// ResourceAliasAnalysisInfo helper functions. -//===----------------------------------------------------------------------===// - -constexpr char kResourceArgUniqueIdAttr[] = "tf._resource_arg_unique_id"; - -// Returns if a VarHandleOp is anonymous, which means it always creates a new -// variable. -bool IsResourceHandleAnonymous(VarHandleOp handle) { - return handle.shared_name() == tensorflow::ResourceHandle::ANONYMOUS_NAME; -} - -// Returns a string unique identifier for a non-anonymous VarHandleOp. -std::string GetVarHandleStringId(VarHandleOp handle) { - auto device = handle.getAttrOfType("device"); - return absl::StrCat(handle.container().str(), "/", handle.shared_name().str(), - "/", device ? device.getValue().str() : std::string("")); -} - -// Finds a unique ID for a VarHandleOp's output. If it is anonymous, always -// creates a new ID; otherwise, tries to reuse the existing ID for the -// referenced variable if it exists, or creates a new one if not. -int64_t GetOrCreateIdForVarHandle(VarHandleOp handle, int64_t* next_id, - llvm::StringMap* name_id_map) { - // Always create a new ID for anonymous handle. - if (IsResourceHandleAnonymous(handle)) return (*next_id)++; - - auto name = GetVarHandleStringId(handle); - auto emplace_res = name_id_map->try_emplace(name, *next_id); - // New ID created, increment next_id. - if (emplace_res.second) ++(*next_id); - return emplace_res.first->second; -} - -} // namespace - //===----------------------------------------------------------------------===// // ResourceAliasAnalysisInfo //===----------------------------------------------------------------------===// +namespace { + +constexpr char kResourceArgUniqueIdAttr[] = "tf._resource_arg_unique_id"; + +} // namespace + constexpr int64_t ResourceAliasAnalysisInfo::kUnknownResourceId; // Constructs the analysis info by analyzing the given function. @@ -338,60 +302,33 @@ ResourceAliasAnalysisInfo::ResourceAliasAnalysisInfo( } }); - llvm::StringMap var_handle_name_id_map; + llvm::SmallDenseMap resource_handle_id_map; func_op.walk([&](Operation* op) { - if (auto var_handle = dyn_cast(op)) { - AddValueUniqueIDMapping( - var_handle.resource(), - GetOrCreateIdForVarHandle(var_handle, &next_unique_id, - &var_handle_name_id_map)); + if (auto resource_alloc = dyn_cast(op)) { + ResourceHandleValueAndId resource = + resource_alloc.GetResourceHandleValueAndId(resource_handle_id_map, + next_unique_id); + AddValueUniqueIDMapping(resource.value, resource.id); } else if (llvm::isa(op)) { for (auto result : filter_resources(op->getResults())) PropagateInputToOutput(op->getOperand(result.getResultNumber()), result); } else if (auto while_op = dyn_cast(op)) { AnalyzeWhileLoop(while_op, backtrack_analysis.GetAnalysisForFunc( - while_op.body_func())); + while_op.body_function())); } else if (auto while_region = dyn_cast(op)) { AnalyzeWhileLoop(while_region, backtrack_analysis.GetAnalysisForRegion( while_region.body())); + } else if (auto case_op = dyn_cast(op)) { + llvm::SmallVector functions; + case_op.get_branch_functions(functions); + AnalyzeFunctionalCaseOrIfOp(case_op, functions, backtrack_analysis); } else if (auto if_op = dyn_cast(op)) { - const auto& then_info = - backtrack_analysis.GetAnalysisForFunc(if_op.then_func()); - const auto& else_info = - backtrack_analysis.GetAnalysisForFunc(if_op.else_func()); - // If a result is a passthrough of both branches' inputs, merge the - // resource IDs of corresponding operands for the two inputs. - for (auto result : filter_resources(if_op.getResults())) { - auto passthrough_then_arg = then_info.GetArg(result.getResultNumber()); - auto passthrough_else_arg = else_info.GetArg(result.getResultNumber()); - if (passthrough_then_arg && passthrough_else_arg) { - Value then_operand = if_op.input()[passthrough_then_arg.getValue()]; - Value else_operand = if_op.input()[passthrough_else_arg.getValue()]; - PropagateInputToOutput(then_operand, result); - PropagateInputToOutput(else_operand, result); - } else { - AddValueUniqueIDMapping(result, kUnknownResourceId); - } - } - } else if (auto if_region = dyn_cast(op)) { - const auto& then_info = - backtrack_analysis.GetAnalysisForRegion(if_region.then_branch()); - const auto& else_info = - backtrack_analysis.GetAnalysisForRegion(if_region.else_branch()); - for (auto result : filter_resources(if_region.getResults())) { - Value then_result = then_info.GetValue(result.getResultNumber()); - Value else_result = else_info.GetValue(result.getResultNumber()); - // For IfRegion, the walk would have visited the else and then regions - // before visiting the IfRegion op. Backtracking of the then and else - // results will either give a value computed within these regions, - // or a region capture. If its a region capture, computed before this - // IfRegion, it will have been visited earlier and a mapping would - // exist for that value. If its computed within the region, then again - // a mapping would exist. - PropagateInputToOutput(then_result, result); - PropagateInputToOutput(else_result, result); - } + AnalyzeFunctionalCaseOrIfOp( + if_op, {if_op.then_function(), if_op.else_function()}, + backtrack_analysis); + } else if (llvm::isa(op)) { + AnalyzeRegionCaseOrIfOp(op, backtrack_analysis); } else if (auto call = dyn_cast(op)) { FuncOp func = dyn_cast(call.resolveCallable()); if (!func) { @@ -501,6 +438,59 @@ void ResourceAliasAnalysisInfo::AnalyzeWhileLoop( } } +template +void ResourceAliasAnalysisInfo::AnalyzeFunctionalCaseOrIfOp( + CaseOrIfOp case_or_if_op, llvm::ArrayRef functions, + const BacktrackAnalysis& backtrack_analysis) { + llvm::SmallVector infos; + infos.reserve(functions.size()); + for (FuncOp func : functions) + infos.push_back(&backtrack_analysis.GetAnalysisForFunc(func)); + + // If a result is a passthrough of all branches' inputs, merge the resource + // IDs of corresponding operands for all the inputs. + for (auto result : filter_resources(case_or_if_op.getResults())) { + llvm::SmallVector, 2> passthrough_args; + passthrough_args.reserve(functions.size()); + for (const auto* info : infos) + passthrough_args.emplace_back(info->GetArg(result.getResultNumber())); + + const bool all_passthrough_args_known = llvm::all_of( + passthrough_args, [](const llvm::Optional& passthrough_arg) { + return passthrough_arg.hasValue(); + }); + if (all_passthrough_args_known) { + for (const auto& passthrough_arg : passthrough_args) { + Value operand = case_or_if_op.input()[passthrough_arg.getValue()]; + PropagateInputToOutput(operand, result); + } + } else { + AddValueUniqueIDMapping(result, kUnknownResourceId); + } + } +} + +void ResourceAliasAnalysisInfo::AnalyzeRegionCaseOrIfOp( + Operation* case_or_if_op, const BacktrackAnalysis& backtrack_analysis) { + llvm::SmallVector infos; + infos.reserve(case_or_if_op->getNumRegions()); + for (Region& region : case_or_if_op->getRegions()) + infos.push_back(&backtrack_analysis.GetAnalysisForRegion(region)); + + // For region Case/If, the walk would have visited all branch regions before + // visiting the Case/If op. Backtracking of each region results will either + // give a value computed within these regions, or a region capture. If it is a + // region capture computed before this Case/If, it will have been visited + // earlier and a mapping would exist for that value. If it is computed within + // the region, then again a mapping would exist. + for (auto result : filter_resources(case_or_if_op->getResults())) { + for (const auto* info : infos) { + Value region_result = info->GetValue(result.getResultNumber()); + PropagateInputToOutput(region_result, result); + } + } +} + bool ResourceAliasAnalysisInfo::IsUnknownResource(Value resource) const { auto it = resource_value_to_ids_.find(resource); assert(it != resource_value_to_ids_.end() && !it->getSecond().empty()); diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h b/tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h index 46bb57c942d..5575767dcc4 100644 --- a/tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h +++ b/tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallSet.h" @@ -77,6 +78,16 @@ class ResourceAliasAnalysisInfo { void AnalyzeWhileLoop(Operation* while_op, const BacktrackAnalysisInfo& body_info); + // Analyzes tf.Case/tf.If ops to compute resourceID's. + template + void AnalyzeFunctionalCaseOrIfOp(CaseOrIfOp case_or_if_op, + llvm::ArrayRef functions, + const BacktrackAnalysis& backtrack_analysis); + + // Analyzes tf.CaseRegion/tf.IfRegion ops to compute resourceID's. + void AnalyzeRegionCaseOrIfOp(Operation* case_or_if_op, + const BacktrackAnalysis& backtrack_analysis); + // Maps each resource-type value to a set of unique IDs that it could alias. llvm::SmallDenseMap, 8> resource_value_to_ids_; diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc index c78a7e403c4..4d2c237e9a0 100644 --- a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc +++ b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc @@ -18,30 +18,31 @@ limitations under the License. #include #include -#include "absl/strings/str_cat.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Block.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/Module.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" -#include "tensorflow/compiler/tf2xla/resource_operation_table.h" -#include "tensorflow/core/framework/resource_mgr.h" namespace mlir { namespace TF { @@ -80,38 +81,125 @@ llvm::SmallDenseSet FindAccessedResources( return resources; } -// Returns an XlaResourceOpInfo (or nullptr if it does not exist) that specifies -// the resource access type of the op. It tells whether the op is read only, -// etc. -// -// TODO(yuanzx): Define this information in a different place. Currently we use -// tensorflow/compiler/tf2xla/resource_operation_table.h. -const tensorflow::XlaResourceOpInfo* GetResourceInfoForOp(Operation* op) { - if (op->getName().getDialect() != - TF::TensorFlowDialect::getDialectNamespace()) { - return nullptr; +// Helper struct defining what memory effects are present for a resource. +struct SideEffects { + bool alloc = false; + bool free = false; + bool read = false; + bool write = false; + + bool IsAllocOnly() const { return alloc && !free && !read && !write; } + bool IsReadOnly() const { return !alloc && !free && read && !write; } +}; + +using ResourceSideEffectsByValue = llvm::SmallDenseMap; + +// Collects memory side effects for an operation by value (operands and +// results). +ResourceSideEffectsByValue GetResourceInfoForOp(Operation* op) { + ResourceSideEffectsByValue resource_info; + + auto interface = dyn_cast(op); + if (!interface) return resource_info; + + llvm::SmallVector effects; + interface.getEffects(effects); + + for (auto& effect : effects) { + // TODO(lyandy): Support effects with no value defined. + if (!effect.getValue()) return ResourceSideEffectsByValue(); + auto it = resource_info.try_emplace(effect.getValue()); + auto& side_effect = it.first->getSecond(); + auto* resource_effect = effect.getEffect(); + if (isa(resource_effect)) { + side_effect.alloc = true; + } else if (isa(resource_effect)) { + side_effect.free = true; + } else if (isa(resource_effect)) { + side_effect.read = true; + } else if (isa(resource_effect)) { + side_effect.write = true; + } else { + return ResourceSideEffectsByValue(); + } } - return tensorflow::GetResourceOpInfoForOp( - op->getName().getStringRef().split('.').second.str()); + + return resource_info; } -// Returns whether `op` accesses resources and it is known to be read-only. -bool OpIsReadOnly(Operation* op) { - auto resource_op_info = GetResourceInfoForOp(op); - return resource_op_info && - resource_op_info->kind() == tensorflow::XlaResourceOpKind::kRead; +// Checks if a value is a result of `op`. +bool IsOperationResult(Operation* op, Value value) { + return value.getDefiningOp() == op; +} + +// Checks if an operation's resource operands are read only. Operation results +// are ignored. +bool IsResourceOpReadOnly(Operation* op, + const ResourceSideEffectsByValue& resource_op_info) { + if (resource_op_info.empty()) return false; + + for (const auto& resource_info : resource_op_info) { + Value value = resource_info.getFirst(); + if (IsOperationResult(op, value)) continue; + const SideEffects& side_effects = resource_info.getSecond(); + if (!side_effects.IsReadOnly()) return false; + } + + return true; +} + +// Checks if an operation's resource results are alloc only and no side effects +// are present for its operands. +bool IsResourceOpAllocOnly(Operation* op, + const ResourceSideEffectsByValue& resource_op_info) { + if (resource_op_info.empty()) return false; + + for (const auto& resource_info : resource_op_info) { + // Operand with side effect. + Value value = resource_info.getFirst(); + if (!IsOperationResult(op, value)) return false; + const SideEffects& side_effects = resource_info.getSecond(); + if (!side_effects.IsAllocOnly()) return false; + } + + return true; } // Returns if `op` is a resource declaration. bool OpIsDeclaration(Operation* op, const ResourceAliasAnalysis::Info& alias_analysis) { - // TODO(yuanzx): Add other types of resources. - return llvm::isa(op) || - (llvm::isa(op) && - !FindAccessedResources(op, alias_analysis).empty()); + return llvm::isa(op) && + !FindAccessedResources(op, alias_analysis).empty(); } -// Returns if `op` is know to not have any side effect. +// A vector of resource variable id's with their associated resource value. +using ResourceIdsByValue = + llvm::SmallVector*>, 4>; + +// Collects resource id's by resource value. If operation resource side effects +// are unknown or a resource is unknown, an empty optional is returned. +llvm::Optional GetResourceIdsByValue( + Operation* op, const ResourceAliasAnalysis::Info& alias_analysis, + const ResourceSideEffectsByValue& resource_op_info) { + ResourceIdsByValue resource_ids_by_value; + if (resource_op_info.empty()) return llvm::None; + + auto collect_ids = [&](ValueRange values) { + for (auto value : filter_resources(values)) { + if (alias_analysis.IsUnknownResource(value)) return false; + const auto& ids = alias_analysis.GetResourceUniqueIds(value); + resource_ids_by_value.push_back({value, &ids}); + } + return true; + }; + + if (collect_ids(op->getOperands()) && collect_ids(op->getResults())) + return resource_ids_by_value; + else + return llvm::None; +} + +// Returns true if `op` is known to not have any side effect. bool OpIsKnownToHaveNoSideEffect(Operation* op) { // Note: Identity op is really side-effect free, but it is not marked as such // in the TF dialect (see comments in definition of Identity op in tf_ops.td) @@ -253,17 +341,17 @@ void SideEffectAnalysisInfo::AnalyzeRegion( if (OpIsDeclaration(&op, alias_analysis)) continue; auto resource_op_info = GetResourceInfoForOp(&op); - if (!resource_op_info && OpIsKnownToHaveNoSideEffect(&op)) continue; + if (resource_op_info.empty() && OpIsKnownToHaveNoSideEffect(&op)) + continue; - llvm::SmallDenseSet resources = - resource_op_info ? FindAccessedResources(&op, alias_analysis) - : UnknownResourceSet(); - assert(!resources.empty()); - const bool is_unknown = resources.count(kUnknownResourceId) > 0; - const bool read_only = OpIsReadOnly(&op); + if (IsResourceOpAllocOnly(&op, resource_op_info)) continue; + + auto resource_ids_by_value = + GetResourceIdsByValue(&op, alias_analysis, resource_op_info); + const bool read_only = IsResourceOpReadOnly(&op, resource_op_info); bool indirectly_tracked_unknown_access = false; // First add edges from known resources. - if (is_unknown) { + if (!resource_ids_by_value.hasValue()) { for (auto& entry : per_resource_access_info_) { if (entry.getFirst() == kUnknownResourceId) continue; AddPredecessorsForAccess(entry.getFirst(), &op, read_only); @@ -272,20 +360,43 @@ void SideEffectAnalysisInfo::AnalyzeRegion( read_only); } } else { - for (int64_t resource : resources) { - AddPredecessorsForAccess(resource, &op, read_only); + // Collect all resource id's and whether their side effect is read only. + llvm::SmallDenseMap read_only_by_resource_id; + for (const auto& resource_ids : *resource_ids_by_value) { + const bool is_result = resource_ids.first.getDefiningOp() == &op; + auto value_resource_info = resource_op_info.find(resource_ids.first); + bool resource_read_only = false; + if (value_resource_info != resource_op_info.end()) { + if (is_result && value_resource_info->getSecond().IsAllocOnly()) + continue; + resource_read_only = value_resource_info->getSecond().IsReadOnly(); + } + + for (const auto& id : *resource_ids.second) { + auto it = + read_only_by_resource_id.try_emplace(id, resource_read_only); + if (!it.second && !resource_read_only) + it.first->getSecond() = resource_read_only; + } + } + + for (const auto& resource : read_only_by_resource_id) { + const auto& resource_id = resource.getFirst(); + const auto& resource_read_only = resource.getSecond(); + AddPredecessorsForAccess(resource_id, &op, resource_read_only); indirectly_tracked_unknown_access |= - unknown_access_indirectly_tracked_by_resource(resource, - read_only); + unknown_access_indirectly_tracked_by_resource(resource_id, + resource_read_only); // Update access info for known resources. - TrackAccess(resource, &op, read_only); + TrackAccess(resource_id, &op, resource_read_only); } } + // If not indirectly tracked, add edges from the unknown resource. if (!indirectly_tracked_unknown_access) { AddPredecessorsForAccess(kUnknownResourceId, &op, read_only); } - if (is_unknown) { + if (!resource_ids_by_value.hasValue()) { // Update access info for unknown resource. TrackAccess(kUnknownResourceId, &op, read_only); } diff --git a/tensorflow/compiler/mlir/tensorflow/c/BUILD b/tensorflow/compiler/mlir/tensorflow/c/BUILD index 243f4b5139f..64c56cf8aa9 100644 --- a/tensorflow/compiler/mlir/tensorflow/c/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/c/BUILD @@ -1,8 +1,8 @@ +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load( "//tensorflow:tensorflow.bzl", "tf_copts", "tf_cuda_library", - "tfe_xla_copts", ) package( @@ -20,7 +20,7 @@ tf_cuda_library( srcs = [ "c_api_unified_experimental_mlir.cc", ], - copts = tf_copts() + tfe_xla_copts(), + copts = tf_copts(), deps = [ "//tensorflow/c:c_api", "//tensorflow/c:tensor_interface", @@ -35,6 +35,7 @@ tf_cuda_library( "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:convert_graphdef", "//tensorflow/compiler/mlir/tensorflow:convert_type", + "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", "//tensorflow/compiler/mlir/tensorflow:error_util", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/core:framework", diff --git a/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc b/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc index c62d62a2d3d..32c51f2e2bd 100644 --- a/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc +++ b/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc @@ -43,6 +43,7 @@ limitations under the License. #include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status_helper.h" #include "tensorflow/c/tf_status_internal.h" +#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -50,6 +51,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/types.pb.h" @@ -74,15 +76,9 @@ using tensorflow::tracing::TracingTensorHandle; namespace { -static void RegisterDialects() { - static bool init_once = []() { - mlir::registerDialect(); - mlir::registerDialect(); - mlir::registerDialect(); - mlir::registerDialect(); - return true; - }(); - (void)init_once; +void RegisterDialects(mlir::MLIRContext& ctx) { + mlir::RegisterAllTensorFlowDialects(ctx.getDialectRegistry()); + ctx.getDialectRegistry().loadAll(&ctx); } Status ConvertDataTypeToTensor(tensorflow::DataType dtype, Builder builder, @@ -239,6 +235,7 @@ class MlirFunctionContext : public TracingContext { : TracingContext(kMlir), context_(std::make_unique()), builder_(context_.get()) { + RegisterDialects(*context_); // TODO(aminim) figure out the location story here module_ = ModuleOp::create(builder_.getUnknownLoc()); func_ = FuncOp::create(builder_.getUnknownLoc(), name, @@ -456,7 +453,8 @@ Status MlirAbstractOp::SetAttrFloat(const char* attr_name, float value) { return Unimplemented("SetAttrFloat has not been implemented yet."); } Status MlirAbstractOp::SetAttrBool(const char* attr_name, bool value) { - return Unimplemented("SetAttrBool has not been implemented yet."); + attrs_[attr_name] = BoolAttr::get(value, context_); + return Status::OK(); } Status MlirAbstractOp::SetAttrShape(const char* attr_name, const int64_t* dims, const int num_dims) { @@ -514,6 +512,7 @@ Status MlirFunction::GetFunctionDef(tensorflow::FunctionDef** f) { return Status::OK(); } PassManager pm(func_.getContext()); + ::tensorflow::applyTensorflowAndCLOptions(pm); pm.addNestedPass(CreateFunctionalToExecutorDialectConversionPass()); pm.addPass(CreateBreakUpIslandsPass()); @@ -656,9 +655,8 @@ Status MlirFunctionContext::Finalize(OutputList* outputs, } builder_.create(func_.getLoc(), ret_operands); - auto arg_types = llvm::to_vector<8>(body.getArgumentTypes()); - auto result_types = - llvm::to_vector<8>(body.getTerminator()->getOperandTypes()); + auto arg_types = body.getArgumentTypes(); + auto result_types = body.getTerminator()->getOperandTypes(); func_.setType(FunctionType::get(arg_types, result_types, func_.getContext())); *f = new MlirFunction(std::move(context_), std::move(module_), func_); return Status::OK(); @@ -666,7 +664,6 @@ Status MlirFunctionContext::Finalize(OutputList* outputs, extern "C" { TracingContext* MlirTracingFactory(const char* fn_name, TF_Status* s) { - RegisterDialects(); return new MlirFunctionContext(fn_name); } } diff --git a/tensorflow/compiler/mlir/tensorflow/ir/dialect_registration.cc b/tensorflow/compiler/mlir/tensorflow/ir/dialect_registration.cc deleted file mode 100644 index 45985cea583..00000000000 --- a/tensorflow/compiler/mlir/tensorflow/ir/dialect_registration.cc +++ /dev/null @@ -1,34 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" - -namespace mlir { - -// Static initialization for TF dialect registration. -static DialectRegistration tf_ops; -static DialectRegistration - tf_executor_dialect; -static DialectRegistration - tf_device_dialect; -static DialectRegistration - tf_saved_model_dialect; -static DialectRegistration shape_dialect; - -} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.cc index 40cc2c99c27..746b34a018a 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.cc @@ -15,8 +15,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" -#include "mlir/IR/Attributes.h" // from @llvm-project - namespace mlir { namespace TF { @@ -79,6 +77,14 @@ ShapeAttr ShapeAttr::get(mlir::MLIRContext* context, return Base::get(context, ArrayRef(), /*unranked=*/true); } +// Get or create a shape attribute. +ShapeAttr ShapeAttr::get(mlir::MLIRContext* context, ShapedType shaped_type) { + if (shaped_type.hasRank()) + return Base::get(context, shaped_type.getShape(), /*unranked=*/false); + + return Base::get(context, ArrayRef(), /*unranked=*/true); +} + llvm::Optional> ShapeAttr::getValue() const { if (hasRank()) return getShape(); return llvm::None; diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h index 5a18b77ab5c..0927aefff68 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h @@ -20,6 +20,8 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project namespace mlir { namespace TF { @@ -43,6 +45,9 @@ class ShapeAttr : public Attribute::AttrBase> shape); + // Get or create a shape attribute from a ShapedType type. + static ShapeAttr get(mlir::MLIRContext* context, ShapedType shaped_type); + llvm::Optional> getValue() const; bool hasRank() const; diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc index 5345000b4bd..3a2e8095139 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc @@ -269,8 +269,6 @@ ParseResult SetReplicateOpOperands( replicated_inputs, llvm::ArrayRef packed_inputs, llvm::ArrayRef region_arg_types, int32_t* n) { - if (replicated_inputs.empty() && packed_inputs.empty()) return success(); - for (const auto& attr : state->attributes) if (attr.first.strref() == "n") if (auto n_attr = attr.second.dyn_cast()) @@ -279,6 +277,8 @@ ParseResult SetReplicateOpOperands( if (*n < 2) return parser->emitError(loc) << "expects 'n' to be at least 2, got " << *n; + if (replicated_inputs.empty() && packed_inputs.empty()) return success(); + for (auto replicated_input_and_idx : llvm::enumerate(replicated_inputs)) { const int32_t idx = replicated_input_and_idx.index(); const auto& replicated_input = replicated_input_and_idx.value(); @@ -369,7 +369,7 @@ void Print(ReplicateOp op, OpAsmPrinter* p) { // [%a, ...] as %block_arg0: type // packed_input // %b as %block_arg1: type - const int32_t n = op.n().getSExtValue(); + const int32_t n = op.n(); const int32_t num_replicated_inputs = (*op.operand_segment_sizes().int_value_begin()).getSExtValue(); const int32_t num_replicated_block_args = num_replicated_inputs / n; @@ -413,7 +413,7 @@ LogicalResult VerifyCompatibleTypes(Type a, Type b) { } LogicalResult Verify(ReplicateOp op) { - int32_t n = op.n().getSExtValue(); + int32_t n = op.n(); // Check number of devices, if set, matches `n`. if (op.devices().hasValue()) { @@ -504,13 +504,12 @@ LogicalResult Verify(ReplicateOp op) { return success(); } -template void BuildReplicateOp( Builder* builder, OperationState* state, int n, const llvm::SmallDenseMap>& devices, - llvm::ArrayRef> replicated_inputs, - llvm::ArrayRef packed_inputs, ResultsTy replica_output_types) { + llvm::ArrayRef> replicated_inputs, + ValueRange packed_inputs, TypeRange replica_output_types) { DCHECK_GE(n, 2); state->addAttribute("n", builder->getI32IntegerAttr(n)); @@ -538,7 +537,7 @@ void BuildReplicateOp( block.addArgument(replicated_input.second); } - for (auto& packed_input : packed_inputs) { + for (auto packed_input : packed_inputs) { state->addOperands(packed_input); block.addArgument(packed_input.getType()); } @@ -560,20 +559,8 @@ void ReplicateOp::build( OpBuilder& builder, OperationState& state, int n, const llvm::SmallDenseMap>& devices, - llvm::ArrayRef, Type>> replicated_inputs, - llvm::ArrayRef packed_inputs, - llvm::ArrayRef replica_output_types) { - BuildReplicateOp(&builder, &state, n, devices, replicated_inputs, - packed_inputs, replica_output_types); -} - -void ReplicateOp::build( - OpBuilder& builder, OperationState& state, int n, - const llvm::SmallDenseMap>& - devices, - llvm::ArrayRef> replicated_inputs, - llvm::ArrayRef packed_inputs, - Operation::result_type_range replica_output_types) { + llvm::ArrayRef> replicated_inputs, + ValueRange packed_inputs, TypeRange replica_output_types) { BuildReplicateOp(&builder, &state, n, devices, replicated_inputs, packed_inputs, replica_output_types); } @@ -670,12 +657,12 @@ void LaunchOp::getCanonicalizationPatterns(OwningRewritePatternList& results, results.insert(context); } +} // namespace tf_device +} // namespace mlir + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// #define GET_OP_CLASSES #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc.inc" - -} // namespace tf_device -} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.h index 688c8ca5715..5b1d9711875 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.h @@ -41,11 +41,11 @@ class TensorFlowDeviceDialect : public Dialect { explicit TensorFlowDeviceDialect(MLIRContext* context); }; +} // namespace tf_device +} // namespace mlir + // Declares the operations for this dialect using the generated header. #define GET_OP_CLASSES #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h.inc" -} // namespace tf_device -} // namespace mlir - #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_DEVICE_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_device_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_device_ops.td index d94a37d9b02..65de4ea306f 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_device_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_device_ops.td @@ -36,7 +36,7 @@ def TfDevice_Dialect : Dialect { XlaRun. }]; - let cppNamespace = "tf_device"; + let cppNamespace = "::mlir::tf_device"; } //===----------------------------------------------------------------------===// @@ -295,14 +295,8 @@ For example: let builders = [ OpBuilder<"OpBuilder& builder, OperationState& state, int n, " "const llvm::SmallDenseMap>& devices, " - "llvm::ArrayRef, Type>> replicated_inputs, " - "llvm::ArrayRef packed_inputs, " - "llvm::ArrayRef replica_output_types">, - OpBuilder<"OpBuilder& builder, OperationState& state, int n, " - "const llvm::SmallDenseMap>& devices, " - "llvm::ArrayRef> replicated_inputs, " - "llvm::ArrayRef packed_inputs, " - "Operation::result_type_range replica_output_types"> + "llvm::ArrayRef> replicated_inputs, " + "ValueRange packed_inputs, TypeRange replica_output_types">, ]; let parser = [{ return Parse$cppClass(&parser, &result); }]; diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc index ea9ae5d9477..f2d0a548420 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc @@ -250,33 +250,6 @@ ParseResult ParseGraphOp(OpAsmParser &parser, OperationState &result) { // tf_executor.fetch //===----------------------------------------------------------------------===// -namespace { - -void Print(FetchOp fetch, OpAsmPrinter &p) { - p << fetch.getOperationName(); - if (fetch.getNumOperands() > 0) { - p << ' '; - p.printOperands(fetch.operand_begin(), fetch.operand_end()); - p << " : "; - interleaveComma(fetch.getOperandTypes(), p); - } - p.printOptionalAttrDict(fetch.getAttrs()); -} - -ParseResult ParseFetchOp(OpAsmParser &parser, OperationState &result) { - SmallVector opInfo; - SmallVector types; - llvm::SMLoc loc = parser.getCurrentLocation(); - return failure(parser.parseOperandList(opInfo) || - (!opInfo.empty() && parser.parseColonTypeList(types)) || - parser.resolveOperands(opInfo, types, loc, result.operands) || - parser.parseOptionalAttrDict(result.attributes) - - ); -} - -} // anonymous namespace - //===----------------------------------------------------------------------===// // tf_executor.island //===----------------------------------------------------------------------===// @@ -411,31 +384,6 @@ ParseResult ParseIslandOp(OpAsmParser &parser, OperationState &result) { // tf_executor.yield //===----------------------------------------------------------------------===// -namespace { - -void Print(YieldOp yield, OpAsmPrinter &p) { - p << yield.getOperationName(); - if (yield.getNumOperands() > 0) { - p << ' '; - p.printOperands(yield.operand_begin(), yield.operand_end()); - p << " : "; - interleaveComma(yield.getOperandTypes(), p); - } - p.printOptionalAttrDict(yield.getAttrs()); -} - -ParseResult ParseYieldOp(OpAsmParser &parser, OperationState &result) { - SmallVector op_info; - SmallVector types; - llvm::SMLoc loc = parser.getCurrentLocation(); - return failure(parser.parseOperandList(op_info) || - (!op_info.empty() && parser.parseColonTypeList(types)) || - parser.resolveOperands(op_info, types, loc, result.operands) || - parser.parseOptionalAttrDict(result.attributes)); -} - -} // anonymous namespace - //===----------------------------------------------------------------------===// // tf_executor.Switch //===----------------------------------------------------------------------===// @@ -848,23 +796,6 @@ LogicalResult Verify(NextIterationSourceOp source) { return success(); } -void Print(NextIterationSourceOp next_iteration, OpAsmPrinter &p) { - p << next_iteration.getOperationName() << " : " << next_iteration.getType(0); - p.printOptionalAttrDict(next_iteration.getAttrs()); -} - -ParseResult ParseNextIterationSourceOp(OpAsmParser &parser, - OperationState &result) { - SmallVector types; - if (parser.parseColonTypeList(types)) return failure(); - - MLIRContext *context = parser.getBuilder().getContext(); - Type token_type = TokenType::get(context); - Type control_type = ControlType::get(context); - result.addTypes({types.front(), token_type, control_type}); - return parser.parseOptionalAttrDict(result.attributes); -} - } // anonymous namespace //===----------------------------------------------------------------------===// @@ -891,36 +822,6 @@ LogicalResult Verify(NextIterationSinkOp sink) { return success(); } -void Print(NextIterationSinkOp next_iteration, OpAsmPrinter &p) { - p << next_iteration.getOperationName() << " ["; - p.printOperand(next_iteration.getOperand(0)); - p << "] "; - p.printOperands(llvm::drop_begin(next_iteration.getOperands(), 1)); - p << " : " << next_iteration.getOperand(1).getType(); - p.printOptionalAttrDict(next_iteration.getAttrs()); -} - -ParseResult ParseNextIterationSinkOp(OpAsmParser &parser, - OperationState &result) { - SmallVector op_infos; - llvm::SMLoc loc = parser.getCurrentLocation(); - - // First type is always the token consumed from the NextIteration.source - Type token_type = TokenType::get(parser.getBuilder().getContext()); - SmallVector types = {token_type}; - - if (parser.parseOperandList(op_infos, 1, OpAsmParser::Delimiter::Square) || - parser.parseOperandList(op_infos) || parser.parseColonTypeList(types)) - return failure(); - - Type control_type = ControlType::get(parser.getBuilder().getContext()); - types.append(op_infos.size() - 2, control_type); - if (parser.resolveOperands(op_infos, types, loc, result.operands)) - return failure(); - - return parser.parseOptionalAttrDict(result.attributes); -} - } // anonymous namespace //===----------------------------------------------------------------------===// @@ -959,32 +860,6 @@ ParseResult ParseExitOp(OpAsmParser &parser, OperationState &result) { // tf_executor.ControlTrigger //===----------------------------------------------------------------------===// -namespace { - -void Print(ControlTriggerOp trigger, OpAsmPrinter &p) { - p << trigger.getOperationName() << ' '; - p.printOperands(trigger.getOperands()); - p.printOptionalAttrDict(trigger.getAttrs()); -} - -ParseResult ParseControlTriggerOp(OpAsmParser &parser, OperationState &result) { - SmallVector op_infos; - SmallVector types; - llvm::SMLoc loc = parser.getCurrentLocation(); - - if (parser.parseOperandList(op_infos)) return failure(); - Type control_type = ControlType::get(parser.getBuilder().getContext()); - types.append(op_infos.size(), control_type); - if (parser.resolveOperands(op_infos, types, loc, result.operands)) - return failure(); - - // Single control as the only output - result.types.push_back(control_type); - return parser.parseOptionalAttrDict(result.attributes); -} - -} // anonymous namespace - //===----------------------------------------------------------------------===// // tf_executor.LoopCond //===----------------------------------------------------------------------===// @@ -1246,12 +1121,12 @@ LogicalResult IslandOp::fold(llvm::ArrayRef operands, return success(); } +} // namespace tf_executor +} // namespace mlir + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// #define GET_OP_CLASSES #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc.inc" - -} // namespace tf_executor -} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h index 60036ddc9f8..2bc13556b4b 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h @@ -57,11 +57,11 @@ class TokenType : public Type::TypeBase { using Base::Base; }; +} // namespace tf_executor +} // namespace mlir + // Declares the operations for this dialect using the generated header. #define GET_OP_CLASSES #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h.inc" -} // namespace tf_executor -} // namespace mlir - #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_EXECUTOR_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td index 3081018b8da..713ddc44cba 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td @@ -43,14 +43,16 @@ def TfExecutor_Dialect : Dialect { value). }]; - let cppNamespace = "tf_executor"; + let cppNamespace = "::mlir::tf_executor"; } // Control type. -def TfeControlType : Type()">, "control">; +def TfeControlType : Type()">, "control">, + BuildableType<"$_builder.getType()">; // Token type. -def TfeTokenType : Type()">, "token">; +def TfeTokenType : Type()">, "token">, + BuildableType<"$_builder.getType()">; // TODO(hinsu): Define and use TensorType instead of AnyType for data operands // and results. For example, MergeOp output type. @@ -148,7 +150,11 @@ def TfExecutor_FetchOp : TfExecutor_Op<"fetch", }]> ]; + let assemblyFormat = "($fetches^ `:` type($fetches))? attr-dict"; + let verifier = ?; + let printer = ?; + let parser = ?; } def TfExecutor_IslandOp : TfExecutor_Op<"island", @@ -229,7 +235,11 @@ def TfExecutor_YieldOp : TfExecutor_Op<"yield", }]> ]; + let assemblyFormat = "($fetches^ `:` type($fetches))? attr-dict"; + let verifier = ?; + let printer = ?; + let parser = ?; } def TfExecutor_SwitchOp : TfExecutor_Op<"Switch", @@ -466,6 +476,10 @@ def TfExecutor_NextIterationSourceOp : TfExecutor_Op<"NextIteration.Source", } }]; + let assemblyFormat = "`:` type($output) attr-dict"; + + let printer = ?; + let parser = ?; } @@ -527,6 +541,11 @@ def TfExecutor_NextIterationSinkOp : TfExecutor_Op<"NextIteration.Sink", result.attributes.append(attributes.begin(), attributes.end()); }]> ]; + + let assemblyFormat = " `[` $token `]` $input (`,` $controlInputs^)? `:` type($input) attr-dict"; + + let printer = ?; + let parser = ?; } def TfExecutor_ExitOp : TfExecutor_Op<"Exit", @@ -552,7 +571,7 @@ def TfExecutor_ExitOp : TfExecutor_Op<"Exit", .Attr("T: type") For example: - %1:2 = tf_executor.Exit %0#0 {T: "tfdtype$DT_INT32"} : tensor<*xi32> + %1:2 = tf_executor.Exit %0#0 : tensor<*xi32> {T: "tfdtype$DT_INT32"} Note: Additional result corresponds to the control output. }]; @@ -607,6 +626,11 @@ def TfExecutor_ControlTriggerOp : TfExecutor_Op<"ControlTrigger", result.attributes.append(attributes.begin(), attributes.end()); }]> ]; + + let assemblyFormat = "$controlInputs attr-dict"; + + let printer = ?; + let parser = ?; } def TfExecutor_LoopCondOp : TfExecutor_Op<"LoopCond", diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index 283e3326029..aa1b7bb81a9 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -16,11 +16,12 @@ limitations under the License. // This is the operation definition file for TensorFlow. // // This file contains TensorFlow ops whose definitions are programmatically -// generated from the TensorFlow codebase. The generated fields for an op -// includes name, summary, description, traits, arguments, results, derived -// attributes. Therefore, modifications to these fields will **not** be -// respected upon subsequent refreshes. However, additional fields after those -// fields will be retained. +// generated from the api-def-files in the following folder: +// tensorflow/core/api_def/base_api +// The generated fields for an op include name, summary, description, traits, +// arguments, results, derived attributes. Therefore, modifications to these +// fields will **not** be respected upon subsequent refreshes. However, +// additional fields after those fields will be retained. // // If you absolutely need to modify the generated fields of an op, move the // definition to `tf_ops.td` and perform the modification there. @@ -28,6 +29,7 @@ limitations under the License. // Ops in this file are sorted alphabetically. include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td" +include "mlir/Interfaces/InferTypeOpInterface.td" def TF_AbsOp : TF_Op<"Abs", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Computes the absolute value of a tensor."; @@ -39,11 +41,11 @@ an output element, this operation computes \\(y = |x|\\). }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8]>:$x + TensorOf<[TF_Bfloat16, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8]>:$x ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8]>:$y + TensorOf<[TF_Bfloat16, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8]>:$y ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -52,12 +54,18 @@ an output element, this operation computes \\(y = |x|\\). def TF_AcosOp : TF_Op<"Acos", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Computes acos of x element-wise."; + let description = [{ +Provided an input tensor, the `tf.math.acos` operation returns the inverse cosine of each element of the tensor. If `y = tf.math.cos(x)` then, `x = tf.math.acos(y)`. + + Input range is `[-1, 1]` and the output has a range of `[0, pi]`. + }]; + let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64]>:$x + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8]>:$x ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64]>:$y + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8]>:$y ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -87,29 +95,6 @@ tf.math.acosh(x) ==> [nan nan 0. 0.62236255 5.9914584 9.903487 inf] TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_AddOp : TF_Op<"Add", [NoSideEffect, ResultsBroadcastableShape, TF_LayoutAgnostic, TF_SameOperandsAndResultElementTypeResolveRef]>, - WithBroadcastableBinOpBuilder { - let summary = "Returns x + y element-wise."; - - let description = [{ -*NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting -[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) - }]; - - let arguments = (ins - TF_NumberOrStrTensor:$x, - TF_NumberOrStrTensor:$y - ); - - let results = (outs - TF_NumberOrStrTensor:$z - ); - - TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; - - let hasCanonicalizer = 1; -} - def TF_AddNOp : TF_Op<"AddN", [Commutative, NoSideEffect]> { let summary = "Add all input tensors element wise."; @@ -123,11 +108,11 @@ Inputs must be of same size and shape. }]; let arguments = (ins - Variadic>:$inputs + Variadic>:$inputs ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8, TF_Variant]>:$sum + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8, TF_Variant]>:$sum ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -136,31 +121,6 @@ Inputs must be of same size and shape. let hasFolder = 1; } -def TF_AddV2Op : TF_Op<"AddV2", [Commutative, NoSideEffect, ResultsBroadcastableShape, TF_CwiseBinary, TF_LayoutAgnostic, TF_SameOperandsAndResultElementTypeResolveRef]>, - WithBroadcastableBinOpBuilder { - let summary = "Returns x + y element-wise."; - - let description = [{ -*NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting -[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) - }]; - - let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint8]>:$x, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint8]>:$y - ); - - let results = (outs - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint8]>:$z - ); - - TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; - - let hasCanonicalizer = 1; - - let hasFolder = 1; -} - def TF_AdjustContrastv2Op : TF_Op<"AdjustContrastv2", [NoSideEffect]> { let summary = "Adjust the contrast of one or more images."; @@ -177,12 +137,12 @@ channel and then adjusts each component of each pixel to }]; let arguments = (ins - TensorOf<[F16, F32]>:$images, - F32Tensor:$contrast_factor + TensorOf<[TF_Float16, TF_Float32]>:$images, + TF_Float32Tensor:$contrast_factor ); let results = (outs - TensorOf<[F16, F32]>:$output + TensorOf<[TF_Float16, TF_Float32]>:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -201,12 +161,12 @@ and then remapped back to RGB colorspace. }]; let arguments = (ins - TensorOf<[F16, F32]>:$images, - F32Tensor:$delta + TensorOf<[TF_Float16, TF_Float32]>:$images, + TF_Float32Tensor:$delta ); let results = (outs - TensorOf<[F16, F32]>:$output + TensorOf<[TF_Float16, TF_Float32]>:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -225,12 +185,12 @@ values, and then remapped back to RGB colorspace. }]; let arguments = (ins - TensorOf<[F16, F32]>:$images, - F32Tensor:$scale + TensorOf<[TF_Float16, TF_Float32]>:$images, + TF_Float32Tensor:$scale ); let results = (outs - TensorOf<[F16, F32]>:$output + TensorOf<[TF_Float16, TF_Float32]>:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -249,14 +209,14 @@ retained with length 1. }]; let arguments = (ins - I1Tensor:$input, + TF_BoolTensor:$input, TF_I32OrI64Tensor:$reduction_indices, DefaultValuedAttr:$keep_dims ); let results = (outs - I1Tensor:$output + TF_BoolTensor:$output ); TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>; @@ -264,7 +224,7 @@ retained with length 1. let verifier = [{ return Verify(*this); }]; } -def TF_AllToAllOp : TF_Op<"AllToAll", [NoSideEffect]> { +def TF_AllToAllOp : TF_Op<"AllToAll", [NoSideEffect, TF_NoConstantFold]> { let summary = "An Op to exchange data across TPU replicas."; let description = [{ @@ -287,8 +247,8 @@ replica 1's output: `[[B], [D]]` }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input, - I32Tensor:$group_assignment, + TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input, + TF_Int32Tensor:$group_assignment, I64Attr:$concat_dimension, I64Attr:$split_dimension, @@ -296,7 +256,7 @@ replica 1's output: `[[B], [D]]` ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output + TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -337,6 +297,88 @@ Equivalent to np.angle. TF_DerivedResultTypeAttr Tout = TF_DerivedResultTypeAttr<0>; } +def TF_AnonymousIteratorOp : TF_Op<"AnonymousIterator", []> { + let summary = "A container for an iterator resource."; + + let arguments = (ins + Confined]>:$output_types, + Confined]>:$output_shapes + ); + + let results = (outs + Res:$handle + ); +} + +def TF_AnonymousIteratorV2Op : TF_Op<"AnonymousIteratorV2", []> { + let summary = "A container for an iterator resource."; + + let arguments = (ins + Confined]>:$output_types, + Confined]>:$output_shapes + ); + + let results = (outs + Res:$handle, + TF_VariantTensor:$deleter + ); +} + +def TF_AnonymousMemoryCacheOp : TF_Op<"AnonymousMemoryCache", []> { + let summary = ""; + + let arguments = (ins); + + let results = (outs + Res:$handle, + TF_VariantTensor:$deleter + ); +} + +def TF_AnonymousMultiDeviceIteratorOp : TF_Op<"AnonymousMultiDeviceIterator", []> { + let summary = "A container for a multi device iterator resource."; + + let arguments = (ins + Confined]>:$devices, + Confined]>:$output_types, + Confined]>:$output_shapes + ); + + let results = (outs + Res:$handle, + TF_VariantTensor:$deleter + ); +} + +def TF_AnonymousRandomSeedGeneratorOp : TF_Op<"AnonymousRandomSeedGenerator", []> { + let summary = ""; + + let arguments = (ins + TF_Int64Tensor:$seed, + TF_Int64Tensor:$seed2 + ); + + let results = (outs + Res:$handle, + TF_VariantTensor:$deleter + ); +} + +def TF_AnonymousSeedGeneratorOp : TF_Op<"AnonymousSeedGenerator", []> { + let summary = ""; + + let arguments = (ins + TF_Int64Tensor:$seed, + TF_Int64Tensor:$seed2, + TF_BoolTensor:$reshuffle + ); + + let results = (outs + Res:$handle, + TF_VariantTensor:$deleter + ); +} + def TF_AnyOp : TF_Op<"Any", [NoSideEffect]> { let summary = [{ Computes the "logical or" of elements across dimensions of a tensor. @@ -350,14 +392,14 @@ retained with length 1. }]; let arguments = (ins - I1Tensor:$input, + TF_BoolTensor:$input, TF_I32OrI64Tensor:$reduction_indices, DefaultValuedAttr:$keep_dims ); let results = (outs - I1Tensor:$output + TF_BoolTensor:$output ); TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>; @@ -369,14 +411,14 @@ def TF_ApproximateEqualOp : TF_Op<"ApproximateEqual", [Commutative, NoSideEffect let summary = "Returns the truth value of abs(x-y) < tolerance element-wise."; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$x, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$y, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$x, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$y, DefaultValuedAttr:$tolerance ); let results = (outs - I1Tensor:$z + TF_BoolTensor:$z ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -402,7 +444,7 @@ Usage: }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input, + TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input, TF_I32OrI64Tensor:$dimension ); @@ -435,7 +477,7 @@ Usage: }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input, + TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input, TF_I32OrI64Tensor:$dimension ); @@ -467,7 +509,7 @@ array([b'3.14', b'2.72'], dtype=object) }]; let arguments = (ins - TensorOf<[F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64]>:$input, + TensorOf<[TF_Bool, TF_Complex128, TF_Complex64, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8]>:$input, DefaultValuedAttr:$precision, DefaultValuedAttr:$scientific, @@ -505,11 +547,11 @@ tf.math.asin(y) # [1.047, 0.785] = x }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64]>:$x + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8]>:$x ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64]>:$y + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8]>:$y ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -549,7 +591,7 @@ If `condition` evaluates to false, print the list of tensors in `data`. }]; let arguments = (ins - I1Tensor:$condition, + TF_BoolTensor:$condition, Variadic:$data, DefaultValuedAttr:$summarize @@ -571,7 +613,7 @@ see the incremented value or a subsequent newer one. }]; let arguments = (ins - TF_ResourceTensor:$resource, + Arg:$resource, TF_Tensor:$value ); @@ -589,7 +631,7 @@ see the decremented value or a subsequent newer one. }]; let arguments = (ins - TF_ResourceTensor:$resource, + Arg:$resource, TF_Tensor:$value ); @@ -607,7 +649,7 @@ this value or a subsequent newer value of the variable. }]; let arguments = (ins - TF_ResourceTensor:$resource, + Arg:$resource, TF_Tensor:$value ); @@ -638,11 +680,11 @@ tf.math.atan(y) # [1.047, 0.785] = x }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64]>:$x + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8]>:$x ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64]>:$y + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8]>:$y ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -663,12 +705,12 @@ where \(r = \sqrt(x^2 + y^2) \). }]; let arguments = (ins - TF_FpTensor:$y, - TF_FpTensor:$x + TF_FloatTensor:$y, + TF_FloatTensor:$x ); let results = (outs - TF_FpTensor:$z + TF_FloatTensor:$z ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -710,7 +752,7 @@ window in `value`. }]; let arguments = (ins - TF_FpTensor:$value, + TF_FloatTensor:$value, Confined]>:$ksize, Confined]>:$strides, @@ -719,7 +761,7 @@ window in `value`. ); let results = (outs - TF_FpTensor:$output + TF_FloatTensor:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -734,7 +776,7 @@ Each entry in `output` is the mean of the corresponding size `ksize` window in }]; let arguments = (ins - TF_FpTensor:$input, + TF_FloatTensor:$input, Confined]>:$ksize, Confined]>:$strides, @@ -743,7 +785,7 @@ Each entry in `output` is the mean of the corresponding size `ksize` window in ); let results = (outs - TF_FpTensor:$output + TF_FloatTensor:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -753,8 +795,8 @@ def TF_AvgPool3DGradOp : TF_Op<"AvgPool3DGrad", [NoSideEffect]> { let summary = "Computes gradients of average pooling function."; let arguments = (ins - I32Tensor:$orig_input_shape, - TF_FpTensor:$grad, + TF_Int32Tensor:$orig_input_shape, + TF_FloatTensor:$grad, Confined]>:$ksize, Confined]>:$strides, @@ -763,7 +805,7 @@ def TF_AvgPool3DGradOp : TF_Op<"AvgPool3DGrad", [NoSideEffect]> { ); let results = (outs - TF_FpTensor:$output + TF_FloatTensor:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>; @@ -773,8 +815,8 @@ def TF_AvgPoolGradOp : TF_Op<"AvgPoolGrad", [NoSideEffect]> { let summary = "Computes gradients of the average pooling function."; let arguments = (ins - I32Tensor:$orig_input_shape, - TF_FpTensor:$grad, + TF_Int32Tensor:$orig_input_shape, + TF_FloatTensor:$grad, Confined]>:$ksize, Confined]>:$strides, @@ -783,7 +825,7 @@ def TF_AvgPoolGradOp : TF_Op<"AvgPoolGrad", [NoSideEffect]> { ); let results = (outs - TF_FpTensor:$output + TF_FloatTensor:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>; @@ -814,15 +856,15 @@ It is computed as: }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$x, - TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$y, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64]>:$x, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64]>:$y, DefaultValuedAttr:$adj_x, DefaultValuedAttr:$adj_y ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$output + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64]>:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -859,15 +901,15 @@ about broadcasting }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, TF_Complex128, TF_Complex64]>:$x, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, TF_Complex128, TF_Complex64]>:$y, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64]>:$x, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64]>:$y, DefaultValuedAttr:$adj_x, DefaultValuedAttr:$adj_y ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, TF_Complex128, TF_Complex64]>:$output + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64]>:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -887,18 +929,18 @@ This op is deprecated. Prefer `tf.nn.batch_normalization`. }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$t, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$m, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$v, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$beta, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$gamma, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$t, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$m, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$v, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$beta, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$gamma, F32Attr:$variance_epsilon, BoolAttr:$scale_after_normalization ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$result + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$result ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -999,7 +1041,7 @@ beta function. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_BiasAddOp : TF_Op<"BiasAdd", [NoSideEffect]> { +def TF_BiasAddOp : TF_Op<"BiasAdd", [NoSideEffect, TF_ContractionFusableInterface]> { let summary = "Adds `bias` to `value`."; let description = [{ @@ -1008,18 +1050,23 @@ Broadcasting is supported, so `value` may have any number of dimensions. }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$value, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$bias, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$value, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$bias, DefaultValuedAttr:$data_format ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + let extraClassDeclaration = [{ + // TF_ContractionFusableInterface: + Optional GetContractionFusion(); + }]; + let verifier = [{ return Verify(*this); }]; @@ -1037,13 +1084,13 @@ the feature dimension is the third-to-last. }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$out_backprop, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$out_backprop, DefaultValuedAttr:$data_format ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -1064,12 +1111,12 @@ Broadcasting is supported, so `value` may have any number of dimensions. }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$value, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$bias + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$value, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$bias ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -1373,13 +1420,13 @@ then the output will be }]; let arguments = (ins - TensorOf<[F32, F64, I32, I64]>:$input, + TensorOf<[TF_Float32, TF_Float64, TF_Int32, TF_Int64]>:$input, F32ArrayAttr:$boundaries ); let results = (outs - I32Tensor:$output + TF_Int32Tensor:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -1408,11 +1455,11 @@ def TF_CeilOp : TF_Op<"Ceil", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Returns element-wise smallest integer not less than x."; let arguments = (ins - TF_FpTensor:$x + TF_FloatTensor:$x ); let results = (outs - TF_FpTensor:$y + TF_FloatTensor:$y ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -1427,13 +1474,13 @@ that are not a number (NaN) or infinity (Inf). Otherwise, passes `tensor` as-is. }]; let arguments = (ins - TF_FpTensor:$tensor, + TF_FloatTensor:$tensor, StrAttr:$message ); let results = (outs - TF_FpTensor:$output + TF_FloatTensor:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -1461,11 +1508,11 @@ case it might be faster to use the CPU. }]; let arguments = (ins - TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$input + TensorOf<[TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64]>:$input ); let results = (outs - TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$output + TensorOf<[TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64]>:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -1482,13 +1529,13 @@ greater than `clip_value_max` are set to `clip_value_max`. }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$t, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$clip_value_min, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$clip_value_max + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$t, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$clip_value_min, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$clip_value_max ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -1507,7 +1554,7 @@ def TF_CollectiveBcastRecvOp : TF_Op<"CollectiveBcastRecv", []> { ); let results = (outs - TensorOf<[F16, F32, F64, I1, I32, I64]>:$data + TensorOf<[TF_Bool, TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64]>:$data ); TF_DerivedResultTypeAttr T = TF_DerivedResultTypeAttr<0>; @@ -1517,7 +1564,7 @@ def TF_CollectiveBcastSendOp : TF_Op<"CollectiveBcastSend", []> { let summary = "Broadcasts a tensor value to one or more other devices."; let arguments = (ins - TensorOf<[F16, F32, F64, I1, I32, I64]>:$input, + TensorOf<[TF_Bool, TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64]>:$input, I64Attr:$group_size, I64Attr:$group_key, @@ -1528,7 +1575,7 @@ def TF_CollectiveBcastSendOp : TF_Op<"CollectiveBcastSend", []> { ); let results = (outs - TensorOf<[F16, F32, F64, I1, I32, I64]>:$data + TensorOf<[TF_Bool, TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64]>:$data ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -1540,7 +1587,7 @@ Mutually accumulates multiple tensors of identical type and shape. }]; let arguments = (ins - TensorOf<[F16, F32, F64, I32, I64]>:$input, + TensorOf<[TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64]>:$input, I64Attr:$group_size, I64Attr:$group_key, @@ -1551,7 +1598,7 @@ Mutually accumulates multiple tensors of identical type and shape. ); let results = (outs - TensorOf<[F16, F32, F64, I32, I64]>:$data + TensorOf<[TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64]>:$data ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -1563,7 +1610,7 @@ Mutually reduces multiple tensors of identical type and shape. }]; let arguments = (ins - TensorOf<[F16, F32, F64, I32, I64]>:$input, + TensorOf<[TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64]>:$input, I64Attr:$group_size, I64Attr:$group_key, @@ -1577,7 +1624,7 @@ Mutually reduces multiple tensors of identical type and shape. ); let results = (outs - TensorOf<[F16, F32, F64, I32, I64]>:$data + TensorOf<[TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64]>:$data ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -1589,10 +1636,10 @@ Mutually reduces multiple tensors of identical type and shape. }]; let arguments = (ins - TensorOf<[F16, F32, F64, I32, I64]>:$input, - I32Tensor:$group_size, - I32Tensor:$group_key, - I32Tensor:$instance_key, + TensorOf<[TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64]>:$input, + TF_Int32Tensor:$group_size, + TF_Int32Tensor:$group_key, + TF_Int32Tensor:$instance_key, TF_AnyStrAttrOf<["Min", "Max", "Mul", "Add"]>:$merge_op, TF_AnyStrAttrOf<["Id", "Div"]>:$final_op, @@ -1600,7 +1647,7 @@ Mutually reduces multiple tensors of identical type and shape. ); let results = (outs - TensorOf<[F16, F32, F64, I32, I64]>:$data + TensorOf<[TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64]>:$data ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -1665,7 +1712,7 @@ def TF_ConcatOp : TF_Op<"Concat", [NoSideEffect]> { let summary = "Concatenates tensors along one dimension."; let arguments = (ins - I32Tensor:$concat_dim, + TF_Int32Tensor:$concat_dim, Variadic:$values ); @@ -1700,12 +1747,12 @@ This is typically used by gradient computations for a concat operation. }]; let arguments = (ins - I32Tensor:$concat_dim, - Variadic:$shape + TF_Int32Tensor:$concat_dim, + Variadic:$shape ); let results = (outs - Variadic:$offset + Variadic:$offset ); TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<1>; @@ -1740,6 +1787,34 @@ def TF_ConcatV2Op : TF_Op<"ConcatV2", [NoSideEffect]> { let hasCanonicalizer = 1; } +def TF_ConfigureDistributedTPUOp : TF_Op<"ConfigureDistributedTPU", []> { + let summary = [{ +Sets up the centralized structures for a distributed TPU system. + }]; + + let arguments = (ins + StrAttr:$embedding_config, + StrAttr:$tpu_embedding_config, + DefaultValuedAttr:$is_global_init, + DefaultValuedAttr:$enable_whole_mesh_compilations, + DefaultValuedAttr:$compilation_failure_closes_chips + ); + + let results = (outs + TF_StrTensor:$topology + ); +} + +def TF_ConfigureTPUEmbeddingOp : TF_Op<"ConfigureTPUEmbedding", []> { + let summary = "Sets up TPUEmbedding in a distributed TPU system."; + + let arguments = (ins + StrAttr:$config + ); + + let results = (outs); +} + def TF_ConjOp : TF_Op<"Conj", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Returns the complex conjugate of a complex number."; @@ -1826,8 +1901,8 @@ horizontal and vertices strides, `strides = [1, stride, stride, 1]`. }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I32]>:$input, - TensorOf<[BF16, F16, F32, F64, I32]>:$filter, + TensorOf<[TF_Bfloat16, TF_Float16, TF_Float32, TF_Float64, TF_Int32]>:$input, + TensorOf<[TF_Bfloat16, TF_Float16, TF_Float32, TF_Float64, TF_Int32]>:$filter, I64ArrayAttr:$strides, DefaultValuedAttr:$use_cudnn_on_gpu, @@ -1838,7 +1913,7 @@ horizontal and vertices strides, `strides = [1, stride, stride, 1]`. ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I32]>:$output + TensorOf<[TF_Bfloat16, TF_Float16, TF_Float32, TF_Float64, TF_Int32]>:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -1862,9 +1937,9 @@ Computes the gradients of convolution with respect to the filter. }]; let arguments = (ins - TF_FpTensor:$input, - I32Tensor:$filter_sizes, - TF_FpTensor:$out_backprop, + TF_FloatTensor:$input, + TF_Int32Tensor:$filter_sizes, + TF_FloatTensor:$out_backprop, I64ArrayAttr:$strides, DefaultValuedAttr:$use_cudnn_on_gpu, @@ -1875,7 +1950,7 @@ Computes the gradients of convolution with respect to the filter. ); let results = (outs - TF_FpTensor:$output + TF_FloatTensor:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -1895,9 +1970,9 @@ Computes the gradients of convolution with respect to the input. }]; let arguments = (ins - I32Tensor:$input_sizes, - TensorOf<[BF16, F16, F32, F64, I32]>:$filter, - TensorOf<[BF16, F16, F32, F64, I32]>:$out_backprop, + TF_Int32Tensor:$input_sizes, + TensorOf<[TF_Bfloat16, TF_Float16, TF_Float32, TF_Float64, TF_Int32]>:$filter, + TensorOf<[TF_Bfloat16, TF_Float16, TF_Float32, TF_Float64, TF_Int32]>:$out_backprop, I64ArrayAttr:$strides, DefaultValuedAttr:$use_cudnn_on_gpu, @@ -1908,7 +1983,7 @@ Computes the gradients of convolution with respect to the input. ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I32]>:$output + TensorOf<[TF_Bfloat16, TF_Float16, TF_Float32, TF_Float64, TF_Int32]>:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>; @@ -1940,8 +2015,8 @@ Our Conv3D implements a form of cross-correlation. }]; let arguments = (ins - TF_FpTensor:$input, - TF_FpTensor:$filter, + TF_FloatTensor:$input, + TF_FloatTensor:$filter, Confined]>:$strides, TF_AnyStrAttrOf<["SAME", "VALID"]>:$padding, @@ -1950,7 +2025,7 @@ Our Conv3D implements a form of cross-correlation. ); let results = (outs - TF_FpTensor:$output + TF_FloatTensor:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -1966,9 +2041,9 @@ Computes the gradients of 3-D convolution with respect to the filter. }]; let arguments = (ins - TF_FpTensor:$input, - I32Tensor:$filter_sizes, - TF_FpTensor:$out_backprop, + TF_FloatTensor:$input, + TF_Int32Tensor:$filter_sizes, + TF_FloatTensor:$out_backprop, Confined]>:$strides, TF_AnyStrAttrOf<["SAME", "VALID"]>:$padding, @@ -1977,7 +2052,7 @@ Computes the gradients of 3-D convolution with respect to the filter. ); let results = (outs - TF_FpTensor:$output + TF_FloatTensor:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -1990,8 +2065,8 @@ Computes the gradients of 3-D convolution with respect to the input. let arguments = (ins TF_I32OrI64Tensor:$input_sizes, - TF_FpTensor:$filter, - TF_FpTensor:$out_backprop, + TF_FloatTensor:$filter, + TF_FloatTensor:$out_backprop, Confined]>:$strides, TF_AnyStrAttrOf<["SAME", "VALID"]>:$padding, @@ -2000,7 +2075,7 @@ Computes the gradients of 3-D convolution with respect to the input. ); let results = (outs - TF_FpTensor:$output + TF_FloatTensor:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>; @@ -2079,7 +2154,7 @@ of corresponding 3-element vectors is cross-multiplied independently. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_CrossReplicaSumOp : TF_Op<"CrossReplicaSum", [NoSideEffect, TF_AllTypesMatch<["input", "output"]>]> { +def TF_CrossReplicaSumOp : TF_Op<"CrossReplicaSum", [NoSideEffect, TF_AllTypesMatch<["input", "output"]>, TF_NoConstantFold]> { let summary = "An Op to sum inputs across replicated TPU instances."; let description = [{ @@ -2092,12 +2167,12 @@ and `B, D, F, H` as group 1. Thus we get the outputs: }]; let arguments = (ins - TensorOf<[BF16, F16, F32, I32, TF_Uint32]>:$input, - I32Tensor:$group_assignment + TensorOf<[TF_Bfloat16, TF_Float16, TF_Float32, TF_Int32, TF_Uint32]>:$input, + TF_Int32Tensor:$group_assignment ); let results = (outs - TensorOf<[BF16, F16, F32, I32, TF_Uint32]>:$output + TensorOf<[TF_Bfloat16, TF_Float16, TF_Float32, TF_Int32, TF_Uint32]>:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -2140,7 +2215,7 @@ tf.cumprod([a, b, c], exclusive=True, reverse=True) # => [b * c, c, 1] }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$x, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$x, TF_I32OrI64Tensor:$axis, DefaultValuedAttr:$exclusive, @@ -2148,7 +2223,7 @@ tf.cumprod([a, b, c], exclusive=True, reverse=True) # => [b * c, c, 1] ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$out + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$out ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -2194,7 +2269,7 @@ tf.cumsum([a, b, c], exclusive=True, reverse=True) # => [b + c, c, 0] }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$x, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$x, TF_I32OrI64Tensor:$axis, DefaultValuedAttr:$exclusive, @@ -2202,7 +2277,7 @@ tf.cumsum([a, b, c], exclusive=True, reverse=True) # => [b + c, c, 0] ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$out + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$out ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -2339,7 +2414,7 @@ decoding partial jpeg image. let arguments = (ins TF_StrTensor:$contents, - I32Tensor:$crop_window, + TF_Int32Tensor:$crop_window, DefaultValuedAttr:$channels, DefaultValuedAttr:$ratio, @@ -2452,6 +2527,64 @@ is the same, though it is cleaner to use `tf.io.decode_image`. TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>; } +def TF_DeleteIteratorOp : TF_Op<"DeleteIterator", []> { + let summary = "A container for an iterator resource."; + + let arguments = (ins + Arg:$handle, + TF_VariantTensor:$deleter + ); + + let results = (outs); +} + +def TF_DeleteMemoryCacheOp : TF_Op<"DeleteMemoryCache", []> { + let summary = ""; + + let arguments = (ins + Arg:$handle, + TF_VariantTensor:$deleter + ); + + let results = (outs); +} + +def TF_DeleteMultiDeviceIteratorOp : TF_Op<"DeleteMultiDeviceIterator", []> { + let summary = "A container for an iterator resource."; + + let arguments = (ins + Arg:$multi_device_iterator, + Arg, "", [TF_DatasetIteratorRead]>:$iterators, + TF_VariantTensor:$deleter + ); + + let results = (outs); + + TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<1>; +} + +def TF_DeleteRandomSeedGeneratorOp : TF_Op<"DeleteRandomSeedGenerator", []> { + let summary = ""; + + let arguments = (ins + Arg:$handle, + TF_VariantTensor:$deleter + ); + + let results = (outs); +} + +def TF_DeleteSeedGeneratorOp : TF_Op<"DeleteSeedGenerator", []> { + let summary = ""; + + let arguments = (ins + Arg:$handle, + TF_VariantTensor:$deleter + ); + + let results = (outs); +} + def TF_DepthToSpaceOp : TF_Op<"DepthToSpace", [NoSideEffect]> { let summary = "DepthToSpace for tensors of type T."; @@ -2588,8 +2721,8 @@ horizontal and vertices strides, `strides = [1, stride, stride, 1]`. }]; let arguments = (ins - TF_FpTensor:$input, - TF_FpTensor:$filter, + TF_FloatTensor:$input, + TF_FloatTensor:$filter, I64ArrayAttr:$strides, TF_AnyStrAttrOf<["SAME", "VALID", "EXPLICIT"]>:$padding, @@ -2599,7 +2732,7 @@ horizontal and vertices strides, `strides = [1, stride, stride, 1]`. ); let results = (outs - TF_FpTensor:$output + TF_FloatTensor:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -2611,9 +2744,9 @@ Computes the gradients of depthwise convolution with respect to the filter. }]; let arguments = (ins - TF_FpTensor:$input, - I32Tensor:$filter_sizes, - TF_FpTensor:$out_backprop, + TF_FloatTensor:$input, + TF_Int32Tensor:$filter_sizes, + TF_FloatTensor:$out_backprop, I64ArrayAttr:$strides, TF_AnyStrAttrOf<["SAME", "VALID", "EXPLICIT"]>:$padding, @@ -2623,7 +2756,7 @@ Computes the gradients of depthwise convolution with respect to the filter. ); let results = (outs - TF_FpTensor:$output + TF_FloatTensor:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -2635,9 +2768,9 @@ Computes the gradients of depthwise convolution with respect to the input. }]; let arguments = (ins - I32Tensor:$input_sizes, - TF_FpTensor:$filter, - TF_FpTensor:$out_backprop, + TF_Int32Tensor:$input_sizes, + TF_FloatTensor:$filter, + TF_FloatTensor:$out_backprop, I64ArrayAttr:$strides, TF_AnyStrAttrOf<["SAME", "VALID", "EXPLICIT"]>:$padding, @@ -2647,12 +2780,42 @@ Computes the gradients of depthwise convolution with respect to the input. ); let results = (outs - TF_FpTensor:$output + TF_FloatTensor:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>; } +def TF_DeserializeIteratorOp : TF_Op<"DeserializeIterator", []> { + let summary = [{ +Converts the given variant tensor to an iterator and stores it in the given resource. + }]; + + let arguments = (ins + Arg:$resource_handle, + TF_VariantTensor:$serialized + ); + + let results = (outs); +} + +def TF_DestroyResourceOp : TF_Op<"DestroyResourceOp", []> { + let summary = "Deletes the resource specified by the handle."; + + let description = [{ +All subsequent operations using the resource will result in a NotFound +error status. + }]; + + let arguments = (ins + TF_ResourceTensor:$resource, + + DefaultValuedAttr:$ignore_lookup_error + ); + + let results = (outs); +} + def TF_DeviceIndexOp : TF_Op<"DeviceIndex", [NoSideEffect]> { let summary = "Return the index of device the op runs."; @@ -2668,7 +2831,7 @@ this op runs. The length of the list is returned in two cases: ); let results = (outs - I32Tensor:$index + TF_Int32Tensor:$index ); } @@ -2696,11 +2859,11 @@ tf.diag(diagonal) ==> [[1, 0, 0, 0] }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$diagonal + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64]>:$diagonal ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$output + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64]>:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -2731,11 +2894,11 @@ tf.diag_part(input) ==> [1, 2, 3, 4] }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$input + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64]>:$input ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$diagonal + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64]>:$diagonal ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -2751,11 +2914,11 @@ Computes Psi, the derivative of Lgamma (the log of the absolute value of }]; let arguments = (ins - TF_FpTensor:$x + TF_FloatTensor:$x ); let results = (outs - TF_FpTensor:$y + TF_FloatTensor:$y ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -2771,12 +2934,12 @@ def TF_DivOp : TF_Op<"Div", [NoSideEffect, ResultsBroadcastableShape, TF_SameOpe }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$x, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$y + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Uint16, TF_Uint8]>:$x, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Uint16, TF_Uint8]>:$y ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$z + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Uint16, TF_Uint8]>:$z ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -2786,25 +2949,24 @@ def TF_DivOp : TF_Op<"Div", [NoSideEffect, ResultsBroadcastableShape, TF_SameOpe let hasFolder = 1; } -def TF_DivNoNanOp : TF_Op<"DivNoNan", [NoSideEffect, ResultsBroadcastableShape, TF_SameOperandsAndResultElementTypeResolveRef]>, - WithBroadcastableBinOpBuilder { - let summary = "Returns 0 if the denominator is zero."; +def TF_DummyMemoryCacheOp : TF_Op<"DummyMemoryCache", []> { + let summary = ""; - let description = [{ -*NOTE*: `DivNoNan` supports broadcasting. More about broadcasting -[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) - }]; - - let arguments = (ins - TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$x, - TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$y - ); + let arguments = (ins); let results = (outs - TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$z + Res:$handle ); +} - TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +def TF_DummySeedGeneratorOp : TF_Op<"DummySeedGenerator", []> { + let summary = ""; + + let arguments = (ins); + + let results = (outs + Res:$handle + ); } def TF_DynamicStitchOp : TF_Op<"DynamicStitch", [NoSideEffect, SameVariadicOperandSize]> { @@ -2878,7 +3040,7 @@ as illustrated on the following example: }]; let arguments = (ins - Variadic:$indices, + Variadic:$indices, Variadic:$data ); @@ -3006,11 +3168,11 @@ See [Fast and Accurate Deep Network Learning by Exponential Linear Units (ELUs) }]; let arguments = (ins - TF_FpTensor:$features + TF_FloatTensor:$features ); let results = (outs - TF_FpTensor:$activations + TF_FloatTensor:$activations ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -3022,12 +3184,12 @@ Computes gradients for the exponential linear (Elu) operation. }]; let arguments = (ins - TF_FpTensor:$gradients, - TF_FpTensor:$outputs + TF_FloatTensor:$gradients, + TF_FloatTensor:$outputs ); let results = (outs - TF_FpTensor:$backprops + TF_FloatTensor:$backprops ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -3041,7 +3203,7 @@ This operation creates a tensor of `shape` and `dtype`. }]; let arguments = (ins - I32Tensor:$shape, + TF_Int32Tensor:$shape, DefaultValuedAttr:$init ); @@ -3165,21 +3327,20 @@ tf.math.equal(x, y) ==> array([True, True]) }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$x, - TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$y, + TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint16, TF_Quint16, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$x, + TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint16, TF_Quint16, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$y, DefaultValuedAttr:$incompatible_shape_error ); let results = (outs - I1Tensor:$z + TF_BoolTensor:$z ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; let builders = [ - OpBuilder<"OpBuilder& builder, OperationState& result, Value x, " - "Value y, BoolAttr incompatible_shape_error"> + OpBuilder<"Value x, Value y, BoolAttr incompatible_shape_error"> ]; let verifier = [{ @@ -3191,11 +3352,11 @@ def TF_ErfOp : TF_Op<"Erf", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Computes the Gauss error function of `x` element-wise."; let arguments = (ins - TF_FpTensor:$x + TF_FloatTensor:$x ); let results = (outs - TF_FpTensor:$y + TF_FloatTensor:$y ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -3207,11 +3368,11 @@ Computes the complementary error function of `x` element-wise. }]; let arguments = (ins - TF_FpTensor:$x + TF_FloatTensor:$x ); let results = (outs - TF_FpTensor:$y + TF_FloatTensor:$y ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -3221,11 +3382,11 @@ def TF_ErfinvOp : TF_Op<"Erfinv", [NoSideEffect]> { let summary = ""; let arguments = (ins - TF_FpTensor:$x + TF_FloatTensor:$x ); let results = (outs - TF_FpTensor:$y + TF_FloatTensor:$y ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -3325,8 +3486,7 @@ size 1. TF_DerivedOperandTypeAttr Tdim = TF_DerivedOperandTypeAttr<1>; let builders = [ - OpBuilder<"OpBuilder& builder, OperationState& result, Value condition, " - "Value dim"> + OpBuilder<"Value condition, Value dim"> ]; } @@ -3366,7 +3526,7 @@ Extract `patches` from `images` and put them in the "depth" output dimension. }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$images, + TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$images, Confined]>:$ksizes, Confined]>:$strides, @@ -3375,7 +3535,7 @@ Extract `patches` from `images` and put them in the "depth" output dimension. ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$patches + TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$patches ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -3486,6 +3646,28 @@ Quantization is called fake since the output is still in floating point. }]; let arguments = (ins + TF_Float32Tensor:$inputs, + + DefaultValuedAttr:$min, + DefaultValuedAttr:$max, + DefaultValuedAttr:$num_bits, + DefaultValuedAttr:$narrow_range + ); + + let results = (outs + TF_Float32Tensor:$outputs + ); + + let verifier = [{ + return Verify(*this); + }]; +} + +def TF_FakeQuantWithMinMaxArgsGradientOp : TF_Op<"FakeQuantWithMinMaxArgsGradient", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Compute gradients for a FakeQuantWithMinMaxArgs operation."; + + let arguments = (ins + F32Tensor:$gradients, F32Tensor:$inputs, DefaultValuedAttr:$min, @@ -3495,12 +3677,8 @@ Quantization is called fake since the output is still in floating point. ); let results = (outs - F32Tensor:$outputs + F32Tensor:$backprops ); - - let verifier = [{ - return Verify(*this); - }]; } def TF_FakeQuantWithMinMaxVarsOp : TF_Op<"FakeQuantWithMinMaxVars", [NoSideEffect]> { @@ -3536,6 +3714,28 @@ values. }]; let arguments = (ins + TF_Float32Tensor:$inputs, + TF_Float32Tensor:$min, + TF_Float32Tensor:$max, + + DefaultValuedAttr:$num_bits, + DefaultValuedAttr:$narrow_range + ); + + let results = (outs + TF_Float32Tensor:$outputs + ); + + let verifier = [{ + return Verify(*this); + }]; +} + +def TF_FakeQuantWithMinMaxVarsGradientOp : TF_Op<"FakeQuantWithMinMaxVarsGradient", [NoSideEffect]> { + let summary = "Compute gradients for a FakeQuantWithMinMaxVars operation."; + + let arguments = (ins + F32Tensor:$gradients, F32Tensor:$inputs, F32Tensor:$min, F32Tensor:$max, @@ -3545,12 +3745,10 @@ values. ); let results = (outs - F32Tensor:$outputs + F32Tensor:$backprops_wrt_input, + F32Tensor:$backprop_wrt_min, + F32Tensor:$backprop_wrt_max ); - - let verifier = [{ - return Verify(*this); - }]; } def TF_FakeQuantWithMinMaxVarsPerChannelOp : TF_Op<"FakeQuantWithMinMaxVarsPerChannel", [NoSideEffect]> { @@ -3587,16 +3785,16 @@ values. }]; let arguments = (ins - F32Tensor:$inputs, - F32Tensor:$min, - F32Tensor:$max, + TF_Float32Tensor:$inputs, + TF_Float32Tensor:$min, + TF_Float32Tensor:$max, DefaultValuedAttr:$num_bits, DefaultValuedAttr:$narrow_range ); let results = (outs - F32Tensor:$outputs + TF_Float32Tensor:$outputs ); let verifier = [{ @@ -3647,20 +3845,20 @@ fill([2, 3], 9) ==> [[9, 9, 9] let hasFolder = 1; - let builders = [OpBuilder< - "OpBuilder &builder, OperationState &result, Value dims, Value value" - >]; + let builders = [ + OpBuilder<"Value dims, Value value"> + ]; } def TF_FloorOp : TF_Op<"Floor", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Returns element-wise largest integer not greater than x."; let arguments = (ins - TF_FpTensor:$x + TF_FloatTensor:$x ); let results = (outs - TF_FpTensor:$y + TF_FloatTensor:$y ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -3676,12 +3874,12 @@ def TF_FloorDivOp : TF_Op<"FloorDiv", [NoSideEffect, ResultsBroadcastableShape]> }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$x, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$y + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Uint16, TF_Uint8]>:$x, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Uint16, TF_Uint8]>:$y ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$z + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Uint16, TF_Uint8]>:$z ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -3702,12 +3900,12 @@ with a flooring divide. E.g. `floor(x / y) * y + mod(x, y) = x`. }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Uint64]>:$x, - TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Uint64]>:$y + TensorOf<[TF_Bfloat16, TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64, TF_Uint64]>:$x, + TensorOf<[TF_Bfloat16, TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64, TF_Uint64]>:$y ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Uint64]>:$z + TensorOf<[TF_Bfloat16, TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64, TF_Uint64]>:$z ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -3722,11 +3920,11 @@ The size of 1D Tensors matches the dimension C of the 4D Tensors. }]; let arguments = (ins - F32Tensor:$x, - F32Tensor:$scale, - F32Tensor:$offset, - F32Tensor:$mean, - F32Tensor:$variance, + TF_Float32Tensor:$x, + TF_Float32Tensor:$scale, + TF_Float32Tensor:$offset, + TF_Float32Tensor:$mean, + TF_Float32Tensor:$variance, DefaultValuedAttr:$epsilon, DefaultValuedAttr:$exponential_avg_factor, @@ -3735,15 +3933,17 @@ The size of 1D Tensors matches the dimension C of the 4D Tensors. ); let results = (outs - F32Tensor:$y, - F32Tensor:$batch_mean, - F32Tensor:$batch_variance, - F32Tensor:$reserve_space_1, - F32Tensor:$reserve_space_2 + TF_Float32Tensor:$y, + TF_Float32Tensor:$batch_mean, + TF_Float32Tensor:$batch_variance, + TF_Float32Tensor:$reserve_space_1, + TF_Float32Tensor:$reserve_space_2 ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + let hasCanonicalizer = 1; + let verifier = [{ return Verify(*this); }]; @@ -3758,11 +3958,11 @@ The size of 1D Tensors matches the dimension C of the 4D Tensors. }]; let arguments = (ins - F32Tensor:$y_backprop, - F32Tensor:$x, - F32Tensor:$scale, - F32Tensor:$reserve_space_1, - F32Tensor:$reserve_space_2, + TF_Float32Tensor:$y_backprop, + TF_Float32Tensor:$x, + TF_Float32Tensor:$scale, + TF_Float32Tensor:$reserve_space_1, + TF_Float32Tensor:$reserve_space_2, DefaultValuedAttr:$epsilon, DefaultValuedAttr:$data_format, @@ -3770,11 +3970,11 @@ The size of 1D Tensors matches the dimension C of the 4D Tensors. ); let results = (outs - F32Tensor:$x_backprop, - F32Tensor:$scale_backprop, - F32Tensor:$offset_backprop, - F32Tensor:$reserve_space_3, - F32Tensor:$reserve_space_4 + TF_Float32Tensor:$x_backprop, + TF_Float32Tensor:$scale_backprop, + TF_Float32Tensor:$offset_backprop, + TF_Float32Tensor:$reserve_space_3, + TF_Float32Tensor:$reserve_space_4 ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -3789,11 +3989,11 @@ The size of 1D Tensors matches the dimension C of the 4D Tensors. }]; let arguments = (ins - TensorOf<[BF16, F16, F32]>:$y_backprop, - TensorOf<[BF16, F16, F32]>:$x, - F32Tensor:$scale, - F32Tensor:$reserve_space_1, - F32Tensor:$reserve_space_2, + TensorOf<[TF_Bfloat16, TF_Float16, TF_Float32]>:$y_backprop, + TensorOf<[TF_Bfloat16, TF_Float16, TF_Float32]>:$x, + TF_Float32Tensor:$scale, + TF_Float32Tensor:$reserve_space_1, + TF_Float32Tensor:$reserve_space_2, DefaultValuedAttr:$epsilon, DefaultValuedAttr:$data_format, @@ -3801,11 +4001,11 @@ The size of 1D Tensors matches the dimension C of the 4D Tensors. ); let results = (outs - TensorOf<[BF16, F16, F32]>:$x_backprop, - F32Tensor:$scale_backprop, - F32Tensor:$offset_backprop, - F32Tensor:$reserve_space_3, - F32Tensor:$reserve_space_4 + TensorOf<[TF_Bfloat16, TF_Float16, TF_Float32]>:$x_backprop, + TF_Float32Tensor:$scale_backprop, + TF_Float32Tensor:$offset_backprop, + TF_Float32Tensor:$reserve_space_3, + TF_Float32Tensor:$reserve_space_4 ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -3821,24 +4021,24 @@ The size of 1D Tensors matches the dimension C of the 4D Tensors. }]; let arguments = (ins - TensorOf<[BF16, F16, F32]>:$y_backprop, - TensorOf<[BF16, F16, F32]>:$x, - F32Tensor:$scale, - F32Tensor:$reserve_space_1, - F32Tensor:$reserve_space_2, - F32Tensor:$reserve_space_3, + TensorOf<[TF_Bfloat16, TF_Float16, TF_Float32]>:$y_backprop, + TensorOf<[TF_Bfloat16, TF_Float16, TF_Float32]>:$x, + TF_Float32Tensor:$scale, + TF_Float32Tensor:$reserve_space_1, + TF_Float32Tensor:$reserve_space_2, + TF_Float32Tensor:$reserve_space_3, DefaultValuedAttr:$epsilon, - DefaultValuedAttr:$data_format, + DefaultValuedAttr, "NHWC">:$data_format, DefaultValuedAttr:$is_training ); let results = (outs - TensorOf<[BF16, F16, F32]>:$x_backprop, - F32Tensor:$scale_backprop, - F32Tensor:$offset_backprop, - F32Tensor:$reserve_space_4, - F32Tensor:$reserve_space_5 + TensorOf<[TF_Bfloat16, TF_Float16, TF_Float32]>:$x_backprop, + TF_Float32Tensor:$scale_backprop, + TF_Float32Tensor:$offset_backprop, + TF_Float32Tensor:$reserve_space_4, + TF_Float32Tensor:$reserve_space_5 ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -3853,6 +4053,95 @@ The size of 1D Tensors matches the dimension C of the 4D Tensors. }]; } +def TF_FusedBatchNormV2Op : TF_Op<"FusedBatchNormV2", [NoSideEffect, TF_FoldOperandsTransposeInterface, TF_LayoutSensitiveInterface]> { + let summary = "Batch normalization."; + + let description = [{ +Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW". +The size of 1D Tensors matches the dimension C of the 4D Tensors. + }]; + + let arguments = (ins + TensorOf<[TF_Bfloat16, TF_Float16, TF_Float32]>:$x, + TF_Float32Tensor:$scale, + TF_Float32Tensor:$offset, + TF_Float32Tensor:$mean, + TF_Float32Tensor:$variance, + + DefaultValuedAttr:$epsilon, + DefaultValuedAttr:$exponential_avg_factor, + DefaultValuedAttr:$data_format, + DefaultValuedAttr:$is_training + ); + + let results = (outs + TensorOf<[TF_Bfloat16, TF_Float16, TF_Float32]>:$y, + TF_Float32Tensor:$batch_mean, + TF_Float32Tensor:$batch_variance, + TF_Float32Tensor:$reserve_space_1, + TF_Float32Tensor:$reserve_space_2 + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr U = TF_DerivedOperandTypeAttr<1>; + + let extraClassDeclaration = [{ + // TF_FoldOperandsTransposeInterface: + SmallVector GetLayoutDependentArgs() { return {0}; } + SmallVector GetLayoutDependentResults() { return {0}; } + LogicalResult FoldOperandsPermutation(ArrayRef permutation); + + // TF_LayoutSensitiveInterface: + StringRef GetOptimalLayout(const RuntimeDevices& devices); + LogicalResult UpdateDataFormat(StringRef data_format); + }]; +} + +def TF_FusedBatchNormV3Op : TF_Op<"FusedBatchNormV3", [NoSideEffect, TF_FoldOperandsTransposeInterface, TF_LayoutSensitiveInterface]> { + let summary = "Batch normalization."; + + let description = [{ +Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW". +The size of 1D Tensors matches the dimension C of the 4D Tensors. + }]; + + let arguments = (ins + TensorOf<[TF_Bfloat16, TF_Float16, TF_Float32]>:$x, + TF_Float32Tensor:$scale, + TF_Float32Tensor:$offset, + TF_Float32Tensor:$mean, + TF_Float32Tensor:$variance, + + DefaultValuedAttr:$epsilon, + DefaultValuedAttr:$exponential_avg_factor, + DefaultValuedAttr, "NHWC">:$data_format, + DefaultValuedAttr:$is_training + ); + + let results = (outs + TensorOf<[TF_Bfloat16, TF_Float16, TF_Float32]>:$y, + TF_Float32Tensor:$batch_mean, + TF_Float32Tensor:$batch_variance, + TF_Float32Tensor:$reserve_space_1, + TF_Float32Tensor:$reserve_space_2, + TF_Float32Tensor:$reserve_space_3 + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr U = TF_DerivedOperandTypeAttr<1>; + + let extraClassDeclaration = [{ + // TF_FoldOperandsTransposeInterface: + SmallVector GetLayoutDependentArgs() { return {0}; } + SmallVector GetLayoutDependentResults() { return {0}; } + LogicalResult FoldOperandsPermutation(ArrayRef permutation); + + // TF_LayoutSensitiveInterface: + StringRef GetOptimalLayout(const RuntimeDevices& devices); + LogicalResult UpdateDataFormat(StringRef data_format); + }]; +} + def TF_GatherOp : TF_Op<"Gather", [NoSideEffect]> { let summary = "Gather slices from `params` according to `indices`."; @@ -4107,7 +4396,7 @@ tf.math.greater(x, y) ==> [False, False, True] ); let results = (outs - I1Tensor:$z + TF_BoolTensor:$z ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -4140,7 +4429,7 @@ tf.math.greater_equal(x, y) ==> [True, False, True, True] ); let results = (outs - I1Tensor:$z + TF_BoolTensor:$z ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -4158,11 +4447,11 @@ See `rgb_to_hsv` for a description of the HSV encoding. }]; let arguments = (ins - TF_FpTensor:$images + TF_FloatTensor:$images ); let results = (outs - TF_FpTensor:$output + TF_FloatTensor:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -4186,7 +4475,7 @@ table will be immutable. ); let results = (outs - TF_ResourceTensor:$table_handle + Res:$table_handle ); } @@ -4268,7 +4557,7 @@ larger, the dimension is padded with zeros. let arguments = (ins TensorOf<[TF_Complex128, TF_Complex64]>:$input, - I32Tensor:$fft_length + TF_Int32Tensor:$fft_length ); let results = (outs @@ -4301,7 +4590,7 @@ the dimension is padded with zeros. let arguments = (ins TensorOf<[TF_Complex128, TF_Complex64]>:$input, - I32Tensor:$fft_length + TF_Int32Tensor:$fft_length ); let results = (outs @@ -4334,7 +4623,7 @@ the dimension is padded with zeros. let arguments = (ins TensorOf<[TF_Complex128, TF_Complex64]>:$input, - I32Tensor:$fft_length + TF_Int32Tensor:$fft_length ); let results = (outs @@ -4491,6 +4780,39 @@ tf.imag(input) ==> [4.75, 5.75] TF_DerivedResultTypeAttr Tout = TF_DerivedResultTypeAttr<0>; } +def TF_InTopKV2Op : TF_Op<"InTopKV2", [NoSideEffect]> { + let summary = "Says whether the targets are in the top `K` predictions."; + + let description = [{ +This outputs a `batch_size` bool array, an entry `out[i]` is `true` if the +prediction for the target class is among the top `k` predictions among +all predictions for example `i`. Note that the behavior of `InTopK` differs +from the `TopK` op in its handling of ties; if multiple classes have the +same prediction value and straddle the top-`k` boundary, all of those +classes are considered to be in the top `k`. + +More formally, let + + \\(predictions_i\\) be the predictions for all classes for example `i`, + \\(targets_i\\) be the target class for example `i`, + \\(out_i\\) be the output for example `i`, + +$$out_i = predictions_{i, targets_i} \in TopKIncludingTies(predictions_i)$$ + }]; + + let arguments = (ins + F32Tensor:$predictions, + TF_I32OrI64Tensor:$targets, + TF_I32OrI64Tensor:$k + ); + + let results = (outs + I1Tensor:$precision + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>; +} + def TF_InfeedDequeueOp : TF_Op<"InfeedDequeue", []> { let summary = [{ A placeholder op for a value that will be fed into the computation. @@ -4507,33 +4829,21 @@ A placeholder op for a value that will be fed into the computation. TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>; } -def TF_InitializeTableFromTextFileV2Op : TF_Op<"InitializeTableFromTextFileV2", []> { - let summary = "Initializes a table from a text file."; - - let description = [{ -It inserts one key-value pair into the table for each line of the file. -The key and value is extracted from the whole line content, elements from the -split line based on `delimiter` or the line number (starting from zero). -Where to extract the key and value from a line is specified by `key_index` and -`value_index`. - -- A value of -1 means use the line number(starting from zero), expects `int64`. -- A value of -2 means use the whole line content, expects `string`. -- A value >= 0 means use the index (starting at zero) of the split line based - on `delimiter`. +def TF_InitializeTableV2Op : TF_Op<"InitializeTableV2", []> { + let summary = [{ +Table initializer that takes two tensors for keys and values respectively. }]; let arguments = (ins - TF_ResourceTensor:$table_handle, - TF_StrTensor:$filename, - - Confined]>:$key_index, - Confined]>:$value_index, - Confined, [IntMinValue<-1>]>:$vocab_size, - DefaultValuedAttr:$delimiter + Arg:$table_handle, + TF_Tensor:$keys, + TF_Tensor:$values ); let results = (outs); + + TF_DerivedOperandTypeAttr Tval = TF_DerivedOperandTypeAttr<2>; + TF_DerivedOperandTypeAttr Tkey = TF_DerivedOperandTypeAttr<1>; } def TF_InplaceUpdateOp : TF_Op<"InplaceUpdate", [NoSideEffect]> { @@ -4548,7 +4858,7 @@ operation create / operate on a copy of `x`. let arguments = (ins TF_Tensor:$x, - I32Tensor:$i, + TF_Int32Tensor:$i, TF_Tensor:$v ); @@ -4567,11 +4877,11 @@ I.e., \\(y = 1 / x\\). }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64]>:$x + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8]>:$x ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64]>:$y + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8]>:$y ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -4690,11 +5000,11 @@ tf.math.is_finite(x) ==> [True, True, True, False, False] }]; let arguments = (ins - TF_FpTensor:$x + TF_FloatTensor:$x ); let results = (outs - I1Tensor:$y + TF_BoolTensor:$y ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -4717,11 +5027,11 @@ tf.math.is_inf(x) ==> [False, True, False, True] }]; let arguments = (ins - TF_FpTensor:$x + TF_FloatTensor:$x ); let results = (outs - I1Tensor:$y + TF_BoolTensor:$y ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -4744,21 +5054,68 @@ tf.math.is_nan(x) ==> [False, True, False, True, False] }]; let arguments = (ins - TF_FpTensor:$x + TF_FloatTensor:$x ); let results = (outs - I1Tensor:$y + TF_BoolTensor:$y ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_IteratorOp : TF_Op<"Iterator", []> { + let summary = "A container for an iterator resource."; + + let arguments = (ins + StrAttr:$shared_name, + StrAttr:$container, + Confined]>:$output_types, + Confined]>:$output_shapes + ); + + let results = (outs + Res:$handle + ); +} + +def TF_IteratorFromStringHandleOp : TF_Op<"IteratorFromStringHandle", []> { + let summary = [{ +Converts the given string representing a handle to an iterator to a resource. + }]; + + let arguments = (ins + TF_StrTensor:$string_handle, + + DefaultValuedAttr:$output_types, + DefaultValuedAttr:$output_shapes + ); + + let results = (outs + Res:$resource_handle + ); +} + +def TF_IteratorFromStringHandleV2Op : TF_Op<"IteratorFromStringHandleV2", []> { + let summary = ""; + + let arguments = (ins + TF_StrTensor:$string_handle, + + DefaultValuedAttr:$output_types, + DefaultValuedAttr:$output_shapes + ); + + let results = (outs + Res:$resource_handle + ); +} + def TF_IteratorGetNextOp : TF_Op<"IteratorGetNext", []> { let summary = "Gets the next output from the given iterator ."; let arguments = (ins - TF_ResourceTensor:$iterator + Arg:$iterator ); let results = (outs @@ -4769,6 +5126,74 @@ def TF_IteratorGetNextOp : TF_Op<"IteratorGetNext", []> { TF_DerivedResultTypeListAttr output_types = TF_DerivedResultTypeListAttr<0>; } +def TF_IteratorGetNextAsOptionalOp : TF_Op<"IteratorGetNextAsOptional", []> { + let summary = [{ +Gets the next output from the given iterator as an Optional variant. + }]; + + let arguments = (ins + Arg:$iterator, + + Confined]>:$output_types, + Confined]>:$output_shapes + ); + + let results = (outs + TF_VariantTensor:$optional + ); +} + +def TF_IteratorGetNextSyncOp : TF_Op<"IteratorGetNextSync", []> { + let summary = "Gets the next output from the given iterator."; + + let description = [{ +This operation is a synchronous version IteratorGetNext. It should only be used +in situations where the iterator does not block the calling thread, or where +the calling thread is not a member of the thread pool used to execute parallel +operations (e.g. in eager mode). + }]; + + let arguments = (ins + Arg:$iterator + ); + + let results = (outs + Variadic:$components + ); + + TF_DerivedResultShapeListAttr output_shapes = TF_DerivedResultShapeListAttr<0>; + TF_DerivedResultTypeListAttr output_types = TF_DerivedResultTypeListAttr<0>; +} + +def TF_IteratorToStringHandleOp : TF_Op<"IteratorToStringHandle", []> { + let summary = [{ +Converts the given `resource_handle` representing an iterator to a string. + }]; + + let arguments = (ins + Arg:$resource_handle + ); + + let results = (outs + TF_StrTensor:$string_handle + ); +} + +def TF_IteratorV2Op : TF_Op<"IteratorV2", []> { + let summary = ""; + + let arguments = (ins + StrAttr:$shared_name, + StrAttr:$container, + Confined]>:$output_types, + Confined]>:$output_shapes + ); + + let results = (outs + Res:$handle + ); +} + def TF_L2LossOp : TF_Op<"L2Loss", [NoSideEffect]> { let summary = "L2 Loss."; @@ -4779,11 +5204,11 @@ Computes half the L2 norm of a tensor without the `sqrt`: }]; let arguments = (ins - TF_FpTensor:$t + TF_FloatTensor:$t ); let results = (outs - TF_FpTensor:$output + TF_FloatTensor:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -4807,7 +5232,7 @@ convolutional neural networks (NIPS 2012)](http://papers.nips.cc/paper/4824-imag }]; let arguments = (ins - TensorOf<[BF16, F16, F32]>:$input, + TensorOf<[TF_Bfloat16, TF_Float16, TF_Float32]>:$input, DefaultValuedAttr:$depth_radius, DefaultValuedAttr:$bias, @@ -4816,7 +5241,7 @@ convolutional neural networks (NIPS 2012)](http://papers.nips.cc/paper/4824-imag ); let results = (outs - TensorOf<[BF16, F16, F32]>:$output + TensorOf<[TF_Bfloat16, TF_Float16, TF_Float32]>:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -4826,9 +5251,9 @@ def TF_LRNGradOp : TF_Op<"LRNGrad", [NoSideEffect]> { let summary = "Gradients for Local Response Normalization."; let arguments = (ins - TensorOf<[BF16, F16, F32]>:$input_grads, - TensorOf<[BF16, F16, F32]>:$input_image, - TensorOf<[BF16, F16, F32]>:$output_image, + TensorOf<[TF_Bfloat16, TF_Float16, TF_Float32]>:$input_grads, + TensorOf<[TF_Bfloat16, TF_Float16, TF_Float32]>:$input_image, + TensorOf<[TF_Bfloat16, TF_Float16, TF_Float32]>:$output_image, DefaultValuedAttr:$depth_radius, DefaultValuedAttr:$bias, @@ -4837,28 +5262,33 @@ def TF_LRNGradOp : TF_Op<"LRNGrad", [NoSideEffect]> { ); let results = (outs - TensorOf<[BF16, F16, F32]>:$output + TensorOf<[TF_Bfloat16, TF_Float16, TF_Float32]>:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_LeakyReluOp : TF_Op<"LeakyRelu", [NoSideEffect, SameOperandsAndResultType]> { +def TF_LeakyReluOp : TF_Op<"LeakyRelu", [NoSideEffect, SameOperandsAndResultType, TF_ContractionFusableInterface]> { let summary = "Computes rectified linear: `max(features, features * alpha)`."; let arguments = (ins - TF_FpTensor:$features, + TF_FloatTensor:$features, DefaultValuedAttr:$alpha ); let results = (outs - TF_FpTensor:$activations + TF_FloatTensor:$activations ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; let hasFolder = 1; + + let extraClassDeclaration = [{ + // TF_ContractionFusableInterface: + Optional GetContractionFusion(); + }]; } def TF_LeakyReluGradOp : TF_Op<"LeakyReluGrad", [NoSideEffect, SameOperandsAndResultType]> { @@ -4867,14 +5297,14 @@ Computes rectified linear gradients for a LeakyRelu operation. }]; let arguments = (ins - TF_FpTensor:$gradients, - TF_FpTensor:$features, + TF_FloatTensor:$gradients, + TF_FloatTensor:$features, DefaultValuedAttr:$alpha ); let results = (outs - TF_FpTensor:$backprops + TF_FloatTensor:$backprops ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -4956,7 +5386,7 @@ tf.math.less(x, y) ==> [False, True, True] ); let results = (outs - I1Tensor:$z + TF_BoolTensor:$z ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -4989,7 +5419,7 @@ tf.math.less_equal(x, y) ==> [True, True, True] ); let results = (outs - I1Tensor:$z + TF_BoolTensor:$z ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -5013,11 +5443,11 @@ tf.math.lgamma(x) ==> [inf, 0.5723649, 0., 2.4537368, inf, -4.6477685] }]; let arguments = (ins - TF_FpTensor:$x + TF_FloatTensor:$x ); let results = (outs - TF_FpTensor:$y + TF_FloatTensor:$y ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -5039,13 +5469,13 @@ tf.linspace(10.0, 12.0, 3, name="linspace") => [ 10.0 11.0 12.0] }]; let arguments = (ins - TF_FpTensor:$start, - TF_FpTensor:$stop, + TF_FloatTensor:$start, + TF_FloatTensor:$stop, TF_I32OrI64Tensor:$num ); let results = (outs - TF_FpTensor:$output + TF_FloatTensor:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -5157,11 +5587,11 @@ For each batch `i` and class `j` we have }]; let arguments = (ins - TF_FpTensor:$logits + TF_FloatTensor:$logits ); let results = (outs - TF_FpTensor:$logsoftmax + TF_FloatTensor:$logsoftmax ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -5177,12 +5607,12 @@ def TF_LogicalAndOp : TF_Op<"LogicalAnd", [Commutative, NoSideEffect, ResultsBro }]; let arguments = (ins - I1Tensor:$x, - I1Tensor:$y + TF_BoolTensor:$x, + TF_BoolTensor:$y ); let results = (outs - I1Tensor:$z + TF_BoolTensor:$z ); } @@ -5190,11 +5620,11 @@ def TF_LogicalNotOp : TF_Op<"LogicalNot", [NoSideEffect, SameOperandsAndResultTy let summary = "Returns the truth value of `NOT x` element-wise."; let arguments = (ins - I1Tensor:$x + TF_BoolTensor:$x ); let results = (outs - I1Tensor:$y + TF_BoolTensor:$y ); let hasCanonicalizer = 1; @@ -5210,15 +5640,31 @@ def TF_LogicalOrOp : TF_Op<"LogicalOr", [Commutative, NoSideEffect, ResultsBroad }]; let arguments = (ins - I1Tensor:$x, - I1Tensor:$y + TF_BoolTensor:$x, + TF_BoolTensor:$y ); let results = (outs - I1Tensor:$z + TF_BoolTensor:$z ); } +def TF_LookupTableExportV2Op : TF_Op<"LookupTableExportV2", []> { + let summary = "Outputs all keys and values in the table."; + + let arguments = (ins + Arg:$table_handle + ); + + let results = (outs + TF_Tensor:$keys, + TF_Tensor:$values + ); + + TF_DerivedResultTypeAttr Tkeys = TF_DerivedResultTypeAttr<0>; + TF_DerivedResultTypeAttr Tvalues = TF_DerivedResultTypeAttr<1>; +} + def TF_LookupTableFindV2Op : TF_Op<"LookupTableFindV2", []> { let summary = "Looks up keys in a table, outputs the corresponding values."; @@ -5231,7 +5677,7 @@ table. It must also be of the same type as the table values. }]; let arguments = (ins - TF_ResourceTensor:$table_handle, + Arg:$table_handle, TF_Tensor:$keys, TF_Tensor:$default_value ); @@ -5255,7 +5701,7 @@ The tensor `values` must be of the type of the table values. }]; let arguments = (ins - TF_ResourceTensor:$table_handle, + Arg:$table_handle, TF_Tensor:$keys, TF_Tensor:$values ); @@ -5266,6 +5712,44 @@ The tensor `values` must be of the type of the table values. TF_DerivedOperandTypeAttr Tout = TF_DerivedOperandTypeAttr<2>; } +def TF_LookupTableInsertV2Op : TF_Op<"LookupTableInsertV2", []> { + let summary = "Updates the table to associates keys with values."; + + let description = [{ +The tensor `keys` must be of the same type as the keys of the table. +The tensor `values` must be of the type of the table values. + }]; + + let arguments = (ins + Arg:$table_handle, + TF_Tensor:$keys, + TF_Tensor:$values + ); + + let results = (outs); + + TF_DerivedOperandTypeAttr Tin = TF_DerivedOperandTypeAttr<1>; + TF_DerivedOperandTypeAttr Tout = TF_DerivedOperandTypeAttr<2>; +} + +def TF_LookupTableRemoveV2Op : TF_Op<"LookupTableRemoveV2", []> { + let summary = "Removes keys and its associated values from a table."; + + let description = [{ +The tensor `keys` must of the same type as the keys of the table. Keys not +already in the table are silently ignored. + }]; + + let arguments = (ins + Arg:$table_handle, + TF_Tensor:$keys + ); + + let results = (outs); + + TF_DerivedOperandTypeAttr Tin = TF_DerivedOperandTypeAttr<1>; +} + def TF_LookupTableSizeV2Op : TF_Op<"LookupTableSizeV2", []> { let summary = "Computes the number of elements in the given table."; @@ -5274,7 +5758,7 @@ def TF_LookupTableSizeV2Op : TF_Op<"LookupTableSizeV2", []> { ); let results = (outs - I64Tensor:$size + TF_Int64Tensor:$size ); } @@ -5316,6 +5800,24 @@ A 2-D example: TF_DerivedResultTypeAttr out_type = TF_DerivedResultTypeAttr<0>; } +def TF_MakeIteratorOp : TF_Op<"MakeIterator", []> { + let summary = [{ +Makes a new iterator from the given `dataset` and stores it in `iterator`. + }]; + + let description = [{ +This operation may be executed multiple times. Each execution will reset the +iterator in `iterator` to the first element of `dataset`. + }]; + + let arguments = (ins + TF_VariantTensor:$dataset, + Arg:$iterator + ); + + let results = (outs); +} + def TF_MatMulOp : TF_Op<"MatMul", [NoSideEffect, TF_SameOperandsAndResultElementTypeResolveRef]> { let summary = [{ Multiply the matrix "a" by the matrix "b". @@ -5332,15 +5834,15 @@ cublas. }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$a, - TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$b, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64]>:$a, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64]>:$b, DefaultValuedAttr:$transpose_a, DefaultValuedAttr:$transpose_b ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$product + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64]>:$product ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -5561,7 +6063,7 @@ tf.matrix_diag_part(input, k = (1, 3), padding_value = 9) let arguments = (ins TF_Tensor:$input, - I32Tensor:$k, + TF_Int32Tensor:$k, TF_Tensor:$padding_value, DefaultValuedAttr, "RIGHT_LEFT">:$align @@ -5671,9 +6173,9 @@ tf.matrix_diag(diagonal, k = -1, num_rows = 3, padding_value = 9) let arguments = (ins TF_Tensor:$diagonal, - I32Tensor:$k, - I32Tensor:$num_rows, - I32Tensor:$num_cols, + TF_Int32Tensor:$k, + TF_Int32Tensor:$num_rows, + TF_Int32Tensor:$num_cols, TF_Tensor:$padding_value ); @@ -5810,9 +6312,9 @@ tf.matrix_diag(diagonal, k = -1, num_rows = 3, padding_value = 9) let arguments = (ins TF_Tensor:$diagonal, - I32Tensor:$k, - I32Tensor:$num_rows, - I32Tensor:$num_cols, + TF_Int32Tensor:$k, + TF_Int32Tensor:$num_rows, + TF_Int32Tensor:$num_cols, TF_Tensor:$padding_value, DefaultValuedAttr, "RIGHT_LEFT">:$align @@ -5843,13 +6345,13 @@ garbage result. }]; let arguments = (ins - TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$input, + TensorOf<[TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64]>:$input, DefaultValuedAttr:$adjoint ); let results = (outs - TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$output + TensorOf<[TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64]>:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -5969,7 +6471,7 @@ tf.matrix_set_diag(diagonals, k = (-1, 0)) let arguments = (ins TF_Tensor:$input, TF_Tensor:$diagonal, - I32Tensor:$k + TF_Int32Tensor:$k ); let results = (outs @@ -6094,7 +6596,7 @@ tf.matrix_set_diag(input, diagonals, k = (-1, 2), align="LEFT_RIGHT") let arguments = (ins TF_Tensor:$input, TF_Tensor:$diagonal, - I32Tensor:$k, + TF_Int32Tensor:$k, DefaultValuedAttr, "RIGHT_LEFT">:$align ); @@ -6119,14 +6621,14 @@ If `adjoint` is `True` then each output matrix satisfies }]; let arguments = (ins - TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$matrix, - TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$rhs, + TensorOf<[TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64]>:$matrix, + TensorOf<[TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64]>:$rhs, DefaultValuedAttr:$adjoint ); let results = (outs - TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$output + TensorOf<[TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64]>:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -6186,15 +6688,15 @@ tf.matmul(a, x) }]; let arguments = (ins - TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$matrix, - TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$rhs, + TensorOf<[TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64]>:$matrix, + TensorOf<[TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64]>:$rhs, DefaultValuedAttr:$lower, DefaultValuedAttr:$adjoint ); let results = (outs - TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$output + TensorOf<[TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64]>:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -6213,39 +6715,39 @@ retained with length 1. }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input, + TensorOf<[TF_Bfloat16, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint16, TF_Qint32, TF_Qint8, TF_Quint16, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input, TF_I32OrI64Tensor:$reduction_indices, DefaultValuedAttr:$keep_dims ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output + TensorOf<[TF_Bfloat16, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint16, TF_Qint32, TF_Qint8, TF_Quint16, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>; - let builders = [OpBuilder< - "OpBuilder &builder, OperationState &result, Value input, " - "Value reduction_indices, BoolAttr keep_dims" - >]; + let builders = [ + OpBuilder<"Value input, Value reduction_indices, BoolAttr keep_dims"> + ]; } def TF_MaxPoolOp : TF_Op<"MaxPool", [NoSideEffect, TF_FoldOperandsTransposeInterface]> { let summary = "Performs max pooling on the input."; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Qint8, TF_Uint16, TF_Uint8]>:$input, + TensorOf<[TF_Bfloat16, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint8, TF_Uint16, TF_Uint8]>:$input, Confined]>:$ksize, Confined]>:$strides, - TF_AnyStrAttrOf<["SAME", "VALID"]>:$padding, + TF_AnyStrAttrOf<["SAME", "VALID", "EXPLICIT"]>:$padding, + DefaultValuedAttr:$explicit_paddings, DefaultValuedAttr, "NHWC">:$data_format ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Qint8, TF_Uint16, TF_Uint8]>:$output + TensorOf<[TF_Bfloat16, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint8, TF_Uint16, TF_Uint8]>:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -6262,7 +6764,7 @@ def TF_MaxPool3DOp : TF_Op<"MaxPool3D", [NoSideEffect]> { let summary = "Performs 3D max pooling on the input."; let arguments = (ins - TensorOf<[BF16, F16, F32]>:$input, + TensorOf<[TF_Bfloat16, TF_Float16, TF_Float32]>:$input, Confined]>:$ksize, Confined]>:$strides, @@ -6271,7 +6773,7 @@ def TF_MaxPool3DOp : TF_Op<"MaxPool3D", [NoSideEffect]> { ); let results = (outs - TensorOf<[BF16, F16, F32]>:$output + TensorOf<[TF_Bfloat16, TF_Float16, TF_Float32]>:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -6281,9 +6783,9 @@ def TF_MaxPool3DGradOp : TF_Op<"MaxPool3DGrad", [NoSideEffect]> { let summary = "Computes gradients of 3D max pooling function."; let arguments = (ins - TensorOf<[BF16, F16, F32]>:$orig_input, - TensorOf<[BF16, F16, F32]>:$orig_output, - TensorOf<[BF16, F16, F32]>:$grad, + TensorOf<[TF_Bfloat16, TF_Float16, TF_Float32]>:$orig_input, + TensorOf<[TF_Bfloat16, TF_Float16, TF_Float32]>:$orig_output, + TensorOf<[TF_Bfloat16, TF_Float16, TF_Float32]>:$grad, Confined]>:$ksize, Confined]>:$strides, @@ -6292,7 +6794,7 @@ def TF_MaxPool3DGradOp : TF_Op<"MaxPool3DGrad", [NoSideEffect]> { ); let results = (outs - TensorOf<[BF16, F16, F32]>:$output + TensorOf<[TF_Bfloat16, TF_Float16, TF_Float32]>:$output ); TF_DerivedOperandTypeAttr TInput = TF_DerivedOperandTypeAttr<0>; @@ -6309,7 +6811,8 @@ def TF_MaxPoolGradOp : TF_Op<"MaxPoolGrad", [NoSideEffect]> { Confined]>:$ksize, Confined]>:$strides, - TF_AnyStrAttrOf<["SAME", "VALID"]>:$padding, + TF_AnyStrAttrOf<["SAME", "VALID", "EXPLICIT"]>:$padding, + DefaultValuedAttr:$explicit_paddings, DefaultValuedAttr:$data_format ); @@ -6324,27 +6827,6 @@ def TF_MaxPoolGradOp : TF_Op<"MaxPoolGrad", [NoSideEffect]> { }]; } -def TF_MaximumOp : TF_Op<"Maximum", [NoSideEffect, ResultsBroadcastableShape, TF_SameOperandsAndResultElementTypeResolveRef]>, - WithBroadcastableBinOpBuilder { - let summary = "Returns the max of x and y (i.e. x > y ? x : y) element-wise."; - - let description = [{ -*NOTE*: `Maximum` supports broadcasting. More about broadcasting -[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) - }]; - - let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, TF_Uint8]>:$x, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, TF_Uint8]>:$y - ); - - let results = (outs - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, TF_Uint8]>:$z - ); - - TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; -} - def TF_MeanOp : TF_Op<"Mean", [NoSideEffect, TF_FoldOperandsTransposeInterface]> { let summary = "Computes the mean of elements across dimensions of a tensor."; @@ -6356,14 +6838,14 @@ retained with length 1. }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input, TF_I32OrI64Tensor:$reduction_indices, DefaultValuedAttr:$keep_dims ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -6440,14 +6922,14 @@ retained with length 1. }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input, + TensorOf<[TF_Bfloat16, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint16, TF_Qint32, TF_Qint8, TF_Quint16, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input, TF_I32OrI64Tensor:$reduction_indices, DefaultValuedAttr:$keep_dims ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output + TensorOf<[TF_Bfloat16, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint16, TF_Qint32, TF_Qint8, TF_Quint16, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -6464,12 +6946,12 @@ def TF_MinimumOp : TF_Op<"Minimum", [NoSideEffect, ResultsBroadcastableShape, TF }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, TF_Uint8]>:$x, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, TF_Uint8]>:$y + TensorOf<[TF_Bfloat16, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Uint8]>:$x, + TensorOf<[TF_Bfloat16, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Uint8]>:$y ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, TF_Uint8]>:$z + TensorOf<[TF_Bfloat16, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Uint8]>:$z ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -6522,7 +7004,7 @@ pad(t, paddings) ==> [[2, 1, 1, 2, 3, 3, 2] } def TF_MlirLocalVarOp : TF_Op<"MlirLocalVarOp", []> { - let summary = "Creates a handle to a in-scope variable."; + let summary = "Creates a handle to an in-scope variable."; let description = [{ Used by internal passes for temporary representation of local state, which will @@ -6532,7 +7014,7 @@ be eventually removed. let arguments = (ins); let results = (outs - TF_ResourceTensor:$resource + Res:$resource ); } @@ -6623,12 +7105,12 @@ def TF_MulOp : TF_Op<"Mul", [Commutative, NoSideEffect, ResultsBroadcastableShap }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$x, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$y + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Uint16, TF_Uint8]>:$x, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Uint16, TF_Uint8]>:$y ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$z + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Uint16, TF_Uint8]>:$z ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -6659,12 +7141,88 @@ Returns x * y element-wise. Returns zero if y is zero, even if x if infinite or TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_MultiDeviceIteratorOp : TF_Op<"MultiDeviceIterator", []> { + let summary = "Creates a MultiDeviceIterator resource."; + + let arguments = (ins + Confined]>:$devices, + StrAttr:$shared_name, + StrAttr:$container, + Confined]>:$output_types, + Confined]>:$output_shapes + ); + + let results = (outs + Res:$handle + ); +} + +def TF_MultiDeviceIteratorFromStringHandleOp : TF_Op<"MultiDeviceIteratorFromStringHandle", []> { + let summary = [{ +Generates a MultiDeviceIterator resource from its provided string handle. + }]; + + let arguments = (ins + TF_StrTensor:$string_handle, + + DefaultValuedAttr:$output_types, + DefaultValuedAttr:$output_shapes + ); + + let results = (outs + Res:$multi_device_iterator + ); +} + +def TF_MultiDeviceIteratorGetNextFromShardOp : TF_Op<"MultiDeviceIteratorGetNextFromShard", []> { + let summary = "Gets next element for the provided shard number."; + + let arguments = (ins + Arg:$multi_device_iterator, + TF_Int32Tensor:$shard_num, + TF_Int64Tensor:$incarnation_id + ); + + let results = (outs + Variadic:$components + ); + + TF_DerivedResultShapeListAttr output_shapes = TF_DerivedResultShapeListAttr<0>; + TF_DerivedResultTypeListAttr output_types = TF_DerivedResultTypeListAttr<0>; +} + +def TF_MultiDeviceIteratorInitOp : TF_Op<"MultiDeviceIteratorInit", []> { + let summary = "Initializes the multi device iterator with the given dataset."; + + let arguments = (ins + TF_VariantTensor:$dataset, + Arg:$multi_device_iterator, + TF_Int64Tensor:$max_buffer_size + ); + + let results = (outs + TF_Int64Tensor:$incarnation_id + ); +} + +def TF_MultiDeviceIteratorToStringHandleOp : TF_Op<"MultiDeviceIteratorToStringHandle", []> { + let summary = "Produces a string handle for the given MultiDeviceIterator."; + + let arguments = (ins + Arg:$multi_device_iterator + ); + + let results = (outs + TF_StrTensor:$string_handle + ); +} + def TF_MultinomialOp : TF_Op<"Multinomial", [TF_CannotDuplicate]> { let summary = "Draws samples from a multinomial distribution."; let arguments = (ins TF_IntOrFpTensor:$logits, - I32Tensor:$num_samples, + TF_Int32Tensor:$num_samples, DefaultValuedAttr:$seed, DefaultValuedAttr:$seed2 @@ -6678,15 +7236,94 @@ def TF_MultinomialOp : TF_Op<"Multinomial", [TF_CannotDuplicate]> { TF_DerivedResultTypeAttr output_dtype = TF_DerivedResultTypeAttr<0>; } +def TF_MutableDenseHashTableV2Op : TF_Op<"MutableDenseHashTableV2", []> { + let summary = [{ +Creates an empty hash table that uses tensors as the backing store. + }]; + + let description = [{ +It uses "open addressing" with quadratic reprobing to resolve +collisions. + +This op creates a mutable hash table, specifying the type of its keys and +values. Each value must be a scalar. Data can be inserted into the table using +the insert operations. It does not support the initialization operation. + }]; + + let arguments = (ins + TF_Tensor:$empty_key, + TF_Tensor:$deleted_key, + + StrAttr:$container, + StrAttr:$shared_name, + DefaultValuedAttr:$use_node_name_sharing, + TypeAttr:$value_dtype, + DefaultValuedAttr({})">:$value_shape, + DefaultValuedAttr:$initial_num_buckets, + DefaultValuedAttr:$max_load_factor + ); + + let results = (outs + Res:$table_handle + ); + + TF_DerivedOperandTypeAttr key_dtype = TF_DerivedOperandTypeAttr<0>; +} + +def TF_MutableHashTableOfTensorsV2Op : TF_Op<"MutableHashTableOfTensorsV2", []> { + let summary = "Creates an empty hash table."; + + let description = [{ +This op creates a mutable hash table, specifying the type of its keys and +values. Each value must be a vector. Data can be inserted into the table using +the insert operations. It does not support the initialization operation. + }]; + + let arguments = (ins + StrAttr:$container, + StrAttr:$shared_name, + DefaultValuedAttr:$use_node_name_sharing, + TypeAttr:$key_dtype, + TypeAttr:$value_dtype, + DefaultValuedAttr({})">:$value_shape + ); + + let results = (outs + Res:$table_handle + ); +} + +def TF_MutableHashTableV2Op : TF_Op<"MutableHashTableV2", []> { + let summary = "Creates an empty hash table."; + + let description = [{ +This op creates a mutable hash table, specifying the type of its keys and +values. Each value must be a scalar. Data can be inserted into the table using +the insert operations. It does not support the initialization operation. + }]; + + let arguments = (ins + StrAttr:$container, + StrAttr:$shared_name, + DefaultValuedAttr:$use_node_name_sharing, + TypeAttr:$key_dtype, + TypeAttr:$value_dtype + ); + + let results = (outs + Res:$table_handle + ); +} + def TF_NdtriOp : TF_Op<"Ndtri", [NoSideEffect]> { let summary = ""; let arguments = (ins - TF_FpTensor:$x + TF_FloatTensor:$x ); let results = (outs - TF_FpTensor:$y + TF_FloatTensor:$y ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -6700,11 +7337,11 @@ I.e., \\(y = -x\\). }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64]>:$x + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8]>:$x ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64]>:$y + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8]>:$y ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -6712,6 +7349,34 @@ I.e., \\(y = -x\\). let hasCanonicalizer = 1; } +def TF_NextAfterOp : TF_Op<"NextAfter", [NoSideEffect, ResultsBroadcastableShape]>, + WithBroadcastableBinOpBuilder { + let summary = [{ +Returns the next representable value of `x1` in the direction of `x2`, element-wise. + }]; + + let description = [{ +This operation returns the same result as the C++ std::nextafter function. + +It can also return a subnormal number. + +@compatibility(cpp) +Equivalent to C++ std::nextafter function. +@end_compatibility + }]; + + let arguments = (ins + TF_F32OrF64Tensor:$x1, + TF_F32OrF64Tensor:$x2 + ); + + let results = (outs + TF_F32OrF64Tensor:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_NoOp : TF_Op<"NoOp", [NoSideEffect]> { let summary = "Does nothing. Only useful as a placeholder for control edges."; @@ -6720,6 +7385,49 @@ def TF_NoOp : TF_Op<"NoOp", [NoSideEffect]> { let results = (outs); } +def TF_NonMaxSuppressionV3Op : TF_Op<"NonMaxSuppressionV3", [NoSideEffect]> { + let summary = [{ +Greedily selects a subset of bounding boxes in descending order of score, + }]; + + let description = [{ +pruning away boxes that have high intersection-over-union (IOU) overlap +with previously selected boxes. Bounding boxes with score less than +`score_threshold` are removed. Bounding boxes are supplied as +[y1, x1, y2, x2], where (y1, x1) and (y2, x2) are the coordinates of any +diagonal pair of box corners and the coordinates can be provided as normalized +(i.e., lying in the interval [0, 1]) or absolute. Note that this algorithm +is agnostic to where the origin is in the coordinate system and more +generally is invariant to orthogonal transformations and translations +of the coordinate system; thus translating or reflections of the coordinate +system result in the same boxes being selected by the algorithm. +The output of this operation is a set of integers indexing into the input +collection of bounding boxes representing the selected boxes. The bounding +box coordinates corresponding to the selected indices can then be obtained +using the `tf.gather operation`. For example: + selected_indices = tf.image.non_max_suppression_v2( + boxes, scores, max_output_size, iou_threshold, score_threshold) + selected_boxes = tf.gather(boxes, selected_indices) + }]; + + let arguments = (ins + TensorOf<[TF_Float16, TF_Float32]>:$boxes, + TensorOf<[TF_Float16, TF_Float32]>:$scores, + TF_Int32Tensor:$max_output_size, + TensorOf<[TF_Float16, TF_Float32]>:$iou_threshold, + TensorOf<[TF_Float16, TF_Float32]>:$score_threshold + ); + + let results = (outs + TF_Int32Tensor:$selected_indices + ); + + TF_DerivedOperandTypeAttr T_threshold = TF_DerivedOperandTypeAttr<3>; + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + + let hasCanonicalizer = 1; +} + def TF_NonMaxSuppressionV4Op : TF_Op<"NonMaxSuppressionV4", [NoSideEffect]> { let summary = [{ Greedily selects a subset of bounding boxes in descending order of score, @@ -6746,18 +7454,18 @@ using the `tf.gather operation`. For example: }]; let arguments = (ins - TensorOf<[F16, F32]>:$boxes, - TensorOf<[F16, F32]>:$scores, - I32Tensor:$max_output_size, - TensorOf<[F16, F32]>:$iou_threshold, - TensorOf<[F16, F32]>:$score_threshold, + TensorOf<[TF_Float16, TF_Float32]>:$boxes, + TensorOf<[TF_Float16, TF_Float32]>:$scores, + TF_Int32Tensor:$max_output_size, + TensorOf<[TF_Float16, TF_Float32]>:$iou_threshold, + TensorOf<[TF_Float16, TF_Float32]>:$score_threshold, DefaultValuedAttr:$pad_to_max_output_size ); let results = (outs - I32Tensor:$selected_indices, - I32Tensor:$valid_outputs + TF_Int32Tensor:$selected_indices, + TF_Int32Tensor:$valid_outputs ); TF_DerivedOperandTypeAttr T_threshold = TF_DerivedOperandTypeAttr<3>; @@ -6795,20 +7503,20 @@ larger than 0. }]; let arguments = (ins - TensorOf<[F16, F32]>:$boxes, - TensorOf<[F16, F32]>:$scores, - I32Tensor:$max_output_size, - TensorOf<[F16, F32]>:$iou_threshold, - TensorOf<[F16, F32]>:$score_threshold, - TensorOf<[F16, F32]>:$soft_nms_sigma, + TensorOf<[TF_Float16, TF_Float32]>:$boxes, + TensorOf<[TF_Float16, TF_Float32]>:$scores, + TF_Int32Tensor:$max_output_size, + TensorOf<[TF_Float16, TF_Float32]>:$iou_threshold, + TensorOf<[TF_Float16, TF_Float32]>:$score_threshold, + TensorOf<[TF_Float16, TF_Float32]>:$soft_nms_sigma, DefaultValuedAttr:$pad_to_max_output_size ); let results = (outs - I32Tensor:$selected_indices, - TensorOf<[F16, F32]>:$selected_scores, - I32Tensor:$valid_outputs + TF_Int32Tensor:$selected_indices, + TensorOf<[TF_Float16, TF_Float32]>:$selected_scores, + TF_Int32Tensor:$valid_outputs ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -6823,21 +7531,20 @@ def TF_NotEqualOp : TF_Op<"NotEqual", [Commutative, NoSideEffect]> { }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$x, - TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$y, + TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint16, TF_Quint16, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$x, + TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint16, TF_Quint16, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$y, DefaultValuedAttr:$incompatible_shape_error ); let results = (outs - I1Tensor:$z + TF_BoolTensor:$z ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; let builders = [ - OpBuilder<"OpBuilder& builder, OperationState& result, Value x, " - "Value y, BoolAttr incompatible_shape_error"> + OpBuilder<"Value x, Value y, BoolAttr incompatible_shape_error"> ]; let verifier = [{ @@ -6939,8 +7646,8 @@ output = }]; let arguments = (ins - TensorOf<[I32, I64, TF_Uint8]>:$indices, - I32Tensor:$depth, + TensorOf<[TF_Int32, TF_Int64, TF_Uint8]>:$indices, + TF_Int32Tensor:$depth, TF_Tensor:$on_value, TF_Tensor:$off_value, @@ -6955,8 +7662,7 @@ output = TF_DerivedOperandTypeAttr TI = TF_DerivedOperandTypeAttr<0>; let builders = [ - OpBuilder<"OpBuilder& builder, OperationState& result, Value indices, " - "Value depth, Value on_value, Value off_value, " + OpBuilder<"Value indices, Value depth, Value on_value, Value off_value, " "IntegerAttr axis"> ]; @@ -6965,6 +7671,44 @@ output = }]; } +def TF_OneShotIteratorOp : TF_Op<"OneShotIterator", []> { + let summary = [{ +Makes a "one-shot" iterator that can be iterated only once. + }]; + + let description = [{ +A one-shot iterator bundles the logic for defining the dataset and +the state of the iterator in a single op, which allows simple input +pipelines to be defined without an additional initialization +("MakeIterator") step. + +One-shot iterators have the following limitations: + +* They do not support parameterization: all logic for creating the underlying + dataset must be bundled in the `dataset_factory` function. +* They are not resettable. Once a one-shot iterator reaches the end of its + underlying dataset, subsequent "IteratorGetNext" operations on that + iterator will always produce an `OutOfRange` error. + +For greater flexibility, use "Iterator" and "MakeIterator" to define +an iterator using an arbitrary subgraph, which may capture tensors +(including fed values) as parameters, and which may be reset multiple +times by rerunning "MakeIterator". + }]; + + let arguments = (ins + SymbolRefAttr:$dataset_factory, + Confined]>:$output_types, + Confined]>:$output_shapes, + StrAttr:$container, + StrAttr:$shared_name + ); + + let results = (outs + Res:$handle + ); +} + def TF_OutfeedEnqueueTupleOp : TF_Op<"OutfeedEnqueueTuple", []> { let summary = "Enqueue multiple Tensor values on the computation outfeed."; @@ -7128,17 +7872,17 @@ stores the parameters for each batch. let arguments = (ins TF_I32OrI64Tensor:$shape, - TF_FpTensor:$means, - TF_FpTensor:$stdevs, - TF_FpTensor:$minvals, - TF_FpTensor:$maxvals, + TF_FloatTensor:$means, + TF_FloatTensor:$stdevs, + TF_FloatTensor:$minvals, + TF_FloatTensor:$maxvals, DefaultValuedAttr:$seed, DefaultValuedAttr:$seed2 ); let results = (outs - TF_FpTensor:$output + TF_FloatTensor:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -7161,12 +7905,12 @@ tf.pow(x, y) ==> [[256, 65536], [9, 27]] }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$x, - TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$y + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64]>:$x, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64]>:$y ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$z + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64]>:$z ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -7232,14 +7976,14 @@ retained with length 1. }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input, TF_I32OrI64Tensor:$reduction_indices, DefaultValuedAttr:$keep_dims ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -7263,14 +8007,14 @@ q_full, r_full = qr(a, full_matrices=True) }]; let arguments = (ins - TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$input, + TensorOf<[TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64]>:$input, DefaultValuedAttr:$full_matrices ); let results = (outs - TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$q, - TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$r + TensorOf<[TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64]>:$q, + TensorOf<[TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64]>:$r ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -7284,7 +8028,7 @@ def TF_QuantizeAndDequantizeOp : TF_Op<"QuantizeAndDequantize", [NoSideEffect, S let summary = "Use QuantizeAndDequantizeV2 instead."; let arguments = (ins - TF_FpTensor:$input, + TF_FloatTensor:$input, DefaultValuedAttr:$signed_input, DefaultValuedAttr:$num_bits, @@ -7294,7 +8038,7 @@ def TF_QuantizeAndDequantizeOp : TF_Op<"QuantizeAndDequantize", [NoSideEffect, S ); let results = (outs - TF_FpTensor:$output + TF_FloatTensor:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -7358,9 +8102,9 @@ The above round function rounds the value based on the given round_mode. }]; let arguments = (ins - TF_FpTensor:$input, - TF_FpTensor:$input_min, - TF_FpTensor:$input_max, + TF_FloatTensor:$input, + TF_FloatTensor:$input_min, + TF_FloatTensor:$input_max, DefaultValuedAttr:$signed_input, DefaultValuedAttr:$num_bits, @@ -7371,7 +8115,7 @@ The above round function rounds the value based on the given round_mode. ); let results = (outs - TF_FpTensor:$output + TF_FloatTensor:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -7386,10 +8130,10 @@ tensor, so its value can change during training. }]; let arguments = (ins - TF_FpTensor:$input, - TF_FpTensor:$input_min, - TF_FpTensor:$input_max, - I32Tensor:$num_bits, + TF_FloatTensor:$input, + TF_FloatTensor:$input_min, + TF_FloatTensor:$input_max, + TF_Int32Tensor:$num_bits, DefaultValuedAttr:$signed_input, DefaultValuedAttr:$range_given, @@ -7398,7 +8142,7 @@ tensor, so its value can change during training. ); let results = (outs - TF_FpTensor:$output + TF_FloatTensor:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -7422,7 +8166,7 @@ the dimension is padded with zeros. let arguments = (ins TF_F32OrF64Tensor:$input, - I32Tensor:$fft_length + TF_Int32Tensor:$fft_length ); let results = (outs @@ -7452,7 +8196,7 @@ the dimension is padded with zeros. let arguments = (ins TF_F32OrF64Tensor:$input, - I32Tensor:$fft_length + TF_Int32Tensor:$fft_length ); let results = (outs @@ -7482,7 +8226,7 @@ the dimension is padded with zeros. let arguments = (ins TF_F32OrF64Tensor:$input, - I32Tensor:$fft_length + TF_Int32Tensor:$fft_length ); let results = (outs @@ -7518,11 +8262,11 @@ array([0.6666667, 1. , 1. ], dtype=float32) }]; let arguments = (ins - TF_FpTensor:$images + TF_FloatTensor:$images ); let results = (outs - TF_FpTensor:$output + TF_FloatTensor:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -7541,14 +8285,14 @@ See http://dl.acm.org/citation.cfm?id=358414 let arguments = (ins TF_I32OrI64Tensor:$shape, - TensorOf<[F16, F32, F64]>:$alpha, + TensorOf<[TF_Float16, TF_Float32, TF_Float64]>:$alpha, DefaultValuedAttr:$seed, DefaultValuedAttr:$seed2 ); let results = (outs - TensorOf<[F16, F32, F64]>:$output + TensorOf<[TF_Float16, TF_Float32, TF_Float64]>:$output ); TF_DerivedOperandTypeAttr S = TF_DerivedOperandTypeAttr<0>; @@ -7578,14 +8322,14 @@ def TF_RandomPoissonOp : TF_Op<"RandomPoisson", [TF_CannotDuplicate]> { let arguments = (ins TF_I32OrI64Tensor:$shape, - TensorOf<[F16, F32, F64]>:$rate, + TensorOf<[TF_Float16, TF_Float32, TF_Float64]>:$rate, DefaultValuedAttr:$seed, DefaultValuedAttr:$seed2 ); let results = (outs - TensorOf<[F16, F32, F64]>:$output + TensorOf<[TF_Float16, TF_Float32, TF_Float64]>:$output ); TF_DerivedOperandTypeAttr S = TF_DerivedOperandTypeAttr<0>; @@ -7611,14 +8355,14 @@ Programming, Volume 2. Addison Wesley let arguments = (ins TF_I32OrI64Tensor:$shape, - TensorOf<[F16, F32, F64, I32, I64]>:$rate, + TensorOf<[TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64]>:$rate, DefaultValuedAttr:$seed, DefaultValuedAttr:$seed2 ); let results = (outs - TensorOf<[F16, F32, F64, I32, I64]>:$output + TensorOf<[TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64]>:$output ); TF_DerivedOperandTypeAttr R = TF_DerivedOperandTypeAttr<1>; @@ -7670,7 +8414,7 @@ The generated values will have mean 0 and standard deviation 1. ); let results = (outs - TF_FpTensor:$output + TF_FloatTensor:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -7693,7 +8437,7 @@ lower bound 0 is included in the range, while the upper bound 1 is excluded. ); let results = (outs - TF_FpTensor:$output + TF_FloatTensor:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -7764,8 +8508,7 @@ tf.range(start, limit, delta) ==> [3, 6, 9, 12, 15] TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<0>; let builders = [ - OpBuilder<"OpBuilder& builder, OperationState& result, Value start, " - "Value limit, Value delta"> + OpBuilder<"Value start, Value limit, Value delta"> ]; } @@ -7775,9 +8518,9 @@ Creates a dataset with a range of values. Corresponds to python's xrange. }]; let arguments = (ins - I64Tensor:$start, - I64Tensor:$stop, - I64Tensor:$step, + TF_Int64Tensor:$start, + TF_Int64Tensor:$stop, + TF_Int64Tensor:$step, Confined]>:$output_types, Confined]>:$output_shapes @@ -7812,13 +8555,13 @@ of the tensor. Rank is also known as "order", "degree", or "ndims." ); let results = (outs - I32Tensor:$output + TF_Int32Tensor:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; let builders = [ - OpBuilder<"OpBuilder& builder, OperationState& result, Value input"> + OpBuilder<"Value input"> ]; let hasFolder = 1; @@ -7878,33 +8621,6 @@ tf.real(input) ==> [-2.25, 3.25] TF_DerivedResultTypeAttr Tout = TF_DerivedResultTypeAttr<0>; } -def TF_RealDivOp : TF_Op<"RealDiv", [NoSideEffect, ResultsBroadcastableShape, TF_CwiseBinary]>, - WithBroadcastableBinOpBuilder { - let summary = "Returns x / y element-wise for real types."; - - let description = [{ -If `x` and `y` are reals, this will return the floating-point division. - -*NOTE*: `Div` supports broadcasting. More about broadcasting -[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) - }]; - - let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$x, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$y - ); - - let results = (outs - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$z - ); - - TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; - - let hasCanonicalizer = 1; - - let hasFolder = 1; -} - def TF_ReciprocalOp : TF_Op<"Reciprocal", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Computes the reciprocal of x element-wise."; @@ -7913,11 +8629,11 @@ I.e., \\(y = 1 / x\\). }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64]>:$x + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8]>:$x ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64]>:$y + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8]>:$y ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -7962,13 +8678,13 @@ most one RecvTPUEmbeddingActivations op in the TPU graph. ); let results = (outs - Variadic:$outputs + Variadic:$outputs ); TF_DerivedResultSizeAttr num_outputs = TF_DerivedResultSizeAttr<0>; } -def TF_ReluOp : TF_Op<"Relu", [NoSideEffect, SameOperandsAndResultType, TF_LayoutAgnostic]> { +def TF_ReluOp : TF_Op<"Relu", [NoSideEffect, SameOperandsAndResultType, TF_ContractionFusableInterface, TF_LayoutAgnostic]> { let summary = "Computes rectified linear: `max(features, 0)`."; let description = [{ @@ -7979,14 +8695,19 @@ array([ 0., 0., -0., 3.], dtype=float32) }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Qint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$features + TensorOf<[TF_Bfloat16, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$features ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Qint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$activations + TensorOf<[TF_Bfloat16, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$activations ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + + let extraClassDeclaration = [{ + // TF_ContractionFusableInterface: + Optional GetContractionFusion(); + }]; } def TF_Relu6Op : TF_Op<"Relu6", [NoSideEffect, SameOperandsAndResultType]> { @@ -8111,8 +8832,7 @@ reshape(t, []) ==> 7 TF_DerivedOperandTypeAttr Tshape = TF_DerivedOperandTypeAttr<1>; let builders = [ - OpBuilder< - "OpBuilder& builder, OperationState& result, Value tensor, Value shape"> + OpBuilder<"Value tensor, Value shape"> ]; let verifier = [{ @@ -8131,15 +8851,15 @@ Input images can be of different types but output images are always float. }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Uint16, TF_Uint8]>:$images, - I32Tensor:$size, + TensorOf<[TF_Bfloat16, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Uint16, TF_Uint8]>:$images, + TF_Int32Tensor:$size, DefaultValuedAttr:$align_corners, DefaultValuedAttr:$half_pixel_centers ); let results = (outs - F32Tensor:$resized_images + TF_Float32Tensor:$resized_images ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -8149,15 +8869,15 @@ def TF_ResizeBilinearGradOp : TF_Op<"ResizeBilinearGrad", [NoSideEffect]> { let summary = "Computes the gradient of bilinear interpolation."; let arguments = (ins - F32Tensor:$grads, - TF_FpTensor:$original_image, + TF_Float32Tensor:$grads, + TF_FloatTensor:$original_image, DefaultValuedAttr:$align_corners, DefaultValuedAttr:$half_pixel_centers ); let results = (outs - TF_FpTensor:$output + TF_FloatTensor:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>; @@ -8169,20 +8889,137 @@ Resize `images` to `size` using nearest neighbor interpolation. }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Uint16, TF_Uint8]>:$images, - I32Tensor:$size, + TensorOf<[TF_Bfloat16, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Uint16, TF_Uint8]>:$images, + TF_Int32Tensor:$size, DefaultValuedAttr:$align_corners, DefaultValuedAttr:$half_pixel_centers ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Uint16, TF_Uint8]>:$resized_images + TensorOf<[TF_Bfloat16, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Uint16, TF_Uint8]>:$resized_images ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_ResizeNearestNeighborGradOp : TF_Op<"ResizeNearestNeighborGrad", [NoSideEffect]> { + let summary = "Computes the gradient of nearest neighbor interpolation."; + + let arguments = (ins + TensorOf<[TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int8, TF_Uint8]>:$grads, + TF_Int32Tensor:$size, + + DefaultValuedAttr:$align_corners, + DefaultValuedAttr:$half_pixel_centers + ); + + let results = (outs + TensorOf<[TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int8, TF_Uint8]>:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_ResourceApplyAdaMaxOp : TF_Op<"ResourceApplyAdaMax", []> { + let summary = "Update '*var' according to the AdaMax algorithm."; + + let description = [{ +m_t <- beta1 * m_{t-1} + (1 - beta1) * g +v_t <- max(beta2 * v_{t-1}, abs(g)) +variable <- variable - learning_rate / (1 - beta1^t) * m_t / (v_t + epsilon) + }]; + + let arguments = (ins + Arg:$var, + Arg:$m, + Arg:$v, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$beta1_power, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lr, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$beta1, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$beta2, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$epsilon, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$grad, + + DefaultValuedAttr:$use_locking + ); + + let results = (outs); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<3>; +} + +def TF_ResourceApplyAdadeltaOp : TF_Op<"ResourceApplyAdadelta", []> { + let summary = "Update '*var' according to the adadelta scheme."; + + let description = [{ +accum = rho() * accum + (1 - rho()) * grad.square(); +update = (update_accum + epsilon).sqrt() * (accum + epsilon()).rsqrt() * grad; +update_accum = rho() * update_accum + (1 - rho()) * update.square(); +var -= update; + }]; + + let arguments = (ins + Arg:$var, + Arg:$accum, + Arg:$accum_update, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lr, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$rho, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$epsilon, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$grad, + + DefaultValuedAttr:$use_locking + ); + + let results = (outs); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<3>; +} + +def TF_ResourceApplyAdagradOp : TF_Op<"ResourceApplyAdagrad", []> { + let summary = "Update '*var' according to the adagrad scheme."; + + let description = [{ +accum += grad * grad +var -= lr * grad * (1 / sqrt(accum)) + }]; + + let arguments = (ins + Arg:$var, + Arg:$accum, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lr, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$grad, + + DefaultValuedAttr:$use_locking, + DefaultValuedAttr:$update_slots + ); + + let results = (outs); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<2>; +} + +def TF_ResourceApplyAdagradDAOp : TF_Op<"ResourceApplyAdagradDA", []> { + let summary = "Update '*var' according to the proximal adagrad scheme."; + + let arguments = (ins + Arg:$var, + Arg:$gradient_accumulator, + Arg:$gradient_squared_accumulator, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$grad, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lr, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$l1, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$l2, + TF_Int64Tensor:$global_step, + + DefaultValuedAttr:$use_locking + ); + + let results = (outs); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<3>; +} + def TF_ResourceApplyAdagradV2Op : TF_Op<"ResourceApplyAdagradV2", []> { let summary = "Update '*var' according to the adagrad scheme."; @@ -8192,11 +9029,11 @@ var -= lr * grad * (1 / (sqrt(accum) + epsilon)) }]; let arguments = (ins - TF_ResourceTensor:$var, - TF_ResourceTensor:$accum, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lr, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$epsilon, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$grad, + Arg:$var, + Arg:$accum, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lr, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$epsilon, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$grad, DefaultValuedAttr:$use_locking, DefaultValuedAttr:$update_slots @@ -8218,16 +9055,16 @@ $$\text{variable} := \text{variable} - \text{lr}_t * m_t / (\sqrt{v_t} + \epsilo }]; let arguments = (ins - TF_ResourceTensor:$var, - TF_ResourceTensor:$m, - TF_ResourceTensor:$v, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$beta1_power, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$beta2_power, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lr, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$beta1, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$beta2, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$epsilon, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$grad, + Arg:$var, + Arg:$m, + Arg:$v, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$beta1_power, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$beta2_power, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lr, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$beta1, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$beta2, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$epsilon, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$grad, DefaultValuedAttr:$use_locking, DefaultValuedAttr:$use_nesterov @@ -8238,6 +9075,32 @@ $$\text{variable} := \text{variable} - \text{lr}_t * m_t / (\sqrt{v_t} + \epsilo TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<3>; } +def TF_ResourceApplyAddSignOp : TF_Op<"ResourceApplyAddSign", []> { + let summary = "Update '*var' according to the AddSign update."; + + let description = [{ +m_t <- beta1 * m_{t-1} + (1 - beta1) * g +update <- (alpha + sign_decay * sign(g) *sign(m)) * g +variable <- variable - lr_t * update + }]; + + let arguments = (ins + Arg:$var, + Arg:$m, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lr, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$alpha, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$sign_decay, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$beta, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$grad, + + DefaultValuedAttr:$use_locking + ); + + let results = (outs); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<2>; +} + def TF_ResourceApplyCenteredRMSPropOp : TF_Op<"ResourceApplyCenteredRMSProp", []> { let summary = "Update '*var' according to the centered RMSProp algorithm."; @@ -8263,15 +9126,15 @@ var <- var - mom }]; let arguments = (ins - TF_ResourceTensor:$var, - TF_ResourceTensor:$mg, - TF_ResourceTensor:$ms, - TF_ResourceTensor:$mom, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lr, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$rho, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$momentum, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$epsilon, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$grad, + Arg:$var, + Arg:$mg, + Arg:$ms, + Arg:$mom, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lr, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$rho, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$momentum, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$epsilon, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$grad, DefaultValuedAttr:$use_locking ); @@ -8281,13 +9144,76 @@ var <- var - mom TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<4>; } +def TF_ResourceApplyFtrlOp : TF_Op<"ResourceApplyFtrl", []> { + let summary = "Update '*var' according to the Ftrl-proximal scheme."; + + let description = [{ +accum_new = accum + grad * grad +linear += grad - (accum_new^(-lr_power) - accum^(-lr_power)) / lr * var +quadratic = 1.0 / (accum_new^(lr_power) * lr) + 2 * l2 +var = (sign(linear) * l1 - linear) / quadratic if |linear| > l1 else 0.0 +accum = accum_new + }]; + + let arguments = (ins + Arg:$var, + Arg:$accum, + Arg:$linear, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$grad, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lr, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$l1, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$l2, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lr_power, + + DefaultValuedAttr:$use_locking, + DefaultValuedAttr:$multiply_linear_by_lr + ); + + let results = (outs); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<3>; +} + +def TF_ResourceApplyFtrlV2Op : TF_Op<"ResourceApplyFtrlV2", []> { + let summary = "Update '*var' according to the Ftrl-proximal scheme."; + + let description = [{ +grad_with_shrinkage = grad + 2 * l2_shrinkage * var +accum_new = accum + grad_with_shrinkage * grad_with_shrinkage +linear += grad_with_shrinkage + + (accum_new^(-lr_power) - accum^(-lr_power)) / lr * var +quadratic = 1.0 / (accum_new^(lr_power) * lr) + 2 * l2 +var = (sign(linear) * l1 - linear) / quadratic if |linear| > l1 else 0.0 +accum = accum_new + }]; + + let arguments = (ins + Arg:$var, + Arg:$accum, + Arg:$linear, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$grad, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lr, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$l1, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$l2, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$l2_shrinkage, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lr_power, + + DefaultValuedAttr:$use_locking, + DefaultValuedAttr:$multiply_linear_by_lr + ); + + let results = (outs); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<3>; +} + def TF_ResourceApplyGradientDescentOp : TF_Op<"ResourceApplyGradientDescent", []> { let summary = "Update '*var' by subtracting 'alpha' * 'delta' from it."; let arguments = (ins - TF_ResourceTensor:$var, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$alpha, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$delta, + Arg:$var, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$alpha, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$delta, DefaultValuedAttr:$use_locking ); @@ -8308,11 +9234,11 @@ var += accum }]; let arguments = (ins - TF_ResourceTensor:$var, - TF_ResourceTensor:$accum, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lr, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$grad, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$momentum, + Arg:$var, + Arg:$accum, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lr, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$grad, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$momentum, DefaultValuedAttr:$use_locking, DefaultValuedAttr:$use_nesterov @@ -8334,11 +9260,11 @@ var -= lr * accum }]; let arguments = (ins - TF_ResourceTensor:$var, - TF_ResourceTensor:$accum, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lr, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$grad, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$momentum, + Arg:$var, + Arg:$accum, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lr, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$grad, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$momentum, DefaultValuedAttr:$use_locking, DefaultValuedAttr:$use_nesterov @@ -8349,6 +9275,116 @@ var -= lr * accum TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<2>; } +def TF_ResourceApplyPowerSignOp : TF_Op<"ResourceApplyPowerSign", []> { + let summary = "Update '*var' according to the AddSign update."; + + let description = [{ +m_t <- beta1 * m_{t-1} + (1 - beta1) * g +update <- exp(logbase * sign_decay * sign(g) * sign(m_t)) * g +variable <- variable - lr_t * update + }]; + + let arguments = (ins + Arg:$var, + Arg:$m, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lr, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$logbase, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$sign_decay, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$beta, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$grad, + + DefaultValuedAttr:$use_locking + ); + + let results = (outs); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<2>; +} + +def TF_ResourceApplyProximalAdagradOp : TF_Op<"ResourceApplyProximalAdagrad", []> { + let summary = [{ +Update '*var' and '*accum' according to FOBOS with Adagrad learning rate. + }]; + + let description = [{ +accum += grad * grad +prox_v = var - lr * grad * (1 / sqrt(accum)) +var = sign(prox_v)/(1+lr*l2) * max{|prox_v|-lr*l1,0} + }]; + + let arguments = (ins + Arg:$var, + Arg:$accum, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lr, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$l1, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$l2, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$grad, + + DefaultValuedAttr:$use_locking + ); + + let results = (outs); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<2>; +} + +def TF_ResourceApplyProximalGradientDescentOp : TF_Op<"ResourceApplyProximalGradientDescent", []> { + let summary = "Update '*var' as FOBOS algorithm with fixed learning rate."; + + let description = [{ +prox_v = var - alpha * delta +var = sign(prox_v)/(1+alpha*l2) * max{|prox_v|-alpha*l1,0} + }]; + + let arguments = (ins + Arg:$var, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$alpha, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$l1, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$l2, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$delta, + + DefaultValuedAttr:$use_locking + ); + + let results = (outs); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>; +} + +def TF_ResourceApplyRMSPropOp : TF_Op<"ResourceApplyRMSProp", []> { + let summary = "Update '*var' according to the RMSProp algorithm."; + + let description = [{ +Note that in dense implementation of this algorithm, ms and mom will +update even if the grad is zero, but in this sparse implementation, ms +and mom will not update in iterations during which the grad is zero. + +mean_square = decay * mean_square + (1-decay) * gradient ** 2 +Delta = learning_rate * gradient / sqrt(mean_square + epsilon) + +ms <- rho * ms_{t-1} + (1-rho) * grad * grad +mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon) +var <- var - mom + }]; + + let arguments = (ins + Arg:$var, + Arg:$ms, + Arg:$mom, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lr, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$rho, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$momentum, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$epsilon, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$grad, + + DefaultValuedAttr:$use_locking + ); + + let results = (outs); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<3>; +} + def TF_ResourceGatherOp : TF_Op<"ResourceGather", []> { let summary = [{ Gather slices from the variable pointed to by `resource` according to `indices`. @@ -8371,7 +9407,7 @@ Produces an output tensor with shape `indices.shape + params.shape[1:]` where: }]; let arguments = (ins - TF_ResourceTensor:$resource, + Arg:$resource, TF_I32OrI64Tensor:$indices, DefaultValuedAttr:$batch_dims, @@ -8386,6 +9422,405 @@ Produces an output tensor with shape `indices.shape + params.shape[1:]` where: TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>; } +def TF_ResourceScatterAddOp : TF_Op<"ResourceScatterAdd", []> { + let summary = "Adds sparse updates to the variable referenced by `resource`."; + + let description = [{ +This operation computes + + # Scalar indices + ref[indices, ...] += updates[...] + + # Vector indices (for each i) + ref[indices[i], ...] += updates[i, ...] + + # High rank indices (for each i, ..., j) + ref[indices[i, ..., j], ...] += updates[i, ..., j, ...] + +Duplicate entries are handled correctly: if multiple `indices` reference +the same location, their contributions add. + +Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. + +
+ +
+ }]; + + let arguments = (ins + Arg:$resource, + TF_I32OrI64Tensor:$indices, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$updates + ); + + let results = (outs); + + TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<1>; + TF_DerivedOperandTypeAttr dtype = TF_DerivedOperandTypeAttr<2>; +} + +def TF_ResourceScatterDivOp : TF_Op<"ResourceScatterDiv", []> { + let summary = [{ +Divides sparse updates into the variable referenced by `resource`. + }]; + + let description = [{ +This operation computes + + # Scalar indices + ref[indices, ...] /= updates[...] + + # Vector indices (for each i) + ref[indices[i], ...] /= updates[i, ...] + + # High rank indices (for each i, ..., j) + ref[indices[i, ..., j], ...] /= updates[i, ..., j, ...] + +Duplicate entries are handled correctly: if multiple `indices` reference +the same location, their contributions multiply. + +Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. + +
+ +
+ }]; + + let arguments = (ins + Arg:$resource, + TF_I32OrI64Tensor:$indices, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$updates + ); + + let results = (outs); + + TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<1>; + TF_DerivedOperandTypeAttr dtype = TF_DerivedOperandTypeAttr<2>; +} + +def TF_ResourceScatterMaxOp : TF_Op<"ResourceScatterMax", []> { + let summary = [{ +Reduces sparse updates into the variable referenced by `resource` using the `max` operation. + }]; + + let description = [{ +This operation computes + + # Scalar indices + ref[indices, ...] = max(ref[indices, ...], updates[...]) + + # Vector indices (for each i) + ref[indices[i], ...] = max(ref[indices[i], ...], updates[i, ...]) + + # High rank indices (for each i, ..., j) + ref[indices[i, ..., j], ...] = max(ref[indices[i, ..., j], ...], updates[i, ..., j, ...]) + +Duplicate entries are handled correctly: if multiple `indices` reference +the same location, their contributions are combined. + +Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. + +
+ +
+ }]; + + let arguments = (ins + Arg:$resource, + TF_I32OrI64Tensor:$indices, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$updates + ); + + let results = (outs); + + TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<1>; + TF_DerivedOperandTypeAttr dtype = TF_DerivedOperandTypeAttr<2>; +} + +def TF_ResourceScatterMinOp : TF_Op<"ResourceScatterMin", []> { + let summary = [{ +Reduces sparse updates into the variable referenced by `resource` using the `min` operation. + }]; + + let description = [{ +This operation computes + + # Scalar indices + ref[indices, ...] = min(ref[indices, ...], updates[...]) + + # Vector indices (for each i) + ref[indices[i], ...] = min(ref[indices[i], ...], updates[i, ...]) + + # High rank indices (for each i, ..., j) + ref[indices[i, ..., j], ...] = min(ref[indices[i, ..., j], ...], updates[i, ..., j, ...]) + +Duplicate entries are handled correctly: if multiple `indices` reference +the same location, their contributions are combined. + +Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. + +
+ +
+ }]; + + let arguments = (ins + Arg:$resource, + TF_I32OrI64Tensor:$indices, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$updates + ); + + let results = (outs); + + TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<1>; + TF_DerivedOperandTypeAttr dtype = TF_DerivedOperandTypeAttr<2>; +} + +def TF_ResourceScatterMulOp : TF_Op<"ResourceScatterMul", []> { + let summary = [{ +Multiplies sparse updates into the variable referenced by `resource`. + }]; + + let description = [{ +This operation computes + + # Scalar indices + ref[indices, ...] *= updates[...] + + # Vector indices (for each i) + ref[indices[i], ...] *= updates[i, ...] + + # High rank indices (for each i, ..., j) + ref[indices[i, ..., j], ...] *= updates[i, ..., j, ...] + +Duplicate entries are handled correctly: if multiple `indices` reference +the same location, their contributions multiply. + +Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. + +
+ +
+ }]; + + let arguments = (ins + Arg:$resource, + TF_I32OrI64Tensor:$indices, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$updates + ); + + let results = (outs); + + TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<1>; + TF_DerivedOperandTypeAttr dtype = TF_DerivedOperandTypeAttr<2>; +} + +def TF_ResourceScatterNdAddOp : TF_Op<"ResourceScatterNdAdd", []> { + let summary = [{ +Applies sparse addition to individual values or slices in a Variable. + }]; + + let description = [{ +`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. + +`indices` must be integer tensor, containing indices into `ref`. +It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. + +The innermost dimension of `indices` (with length `K`) corresponds to +indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th +dimension of `ref`. + +`updates` is `Tensor` of rank `Q-1+P-K` with shape: + +``` +[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]] +``` + +For example, say we want to add 4 scattered elements to a rank-1 tensor to +8 elements. In Python, that addition would look like this: + +```python +ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8], use_resource=True) +indices = tf.constant([[4], [3], [1], [7]]) +updates = tf.constant([9, 10, 11, 12]) +add = tf.scatter_nd_add(ref, indices, updates) +with tf.Session() as sess: + print sess.run(add) +``` + +The resulting update to ref would look like this: + + [1, 13, 3, 14, 14, 6, 7, 20] + +See `tf.scatter_nd` for more details about how to make updates to +slices. + }]; + + let arguments = (ins + Arg:$ref, + TF_I32OrI64Tensor:$indices, + TF_Tensor:$updates, + + DefaultValuedAttr:$use_locking + ); + + let results = (outs); + + TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<1>; + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<2>; +} + +def TF_ResourceScatterNdSubOp : TF_Op<"ResourceScatterNdSub", []> { + let summary = [{ +Applies sparse subtraction to individual values or slices in a Variable. + }]; + + let description = [{ +`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. + +`indices` must be integer tensor, containing indices into `ref`. +It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. + +The innermost dimension of `indices` (with length `K`) corresponds to +indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th +dimension of `ref`. + +`updates` is `Tensor` of rank `Q-1+P-K` with shape: + +``` +[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]] +``` + +For example, say we want to subtract 4 scattered elements from a rank-1 tensor +with 8 elements. In Python, that subtraction would look like this: + +```python +ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8], use_resource=True) +indices = tf.constant([[4], [3], [1], [7]]) +updates = tf.constant([9, 10, 11, 12]) +sub = tf.scatter_nd_sub(ref, indices, updates) +with tf.Session() as sess: + print sess.run(sub) +``` + +The resulting update to ref would look like this: + + [1, -9, 3, -6, -4, 6, 7, -4] + +See `tf.scatter_nd` for more details about how to make updates to +slices. + }]; + + let arguments = (ins + Arg:$ref, + TF_I32OrI64Tensor:$indices, + TF_Tensor:$updates, + + DefaultValuedAttr:$use_locking + ); + + let results = (outs); + + TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<1>; + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<2>; +} + +def TF_ResourceScatterNdUpdateOp : TF_Op<"ResourceScatterNdUpdate", []> { + let summary = [{ +Applies sparse `updates` to individual values or slices within a given + }]; + + let description = [{ +variable according to `indices`. + +`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. + +`indices` must be integer tensor, containing indices into `ref`. +It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. + +The innermost dimension of `indices` (with length `K`) corresponds to +indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th +dimension of `ref`. + +`updates` is `Tensor` of rank `Q-1+P-K` with shape: + +``` +[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. +``` + +For example, say we want to update 4 scattered elements to a rank-1 tensor to +8 elements. In Python, that update would look like this: + +```python + ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) + indices = tf.constant([[4], [3], [1] ,[7]]) + updates = tf.constant([9, 10, 11, 12]) + update = tf.scatter_nd_update(ref, indices, updates) + with tf.Session() as sess: + print sess.run(update) +``` + +The resulting update to ref would look like this: + + [1, 11, 3, 10, 9, 6, 7, 12] + +See `tf.scatter_nd` for more details about how to make updates to +slices. + }]; + + let arguments = (ins + Arg:$ref, + TF_I32OrI64Tensor:$indices, + TF_Tensor:$updates, + + DefaultValuedAttr:$use_locking + ); + + let results = (outs); + + TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<1>; + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<2>; +} + +def TF_ResourceScatterSubOp : TF_Op<"ResourceScatterSub", []> { + let summary = [{ +Subtracts sparse updates from the variable referenced by `resource`. + }]; + + let description = [{ +This operation computes + + # Scalar indices + ref[indices, ...] -= updates[...] + + # Vector indices (for each i) + ref[indices[i], ...] -= updates[i, ...] + + # High rank indices (for each i, ..., j) + ref[indices[i, ..., j], ...] -= updates[i, ..., j, ...] + +Duplicate entries are handled correctly: if multiple `indices` reference +the same location, their contributions add. + +Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. + +
+ +
+ }]; + + let arguments = (ins + Arg:$resource, + TF_I32OrI64Tensor:$indices, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$updates + ); + + let results = (outs); + + TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<1>; + TF_DerivedOperandTypeAttr dtype = TF_DerivedOperandTypeAttr<2>; +} + def TF_ResourceScatterUpdateOp : TF_Op<"ResourceScatterUpdate", []> { let summary = [{ Assigns sparse updates to the variable referenced by `resource`. @@ -8405,7 +9840,7 @@ This operation computes }]; let arguments = (ins - TF_ResourceTensor:$resource, + Arg:$resource, TF_I32OrI64Tensor:$indices, TF_Tensor:$updates ); @@ -8416,6 +9851,38 @@ This operation computes TF_DerivedOperandTypeAttr dtype = TF_DerivedOperandTypeAttr<2>; } +def TF_ResourceStridedSliceAssignOp : TF_Op<"ResourceStridedSliceAssign", []> { + let summary = "Assign `value` to the sliced l-value reference of `ref`."; + + let description = [{ +The values of `value` are assigned to the positions in the variable +`ref` that are selected by the slice parameters. The slice parameters +`begin, `end`, `strides`, etc. work exactly as in `StridedSlice`. + +NOTE this op currently does not support broadcasting and so `value`'s +shape must be exactly the shape produced by the slice of `ref`. + }]; + + let arguments = (ins + Arg:$ref, + TF_I32OrI64Tensor:$begin, + TF_I32OrI64Tensor:$end, + TF_I32OrI64Tensor:$strides, + TF_Tensor:$value, + + DefaultValuedAttr:$begin_mask, + DefaultValuedAttr:$end_mask, + DefaultValuedAttr:$ellipsis_mask, + DefaultValuedAttr:$new_axis_mask, + DefaultValuedAttr:$shrink_axis_mask + ); + + let results = (outs); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<4>; + TF_DerivedOperandTypeAttr Index = TF_DerivedOperandTypeAttr<1>; +} + def TF_RestoreV2Op : TF_Op<"RestoreV2", []> { let summary = "Restores tensors from a V2 checkpoint."; @@ -8577,12 +10044,12 @@ reverse(t, dims) ==> [[[[8, 9, 10, 11], }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Str, TF_Uint16, TF_Uint8]>:$tensor, + TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Str, TF_Uint16, TF_Uint8]>:$tensor, TF_I32OrI64Tensor:$axis ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Str, TF_Uint16, TF_Uint8]>:$output + TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Str, TF_Uint16, TF_Uint8]>:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -8657,11 +10124,11 @@ rint([-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]) ==> [-2., -2., -0., 0., 2., 2., 2.] }]; let arguments = (ins - TF_FpTensor:$x + TF_FloatTensor:$x ); let results = (outs - TF_FpTensor:$y + TF_FloatTensor:$y ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -8719,11 +10186,11 @@ according to the current system rounding mode use std::cint. }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64]>:$x + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8]>:$x ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64]>:$y + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8]>:$y ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -8938,12 +10405,12 @@ tf.segment_mean(c, tf.constant([0, 0, 1])) }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$data, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$data, TF_I32OrI64Tensor:$segment_ids ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output ); TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<1>; @@ -9020,12 +10487,12 @@ tf.segment_prod(c, tf.constant([0, 0, 1])) }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$data, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$data, TF_I32OrI64Tensor:$segment_ids ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output ); TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<1>; @@ -9061,12 +10528,12 @@ tf.segment_sum(c, tf.constant([0, 0, 1])) }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$data, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$data, TF_I32OrI64Tensor:$segment_ids ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output ); TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<1>; @@ -9118,7 +10585,7 @@ select(condition, t, e) ==> [[1, 2], }]; let arguments = (ins - I1Tensor:$condition, + TF_BoolTensor:$condition, TF_Tensor:$t, TF_Tensor:$e ); @@ -9140,7 +10607,7 @@ def TF_SelectV2Op : TF_Op<"SelectV2", [NoSideEffect, ResultsBroadcastableShape]> let summary = ""; let arguments = (ins - I1Tensor:$condition, + TF_BoolTensor:$condition, TF_Tensor:$t, TF_Tensor:$e ); @@ -9152,7 +10619,7 @@ def TF_SelectV2Op : TF_Op<"SelectV2", [NoSideEffect, ResultsBroadcastableShape]> TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>; let builders = [ - OpBuilder<"OpBuilder& builder, OperationState& result, Value condition, Value e, Value t"> + OpBuilder<"Value condition, Value e, Value t"> ]; } @@ -9176,14 +10643,14 @@ e = self_adjoint_eig(a, compute_v=False) }]; let arguments = (ins - TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$input, + TensorOf<[TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64]>:$input, DefaultValuedAttr:$compute_v ); let results = (outs - TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$e, - TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$v + TensorOf<[TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64]>:$e, + TensorOf<[TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64]>:$v ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -9205,11 +10672,11 @@ See [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515) }]; let arguments = (ins - TF_FpTensor:$features + TF_FloatTensor:$features ); let results = (outs - TF_FpTensor:$activations + TF_FloatTensor:$activations ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -9221,17 +10688,33 @@ Computes gradients for the scaled exponential linear (Selu) operation. }]; let arguments = (ins - TF_FpTensor:$gradients, - TF_FpTensor:$outputs + TF_FloatTensor:$gradients, + TF_FloatTensor:$outputs ); let results = (outs - TF_FpTensor:$backprops + TF_FloatTensor:$backprops ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_SerializeIteratorOp : TF_Op<"SerializeIterator", []> { + let summary = [{ +Converts the given `resource_handle` representing an iterator to a variant tensor. + }]; + + let arguments = (ins + Arg:$resource_handle, + + DefaultValuedAttr:$external_state_policy + ); + + let results = (outs + TF_VariantTensor:$serialized + ); +} + def TF_ShapeOp : TF_Op<"Shape", [NoSideEffect]> { let summary = "Returns the shape of a tensor."; @@ -9262,7 +10745,7 @@ shape(t) ==> [2, 2, 3] }]; let builders = [ - OpBuilder<"OpBuilder& builder, OperationState& result, Value input, BoolAttr use32Bit"> + OpBuilder<"Value input, BoolAttr use32Bit"> ]; let hasFolder = 1; @@ -9305,8 +10788,8 @@ Generate a sharded filename. The filename is printf formatted as let arguments = (ins TF_StrTensor:$basename, - I32Tensor:$shard, - I32Tensor:$num_shards + TF_Int32Tensor:$shard, + TF_Int32Tensor:$num_shards ); let results = (outs @@ -9314,6 +10797,76 @@ Generate a sharded filename. The filename is printf formatted as ); } +def TF_ShuffleAndRepeatDatasetV2Op : TF_Op<"ShuffleAndRepeatDatasetV2", []> { + let summary = ""; + + let arguments = (ins + TF_VariantTensor:$input_dataset, + TF_Int64Tensor:$buffer_size, + TF_Int64Tensor:$seed, + TF_Int64Tensor:$seed2, + TF_Int64Tensor:$count, + Arg:$seed_generator, + + DefaultValuedAttr:$reshuffle_each_iteration, + Confined]>:$output_types, + Confined]>:$output_shapes + ); + + let results = (outs + TF_VariantTensor:$handle + ); +} + +def TF_ShuffleDatasetV2Op : TF_Op<"ShuffleDatasetV2", []> { + let summary = ""; + + let arguments = (ins + TF_VariantTensor:$input_dataset, + TF_Int64Tensor:$buffer_size, + Arg:$seed_generator, + + Confined]>:$output_types, + Confined]>:$output_shapes + ); + + let results = (outs + TF_VariantTensor:$handle + ); +} + +def TF_ShuffleDatasetV3Op : TF_Op<"ShuffleDatasetV3", []> { + let summary = ""; + + let arguments = (ins + TF_VariantTensor:$input_dataset, + TF_Int64Tensor:$buffer_size, + TF_Int64Tensor:$seed, + TF_Int64Tensor:$seed2, + Arg:$seed_generator, + + DefaultValuedAttr:$reshuffle_each_iteration, + Confined]>:$output_types, + Confined]>:$output_shapes + ); + + let results = (outs + TF_VariantTensor:$handle + ); +} + +def TF_ShutdownDistributedTPUOp : TF_Op<"ShutdownDistributedTPU", []> { + let summary = "Shuts down a running distributed TPU system."; + + let description = [{ +The op returns an error if no system is running. + }]; + + let arguments = (ins); + + let results = (outs); +} + def TF_SigmoidOp : TF_Op<"Sigmoid", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Computes sigmoid of `x` element-wise."; @@ -9366,11 +10919,11 @@ Example usage: }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$x + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64]>:$x ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$y + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64]>:$y ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -9513,11 +11066,11 @@ For each batch `i` and class `j` we have }]; let arguments = (ins - TF_FpTensor:$logits + TF_FloatTensor:$logits ); let results = (outs - TF_FpTensor:$softmax + TF_FloatTensor:$softmax ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -9537,13 +11090,13 @@ Inputs are the logits, not probabilities. }]; let arguments = (ins - TF_FpTensor:$features, - TF_FpTensor:$labels + TF_FloatTensor:$features, + TF_FloatTensor:$labels ); let results = (outs - TF_FpTensor:$loss, - TF_FpTensor:$backprop + TF_FloatTensor:$loss, + TF_FloatTensor:$backprop ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -9557,11 +11110,11 @@ def TF_SoftplusOp : TF_Op<"Softplus", [NoSideEffect, SameOperandsAndResultType]> let summary = "Computes softplus: `log(exp(features) + 1)`."; let arguments = (ins - TF_FpTensor:$features + TF_FloatTensor:$features ); let results = (outs - TF_FpTensor:$activations + TF_FloatTensor:$activations ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -9571,12 +11124,12 @@ def TF_SoftplusGradOp : TF_Op<"SoftplusGrad", [NoSideEffect, SameOperandsAndResu let summary = "Computes softplus gradients for a softplus operation."; let arguments = (ins - TF_FpTensor:$gradients, - TF_FpTensor:$features + TF_FloatTensor:$gradients, + TF_FloatTensor:$features ); let results = (outs - TF_FpTensor:$backprops + TF_FloatTensor:$backprops ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -9586,11 +11139,11 @@ def TF_SoftsignOp : TF_Op<"Softsign", [NoSideEffect, SameOperandsAndResultType]> let summary = "Computes softsign: `features / (abs(features) + 1)`."; let arguments = (ins - TF_FpTensor:$features + TF_FloatTensor:$features ); let results = (outs - TF_FpTensor:$activations + TF_FloatTensor:$activations ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -9600,12 +11153,12 @@ def TF_SoftsignGradOp : TF_Op<"SoftsignGrad", [NoSideEffect, SameOperandsAndResu let summary = "Computes softsign gradients for a softsign operation."; let arguments = (ins - TF_FpTensor:$gradients, - TF_FpTensor:$features + TF_FloatTensor:$gradients, + TF_FloatTensor:$features ); let results = (outs - TF_FpTensor:$backprops + TF_FloatTensor:$backprops ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -9639,7 +11192,7 @@ block size. TF_DerivedOperandTypeAttr Tpaddings = TF_DerivedOperandTypeAttr<1>; } -def TF_SpaceToBatchNDOp : TF_Op<"SpaceToBatchND", [NoSideEffect]> { +def TF_SpaceToBatchNDOp : TF_Op<"SpaceToBatchND", [DeclareOpInterfaceMethods, NoSideEffect]> { let summary = "SpaceToBatch for N-D tensors of type T."; let description = [{ @@ -9666,6 +11219,14 @@ precise description. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; TF_DerivedOperandTypeAttr Tpaddings = TF_DerivedOperandTypeAttr<2>; TF_DerivedOperandTypeAttr Tblock_shape = TF_DerivedOperandTypeAttr<1>; + + let verifier = [{ return Verify(*this); }]; + + let extraClassDeclaration = [{ + static bool isCompatibleReturnTypes(ArrayRef l, ArrayRef r) { + return ArraysAreCastCompatible(l, r); + } + }]; } def TF_SpaceToDepthOp : TF_Op<"SpaceToDepth", [NoSideEffect]> { @@ -9816,22 +11377,57 @@ backpropagation, }]; let arguments = (ins - I64Tensor:$indices, + TF_Int64Tensor:$indices, TF_Tensor:$values, - I64Tensor:$dense_shape, + TF_Int64Tensor:$dense_shape, TF_Tensor:$default_value ); let results = (outs - I64Tensor:$output_indices, + TF_Int64Tensor:$output_indices, TF_Tensor:$output_values, - I1Tensor:$empty_row_indicator, - I64Tensor:$reverse_index_map + TF_BoolTensor:$empty_row_indicator, + TF_Int64Tensor:$reverse_index_map ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>; } +def TF_SparseMatMulOp : TF_Op<"SparseMatMul", [NoSideEffect]> { + let summary = [{ +Multiply matrix "a" by matrix "b". + }]; + + let description = [{ +The inputs must be two-dimensional matrices and the inner dimension of "a" must +match the outer dimension of "b". Both "a" and "b" must be `Tensor`s not +`SparseTensor`s. This op is optimized for the case where at least one of "a" or +"b" is sparse, in the sense that they have a large proportion of zero values. +The breakeven for using this versus a dense matrix multiply on one platform was +30% zero values in the sparse matrix. + +The gradient computation of this operation will only take advantage of sparsity +in the input gradient when that gradient comes from a Relu. + }]; + + let arguments = (ins + TensorOf<[TF_Bfloat16, TF_Float32]>:$a, + TensorOf<[TF_Bfloat16, TF_Float32]>:$b, + + DefaultValuedAttr:$transpose_a, + DefaultValuedAttr:$transpose_b, + DefaultValuedAttr:$a_is_sparse, + DefaultValuedAttr:$b_is_sparse + ); + + let results = (outs + TF_Float32Tensor:$product + ); + + TF_DerivedOperandTypeAttr Ta = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr Tb = TF_DerivedOperandTypeAttr<1>; +} + def TF_SparseReshapeOp : TF_Op<"SparseReshape", [NoSideEffect]> { let summary = [{ Reshapes a SparseTensor to represent values in a new dense shape. @@ -9856,14 +11452,14 @@ has length `R_out`, then `input_indices` has shape `[N, R_in]`, }]; let arguments = (ins - I64Tensor:$input_indices, - I64Tensor:$input_shape, - I64Tensor:$new_shape + TF_Int64Tensor:$input_indices, + TF_Int64Tensor:$input_shape, + TF_Int64Tensor:$new_shape ); let results = (outs - I64Tensor:$output_indices, - I64Tensor:$output_shape + TF_Int64Tensor:$output_indices, + TF_Int64Tensor:$output_shape ); } @@ -9879,13 +11475,13 @@ See `tf.sparse.segment_sum` for usage examples. }]; let arguments = (ins - TensorOf<[BF16, F32, F64]>:$data, + TensorOf<[TF_Bfloat16, TF_Float32, TF_Float64]>:$data, TF_I32OrI64Tensor:$indices, TF_I32OrI64Tensor:$segment_ids ); let results = (outs - TensorOf<[BF16, F32, F64]>:$output + TensorOf<[TF_Bfloat16, TF_Float32, TF_Float64]>:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -9908,13 +11504,13 @@ Inputs are the logits, not probabilities. }]; let arguments = (ins - TF_FpTensor:$features, + TF_FloatTensor:$features, TF_I32OrI64Tensor:$labels ); let results = (outs - TF_FpTensor:$loss, - TF_FpTensor:$backprop + TF_FloatTensor:$loss, + TF_FloatTensor:$backprop ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -9969,7 +11565,7 @@ def TF_SplitOp : TF_Op<"Split", [NoSideEffect]> { let summary = "Splits a tensor into `num_split` tensors along one dimension."; let arguments = (ins - I32Tensor:$split_dim, + TF_Int32Tensor:$split_dim, TF_Tensor:$value ); @@ -9989,7 +11585,7 @@ def TF_SplitVOp : TF_Op<"SplitV", [NoSideEffect]> { let arguments = (ins TF_Tensor:$value, TF_I32OrI64Tensor:$size_splits, - I32Tensor:$split_dim + TF_Int32Tensor:$split_dim ); let results = (outs @@ -10049,11 +11645,11 @@ I.e., \\(y = x * x = x^2\\). }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64]>:$x + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8]>:$x ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64]>:$y + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8]>:$y ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -10071,12 +11667,12 @@ def TF_SquaredDifferenceOp : TF_Op<"SquaredDifference", [Commutative, NoSideEffe }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$x, - TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$y + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64]>:$x, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64]>:$y ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$z + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64]>:$z ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -10123,7 +11719,7 @@ def TF_StackCloseV2Op : TF_Op<"StackCloseV2", []> { let summary = "Delete the stack from its resource container."; let arguments = (ins - TF_ResourceTensor:$handle + Arg:$handle ); let results = (outs); @@ -10133,7 +11729,7 @@ def TF_StackPopV2Op : TF_Op<"StackPopV2", []> { let summary = "Pop the element at the top of the stack."; let arguments = (ins - TF_ResourceTensor:$handle + Arg:$handle ); let results = (outs @@ -10147,7 +11743,7 @@ def TF_StackPushV2Op : TF_Op<"StackPushV2", []> { let summary = "Push an element onto the stack."; let arguments = (ins - TF_ResourceTensor:$handle, + Arg:$handle, TF_Tensor:$elem, DefaultValuedAttr:$swap_memory @@ -10164,23 +11760,23 @@ def TF_StackV2Op : TF_Op<"StackV2", []> { let summary = "A stack that produces elements in first-in last-out order."; let arguments = (ins - I32Tensor:$max_size, + TF_Int32Tensor:$max_size, TypeAttr:$elem_type, StrAttr:$stack_name ); let results = (outs - TF_ResourceTensor:$handle + Res:$handle ); } -def TF_StatelessMultinomialOp : TF_Op<"StatelessMultinomial", [NoSideEffect]> { +def TF_StatelessMultinomialOp : TF_Op<"StatelessMultinomial", [NoSideEffect, TF_NoConstantFold]> { let summary = "Draws samples from a multinomial distribution."; let arguments = (ins TF_IntOrFpTensor:$logits, - I32Tensor:$num_samples, + TF_Int32Tensor:$num_samples, TF_I32OrI64Tensor:$seed ); @@ -10193,7 +11789,82 @@ def TF_StatelessMultinomialOp : TF_Op<"StatelessMultinomial", [NoSideEffect]> { TF_DerivedResultTypeAttr output_dtype = TF_DerivedResultTypeAttr<0>; } -def TF_StatelessRandomNormalOp : TF_Op<"StatelessRandomNormal", [NoSideEffect]> { +def TF_StatelessParameterizedTruncatedNormalOp : TF_Op<"StatelessParameterizedTruncatedNormal", [NoSideEffect, TF_NoConstantFold]> { + let summary = ""; + + let arguments = (ins + TF_I32OrI64Tensor:$shape, + TF_I32OrI64Tensor:$seed, + TensorOf<[TF_Float16, TF_Float32, TF_Float64]>:$means, + TensorOf<[TF_Float16, TF_Float32, TF_Float64]>:$stddevs, + TensorOf<[TF_Float16, TF_Float32, TF_Float64]>:$minvals, + TensorOf<[TF_Float16, TF_Float32, TF_Float64]>:$maxvals + ); + + let results = (outs + TensorOf<[TF_Float16, TF_Float32, TF_Float64]>:$output + ); + + TF_DerivedOperandTypeAttr S = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr Tseed = TF_DerivedOperandTypeAttr<1>; + TF_DerivedOperandTypeAttr dtype = TF_DerivedOperandTypeAttr<2>; +} + +def TF_StatelessRandomBinomialOp : TF_Op<"StatelessRandomBinomial", [NoSideEffect, TF_NoConstantFold]> { + let summary = [{ +Outputs deterministic pseudorandom random numbers from a binomial distribution. + }]; + + let description = [{ +Outputs random values from a binomial distribution. + +The outputs are a deterministic function of `shape`, `seed`, `counts`, and `probs`. + }]; + + let arguments = (ins + TF_I32OrI64Tensor:$shape, + TF_I32OrI64Tensor:$seed, + TensorOf<[TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64]>:$counts, + TensorOf<[TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64]>:$probs + ); + + let results = (outs + TensorOf<[TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64]>:$output + ); + + TF_DerivedOperandTypeAttr S = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<2>; + TF_DerivedOperandTypeAttr Tseed = TF_DerivedOperandTypeAttr<1>; + TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>; +} + +def TF_StatelessRandomGammaV2Op : TF_Op<"StatelessRandomGammaV2", [NoSideEffect, TF_NoConstantFold]> { + let summary = [{ +Outputs deterministic pseudorandom random numbers from a gamma distribution. + }]; + + let description = [{ +Outputs random values from a gamma distribution. + +The outputs are a deterministic function of `shape`, `seed`, and `alpha`. + }]; + + let arguments = (ins + TF_I32OrI64Tensor:$shape, + TF_I32OrI64Tensor:$seed, + TensorOf<[TF_Float16, TF_Float32, TF_Float64]>:$alpha + ); + + let results = (outs + TensorOf<[TF_Float16, TF_Float32, TF_Float64]>:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr Tseed = TF_DerivedOperandTypeAttr<1>; + TF_DerivedOperandTypeAttr dtype = TF_DerivedOperandTypeAttr<2>; +} + +def TF_StatelessRandomNormalOp : TF_Op<"StatelessRandomNormal", [NoSideEffect, TF_NoConstantFold]> { let summary = [{ Outputs deterministic pseudorandom values from a normal distribution. }]; @@ -10210,7 +11881,7 @@ The outputs are a deterministic function of `shape` and `seed`. ); let results = (outs - TF_FpTensor:$output + TF_FloatTensor:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -10218,7 +11889,34 @@ The outputs are a deterministic function of `shape` and `seed`. TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>; } -def TF_StatelessRandomUniformOp : TF_Op<"StatelessRandomUniform", [NoSideEffect]> { +def TF_StatelessRandomPoissonOp : TF_Op<"StatelessRandomPoisson", [NoSideEffect, TF_NoConstantFold]> { + let summary = [{ +Outputs deterministic pseudorandom random numbers from a Poisson distribution. + }]; + + let description = [{ +Outputs random values from a Poisson distribution. + +The outputs are a deterministic function of `shape`, `seed`, and `lam`. + }]; + + let arguments = (ins + TF_I32OrI64Tensor:$shape, + TF_I32OrI64Tensor:$seed, + TensorOf<[TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64]>:$lam + ); + + let results = (outs + TensorOf<[TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64]>:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr Tseed = TF_DerivedOperandTypeAttr<1>; + TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>; + TF_DerivedOperandTypeAttr Rtype = TF_DerivedOperandTypeAttr<2>; +} + +def TF_StatelessRandomUniformOp : TF_Op<"StatelessRandomUniform", [NoSideEffect, TF_NoConstantFold]> { let summary = [{ Outputs deterministic pseudorandom random values from a uniform distribution. }]; @@ -10236,7 +11934,7 @@ The outputs are a deterministic function of `shape` and `seed`. ); let results = (outs - TF_FpTensor:$output + TF_FloatTensor:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -10244,7 +11942,32 @@ The outputs are a deterministic function of `shape` and `seed`. TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>; } -def TF_StatelessRandomUniformIntOp : TF_Op<"StatelessRandomUniformInt", [NoSideEffect]> { +def TF_StatelessRandomUniformFullIntOp : TF_Op<"StatelessRandomUniformFullInt", [NoSideEffect, TF_NoConstantFold]> { + let summary = [{ +Outputs deterministic pseudorandom random integers from a uniform distribution. + }]; + + let description = [{ +The generated values are uniform integers covering the whole range of `dtype`. + +The outputs are a deterministic function of `shape` and `seed`. + }]; + + let arguments = (ins + TF_I32OrI64Tensor:$shape, + TensorOf<[TF_Int32, TF_Int64, TF_Uint32, TF_Uint64]>:$seed + ); + + let results = (outs + TensorOf<[TF_Int32, TF_Int64, TF_Uint32, TF_Uint64]>:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr Tseed = TF_DerivedOperandTypeAttr<1>; + TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>; +} + +def TF_StatelessRandomUniformIntOp : TF_Op<"StatelessRandomUniformInt", [NoSideEffect, TF_NoConstantFold]> { let summary = [{ Outputs deterministic pseudorandom random integers from a uniform distribution. }]; @@ -10271,7 +11994,7 @@ The outputs are a deterministic function of `shape`, `seed`, `minval`, and `maxv TF_DerivedOperandTypeAttr dtype = TF_DerivedOperandTypeAttr<2>; } -def TF_StatelessTruncatedNormalOp : TF_Op<"StatelessTruncatedNormal", [NoSideEffect]> { +def TF_StatelessTruncatedNormalOp : TF_Op<"StatelessTruncatedNormal", [NoSideEffect, TF_NoConstantFold]> { let summary = [{ Outputs deterministic pseudorandom values from a truncated normal distribution. }]; @@ -10290,7 +12013,7 @@ The outputs are a deterministic function of `shape` and `seed`. ); let results = (outs - TF_FpTensor:$output + TF_FloatTensor:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -10566,7 +12289,7 @@ array([0, 2, 2]) ); let results = (outs - I64Tensor:$output + TF_Int64Tensor:$output ); } @@ -10580,12 +12303,12 @@ def TF_SubOp : TF_Op<"Sub", [NoSideEffect, ResultsBroadcastableShape, TF_CwiseBi }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint32, TF_Uint8]>:$x, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint32, TF_Uint8]>:$y + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Uint16, TF_Uint32, TF_Uint8]>:$x, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Uint16, TF_Uint32, TF_Uint8]>:$y ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint32, TF_Uint8]>:$z + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Uint16, TF_Uint32, TF_Uint8]>:$z ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -10606,23 +12329,22 @@ retained with length 1. }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input, TF_I32OrI64Tensor:$reduction_indices, DefaultValuedAttr:$keep_dims ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>; - let builders = [OpBuilder< - "OpBuilder &builder, OperationState &result, Value input, " - "Value reduction_indices, BoolAttr keep_dims" - >]; + let builders = [ + OpBuilder<"Value input, Value reduction_indices, BoolAttr keep_dims"> + ]; } def TF_SymbolicGradientOp : TF_Op<"SymbolicGradient", [NoSideEffect]> { @@ -10687,7 +12409,7 @@ For internal use only. let arguments = (ins TF_Tensor:$input, - I64Tensor:$layout + TF_Int64Tensor:$layout ); let results = (outs @@ -10697,6 +12419,30 @@ For internal use only. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_TPUEmbeddingActivationsOp : TF_Op<"TPUEmbeddingActivations", [NoSideEffect]> { + let summary = "An op enabling differentiation of TPU Embeddings."; + + let description = [{ +This op simply returns its first input, which is assumed to have been sliced +from the Tensors returned by TPUEmbeddingDequeueActivations. The presence of +this op, and its first argument being a trainable Variable, enables automatic +differentiation of graphs containing embeddings via the TPU Embedding Python +libraries. + }]; + + let arguments = (ins + TF_Float32Tensor:$embedding_variable, + TF_Float32Tensor:$sliced_activations, + + Confined]>:$table_id, + Confined]>:$lookup_id + ); + + let results = (outs + TF_Float32Tensor:$output + ); +} + def TF_TPUExecuteOp : TF_Op<"TPUExecute", []> { let summary = "Op that loads and executes a TPU program on a TPU device."; @@ -10765,7 +12511,7 @@ For internal use only. ); let results = (outs - I64Tensor:$layout + TF_Int64Tensor:$layout ); } @@ -10781,7 +12527,7 @@ consumed by TPUPartitionedCall. let arguments = (ins); let results = (outs - I32Tensor:$device_ordinals + TF_Int32Tensor:$device_ordinals ); } @@ -10858,9 +12604,9 @@ variables. }]; let arguments = (ins - Variadic:$vars, + Arg, "", [TF_VariableRead, TF_VariableWrite]>:$vars, TF_StrTensor:$new_format_key, - TF_ResourceTensor:$format_state_var + Arg:$format_state_var ); let results = (outs); @@ -10884,11 +12630,11 @@ Given an input tensor, this function computes tangent of every }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64]>:$x + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8]>:$x ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64]>:$y + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8]>:$y ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -10902,10 +12648,11 @@ Given an input tensor, this function computes hyperbolic tangent of every element in the tensor. Input range is `[-inf, inf]` and output range is `[-1,1]`. - ```python - x = tf.constant([-float("inf"), -5, -0.5, 1, 1.2, 2, 3, float("inf")]) - tf.math.tanh(x) ==> [-1. -0.99990916 -0.46211717 0.7615942 0.8336547 0.9640276 0.9950547 1.] - ``` + >>> x = tf.constant([-float("inf"), -5, -0.5, 1, 1.2, 2, 3, float("inf")]) + >>> tf.math.tanh(x) + }]; let arguments = (ins @@ -10948,7 +12695,7 @@ of a step/run. }]; let arguments = (ins - TF_ResourceTensor:$handle + Arg:$handle ); let results = (outs); @@ -10972,15 +12719,15 @@ All elements must have the same shape (excepting the first dimension). }]; let arguments = (ins - TF_ResourceTensor:$handle, - F32Tensor:$flow_in, + Arg:$handle, + TF_Float32Tensor:$flow_in, DefaultValuedAttr:$element_shape_except0 ); let results = (outs TF_Tensor:$value, - I64Tensor:$lengths + TF_Int64Tensor:$lengths ); TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>; @@ -10996,9 +12743,9 @@ All elements selected by `indices` must have the same shape. }]; let arguments = (ins - TF_ResourceTensor:$handle, - I32Tensor:$indices, - F32Tensor:$flow_in, + Arg:$handle, + TF_Int32Tensor:$indices, + TF_Float32Tensor:$flow_in, DefaultValuedAttr:$element_shape ); @@ -11055,15 +12802,15 @@ calculation gets its own TensorArray accumulator. }]; let arguments = (ins - TF_ResourceTensor:$handle, - F32Tensor:$flow_in, + Arg:$handle, + TF_Float32Tensor:$flow_in, StrAttr:$source ); let results = (outs - TF_ResourceTensor:$grad_handle, - F32Tensor:$flow_out + Res:$grad_handle, + TF_Float32Tensor:$flow_out ); } @@ -11071,9 +12818,9 @@ def TF_TensorArrayReadV3Op : TF_Op<"TensorArrayReadV3", []> { let summary = "Read an element from the TensorArray into output `value`."; let arguments = (ins - TF_ResourceTensor:$handle, - I32Tensor:$index, - F32Tensor:$flow_in + Arg:$handle, + TF_Int32Tensor:$index, + TF_Float32Tensor:$flow_in ); let results = (outs @@ -11093,14 +12840,14 @@ Scatter the data from the input value into specific TensorArray elements. }]; let arguments = (ins - TF_ResourceTensor:$handle, - I32Tensor:$indices, + Arg:$handle, + TF_Int32Tensor:$indices, TF_Tensor:$value, - F32Tensor:$flow_in + TF_Float32Tensor:$flow_in ); let results = (outs - F32Tensor:$flow_out + TF_Float32Tensor:$flow_out ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<2>; @@ -11110,12 +12857,12 @@ def TF_TensorArraySizeV3Op : TF_Op<"TensorArraySizeV3", []> { let summary = "Get the current size of the TensorArray."; let arguments = (ins - TF_ResourceTensor:$handle, - F32Tensor:$flow_in + Arg:$handle, + TF_Float32Tensor:$flow_in ); let results = (outs - I32Tensor:$size + TF_Int32Tensor:$size ); } @@ -11145,14 +12892,14 @@ and having size }]; let arguments = (ins - TF_ResourceTensor:$handle, + Arg:$handle, TF_Tensor:$value, - I64Tensor:$lengths, - F32Tensor:$flow_in + TF_Int64Tensor:$lengths, + TF_Float32Tensor:$flow_in ); let results = (outs - F32Tensor:$flow_out + TF_Float32Tensor:$flow_out ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>; @@ -11166,7 +12913,7 @@ Write data via Write and read via Read or Pack. }]; let arguments = (ins - I32Tensor:$size, + TF_Int32Tensor:$size, TypeAttr:$dtype, DefaultValuedAttr:$element_shape, @@ -11177,8 +12924,8 @@ Write data via Write and read via Read or Pack. ); let results = (outs - TF_ResourceTensor:$handle, - F32Tensor:$flow + Res:$handle, + TF_Float32Tensor:$flow ); } @@ -11186,14 +12933,14 @@ def TF_TensorArrayWriteV3Op : TF_Op<"TensorArrayWriteV3", []> { let summary = "Push an element onto the tensor_array."; let arguments = (ins - TF_ResourceTensor:$handle, - I32Tensor:$index, + Arg:$handle, + TF_Int32Tensor:$index, TF_Tensor:$value, - F32Tensor:$flow_in + TF_Float32Tensor:$flow_in ); let results = (outs - F32Tensor:$flow_out + TF_Float32Tensor:$flow_out ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<2>; @@ -11219,12 +12966,12 @@ lengths: Output tensor containing sizes of the 0th dimension of tensors in the l let arguments = (ins TF_VariantTensor:$input_handle, TF_I32OrI64Tensor:$element_shape, - I64Tensor:$leading_dims + TF_Int64Tensor:$leading_dims ); let results = (outs TF_Tensor:$tensor, - I64Tensor:$lengths + TF_Int64Tensor:$lengths ); TF_DerivedOperandTypeAttr shape_type = TF_DerivedOperandTypeAttr<1>; @@ -11291,8 +13038,8 @@ values: The tensor. let arguments = (ins TF_VariantTensor:$input_handle, - I32Tensor:$indices, - I32Tensor:$element_shape + TF_Int32Tensor:$indices, + TF_Int32Tensor:$element_shape ); let results = (outs @@ -11307,8 +13054,8 @@ def TF_TensorListGetItemOp : TF_Op<"TensorListGetItem", [NoSideEffect]> { let arguments = (ins TF_VariantTensor:$input_handle, - I32Tensor:$index, - I32Tensor:$element_shape + TF_Int32Tensor:$index, + TF_Int32Tensor:$element_shape ); let results = (outs @@ -11331,7 +13078,7 @@ length: the number of tensors in the list ); let results = (outs - I32Tensor:$length + TF_Int32Tensor:$length ); } @@ -11351,7 +13098,7 @@ element_shape: the shape of the output tensor let arguments = (ins TF_VariantTensor:$input_handle, - I32Tensor:$element_shape + TF_Int32Tensor:$element_shape ); let results = (outs @@ -11397,7 +13144,7 @@ size: size of the output list let arguments = (ins TF_VariantTensor:$input_handle, - I32Tensor:$size + TF_Int32Tensor:$size ); let results = (outs @@ -11421,7 +13168,7 @@ output_handle: The TensorList. let arguments = (ins TF_VariantTensor:$input_handle, TF_Tensor:$tensor, - I32Tensor:$indices + TF_Int32Tensor:$indices ); let results = (outs @@ -11436,7 +13183,7 @@ def TF_TensorListSetItemOp : TF_Op<"TensorListSetItem", [NoSideEffect]> { let arguments = (ins TF_VariantTensor:$input_handle, - I32Tensor:$index, + TF_Int32Tensor:$index, TF_Tensor:$item ); @@ -11460,7 +13207,7 @@ num_elements: optional. If not -1, the number of elements in the list. let arguments = (ins TF_VariantTensor:$input_handle, - I32Tensor:$element_shape, + TF_Int32Tensor:$element_shape, DefaultValuedAttr:$num_elements ); @@ -11573,14 +13320,46 @@ On GPU, if an out of bound index is found, the index is ignored. let verifier = [{ return Verify(*this); }]; let builders = [ - OpBuilder< - "OpBuilder& builder, OperationState& result, " - "Value tensor, Value indices, Value updates", - [{build(builder, result, tensor.getType(), tensor, indices, updates);}] + OpBuilder<"Value tensor, Value indices, Value updates", + [{build($_builder, $_state, tensor.getType(), tensor, indices, updates);}] > ]; } +def TF_TensorStridedSliceUpdateOp : TF_Op<"TensorStridedSliceUpdate", [NoSideEffect]> { + let summary = "Assign `value` to the sliced l-value reference of `input`."; + + let description = [{ +The values of `value` are assigned to the positions in the tensor `input` that +are selected by the slice parameters. The slice parameters `begin` `end` +`strides` etc. work exactly as in `StridedSlice`. + +NOTE this op currently does not support broadcasting and so `value`'s shape +must be exactly the shape produced by the slice of `input`. + }]; + + let arguments = (ins + TF_Tensor:$input, + TF_I32OrI64Tensor:$begin, + TF_I32OrI64Tensor:$end, + TF_I32OrI64Tensor:$strides, + TF_Tensor:$value, + + DefaultValuedAttr:$begin_mask, + DefaultValuedAttr:$end_mask, + DefaultValuedAttr:$ellipsis_mask, + DefaultValuedAttr:$new_axis_mask, + DefaultValuedAttr:$shrink_axis_mask + ); + + let results = (outs + TF_Tensor:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr Index = TF_DerivedOperandTypeAttr<1>; +} + def TF_TileOp : TF_Op<"Tile", [NoSideEffect]> { let summary = "Constructs a tensor by tiling a given tensor."; @@ -11625,9 +13404,9 @@ array([[1, 2, 3, 1, 2, 3], TF_DerivedOperandTypeAttr Tmultiples = TF_DerivedOperandTypeAttr<1>; TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; - // TODO(parkers): Add folds for multiples = [1,...]. - // TODO(parkers): Add errors for negative multiples and multiples.size() != - // input.rank() + let verifier = [{ return Verify(*this); }]; + + let hasFolder = 1; } def TF_TopKV2Op : TF_Op<"TopKV2", [NoSideEffect]> { @@ -11650,14 +13429,14 @@ If two elements are equal, the lower-index element appears first. let arguments = (ins TF_IntOrFpTensor:$input, - I32Tensor:$k, + TF_Int32Tensor:$k, DefaultValuedAttr:$sorted ); let results = (outs TF_IntOrFpTensor:$values, - I32Tensor:$indices + TF_Int32Tensor:$indices ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -11686,8 +13465,7 @@ The output `y` has the same rank as `x`. The shapes of `x` and `y` satisfy: TF_DerivedOperandTypeAttr Tperm = TF_DerivedOperandTypeAttr<1>; let builders = [ - OpBuilder< - "OpBuilder& builder, OperationState& result, Value x, Value perm"> + OpBuilder<"Value x, Value perm"> ]; let verifier = [{ @@ -11712,12 +13490,12 @@ Python Semantics. }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$x, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$y + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Uint16, TF_Uint8]>:$x, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Uint16, TF_Uint8]>:$y ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$z + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Uint16, TF_Uint8]>:$z ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -11768,7 +13546,7 @@ deviations from the mean are dropped and re-picked. ); let results = (outs - TF_FpTensor:$output + TF_FloatTensor:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -11986,13 +13764,13 @@ dropped, and will not be included in the result. }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$data, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$data, TF_I32OrI64Tensor:$segment_ids, TF_I32OrI64Tensor:$num_segments ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output ); TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<1>; @@ -12035,13 +13813,13 @@ tf.unsorted_segment_sum(c, tf.constant([0, 1, 0]), num_segments=2) }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$data, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$data, TF_I32OrI64Tensor:$segment_ids, TF_I32OrI64Tensor:$num_segments ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output ); TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<1>; @@ -12095,11 +13873,11 @@ Checks whether a resource handle-based variable has been initialized. }]; let arguments = (ins - TF_ResourceTensor:$resource + Arg:$resource ); let results = (outs - I1Tensor:$is_initialized + TF_BoolTensor:$is_initialized ); let hasCanonicalizer = 1; @@ -12120,7 +13898,7 @@ shape(t) ==> [2, 2, 3] }]; let arguments = (ins - TF_ResourceTensor:$input + Arg:$input ); let results = (outs @@ -12226,161 +14004,27 @@ where(input) ==> [[0, 0, 0], }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input + TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input ); let results = (outs - I64Tensor:$index + TF_Int64Tensor:$index ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_WriteAudioSummaryOp : TF_Op<"WriteAudioSummary", []> { - let summary = "Writes an audio summary."; - - let description = [{ -Writes encoded audio summary `tensor` at `step` with `tag` using summary `writer`. -`sample_rate` is the audio sample rate is Hz. - }]; - - let arguments = (ins - TF_ResourceTensor:$writer, - I64Tensor:$step, - TF_StrTensor:$tag, - F32Tensor:$tensor, - F32Tensor:$sample_rate, - - Confined, [IntMinValue<1>]>:$max_outputs - ); - - let results = (outs); -} - -def TF_WriteGraphSummaryOp : TF_Op<"WriteGraphSummary", []> { - let summary = "Writes a graph summary."; - - let description = [{ -Writes TensorFlow graph `tensor` at `step` using summary `writer`. - }]; - - let arguments = (ins - TF_ResourceTensor:$writer, - I64Tensor:$step, - TF_StrTensor:$tensor - ); - - let results = (outs); -} - -def TF_WriteHistogramSummaryOp : TF_Op<"WriteHistogramSummary", []> { - let summary = "Writes a histogram summary."; - - let description = [{ -Writes histogram `values` at `step` with `tag` using summary `writer`. - }]; - - let arguments = (ins - TF_ResourceTensor:$writer, - I64Tensor:$step, - TF_StrTensor:$tag, - TF_IntOrFpTensor:$values - ); - - let results = (outs); - - TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<3>; -} - -def TF_WriteImageSummaryOp : TF_Op<"WriteImageSummary", []> { - let summary = "Writes an image summary."; - - let description = [{ -Writes image `tensor` at `step` with `tag` using summary `writer`. -`tensor` is image with shape [height, width, channels]. - }]; - - let arguments = (ins - TF_ResourceTensor:$writer, - I64Tensor:$step, - TF_StrTensor:$tag, - TensorOf<[F16, F32, TF_Uint8]>:$tensor, - TF_Uint8Tensor:$bad_color, - - Confined, [IntMinValue<1>]>:$max_images - ); - - let results = (outs); - - TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<3>; -} - -def TF_WriteRawProtoSummaryOp : TF_Op<"WriteRawProtoSummary", []> { - let summary = "Writes a serialized proto summary."; - - let description = [{ -Writes `tensor`, a serialized proto at `step` using summary `writer`. - }]; - - let arguments = (ins - TF_ResourceTensor:$writer, - I64Tensor:$step, - TF_StrTensor:$tensor - ); - - let results = (outs); -} - -def TF_WriteScalarSummaryOp : TF_Op<"WriteScalarSummary", []> { - let summary = "Writes a scalar summary."; - - let description = [{ -Writes scalar `value` at `step` with `tag` using summary `writer`. - }]; - - let arguments = (ins - TF_ResourceTensor:$writer, - I64Tensor:$step, - TF_StrTensor:$tag, - TF_IntOrFpTensor:$value - ); - - let results = (outs); - - TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<3>; -} - -def TF_WriteSummaryOp : TF_Op<"WriteSummary", []> { - let summary = "Writes a tensor summary."; - - let description = [{ -Writes `tensor` at `step` with `tag` using summary `writer`. - }]; - - let arguments = (ins - TF_ResourceTensor:$writer, - I64Tensor:$step, - TF_Tensor:$tensor, - TF_StrTensor:$tag, - TF_StrTensor:$summary_metadata - ); - - let results = (outs); - - TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<2>; -} - def TF_XdivyOp : TF_Op<"Xdivy", [NoSideEffect, ResultsBroadcastableShape, TF_SameOperandsAndResultElementTypeResolveRef]>, WithBroadcastableBinOpBuilder { let summary = "Returns 0 if x == 0, and x / y otherwise, elementwise."; let arguments = (ins - TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$x, - TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$y + TensorOf<[TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64]>:$x, + TensorOf<[TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64]>:$y ); let results = (outs - TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$z + TensorOf<[TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64]>:$z ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -12398,14 +14042,14 @@ for binary operators. }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lhs, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$rhs, - TF_I32OrI64Tensor:$broadcast_dims + Arg, [{the LHS input tensor}]>:$lhs, + Arg, [{the RHS input tensor}]>:$rhs, + Arg:$broadcast_dims ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lhs_output, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$rhs_output + Res, [{the broadcasted LHS tensor}]>:$lhs_output, + Res, [{the broadcasted RHS tensor}]>:$rhs_output ); TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<2>; @@ -12421,20 +14065,20 @@ https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lhs, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$rhs, - TF_I32OrI64Tensor:$window_strides, - TF_I32OrI64Tensor:$padding, - TF_I32OrI64Tensor:$lhs_dilation, - TF_I32OrI64Tensor:$rhs_dilation, - TF_I32OrI64Tensor:$feature_group_count, + Arg, [{the input tensor}]>:$lhs, + Arg, [{the kernel tensor}]>:$rhs, + Arg:$window_strides, + Arg:$padding, + Arg:$lhs_dilation, + Arg:$rhs_dilation, + Arg:$feature_group_count, StrAttr:$dimension_numbers, StrAttr:$precision_config ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output ); TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<2>; @@ -12450,15 +14094,15 @@ https://www.tensorflow.org/performance/xla/operation_semantics#dotgeneral }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lhs, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$rhs, + Arg, [{the LHS tensor}]>:$lhs, + Arg, [{the RHS tensor}]>:$rhs, StrAttr:$dimension_numbers, StrAttr:$precision_config ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -12479,8 +14123,11 @@ with dimension size equal to the rank of operand. }]; let arguments = (ins - TF_Tensor:$input, - TF_I32OrI64Tensor:$start_indices, + Arg:$input, + Arg:$start_indices, TF_I32OrI64Tensor:$size_indices ); @@ -12508,19 +14155,44 @@ Handling of out-of-bounds slice indices is implementation-defined. }]; let arguments = (ins - TF_Tensor:$input, - TF_Tensor:$update, - TF_I32OrI64Tensor:$indices + Arg:$input, + Arg:$update, + Arg:$indices ); let results = (outs - TF_Tensor:$output + Res:$output ); TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<2>; TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_XlaEinsumOp : TF_Op<"XlaEinsum", [NoSideEffect]> { + let summary = [{ +An op which supports basic einsum op with 2 inputs and 1 output. + }]; + + let description = [{ +This op has better TPU performance since it doesn't have explicitly reshape and +transpose operations as tf.einsum does. + }]; + + let arguments = (ins + TensorOf<[TF_Bfloat16, TF_Complex64, TF_Float32]>:$a, + TensorOf<[TF_Bfloat16, TF_Complex64, TF_Float32]>:$b, + + StrAttr:$equation + ); + + let results = (outs + TensorOf<[TF_Bfloat16, TF_Complex64, TF_Float32]>:$product + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_XlaGatherOp : TF_Op<"XlaGather", [NoSideEffect]> { let summary = "Wraps the XLA Gather operator documented at"; @@ -12529,16 +14201,16 @@ https://www.tensorflow.org/xla/operation_semantics#gather }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$operand, - TF_I32OrI64Tensor:$start_indices, - TF_I32OrI64Tensor:$slice_sizes, + Arg, [{The array we're gathering from.}]>:$operand, + Arg:$start_indices, + Arg:$slice_sizes, StrAttr:$dimension_numbers, BoolAttr:$indices_are_sorted ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output ); TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<1>; @@ -12580,13 +14252,13 @@ Sorts a tensor. Currently only sorts in ascending order are supported. }]; let arguments = (ins - TF_IntOrFpTensor:$keys, - TF_Tensor:$values + Arg:$keys, + Arg:$values ); let results = (outs - TF_IntOrFpTensor:$sorted_keys, - TF_Tensor:$sorted_values + Res:$sorted_keys, + Res:$sorted_values ); TF_DerivedOperandTypeAttr V = TF_DerivedOperandTypeAttr<1>; @@ -12602,15 +14274,15 @@ https://www.tensorflow.org/performance/xla/operation_semantics#pad }]; let arguments = (ins - TF_Tensor:$input, - TF_Tensor:$padding_value, - TF_I32OrI64Tensor:$padding_low, - TF_I32OrI64Tensor:$padding_high, - TF_I32OrI64Tensor:$padding_interior + Arg:$input, + Arg:$padding_value, + Arg:$padding_low, + Arg:$padding_high, + Arg:$padding_interior ); let results = (outs - TF_Tensor:$output + Res:$output ); TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<2>; @@ -12647,15 +14319,15 @@ https://www.tensorflow.org/performance/xla/operation_semantics#reduce . }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$init_value, + Arg, [{the input tensor}]>:$input, + Arg, [{a scalar representing the initial value for the reduction}]>:$init_value, I64ArrayAttr:$dimensions_to_reduce, SymbolRefAttr:$reducer ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -12667,7 +14339,7 @@ def TF_XlaReplicaIdOp : TF_Op<"XlaReplicaId", [NoSideEffect]> { let arguments = (ins); let results = (outs - I32Tensor:$id + TF_Int32Tensor:$id ); } @@ -12679,9 +14351,10 @@ https://www.tensorflow.org/xla/operation_semantics#scatter. }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$operand, - TF_I32OrI64Tensor:$scatter_indices, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$updates, + Arg, [{Array to be scattered into.}]>:$operand, + Arg:$scatter_indices, + Arg, [{Array containing the values that must be used for scattering.}]>:$updates, SymbolRefAttr:$update_computation, StrAttr:$dimension_numbers, @@ -12689,7 +14362,7 @@ https://www.tensorflow.org/xla/operation_semantics#scatter. ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output ); TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<1>; @@ -12710,7 +14383,7 @@ i=0...N-1. }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$a, + Arg, [{the input tensor.}]>:$a, BoolAttr:$lower, I64Attr:$max_iter, @@ -12718,8 +14391,10 @@ i=0...N-1. ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$w, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$v + Res, [{The eigenvalues in ascending order, each repeated according to its +multiplicity.}]>:$w, + Res, [{The column v[..., :, i] is the normalized eigenvector corresponding to the +eigenvalue w[..., i].}]>:$v ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -12758,7 +14433,7 @@ tensor such that tensor[...,:,:] = u[..., :, :] * Diag(s[..., :]) * Transpose(v[ }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$a, + Arg, [{the input tensor.}]>:$a, I64Attr:$max_iter, F32Attr:$epsilon, @@ -12766,9 +14441,10 @@ tensor such that tensor[...,:,:] = u[..., :, :] * Diag(s[..., :]) * Transpose(v[ ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$s, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$u, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$v + Res, [{Singular values. The values are sorted in reverse order of magnitude, so +s[..., 0] is the largest value, s[..., 1] is the second largest, etc.}]>:$s, + Res, [{Left singular vectors.}]>:$u, + Res, [{Right singular vectors.}]>:$v ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -12778,12 +14454,12 @@ def TF_Xlog1pyOp : TF_Op<"Xlog1py", [NoSideEffect, TF_SameOperandsAndResultEleme let summary = "Returns 0 if x == 0, and x * log1p(y) otherwise, elementwise."; let arguments = (ins - TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$x, - TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$y + TensorOf<[TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64]>:$x, + TensorOf<[TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64]>:$y ); let results = (outs - TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$z + TensorOf<[TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64]>:$z ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -12794,12 +14470,12 @@ def TF_XlogyOp : TF_Op<"Xlogy", [NoSideEffect, ResultsBroadcastableShape, TF_Sam let summary = "Returns 0 if x == 0, and x * log(y) otherwise, elementwise."; let arguments = (ins - TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$x, - TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$y + TensorOf<[TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64]>:$x, + TensorOf<[TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64]>:$y ); let results = (outs - TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$z + TensorOf<[TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64]>:$z ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -12828,12 +14504,12 @@ expected to create these operators. }]; let arguments = (ins - TensorOf<[F16, F32]>:$x, - F32Tensor:$scale, - F32Tensor:$offset, - F32Tensor:$mean, - F32Tensor:$variance, - Variadic>:$side_input, + TensorOf<[TF_Float16, TF_Float32]>:$x, + TF_Float32Tensor:$scale, + TF_Float32Tensor:$offset, + TF_Float32Tensor:$mean, + TF_Float32Tensor:$variance, + Variadic>:$side_input, DefaultValuedAttr:$epsilon, DefaultValuedAttr:$exponential_avg_factor, @@ -12843,12 +14519,12 @@ expected to create these operators. ); let results = (outs - TensorOf<[F16, F32]>:$y, - F32Tensor:$batch_mean, - F32Tensor:$batch_variance, - F32Tensor:$reserve_space_1, - F32Tensor:$reserve_space_2, - F32Tensor:$reserve_space_3 + TensorOf<[TF_Float16, TF_Float32]>:$y, + TF_Float32Tensor:$batch_mean, + TF_Float32Tensor:$batch_variance, + TF_Float32Tensor:$reserve_space_1, + TF_Float32Tensor:$reserve_space_2, + TF_Float32Tensor:$reserve_space_3 ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -12893,7 +14569,8 @@ create these operators. DefaultValuedAttr:$dilations, DefaultValuedAttr:$use_cudnn_on_gpu, DefaultValuedAttr:$fused_ops, - DefaultValuedAttr:$epsilon + DefaultValuedAttr:$epsilon, + DefaultValuedAttr:$leakyrelu_alpha ); let results = (outs @@ -12930,9 +14607,9 @@ expected to create these operators. }]; let arguments = (ins - TensorOf<[BF16, F32]>:$a, - TensorOf<[BF16, F32]>:$b, - Variadic>:$args, + TensorOf<[TF_Bfloat16, TF_Float32]>:$a, + TensorOf<[TF_Bfloat16, TF_Float32]>:$b, + Variadic>:$args, DefaultValuedAttr:$transpose_a, DefaultValuedAttr:$transpose_b, @@ -12941,7 +14618,7 @@ expected to create these operators. ); let results = (outs - TensorOf<[BF16, F32]>:$product + TensorOf<[TF_Bfloat16, TF_Float32]>:$product ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -12959,13 +14636,19 @@ Tensor of activations per table specified in the model. }]; let arguments = (ins - TF_VariantTensor:$deduplication_data, + Arg:$deduplication_data, StrAttr:$config ); let results = (outs - Variadic:$outputs + Res, [{A TensorList of embedding activations containing one Tensor per +embedding table in the model.}]>:$outputs ); TF_DerivedResultSizeAttr num_tables = TF_DerivedResultSizeAttr<0>; @@ -12991,7 +14674,7 @@ look up the program in the compilation cache. }]; let arguments = (ins - Variadic:$dynamic_shapes, + Variadic:$dynamic_shapes, StrAttr:$mlir_module, StrAttr:$metadata @@ -13034,13 +14717,13 @@ expected to create these operators. }]; let arguments = (ins - TensorOf<[F16, F32, F64]>:$x, + TensorOf<[TF_Float16, TF_Float32, TF_Float64]>:$x, StrArrayAttr:$op_names ); let results = (outs - TensorOf<[F16, F32, F64]>:$y + TensorOf<[TF_Float16, TF_Float32, TF_Float64]>:$y ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -13052,7 +14735,7 @@ A pseudo-op to represent host-side computation in an XLA program. }]; let arguments = (ins - Variadic:$inputs, + Arg, [{A list of tensors that will be sent to the host.}]>:$inputs, StrAttr:$send_key, StrAttr:$recv_key, @@ -13060,7 +14743,7 @@ A pseudo-op to represent host-side computation in an XLA program. ); let results = (outs - Variadic:$outputs + Res, [{A list of tensors that will be returned to the device.}]>:$outputs ); TF_DerivedOperandTypeListAttr Tinputs = TF_DerivedOperandTypeListAttr<0>; @@ -13073,14 +14756,15 @@ A placeholder op to receive values from a running XLA computation. }]; let arguments = (ins - TF_StrTensor:$dynamic_key, + Arg:$dynamic_key, StrAttr:$key, I64Attr:$device_ordinal ); let results = (outs - Variadic:$outputs + Res, [{A list of tensors that will be received from the XLA computation.}]>:$outputs ); TF_DerivedResultTypeListAttr Toutputs = TF_DerivedResultTypeListAttr<0>; @@ -13090,8 +14774,9 @@ def TF__XlaSendFromHostOp : TF_Op<"_XlaSendFromHost", []> { let summary = "A placeholder op to send values to a running XLA computation."; let arguments = (ins - Variadic:$inputs, - TF_StrTensor:$dynamic_key, + Arg, [{A list of tensors that will be sent to the XLA computation.}]>:$inputs, + Arg:$dynamic_key, StrAttr:$key, I64Attr:$device_ordinal diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td index 1755c975c23..15c0d7b10f7 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td @@ -46,7 +46,7 @@ Invariants: TODO: Make invariants more structured so that we can reference them in ops. }]; - let cppNamespace = "TF"; + let cppNamespace = "::mlir::TF"; } //===----------------------------------------------------------------------===// @@ -73,6 +73,9 @@ def TF_LayoutAgnostic : NativeOpTrait<"TF::LayoutAgnostic">; // certain state around within their implementations. def TF_CannotDuplicate : NativeOpTrait<"TF::CannotDuplicate">; +// Trait to indicate an operation cannot be constant folded. +def TF_NoConstantFold : NativeOpTrait<"TF::NoConstantFold">; + // Coefficient wise binary operation with implicit broadcasting support, for // example tf.Sub operation. def TF_CwiseBinary : NativeOpTrait<"TF::CwiseBinary">; @@ -108,14 +111,44 @@ class TF_ResourceBase : def TF_VariableResource : TF_ResourceBase<"Variable">; def TF_StackResource : TF_ResourceBase<"Stack">; def TF_TensorArrayResource : TF_ResourceBase<"TensorArray">; +def TF_SummaryResource : TF_ResourceBase<"Summary">; +def TF_LookupTableResource : TF_ResourceBase<"LookupTable">; +def TF_DatasetSeedGeneratorResource : TF_ResourceBase<"DatasetSeedGenerator">; +def TF_DatasetMemoryCacheResource : TF_ResourceBase<"DatasetMemoryCache">; +def TF_DatasetIteratorResource : TF_ResourceBase<"DatasetIterator">; def TF_VariableRead : MemRead; def TF_StackRead : MemRead; def TF_TensorArrayRead : MemRead; +def TF_LookupTableRead : MemRead; +def TF_DatasetSeedGeneratorRead : MemRead; +def TF_DatasetMemoryCacheRead : MemRead; +def TF_DatasetIteratorRead : MemRead; def TF_VariableWrite : MemWrite; def TF_StackWrite : MemWrite; def TF_TensorArrayWrite : MemWrite; +def TF_SummaryWrite : MemWrite; +def TF_LookupTableWrite : MemWrite; +def TF_DatasetSeedGeneratorWrite : MemWrite; +def TF_DatasetMemoryCacheWrite : MemWrite; +def TF_DatasetIteratorWrite : MemWrite; + +def TF_VariableAlloc : MemAlloc; +def TF_StackAlloc : MemAlloc; +def TF_TensorArrayAlloc : MemAlloc; +def TF_SummaryAlloc : MemAlloc; +def TF_LookupTableAlloc : MemAlloc; +def TF_DatasetSeedGeneratorAlloc : MemAlloc; +def TF_DatasetMemoryCacheAlloc : MemAlloc; +def TF_DatasetIteratorAlloc : MemAlloc; + +def TF_StackFree : MemFree; +def TF_TensorArrayFree : MemFree; +def TF_SummaryFree : MemFree; +def TF_DatasetSeedGeneratorFree : MemFree; +def TF_DatasetMemoryCacheFree : MemFree; +def TF_DatasetIteratorFree : MemFree; //===----------------------------------------------------------------------===// // TensorFlow op definitions @@ -157,118 +190,194 @@ class TF_TensorFlowType : "TensorFlow " # description # " type">, BuildableType<"getType()">; -// Any tensor element type allowed in TensorFlow ops -def TF_ElementType : Type, - "tf.dtype">; +//===----------------------------------------------------------------------===// +// Reference types -// Any TensorFlow tensor type -def TF_Tensor : TensorOf<[TF_ElementType]>; +// Float reference types +def TF_Float16Ref : TF_TensorFlowType<"HalfRef", "f16ref">; +def TF_Float32Ref : TF_TensorFlowType<"FloatRef", "f32ref">; +def TF_Float64Ref : TF_TensorFlowType<"DoubleRef", "f64ref">; +def TF_Bfloat16Ref : TF_TensorFlowType<"Bfloat16Ref", "bf16ref">; + +// Complex reference types +def TF_Complex64Ref : TF_TensorFlowType<"Complex64Ref", "complex64ref">; +def TF_Complex128Ref : TF_TensorFlowType<"Complex128Ref", "complex128ref">; + +// Integer reference types +def TF_Int8Ref : TF_TensorFlowType<"Int8Ref", "i8ref">; +def TF_Int16Ref : TF_TensorFlowType<"Int16Ref", "i16ref">; +def TF_Int32Ref : TF_TensorFlowType<"Int32Ref", "i32ref">; +def TF_Int64Ref : TF_TensorFlowType<"Int64Ref", "i64ref">; + +def TF_Uint8Ref : TF_TensorFlowType<"Uint8Ref", "ui8ref">; +def TF_Uint16Ref : TF_TensorFlowType<"Uint16Ref", "ui16ref">; +def TF_Uint32Ref : TF_TensorFlowType<"Uint32Ref", "ui32ref">; +def TF_Uint64Ref : TF_TensorFlowType<"Uint64Ref", "ui64ref">; + +// Quantized reference types +def TF_Qint8Ref : TF_TensorFlowType<"Qint8Ref", "qint8ref">; +def TF_Qint16Ref : TF_TensorFlowType<"Qint16Ref", "qint16ref">; +def TF_Qint32Ref : TF_TensorFlowType<"Qint32Ref", "qint32ref">; +def TF_Quint8Ref : TF_TensorFlowType<"Quint8Ref", "quint8ref">; +def TF_Quint16Ref : TF_TensorFlowType<"Quint16Ref", "quint16ref">; + +// Other reference types +def TF_BoolRef : TF_TensorFlowType<"BoolRef", "boolref">; +def TF_ResourceRef : TF_TensorFlowType<"ResourceRef", "resourceref">; +def TF_StrRef : TF_TensorFlowType<"StringRef", "stringref">; +def TF_VariantRef : TF_TensorFlowType<"VariantRef", "variantref">; //===----------------------------------------------------------------------===// -// Integer types +// Integer types (including corresponding reference types) -def TF_I32Or64 : SignlessIntOfWidths<[32, 64]>; +def TF_Bool : AnyTypeOf<[I<1>, TF_BoolRef], "bool">; -def TF_I32OrI64Tensor : TensorOf<[TF_I32Or64]>; +def TF_Int8 : AnyTypeOf<[I8, TF_Int8Ref], "8-bit integer">; +def TF_Int16 : AnyTypeOf<[I16, TF_Int16Ref], "16-bit integer">; +def TF_Int32 : AnyTypeOf<[I32, TF_Int32Ref], "32-bit integer">; +def TF_Int64 : AnyTypeOf<[I64, TF_Int64Ref], "64-bit integer">; +def TF_I32OrI64 : AnyTypeOf<[I32, I64, TF_Int32Ref, TF_Int64Ref], + "32/64-bit signed integer">; -def TF_Uint8 : UI<8>; -def TF_Uint8Tensor : TensorOf<[TF_Uint8]>; - -def TF_Uint16 : UI<16>; -def TF_Uint16Tensor : TensorOf<[TF_Uint16]>; - -def TF_Uint32 : UI<32>; -def TF_Uint32Tensor : TensorOf<[TF_Uint32]>; - -def TF_Uint64 : UI<64>; -def TF_Uint64Tensor : TensorOf<[TF_Uint64]>; +def TF_Uint8 : AnyTypeOf<[UI<8>, TF_Uint8Ref], "8-bit unsigned integer">; +def TF_Uint16 : AnyTypeOf<[UI<16>, TF_Uint16Ref], "16-bit unsigned integer">; +def TF_Uint32 : AnyTypeOf<[UI<32>, TF_Uint32Ref], "32-bit unsigned integer">; +def TF_Uint64 : AnyTypeOf<[UI<64>, TF_Uint64Ref], "64-bit unsigned integer">; // Any unsigned integer type -def TF_UInt : UnsignedIntOfWidths<[8, 16, 32, 64]>; +def TF_UInt : AnyTypeOf<[TF_Uint8, TF_Uint16, TF_Uint32, TF_Uint64], + "unsigned integer">; // Any signed integer type -def TF_SInt : SignlessIntOfWidths<[8, 16, 32, 64]>; +def TF_SInt : AnyTypeOf<[TF_Int8, TF_Int16, TF_Int32, TF_Int64], + "signed integer">; // Any integer type -def TF_Int : AnyTypeOf<[TF_SInt, TF_UInt]>; +def TF_Int : AnyTypeOf<[TF_SInt, TF_UInt], "integer">; + +// Tensor types +def TF_BoolTensor : TensorOf<[TF_Bool]>; -// Any integer tensor types def TF_IntTensor : TensorOf<[TF_Int]>; +def TF_Int8Tensor : TensorOf<[TF_Int8]>; +def TF_Int16Tensor : TensorOf<[TF_Int16]>; +def TF_Int32Tensor : TensorOf<[TF_Int32]>; +def TF_Int64Tensor : TensorOf<[TF_Int64]>; +def TF_I32OrI64Tensor : TensorOf<[TF_I32OrI64]>; + +def TF_Uint8Tensor : TensorOf<[TF_Uint8]>; +def TF_Uint16Tensor : TensorOf<[TF_Uint16]>; +def TF_Uint32Tensor : TensorOf<[TF_Uint32]>; +def TF_Uint64Tensor : TensorOf<[TF_Uint64]>; //===----------------------------------------------------------------------===// -// Quantized types -def TF_Qint8 : TF_TensorFlowType<"Qint8", "qint8">; -def TF_Qint16 : TF_TensorFlowType<"Qint16", "qint16">; -def TF_Qint32 : TF_TensorFlowType<"Qint32", "qint32">; -def TF_Quint8 : TF_TensorFlowType<"Quint8", "quint8">; -def TF_Quint16 : TF_TensorFlowType<"Quint16", "quint16">; +// Quantized types (including corresponding reference types) + +def TF_Qint8 : AnyTypeOf< + [TF_TensorFlowType<"Qint8", "qint8">, TF_Qint8Ref], + "8-bit quantized integer">; +def TF_Qint16 : AnyTypeOf< + [TF_TensorFlowType<"Qint16", "qint16">, TF_Qint16Ref], + "16-bit quantized integer">; +def TF_Qint32 : AnyTypeOf< + [TF_TensorFlowType<"Qint32", "qint32">, TF_Qint32Ref], + "32-bit quantized integer">; +def TF_Quint8 : AnyTypeOf< + [TF_TensorFlowType<"Quint8", "quint8">, TF_Quint8Ref], + "8-bit quantized unsigned integer">; +def TF_Quint16 : AnyTypeOf< + [TF_TensorFlowType<"Quint16", "quint16">, TF_Quint16Ref], + "16-bit quantized unsigned integer">; // Any quantized type -def TF_AnyQuantized : AnyTypeOf<[TF_Qint8, TF_Qint16, TF_Qint32, TF_Quint8, - TF_Quint16]>; -//===----------------------------------------------------------------------===// -// Floating-point types - -def TF_F32Or64 : FloatOfWidths<[32, 64]>; - -def TF_F32OrF64Tensor : TensorOf<[TF_F32Or64]>; - -// Any floating-point tensor types -def TF_FpTensor : TensorOf<[AnyFloat]>; +def TF_Quantized : AnyTypeOf< + [TF_Qint8, TF_Qint16, TF_Qint32, TF_Quint8, TF_Quint16], "quantized">; //===----------------------------------------------------------------------===// -// Complex types +// Floating-point types (including corresponding reference types) + +def TF_Float16 : AnyTypeOf<[F16, TF_Float16Ref], "16-bit float">; +def TF_Float32 : AnyTypeOf<[F32, TF_Float32Ref], "32-bit float">; +def TF_Float64 : AnyTypeOf<[F64, TF_Float64Ref], "64-bit float">; +def TF_Bfloat16 : AnyTypeOf<[BF16, TF_Bfloat16Ref], "bfloat16">; + +def TF_F32OrF64 : AnyTypeOf<[TF_Float32, TF_Float64], "32/64-bit float">; + +def TF_Float : AnyTypeOf< + [TF_Float16, TF_Float32, TF_Float64, TF_Bfloat16, + TF_Float16Ref, TF_Float32Ref, TF_Float64Ref, TF_Bfloat16Ref], + "floating-point">; + +// Tensor types +def TF_FloatTensor : TensorOf<[TF_Float]>; +def TF_F32OrF64Tensor : TensorOf<[TF_F32OrF64]>; +def TF_Float16Tensor : TensorOf<[TF_Float16]>; +def TF_Float32Tensor : TensorOf<[TF_Float32]>; +def TF_Float64Tensor : TensorOf<[TF_Float64]>; +def TF_Bfloat16Tensor : TensorOf<[TF_Bfloat16]>; + +//===----------------------------------------------------------------------===// +// Complex types (including corresponding reference types) // TODO(suderman): Remove TF_Complex64 and use a standard ops declaration, along // with the associated cleanup. -def TF_Complex64 : Complex>; -def TF_Complex64Tensor : TensorOf<[TF_Complex64]>; +def TF_Complex64 : AnyTypeOf<[Complex>, TF_Complex64Ref], + "64-bit complex">; +def TF_Complex128 : AnyTypeOf<[Complex>, TF_Complex128Ref], + "128-bit complex">; +def TF_Complex : AnyTypeOf<[TF_Complex64, TF_Complex128], "complex">; -def TF_Complex128 : Complex>; +// Tensor types +def TF_ComplexTensor : TensorOf<[TF_Complex]>; +def TF_Complex64Tensor : TensorOf<[TF_Complex64]>; def TF_Complex128Tensor : TensorOf<[TF_Complex128]>; -def TF_AnyComplex : AnyTypeOf<[TF_Complex64, TF_Complex128], - "64/128-bit complex type">; - -def TF_ComplexTensor : TensorOf<[TF_AnyComplex]>; - //===----------------------------------------------------------------------===// -// String/variant/resource types +// String/variant/resource types (including corresponding reference types) -def TF_Str : TF_TensorFlowType<"String", "string">; +def TF_Str : AnyTypeOf< + [TF_TensorFlowType<"String", "str">, TF_StrRef], "string">; def TF_StrTensor : TensorOf<[TF_Str]>; -def TF_Variant : TF_TensorFlowType<"Variant", "variant">; +def TF_Variant : AnyTypeOf< + [TF_TensorFlowType<"Variant", "var">, TF_VariantRef], "variant">; def TF_VariantTensor : TensorOf<[TF_Variant]>; -def TF_Resource : TF_TensorFlowType<"Resource", "resource">; +def TF_Resource : AnyTypeOf< + [TF_TensorFlowType<"Resource", "res">, TF_ResourceRef], "resource">; def TF_ResourceTensor : TensorOf<[TF_Resource]>; //===----------------------------------------------------------------------===// // Multi-category type constraints -def TF_IntOrF32OrF64Tensor: TensorOf<[TF_Int, TF_F32Or64]>; +def TF_IntOrF32OrF64Tensor: TensorOf<[TF_Int, TF_F32OrF64]>; +def TF_FpOrI32OrI64Tensor : TensorOf<[TF_Float, TF_I32OrI64]>; +def TF_IntOrFpTensor : TensorOf<[TF_Int, TF_Float]>; +def TF_SintOrFpTensor : TensorOf<[TF_SInt, TF_Float]>; +def TF_FpOrComplexTensor : TensorOf<[TF_Float, TF_Complex]>; -def TF_FpOrI32OrI64Tensor : TensorOf<[AnyFloat, TF_I32Or64]>; +def TF_Number : AnyTypeOf< + [TF_Int, TF_Float, TF_Quantized, TF_Complex], "number">; +def TF_NumberTensor : TensorOf<[TF_Number]>; -// Any integer or floating-point tensor types -def TF_IntOrFpTensor : TensorOf<[TF_Int, AnyFloat]>; +def TF_NumberNotQuantizedOrStr : + AnyTypeOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint8, TF_Str]>; +def TF_NumberNotQuantizedOrStrTensor : TensorOf<[TF_NumberNotQuantizedOrStr]>; -def TF_SintOrFpTensor : TensorOf<[TF_SInt, AnyFloat]>; +//===----------------------------------------------------------------------===// +// Tensor and tensor element types -def TF_FpOrComplexTensor : TensorOf<[AnyFloat, TF_AnyComplex]>; +// Any tensor element type allowed in TensorFlow ops +// (see https://www.tensorflow.org/api_docs/python/tf/dtypes/DType) +def TF_ElementType : Type, + "tf.dtype">; -def TF_AnyNumber : AnyTypeOf<[TF_Int, AnyFloat, TF_AnyQuantized, TF_AnyComplex], - "number">; - -def TF_NumberTensor : TensorOf<[TF_AnyNumber]>; - -def TF_NumberOrStr : AnyTypeOf<[AnyFloat, TF_SInt, TF_AnyComplex, TF_Uint8, TF_Str]>; -def TF_NumberOrStrTensor : TensorOf<[TF_NumberOrStr]>; +// Any TensorFlow tensor type +def TF_Tensor : TensorOf<[TF_ElementType]>; //===----------------------------------------------------------------------===// // TensorFlow attribute definitions @@ -423,7 +532,7 @@ class TF_DerivedResultShapeListAttr : DerivedAttr< // A derived attribute that returns the shape of the first result type. def TF_DerivedResultShapeAttr : DerivedAttr<"ShapedType", "return (*getOperation()->result_type_begin()).cast();", - [{ TypeAttr::get($_self) }]>; + [{ mlir::TF::ShapeAttr::get($_ctx, $_self) }]>; // A derived attribute that returns the element type of the tensor held by a // named resource-type operand or result. @@ -443,7 +552,7 @@ class TF_DerivedOperandOrResultHandleShapeAttr : DerivedAttr< " .cast();\n" "assert(!resource_type.getSubtypes().empty() && \"unknown shape\");\n" "return resource_type.getSubtypes().begin()->cast();", - [{ TypeAttr::get($_self) }]>; + [{ mlir::TF::ShapeAttr::get($_ctx, $_self) }]>; def TF_IntTypeAttr : TypeAttrBase<"IntegerType", "integer type"> { let returnType = "Type"; diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h index ec1f748367d..3a6a9336a24 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h @@ -15,14 +15,121 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OP_INTERFACES_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OP_INTERFACES_H_ + +#include + +#include "llvm/ADT/DenseMapInfo.h" +#include "llvm/ADT/Hashing.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/OpImplementation.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_verifiers.h" +#include "tensorflow/core/framework/resource_mgr.h" namespace mlir { namespace TF { + +//===----------------------------------------------------------------------===// +// TensorFlow Contraction Fusion. +//===----------------------------------------------------------------------===// + +struct ContractionFusion { + explicit ContractionFusion( + StringRef output_kernel, ArrayRef additional_arguments = {}, + ArrayRef additional_attributes = {}) + : output_kernel(output_kernel.str()), + additional_arguments(additional_arguments.begin(), + additional_arguments.end()), + additional_attributes(additional_attributes.begin(), + additional_attributes.end()) {} + + // Name of the output kernel implementing the contraction fusion. + std::string output_kernel; + + // Indices of additional arguments that will be forwarded to the fused + // operation (e.g. forward bias vector if fusing BiasAdd operation). + SmallVector additional_arguments; + + // Add additional attributes to the fused node. + SmallVector additional_attributes; +}; + +//===----------------------------------------------------------------------===// +// TensorFlow Resource Handles. +//===----------------------------------------------------------------------===// + +inline bool IsResourceHandleAnonymous(StringRef name) { + return name == ::tensorflow::ResourceHandle::ANONYMOUS_NAME; +} + +// Helper struct representing an identifier for a resource handle. For resource +// handles created explicitly and shared across resource allocator ops, +// `container`, `name`, and `device` can be set. If an resource handle is tied +// to an instance of an operation (e.g. TensorFlow runtime operation caching), +// `op` can be set instead. +struct ResourceHandle { + ResourceHandle(StringRef container, StringRef name, StringRef device, + Operation* op) + : container(container), name(name), device(device), op(op) {} + + bool operator==(const ResourceHandle& rhs) const { + return container == rhs.container && name == rhs.name && + device == rhs.device && op == rhs.op; + } + + // Make ResourceHandle hashable. + friend ::llvm::hash_code hash_value(const ResourceHandle& resource_handle); + + std::string container; + std::string name; + std::string device; + Operation* op = nullptr; +}; + +// Make ResourceHandle hashable. +inline ::llvm::hash_code hash_value(const ResourceHandle& resource_handle) { + return ::llvm::hash_combine(resource_handle.container, resource_handle.name, + resource_handle.device, resource_handle.op); +} + +// Helper struct holding a resource handle value and unique id associated to the +// resource handle. +struct ResourceHandleValueAndId { + ResourceHandleValueAndId(Value value, int64_t id) : value(value), id(id) {} + + Value value; + int64_t id = -1; +}; + #include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h.inc" } // namespace TF } // namespace mlir +namespace llvm { +template <> +struct DenseMapInfo { + static mlir::TF::ResourceHandle getEmptyKey() { + return {/*container=*/"", /*name=*/"", /*device=*/"", + /*op=*/DenseMapInfo::getEmptyKey()}; + } + + static mlir::TF::ResourceHandle getTombstoneKey() { + return {/*container=*/"", /*name=*/"", /*device=*/"", + /*op=*/DenseMapInfo::getTombstoneKey()}; + } + + static unsigned getHashValue( + const mlir::TF::ResourceHandle& resource_handle) { + return mlir::TF::hash_value(resource_handle); + } + + static bool isEqual(const mlir::TF::ResourceHandle& lhs, + const mlir::TF::ResourceHandle& rhs) { + return lhs == rhs; + } +}; +} // namespace llvm + #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OP_INTERFACES_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.td index 3743bdda043..1ed30c89a77 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.td @@ -21,7 +21,7 @@ limitations under the License. include "mlir/IR/OpBase.td" //===----------------------------------------------------------------------===// -// TensorFlow interfaces +// TensorFlow Layout Optimization Interfaces. //===----------------------------------------------------------------------===// def TF_LayoutSensitiveInterface : OpInterface<"LayoutSensitiveInterface"> { @@ -104,4 +104,48 @@ def TF_FoldOperandsTransposeInterface : OpInterface<"FoldOperandsTransposeInterf }]; } +//===----------------------------------------------------------------------===// +// TensorFlow Contraction Fusion Interfaces. +//===----------------------------------------------------------------------===// + +def TF_ContractionFusableInterface : OpInterface<"ContractionFusableInterface"> { + let description = [{ + A contraction fusable operation is one that can be fused into the output of + a tensor contraction (MatMul, Conv2D, etc...) operation. + + For example all element wise operations are trivially contraction fusable. + }]; + + let methods = [ + InterfaceMethod< + [{Returns contraction fusion if the operation satisfies all the fusion + requirements. Otherwise returns empty optional.}], + "Optional", "GetContractionFusion", (ins) + >, + ]; +} + +//===----------------------------------------------------------------------===// +// TensorFlow Resource Handle Interfaces. +//===----------------------------------------------------------------------===// + +def TF_ResourceHandleAllocatorInterface : OpInterface<"ResourceHandleAllocatorInterface"> { + let description = [{ + A resource handle allocator operation is one that creates a resource handle, + or looks up and reuses an existing resource handle. + }]; + + let methods = [ + InterfaceMethod< + /*desc=*/[{Returns the resource handle value and unique id associated with + the resource handle. If a resource handle is reused, then an + existing id will be returned.}], + /*retTy=*/"ResourceHandleValueAndId", + /*methodName=*/"GetResourceHandleValueAndId", + /*args=*/(ins "llvm::SmallDenseMap&":$resource_handle_id_map, + "int64_t&":$next_id) + >, + ]; +} + #endif // TF_OP_INTERFACES diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc index 737442d5f8c..634004038d0 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc @@ -233,6 +233,10 @@ TensorFlowDialect::TensorFlowDialect(MLIRContext *context) #define GET_OP_LIST #include "tensorflow/compiler/mlir/tensorflow/ir/tf_all_ops.cc.inc" >(); + addOperations< +#define GET_OP_LIST +#include "tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.cc.inc" + >(); addTypes< #define HANDLE_TF_TYPE(tftype, enumerant, name) tftype##Type, #define HANDLE_LAST_TF_TYPE(tftype, enumerant, name) tftype##Type diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h index 3169f7fba8d..9ebd59007e3 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h @@ -43,6 +43,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_verifiers.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.h" namespace mlir { namespace TF { @@ -112,8 +113,7 @@ class TensorFlowDialect : public Dialect { // same interface. template void addOperations() { - (void)std::initializer_list{ - 0, (addOperation(AbstractOperation::get(*this)), 0)...}; + Dialect::addOperations(); } using ConstantFoldHook = LogicalResult (*)(Operation *, ArrayRef, diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td index db0a97d4b96..c814153eb43 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td @@ -116,6 +116,24 @@ An n-way switch statement, implementing the following: let verifier = [{ return Verify(*this); }]; + + + let extraClassDeclaration = [{ + int num_branches() { return branches().size(); } + + // Gets function corresponding branch # `index`. + FuncOp branch_function(int index) { + auto flat_sym_ref = branches()[index].cast(); + return SymbolTable::lookupNearestSymbolFrom(*this, flat_sym_ref); + } + + // Gets all branch functions. + void get_branch_functions(SmallVectorImpl &functions) { + functions.reserve(num_branches()); + for (int idx : llvm::seq(0, num_branches())) + functions.push_back(branch_function(idx)); + } + }]; } def TF_CaseRegionOp : TF_Op<"CaseRegion", @@ -160,6 +178,9 @@ An n-way switch statement, implementing the following: let verifier = [{ return Verify(*this); }]; + + let hasCanonicalizer = 1; + } // In MLIR, the TensorFlow tensor value is represented as an ElementsAttr, with @@ -206,12 +227,12 @@ source_target_pairs=`[[0,1],[1,2],[2,3],[3,0]]` gets the outputs: }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input, + TensorOf<[TF_Bfloat16, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input, I32Tensor:$source_target_pairs ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output + TensorOf<[TF_Bfloat16, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -231,7 +252,7 @@ element_shape: a shape compatible with that of elements in the list. let arguments = (ins TF_I32OrI64Tensor:$element_shape, - I32Tensor:$max_num_elements + TF_Int32Tensor:$max_num_elements ); } @@ -305,12 +326,12 @@ else_branch: A function that takes 'inputs' and returns a list of let extraClassDeclaration = [{ // Get the then branch function. - FuncOp then_func() { + FuncOp then_function() { return SymbolTable::lookupNearestSymbolFrom(*this, then_branch()); } // Get the else branch function. - FuncOp else_func() { + FuncOp else_function() { return SymbolTable::lookupNearestSymbolFrom(*this, else_branch()); } }]; @@ -369,6 +390,12 @@ else_branch: A region that computes the outputs of the op if cond = false. return Verify(*this); }]; + let builders = [OpBuilder< + "OpBuilder &builder, OperationState &result, TypeRange resultTypes, ValueRange operands, llvm::ArrayRef<::mlir::NamedAttribute> attributes, unsigned numRegions", [{ + assert(numRegions == 2u && "mismatched number of regions"); + build(builder, result, resultTypes, operands, attributes); + }]>]; + let hasCanonicalizer = 1; } @@ -424,7 +451,7 @@ def TF_ParseExampleOp : TF_Op<"ParseExample", TF_StrTensor:$names, Variadic:$sparse_keys, Variadic:$dense_keys, - Variadic>:$dense_defaults, + Variadic>:$dense_defaults, TF_ShapeAttrArray:$dense_shapes, I32ElementsAttr:$result_segment_sizes, @@ -432,10 +459,10 @@ def TF_ParseExampleOp : TF_Op<"ParseExample", ); let results = (outs - Variadic:$sparse_indices, // len(sparse_types) - Variadic>:$sparse_values, // len(sparse_types) - Variadic:$sparse_shapes, // len(sparse_types) - Variadic>:$dense_values // len(Tdense) + Variadic:$sparse_indices, // len(sparse_types) + Variadic>:$sparse_values, // len(sparse_types) + Variadic:$sparse_shapes, // len(sparse_types) + Variadic>:$dense_values // len(Tdense) ); TF_DerivedOperandSizeAttr Nsparse = TF_DerivedOperandSizeAttr<2>; @@ -459,7 +486,7 @@ def TF_ParseExampleV2Op : TF_Op<"ParseExampleV2", TF_StrTensor:$sparse_keys, TF_StrTensor:$dense_keys, TF_StrTensor:$ragged_keys, - Variadic>:$dense_defaults, + Variadic>:$dense_defaults, Confined]>:$num_sparse, TF_ShapeAttrArray:$dense_shapes, @@ -467,13 +494,13 @@ def TF_ParseExampleV2Op : TF_Op<"ParseExampleV2", ); let results = (outs - Variadic:$sparse_indices, // len(sparse_types) - Variadic>:$sparse_values, // len(sparse_types) - Variadic:$sparse_shapes, // len(sparse_types) - Variadic>:$dense_values, // len(Tdense) - Variadic>:$ragged_values, // len(ragged_value_types) + Variadic:$sparse_indices, // len(sparse_types) + Variadic>:$sparse_values, // len(sparse_types) + Variadic:$sparse_shapes, // len(sparse_types) + Variadic>:$dense_values, // len(Tdense) + Variadic>:$ragged_values, // len(ragged_value_types) // = len(ragged_split_types) - Variadic>:$ragged_row_splits // len(ragged_split_types) + Variadic>:$ragged_row_splits // len(ragged_split_types) // = len(ragged_value_types) ); @@ -570,36 +597,6 @@ def TF_PlaceholderWithDefaultOp : TF_Op<"PlaceholderWithDefault", [NoSideEffect] DerivedAttr shape = TF_DerivedResultShapeAttr; } -def TF_SparseMatMulOp : TF_Op<"SparseMatMul", [NoSideEffect]> { - let summary = [{ -SparseMatMul is MatMul with hints on the sparseness of the matrices. - }]; - - let description = [{ -Similar to MatMul, with a_is_sparse and b_is_sparse indicating whether a and b -are sparse matrices. - }]; - - let arguments = (ins - TensorOf<[BF16, F32]>:$a, - TensorOf<[BF16, F32]>:$b, - - DefaultValuedAttr:$a_is_sparse, - DefaultValuedAttr:$b_is_sparse, - - DefaultValuedAttr:$transpose_a, - DefaultValuedAttr:$transpose_b - ); - - let results = (outs - TensorOf<[F32]>:$product - ); - - TF_DerivedOperandTypeAttr Ta = TF_DerivedOperandTypeAttr<0>; - TF_DerivedOperandTypeAttr Tb = TF_DerivedOperandTypeAttr<1>; -} - - def TF_StatefulPartitionedCallOp : TF_Op<"StatefulPartitionedCall", [CallOpInterface]> { let summary = @@ -691,18 +688,18 @@ body: A function that takes a list of tensors and returns another let extraClassDeclaration = [{ // Get the condition function. - FuncOp cond_func() { + FuncOp cond_function() { return SymbolTable::lookupNearestSymbolFrom(*this, cond()); } // Get the body function. - FuncOp body_func() { + FuncOp body_function() { return SymbolTable::lookupNearestSymbolFrom(*this, body()); } }]; } -def TL_WhileRegionOp : TF_Op<"WhileRegion", +def TF_WhileRegionOp : TF_Op<"WhileRegion", [DeclareOpInterfaceMethods, SingleBlockImplicitTerminator<"YieldOp">]> { let summary = "while operation"; @@ -765,7 +762,7 @@ element_dtype: the desired type of elements in the list. let arguments = (ins TF_I32OrI64Tensor:$element_shape, - I32Tensor:$num_elements + TF_Int32Tensor:$num_elements ); } @@ -799,7 +796,7 @@ This operation holds the metadata common to operations of a `tpu.replicate()` co let results = (outs); } -def TF_VarHandleOp : TF_Op<"VarHandleOp", []> { +def TF_VarHandleOp : TF_Op<"VarHandleOp", [TF_ResourceHandleAllocatorInterface]> { let summary = "Creates a handle to a Variable resource from its name."; let description = [{ @@ -821,13 +818,20 @@ Example: ); let results = (outs - TF_ResourceTensor:$resource + Res:$resource ); TF_DerivedOperandOrResultHandleTypeAttr dtype = TF_DerivedOperandOrResultHandleTypeAttr<"resource">; TF_DerivedOperandOrResultHandleShapeAttr shape = TF_DerivedOperandOrResultHandleShapeAttr<"resource">; + + let extraClassDeclaration = [{ + // TF_ResourceHandleAllocatorInterface: + ResourceHandleValueAndId GetResourceHandleValueAndId( + llvm::SmallDenseMap &resource_handle_id_map, + int64_t &next_id); + }]; } // Multiple variadic operands with different sizes are not supported by the @@ -986,8 +990,8 @@ Creates a dataset that batches `batch_size` elements from `input_dataset`. let arguments = (ins TF_VariantTensor:$input_dataset, - I64Tensor:$batch_size, - I1Tensor:$drop_remainder, + TF_Int64Tensor:$batch_size, + TF_BoolTensor:$drop_remainder, DefaultValuedAttr:$parallel_copy, Confined]>:$output_types, @@ -1036,9 +1040,9 @@ to `batch_size * num_parallel_batches` copies of `f` in parallel. let arguments = (ins TF_VariantTensor:$input_dataset, Variadic:$other_arguments, - I64Tensor:$batch_size, - I64Tensor:$num_parallel_calls, - I1Tensor:$drop_remainder, + TF_Int64Tensor:$batch_size, + TF_Int64Tensor:$num_parallel_calls, + TF_BoolTensor:$drop_remainder, SymbolRefAttr:$f, Confined]>:$output_types, @@ -1066,7 +1070,7 @@ def TF_ParallelMapDatasetOp : TF_Op<"ParallelMapDataset", [NoSideEffect]> { let arguments = (ins TF_VariantTensor:$input_dataset, Variadic:$other_arguments, - I32Tensor:$num_parallel_calls, + TF_Int32Tensor:$num_parallel_calls, SymbolRefAttr:$f, Confined]>:$output_types, @@ -1148,11 +1152,11 @@ This function is faster and numerically stabler than `bessel_i0(x)`. }]; let arguments = (ins - TF_FpTensor:$x + TF_FloatTensor:$x ); let results = (outs - TF_FpTensor:$y + TF_FloatTensor:$y ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -1169,11 +1173,11 @@ This function is faster and numerically stabler than `bessel_i1(x)`. }]; let arguments = (ins - TF_FpTensor:$x + TF_FloatTensor:$x ); let results = (outs - TF_FpTensor:$y + TF_FloatTensor:$y ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -1184,7 +1188,7 @@ def TF_TPUPartitionedCallOp : TF_Op<"TPUPartitionedCall", [CallOpInterface]> { let arguments = (ins Variadic:$args, - I32Tensor:$device_ordinal, + TF_Int32Tensor:$device_ordinal, SymbolRefAttr:$f, DefaultValuedAttr:$autotuner_thresh @@ -1213,63 +1217,6 @@ def TF_TPUPartitionedCallOp : TF_Op<"TPUPartitionedCall", [CallOpInterface]> { let verifier = [{ return VerifyPartitionedCall(*this); }]; } -class TF_FusedBatchNormOpBase : TF_Op { - let summary = "Batch normalization."; - - let description = [{ -Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW". -The size of 1D Tensors matches the dimension C of the 4D Tensors. - }]; - - let arguments = (ins - TensorOf<[BF16, F16, F32]>:$x, - F32Tensor:$scale, - F32Tensor:$offset, - F32Tensor:$mean, - F32Tensor:$variance, - - DefaultValuedAttr:$epsilon, - DefaultValuedAttr:$exponential_avg_factor, - DefaultValuedAttr:$data_format, - DefaultValuedAttr:$is_training - ); - - TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; - TF_DerivedOperandTypeAttr U = TF_DerivedOperandTypeAttr<1>; - - let extraClassDeclaration = [{ - // TF_FoldOperandsTransposeInterface: - SmallVector GetLayoutDependentArgs() { return {0}; } - SmallVector GetLayoutDependentResults() { return {0}; } - LogicalResult FoldOperandsPermutation(ArrayRef permutation); - - // TF_LayoutSensitiveInterface: - StringRef GetOptimalLayout(const RuntimeDevices& devices); - LogicalResult UpdateDataFormat(StringRef data_format); - }]; -} - -def TF_FusedBatchNormV2Op : TF_FusedBatchNormOpBase<"FusedBatchNormV2"> { - let results = (outs - TensorOf<[BF16, F16, F32]>:$y, - F32Tensor:$batch_mean, - F32Tensor:$batch_variance, - F32Tensor:$reserve_space_1, - F32Tensor:$reserve_space_2 - ); -} - -def TF_FusedBatchNormV3Op : TF_FusedBatchNormOpBase<"FusedBatchNormV3"> { - let results = (outs - TensorOf<[BF16, F16, F32]>:$y, - F32Tensor:$batch_mean, - F32Tensor:$batch_variance, - F32Tensor:$reserve_space_1, - F32Tensor:$reserve_space_2, - F32Tensor:$reserve_space_3 - ); -} - def TF_BatchFunctionOp : TF_Op<"BatchFunction", [AttrSizedOperandSegments]> { let summary = [{ Batches all the inputs tensors to the computation done by the function. @@ -1341,4 +1288,649 @@ must be a Tensor or a list/tuple of Tensors. TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>; } +def TF_AddV2Op : TF_Op<"AddV2", [Commutative, NoSideEffect, ResultsBroadcastableShape, TF_CwiseBinary, TF_LayoutAgnostic, TF_SameOperandsAndResultElementTypeResolveRef]>, + WithBroadcastableBinOpBuilder { + let summary = "Returns x + y element-wise."; + + let description = [{ +*NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting +[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + }]; + + let arguments = (ins + TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint8, TF_Uint32]>:$x, + TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint8, TF_Uint32]>:$y + ); + + let results = (outs + TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint8, TF_Uint32]>:$z + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + + let hasCanonicalizer = 1; + + let hasFolder = 1; +} + +def TF_DivNoNanOp : TF_Op<"DivNoNan", [NoSideEffect, ResultsBroadcastableShape, TF_SameOperandsAndResultElementTypeResolveRef]>, + WithBroadcastableBinOpBuilder { + let summary = "Returns 0 if the denominator is zero."; + + let description = [{ +*NOTE*: `DivNoNan` supports broadcasting. More about broadcasting +[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + }]; + + let arguments = (ins + TensorOf<[TF_Float16, TF_Float32, TF_Float64, TF_Complex]>:$x, + TensorOf<[TF_Float16, TF_Float32, TF_Float64, TF_Complex]>:$y + ); + + let results = (outs + TensorOf<[TF_Float16, TF_Float32, TF_Float64, TF_Complex]>:$z + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_MaximumOp : TF_Op<"Maximum", [NoSideEffect, ResultsBroadcastableShape, TF_SameOperandsAndResultElementTypeResolveRef]>, + WithBroadcastableBinOpBuilder { + let summary = "Returns the max of x and y (i.e. x > y ? x : y) element-wise."; + + let description = [{ +*NOTE*: `Maximum` supports broadcasting. More about broadcasting +[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + }]; + + let arguments = (ins + TensorOf<[TF_Float, TF_Int16, TF_Int32, TF_Int64, TF_Uint8]>:$x, + TensorOf<[TF_Float, TF_Int16, TF_Int32, TF_Int64, TF_Uint8]>:$y + ); + + let results = (outs + TensorOf<[TF_Float, TF_Int16, TF_Int32, TF_Int64, TF_Uint8]>:$z + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_RealDivOp : TF_Op<"RealDiv", [NoSideEffect, ResultsBroadcastableShape, TF_CwiseBinary]>, + WithBroadcastableBinOpBuilder { + let summary = "Returns x / y element-wise for real types."; + + let description = [{ +If `x` and `y` are reals, this will return the floating-point division. + +*NOTE*: `Div` supports broadcasting. More about broadcasting +[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + }]; + + let arguments = (ins + TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint16, TF_Uint8]>:$x, + TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint16, TF_Uint8]>:$y + ); + + let results = (outs + TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint16, TF_Uint8]>:$z + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + + let hasCanonicalizer = 1; + + let hasFolder = 1; +} + +def TF_AddOp : TF_Op<"Add", [NoSideEffect, ResultsBroadcastableShape, TF_LayoutAgnostic, TF_SameOperandsAndResultElementTypeResolveRef]>, + WithBroadcastableBinOpBuilder { + let summary = "Returns x + y element-wise."; + + let description = [{ +*NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting +[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + +Given two input tensors, the `tf.add` operation computes the sum for every element in the tensor. + +Both input and output have a range `(-inf, inf)`. + }]; + + let arguments = (ins + TensorOf<[TF_NumberNotQuantizedOrStr]>:$x, + TensorOf<[TF_NumberNotQuantizedOrStr]>:$y + ); + + let results = (outs + TensorOf<[TF_NumberNotQuantizedOrStr]>:$z + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + + let hasCanonicalizer = 1; +} + +def TF_StatefulStandardNormalV2Op : TF_Op<"StatefulStandardNormalV2", []> { + let summary = "Outputs random values from a normal distribution."; + + let description = [{ +The generated values will have mean 0 and standard deviation 1. + }]; + + let arguments = (ins + Arg:$resource, + TF_Int64Tensor:$algorithm, + TF_I32OrI64Tensor:$shape + ); + + let results = (outs + TF_FloatTensor:$output + ); + + TF_DerivedOperandTypeAttr shape_dtype = TF_DerivedOperandTypeAttr<2>; + TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>; +} + +def TF_StatefulTruncatedNormalOp : TF_Op<"StatefulTruncatedNormal", []> { + let summary = "Outputs random values from a truncated normal distribution."; + + let description = [{ +The generated values follow a normal distribution with mean 0 and standard +deviation 1, except that values whose magnitude is more than 2 standard +deviations from the mean are dropped and re-picked. + }]; + + let arguments = (ins + Arg:$resource, + TF_Int64Tensor:$algorithm, + TF_I32OrI64Tensor:$shape + ); + + let results = (outs + TF_FloatTensor:$output + ); + + TF_DerivedOperandTypeAttr shape_dtype = TF_DerivedOperandTypeAttr<2>; + TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>; +} + +def TF_StatefulUniformOp : TF_Op<"StatefulUniform", []> { + let summary = "Outputs random values from a uniform distribution."; + + let description = [{ +The generated values follow a uniform distribution in the range `[0, 1)`. The +lower bound 0 is included in the range, while the upper bound 1 is excluded. + }]; + + let arguments = (ins + Arg:$resource, + TF_Int64Tensor:$algorithm, + TF_I32OrI64Tensor:$shape + ); + + let results = (outs + TF_FloatTensor:$output + ); + + TF_DerivedOperandTypeAttr shape_dtype = TF_DerivedOperandTypeAttr<2>; + TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>; +} + +def TF_StatefulUniformFullIntOp : TF_Op<"StatefulUniformFullInt", []> { + let summary = "Outputs random integers from a uniform distribution."; + + let description = [{ +The generated values are uniform integers covering the whole range of `dtype`. + }]; + + let arguments = (ins + Arg:$resource, + TF_Int64Tensor:$algorithm, + TF_I32OrI64Tensor:$shape + ); + + let results = (outs + TensorOf<[TF_Int32, TF_Int64, TF_Uint32, TF_Uint64]>:$output + ); + + TF_DerivedOperandTypeAttr shape_dtype = TF_DerivedOperandTypeAttr<2>; + TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>; +} + +// TODO(lyandy): Investigate supported dtypes (`minval`, `maxval`, `output`) for +// `tf.StatefulUniformInt`. tf2xla kernels support i32, i64, ui32, and ui64 +// while TensorFlow CPU/GPU kernels only support i32 and i64. +def TF_StatefulUniformIntOp : TF_Op<"StatefulUniformInt", []> { + let summary = "Outputs random integers from a uniform distribution."; + + let description = [{ +The generated values are uniform integers in the range `[minval, maxval)`. +The lower bound `minval` is included in the range, while the upper bound +`maxval` is excluded. + +The random integers are slightly biased unless `maxval - minval` is an exact +power of two. The bias is small for values of `maxval - minval` significantly +smaller than the range of the output (either `2^32` or `2^64`). + }]; + + let arguments = (ins + Arg:$resource, + TF_Int64Tensor:$algorithm, + TF_I32OrI64Tensor:$shape, + TensorOf<[TF_Int32, TF_Int64, TF_Uint32, TF_Uint64]>:$minval, + TensorOf<[TF_Int32, TF_Int64, TF_Uint32, TF_Uint64]>:$maxval + ); + + let results = (outs + TensorOf<[TF_Int32, TF_Int64, TF_Uint32, TF_Uint64]>:$output + ); + + TF_DerivedOperandTypeAttr shape_dtype = TF_DerivedOperandTypeAttr<2>; + TF_DerivedOperandTypeAttr dtype = TF_DerivedOperandTypeAttr<3>; +} + +def TF_CloseSummaryWriterOp : TF_Op<"CloseSummaryWriter", []> { + let summary = "Flushes and closes the summary writer."; + + let description = [{ +Also removes it from the resource manager. To reopen, use another +CreateSummaryFileWriter op. + +writer: A handle to the summary writer resource. + }]; + + let arguments = (ins + Arg:$writer + ); + + let results = (outs); +} + +// TODO(b/168035831): Model db_uri read/write. +def TF_CreateSummaryDbWriterOp : TF_Op<"CreateSummaryDbWriter", []> { + let summary = "Creates summary database writer accessible by given resource handle."; + + let description = [{ +This can be used to write tensors from the execution graph directly +to a database. Only SQLite is supported right now. This function +will create the schema if it doesn't exist. Entries in the Users, +Experiments, and Runs tables will be created automatically if they +don't already exist. + +writer: Handle to SummaryWriter resource to overwrite. +db_uri: For example "file:/tmp/foo.sqlite". +experiment_name: Can't contain ASCII control characters or <>. Case + sensitive. If empty, then the Run will not be associated with any + Experiment. +run_name: Can't contain ASCII control characters or <>. Case sensitive. + If empty, then each Tag will not be associated with any Run. +user_name: Must be valid as both a DNS label and Linux username. If + empty, then the Experiment will not be associated with any User. + }]; + + let arguments = (ins + Arg:$writer, + TF_StrTensor:$db_uri, + TF_StrTensor:$experiment_name, + TF_StrTensor:$run_name, + TF_StrTensor:$user_name + ); + + let results = (outs); +} + +// TODO(b/168035831): Model logdir read/write. +def TF_CreateSummaryFileWriterOp : TF_Op<"CreateSummaryFileWriter", []> { + let summary = "Creates a summary file writer accessible by the given resource handle."; + + let description = [{ +writer: A handle to the summary writer resource +logdir: Directory where the event file will be written. +max_queue: Size of the queue of pending events and summaries. +flush_millis: How often, in milliseconds, to flush the pending events and + summaries to disk. +filename_suffix: Every event file's name is suffixed with this suffix. + }]; + + let arguments = (ins + Arg:$writer, + TF_StrTensor:$logdir, + TF_Int32Tensor:$max_queue, + TF_Int32Tensor:$flush_millis, + TF_StrTensor:$filename_suffix + ); + + let results = (outs); +} + +def TF_FlushSummaryWriterOp : TF_Op<"FlushSummaryWriter", []> { + let summary = "Flushes the writer's unwritten events."; + + let description = [{ +writer: A handle to the summary writer resource. + }]; + + let arguments = (ins + Arg:$writer + ); + + let results = (outs); +} + +def TF_ImportEventOp : TF_Op<"ImportEvent", []> { + let summary = "Outputs a `tf.Event` protocol buffer."; + + let description = [{ +When CreateSummaryDbWriter is being used, this op can be useful for +importing data from event logs. + +writer: A handle to a summary writer. +event: A string containing a binary-encoded tf.Event proto. + }]; + + let arguments = (ins + Arg:$writer, + TF_StrTensor:$event + ); + + let results = (outs); +} + +def TF_SummaryWriterOp : TF_Op<"SummaryWriter", [TF_ResourceHandleAllocatorInterface]> { + let summary = "Returns a handle to be used to access a summary writer."; + + let description = [{ +The summary writer is an in-graph resource which can be used by ops to write +summaries to event files. + +writer: the summary writer resource. Scalar handle. + }]; + + let arguments = (ins + StrAttr:$shared_name, + StrAttr:$container + ); + + let results = (outs + Res:$writer + ); + + let extraClassDeclaration = [{ + // TF_ResourceHandleAllocatorInterface: + ResourceHandleValueAndId GetResourceHandleValueAndId( + llvm::SmallDenseMap &resource_handle_id_map, + int64_t &next_id); + }]; +} + +def TF_WriteAudioSummaryOp : TF_Op<"WriteAudioSummary", []> { + let summary = "Writes a `Summary` protocol buffer with audio."; + + let description = [{ +The summary has up to `max_outputs` summary values containing audio. The +audio is built from `tensor` which must be 3-D with shape `[batch_size, +frames, channels]` or 2-D with shape `[batch_size, frames]`. The values are +assumed to be in the range of `[-1.0, 1.0]` with a sample rate of `sample_rate`. + +The `tag` argument is a scalar `Tensor` of type `string`. It is used to +build the `tag` of the summary values: + +* If `max_outputs` is 1, the summary value tag is '*tag*/audio'. +* If `max_outputs` is greater than 1, the summary value tags are + generated sequentially as '*tag*/audio/0', '*tag*/audio/1', etc. + +writer: A handle to a summary writer. +step: The step to write the summary for. +tag: Scalar. Used to build the `tag` attribute of the summary values. +tensor: 2-D of shape `[batch_size, frames]`. +sample_rate: The sample rate of the signal in hertz. +max_outputs: Max number of batch elements to generate audio for. + }]; + + let arguments = (ins + Arg:$writer, + TF_Int64Tensor:$step, + TF_StrTensor:$tag, + TF_Float32Tensor:$tensor, + TF_Float32Tensor:$sample_rate, + + Confined, [IntMinValue<1>]>:$max_outputs + ); + + let results = (outs); +} + +def TF_WriteGraphSummaryOp : TF_Op<"WriteGraphSummary", []> { + let summary = "Writes a `GraphDef` protocol buffer to a `SummaryWriter`."; + + let description = [{ +writer: Handle of `SummaryWriter`. +step: The step to write the summary for. +tensor: A scalar string of the serialized tf.GraphDef proto. + }]; + + let arguments = (ins + Arg:$writer, + TF_Int64Tensor:$step, + TF_StrTensor:$tensor + ); + + let results = (outs); +} + +def TF_WriteHistogramSummaryOp : TF_Op<"WriteHistogramSummary", []> { + let summary = "Writes a histogram summary."; + + let description = [{ +The generated +[`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto) +has one summary value containing a histogram for `values`. + +This op reports an `InvalidArgument` error if any value is not finite. + +writer: A handle to a summary writer. +step: The step to write the summary for. +tag: Scalar. Tag to use for the `Summary.Value`. +values: Any shape. Values to use to build the histogram. + }]; + + let arguments = (ins + Arg:$writer, + TF_Int64Tensor:$step, + TF_StrTensor:$tag, + TF_IntOrFpTensor:$values + ); + + let results = (outs); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<3>; +} + +def TF_WriteImageSummaryOp : TF_Op<"WriteImageSummary", []> { + let summary = "Writes a `Summary` protocol buffer with images."; + + let description = [{ +The summary has up to `max_images` summary values containing images. The +images are built from `tensor` which must be 4-D with shape `[batch_size, +height, width, channels]` and where `channels` can be: + +* 1: `tensor` is interpreted as Grayscale. +* 3: `tensor` is interpreted as RGB. +* 4: `tensor` is interpreted as RGBA. + +The images have the same number of channels as the input tensor. For float +input, the values are normalized one image at a time to fit in the range +`[0, 255]`. `uint8` values are unchanged. The op uses two different +normalization algorithms: + +* If the input values are all positive, they are rescaled so the largest one + is 255. + +* If any input value is negative, the values are shifted so input value 0.0 + is at 127. They are then rescaled so that either the smallest value is 0, + or the largest one is 255. + +The `tag` argument is a scalar `Tensor` of type `string`. It is used to +build the `tag` of the summary values: + +* If `max_images` is 1, the summary value tag is '*tag*/image'. +* If `max_images` is greater than 1, the summary value tags are + generated sequentially as '*tag*/image/0', '*tag*/image/1', etc. + +The `bad_color` argument is the color to use in the generated images for +non-finite input values. It is a `unit8` 1-D tensor of length `channels`. +Each element must be in the range `[0, 255]` (It represents the value of a +pixel in the output image). Non-finite values in the input tensor are +replaced by this tensor in the output image. The default value is the color +red. + +writer: A handle to a summary writer. +step: The step to write the summary for. +tag: Scalar. Used to build the `tag` attribute of the summary values. +tensor: 4-D of shape `[batch_size, height, width, channels]` where + `channels` is 1, 3, or 4. +max_images: Max number of batch elements to generate images for. +bad_color: Color to use for pixels with non-finite values. + }]; + + let arguments = (ins + Arg:$writer, + TF_Int64Tensor:$step, + TF_StrTensor:$tag, + TensorOf<[TF_Float16, TF_Float32, TF_Uint8]>:$tensor, + TF_Uint8Tensor:$bad_color, + + Confined, [IntMinValue<1>]>:$max_images + ); + + let results = (outs); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<3>; +} + +def TF_WriteRawProtoSummaryOp : TF_Op<"WriteRawProtoSummary", []> { + let summary = "Writes a `Summary` protocol buffer with serialized string `Summary` protocol buffers."; + + let description = [{ +writer: A handle to a summary writer. +step: The step to write the summary for. +tensor: A tensor holding one or more serialized `Summary` protobufs to write. + }]; + + let arguments = (ins + Arg:$writer, + TF_Int64Tensor:$step, + TF_StrTensor:$tensor + ); + + let results = (outs); +} + +def TF_WriteScalarSummaryOp : TF_Op<"WriteScalarSummary", []> { + let summary = "Writes a `Summary` protocol buffer with scalar values."; + + let description = [{ +The input `tag` and `value` must have the scalars. + +writer: A handle to a summary writer. +step: The step to write the summary for. +tag: Tag for the summary. +value: Value for the summary. + }]; + + let arguments = (ins + Arg:$writer, + TF_Int64Tensor:$step, + TF_StrTensor:$tag, + TF_IntOrFpTensor:$value + ); + + let results = (outs); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<3>; +} + +def TF_WriteSummaryOp : TF_Op<"WriteSummary", []> { + let summary = "Outputs a `Summary` protocol buffer with a tensor."; + + let description = [{ +writer: A handle to a summary writer. +step: The step to write the summary for. +tensor: A tensor to serialize. +tag: The summary's tag. +summary_metadata: Serialized SummaryMetadata protocol buffer containing + plugin-related metadata for this summary. + }]; + + let arguments = (ins + Arg:$writer, + TF_Int64Tensor:$step, + TF_Tensor:$tensor, + TF_StrTensor:$tag, + TF_StrTensor:$summary_metadata + ); + + let results = (outs); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<2>; +} + +def TF_InitializeTableFromDatasetOp : TF_Op<"InitializeTableFromDataset", []> { + let summary = ""; + + let arguments = (ins + Arg:$table_handle, + TF_VariantTensor:$dataset + ); + + let results = (outs); +} + +// TODO(b/168035831): Model filename read. +def TF_InitializeTableFromTextFileV2Op : TF_Op<"InitializeTableFromTextFileV2", []> { + let summary = "Initializes a table from a text file."; + + let description = [{ +It inserts one key-value pair into the table for each line of the file. +The key and value is extracted from the whole line content, elements from the +split line based on `delimiter` or the line number (starting from zero). +Where to extract the key and value from a line is specified by `key_index` and +`value_index`. + +- A value of -1 means use the line number(starting from zero), expects `int64`. +- A value of -2 means use the whole line content, expects `string`. +- A value >= 0 means use the index (starting at zero) of the split line based + on `delimiter`. + }]; + + let arguments = (ins + Arg:$table_handle, + TF_StrTensor:$filename, + + Confined]>:$key_index, + Confined]>:$value_index, + Confined, [IntMinValue<-1>]>:$vocab_size, + DefaultValuedAttr:$delimiter + ); + + let results = (outs); +} + +// TODO(b/168035831): Model filename read. +def TF_CacheDatasetV2Op : TF_Op<"CacheDatasetV2", []> { + let summary = ""; + + let arguments = (ins + TF_VariantTensor:$input_dataset, + TF_StrTensor:$filename, + Arg:$cache, + + Confined]>:$output_types, + Confined]>:$output_shapes + ); + + let results = (outs + TF_VariantTensor:$handle + ); +} + #endif // TF_OPS diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc index b465c1da68c..8bbc6a843e5 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc @@ -176,6 +176,72 @@ static LogicalResult Verify(BatchMatMulV2Op op) { if (!HasRankAtLeast(op.y(), 2)) { return op.emitOpError("requires rhs operand to have rank at least two"); } + + RankedTensorType x_ty = GetRankedTensorTypeForOperand(op.x()); + RankedTensorType y_ty = GetRankedTensorTypeForOperand(op.y()); + + if (!x_ty || !y_ty) return success(); + + ArrayRef x_shape = x_ty.getShape(); + ArrayRef y_shape = y_ty.getShape(); + + // Check broadcast compatibility if both input shapes are known. + // + // The last two dimensions are non-batch dimensions that don't need to + // participate in batch dimension compatibility check. + + llvm::SmallVector result_batch_shape; + if (!OpTrait::util::getBroadcastedShape( + x_shape.drop_back(2), y_shape.drop_back(2), result_batch_shape)) + return op.emitOpError() + << "found incompatible broadcast batch dimensions for lhs shape " + << x_ty << " and rhs shape " << y_ty; + + RankedTensorType output_ty = GetRankedTensorTypeForOperand(op.output()); + if (!output_ty) return success(); + + int64_t expected_output_rank = std::max(x_ty.getRank(), y_ty.getRank()); + if (output_ty.getRank() != expected_output_rank) + return op.emitOpError() + << "found invalid output rank, expected " << expected_output_rank + << " but got " << output_ty.getRank(); + + // Check output batch dim with potential broadcasting. + ArrayRef output_shape = output_ty.getShape(); + for (int i = 0; i < result_batch_shape.size(); ++i) { + if (output_shape[i] != ShapedType::kDynamicSize && + output_shape[i] != result_batch_shape[i]) + return op.emitOpError() + << "has mismatching input batch dimension " + << result_batch_shape[i] << " and output batch dimension " + << output_shape[i]; + } + + // Check output shape for non-batch dimension, following documentation below. + // https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-mat-mul + int64_t x_row_dim = x_shape[x_shape.size() - 2]; + int64_t x_col_dim = x_shape[x_shape.size() - 1]; + int64_t y_row_dim = y_shape[y_shape.size() - 2]; + int64_t y_col_dim = y_shape[y_shape.size() - 1]; + int64_t out_row_dim = output_shape[output_shape.size() - 2]; + int64_t out_col_dim = output_shape[output_shape.size() - 1]; + + int64_t expected_out_row_dim = op.adj_x() ? x_col_dim : x_row_dim; + int64_t expected_out_col_dim = op.adj_y() ? y_row_dim : y_col_dim; + + if (expected_out_row_dim != ShapedType::kDynamicSize && + out_row_dim != ShapedType::kDynamicSize && + out_row_dim != expected_out_row_dim) + return op.emitOpError() + << "found invalid output dimension on row, expected " + << expected_out_row_dim << " but got " << out_row_dim; + if (expected_out_col_dim != ShapedType::kDynamicSize && + out_col_dim != ShapedType::kDynamicSize && + out_col_dim != expected_out_col_dim) + return op.emitOpError() + << "found invalid output dimension on col, expected " + << expected_out_col_dim << " but got " << out_col_dim; + return success(); } @@ -190,7 +256,7 @@ void BatchMatMulV2Op::getCanonicalizationPatterns( static LogicalResult Verify(BatchToSpaceOp op) { // Op already has a constraint that block_size >= 2. - int64_t block_size = op.block_size().getSExtValue(); + int64_t block_size = op.block_size(); llvm::SmallVector input_shape(4, ShapedType::kDynamicSize); auto input_type = op.input().getType().cast(); @@ -381,6 +447,13 @@ static LogicalResult Verify(BiasAddOp op) { return success(); } +Optional BiasAddOp::GetContractionFusion() { + // Only NHWC in f32 is supported for fusion. + if (data_format() != "NHWC" || !T().isF32()) return None; + + return ContractionFusion("BiasAdd", /*additional_arguments=*/{1}); +} + //===----------------------------------------------------------------------===// // BiasAddGradOp //===----------------------------------------------------------------------===// @@ -473,8 +546,7 @@ LogicalResult FoldConstantCaseOp::matchAndRewrite( if (!matchPattern(op.branch_index(), m_Constant(&branch))) return failure(); int index = *branch.getValues().begin(); - if (index < 0 || index >= op.branches().size()) - index = op.branches().size() - 1; + if (index < 0 || index >= op.num_branches()) index = op.num_branches() - 1; auto func = op.branches()[index].cast(); auto empty = rewriter.getStringAttr(""); @@ -507,8 +579,9 @@ static LogicalResult VerifyCaseOrIfOpBranchFunctions( // Functions have one less operand compared to op as first operand is elided // (`cond` of `tf.If` and `branch_index` of `tf.Case`). - int expected_num_inputs = op->getNumOperands() - 1; - int expected_num_results = op->getNumResults(); + TypeRangeWithDesc input{op->getOperands().drop_front().getTypes(), "input"}; + TypeRangeWithDesc result{op->getResultTypes(), "result"}; + for (auto branch : llvm::enumerate(branches)) { auto branch_func = SymbolTable::lookupNearestSymbolFrom( op, branch.value().cast()); @@ -518,47 +591,22 @@ static LogicalResult VerifyCaseOrIfOpBranchFunctions( << branch.value() << ") to point to a defined function"; FunctionType branch_type = branch_func.getType(); - if (branch_type.getNumInputs() != expected_num_inputs) - return op->emitOpError() - << "expects all branches to have " << expected_num_inputs - << " input(s), but " << branch_name(branch.index()) << " has " - << branch_type.getNumInputs() << " input(s)"; + std::string desc = branch_name(branch.index()) + " input"; + TypeRangeWithDesc branch_input{branch_type.getInputs(), desc}; + if (failed(VerifyTypeRangesAreCompatible(op, branch_input, input))) + return failure(); - if (branch_type.getNumResults() != expected_num_results) - return op->emitOpError() - << "expects all branches to have " << expected_num_results - << " result(s), but " << branch_name(branch.index()) << " has " - << branch_type.getNumResults() << " result(s)"; - - // Non-conditional operands starting with the second operand are passed to - // branches and should be compatible across all branches' inputs. - for (auto operand_type : - llvm::enumerate(llvm::drop_begin(op->getOperandTypes(), 1))) { - Type branch_input_i_type = branch_type.getInput(operand_type.index()); - if (!AreCastCompatible({operand_type.value(), branch_input_i_type})) - return op->emitOpError() - << "expects operand type " << operand_type.value() - << " to be cast compatible with " << branch_name(branch.index()) - << " input type " << branch_input_i_type << " at index " - << operand_type.index(); - } - - // Branches' results should be pair-wise compatible with the op results. - for (auto result_type : llvm::enumerate(op->getResultTypes())) { - Type branch_result_i_type = branch_type.getResult(result_type.index()); - if (!AreCastCompatible({result_type.value(), branch_result_i_type})) - return op->emitOpError() - << "expects result type " << result_type.value() - << " to be cast compatible with " << branch_name(branch.index()) - << " result type " << branch_result_i_type << " at index " - << result_type.index(); - } + desc = branch_name(branch.index()) + " result"; + TypeRangeWithDesc branch_result{branch_type.getResults(), desc}; + if (failed(VerifyTypeRangesAreCompatible(op, branch_result, result))) + return failure(); branch_types.push_back(branch_type); } // If branches have incompatible input types that means that no tensor can // serve as input to all the functions. Hence, the op is invalid. + int expected_num_inputs = op->getNumOperands() - 1; for (int i = 0; i < expected_num_inputs; ++i) { SmallVector branch_input_i_types; branch_input_i_types.reserve(branches.size()); @@ -597,16 +645,89 @@ static LogicalResult Verify(CaseRegionOp op) { if (failed(VerifyCaseOpBase(op, op.branch_index()))) return failure(); + TypeRangeWithDesc results{op.getResultTypes(), "result"}; + for (auto region_and_idx : llvm::enumerate(op.branches())) { - std::string region_name = - llvm::formatv("region #{0}", region_and_idx.index()).str(); - if (failed(VerifyRegionResults(op, region_and_idx.value(), region_name))) + std::string description = + llvm::formatv("branch #{0} result", region_and_idx.index()).str(); + Operation *yield = region_and_idx.value().front().getTerminator(); + TypeRangeWithDesc branch_results{yield->getOperandTypes(), description}; + if (failed(VerifyTypeRangesAreCompatible(op, branch_results, results))) return failure(); } return success(); } +namespace { +// Eliminate values that pass through the CaseRegionOp or IfRegionOp branches. +template +class CaseOrIfRegionEliminatePassThrough + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(CaseOrIfRegionOp op, + PatternRewriter &rewriter) const override { + RegionRange branches = op.getRegions(); + SmallVector new_result_types; + // Maps pass through results to extern values. + llvm::SmallDenseMap result_to_extern_value; + + for (auto result : op.getResults()) { + unsigned index = result.getResultNumber(); + Region *first_branch = *branches.begin(); + Operation *first_terminator = first_branch->front().getTerminator(); + Value returned_val = first_terminator->getOperand(index); + + // Pass through values would be defined outside the branch region. Keep + // the type of non pass through results to create a new op later, if + // required. + if (returned_val.getParentBlock() == &first_branch->front()) { + new_result_types.push_back(result.getType()); + continue; + } + // Check if the same extern value is returned in each branch. + for (Region *region : branches.drop_front()) { + Operation *terminator = region->front().getTerminator(); + if (terminator->getOperand(index) != returned_val) return failure(); + } + result_to_extern_value[result] = returned_val; + } + + // If no pass through values are found, no change is required. + if (result_to_extern_value.empty()) return failure(); + + // Create new case/if region op. + auto new_op = rewriter.create( + op.getLoc(), new_result_types, op.getOperand(), op.getAttrs(), + op.getNumRegions()); + + int next_index = 0; + for (auto result : op.getResults()) { + if (!result_to_extern_value.count(result)) { + result.replaceAllUsesWith(new_op.getResult(next_index++)); + continue; + } + result.replaceAllUsesWith(result_to_extern_value[result]); + for (Region *branch : branches) + branch->front().getTerminator()->eraseOperand(next_index); + } + + // Move region bodies to the new op. + for (auto region_index : llvm::seq(0, branches.size())) + new_op.getRegion(region_index).takeBody(op.getRegion(region_index)); + + op.erase(); + return success(); + } +}; +} // namespace + +void CaseRegionOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert>(context); +} + //===----------------------------------------------------------------------===// // CastOp //===----------------------------------------------------------------------===// @@ -1639,7 +1760,7 @@ static LogicalResult Verify(FakeQuantWithMinMaxArgsOp op) { return op.emitOpError("range is invalid: [" + Twine(std::to_string(rmin)) + "," + Twine(std::to_string(rmax)) + "]"); } - int64_t num_bits = op.num_bits().getSExtValue(); + int64_t num_bits = op.num_bits(); if (num_bits < 2 || num_bits > 16) { return op.emitOpError( "requires num_bits to be between 2 and 16, inclusive"); @@ -1659,7 +1780,7 @@ static LogicalResult Verify(FakeQuantWithMinMaxVarsOp op) { if (max && !IsOfRankedFloatTensorType(max, 0)) return op.emitOpError("requires max to be a 0d float tensor"); - int64_t num_bits = op.num_bits().getSExtValue(); + int64_t num_bits = op.num_bits(); if (num_bits < 2 || num_bits > 16) { return op.emitOpError( "requires num_bits to be between 2 and 16, inclusive"); @@ -1683,7 +1804,7 @@ static LogicalResult Verify(FakeQuantWithMinMaxVarsPerChannelOp op) { if (!HasRankAtLeast(inputs, 1)) return op.emitError("requires inputs to be at least 1d float tensor"); - int64_t num_bits = op.num_bits().getSExtValue(); + int64_t num_bits = op.num_bits(); if (num_bits < 2 || num_bits > 16) { return op.emitOpError( "requires num_bits to be between 2 and 16, inclusive"); @@ -1886,7 +2007,7 @@ StringRef FusedBatchNormV3Op::GetOptimalLayout(const RuntimeDevices &devices) { //===----------------------------------------------------------------------===// static LogicalResult Verify(GatherV2Op op) { - int64_t batch_dims = op.batch_dims().getSExtValue(); + int64_t batch_dims = op.batch_dims(); if (auto ty = op.indices().getType().dyn_cast()) { int64_t rank = ty.getRank(); if (batch_dims > rank || batch_dims < -rank) @@ -1992,9 +2113,18 @@ void IfOp::getCanonicalizationPatterns(OwningRewritePatternList &results, //===----------------------------------------------------------------------===// static LogicalResult Verify(IfRegionOp op) { - if (failed(VerifyRegionResults(op, op.then_branch(), "then"))) + TypeRange then_types = + op.then_branch().front().getTerminator()->getOperandTypes(); + TypeRange else_types = + op.else_branch().front().getTerminator()->getOperandTypes(); + + TypeRangeWithDesc results{op.getResultTypes(), "result"}; + TypeRangeWithDesc then_results{then_types, "then result"}; + TypeRangeWithDesc else_results{else_types, "else result"}; + + if (failed(VerifyTypeRangesAreCompatible(op, then_results, results))) return failure(); - if (failed(VerifyRegionResults(op, op.else_branch(), "else"))) + if (failed(VerifyTypeRangesAreCompatible(op, else_results, results))) return failure(); return success(); } @@ -2051,7 +2181,8 @@ LogicalResult FoldConstantIfRegionOp::matchAndRewrite( void IfRegionOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); + results.insert>(context); } //===----------------------------------------------------------------------===// @@ -2102,6 +2233,15 @@ OpFoldResult LeakyReluOp::fold(ArrayRef operands) { return {}; } +Optional LeakyReluOp::GetContractionFusion() { + // Only f32 is supported for fusion. + if (!T().isF32()) return None; + + NamedAttribute alpha(Identifier::get("alpha", getContext()), alphaAttr()); + return ContractionFusion("LeakyRelu", /*additional_arguments=*/{}, + /*additional_attributes=*/{alpha}); +} + //===----------------------------------------------------------------------===// // LogOp //===----------------------------------------------------------------------===// @@ -2223,12 +2363,12 @@ OpFoldResult MulOp::fold(ArrayRef operands) { return IdentityArithmeticOpFolder(*this, operands); } +} // namespace TF +} // namespace mlir + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// #define GET_OP_CLASSES #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc.inc" - -} // namespace TF -} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h index 19a927a23d7..8d98632b198 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h @@ -43,6 +43,9 @@ namespace TF { class YieldOp; +} // namespace TF +} // namespace mlir + // TODO(b/131258166): TensorFlow's mutex.h defines a `mutex_lock` macro, whose // purpose is to catch bug on `tensorflow::mutex_lock`. We don't use // `tensorflow::mutex_lock` here but we have ops (`tf.MutexLock` and @@ -56,7 +59,4 @@ class YieldOp; #define GET_OP_CLASSES #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h.inc" -} // namespace TF -} // namespace mlir - #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_A_M_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_helpers.inc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_helpers.inc index bb7d9a50521..72ca50b5c37 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_helpers.inc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_helpers.inc @@ -543,27 +543,27 @@ static LogicalResult VerifyReductionInputAndDims(Value input, Value dims, return success(); } -LogicalResult VerifyRegionResults(Operation *op, Region ®ion, - StringRef region_name) { - auto op_name = op->getName().getStringRef(); - // verify that op outputs match yield inputs - YieldOp yield = cast(region.front().getTerminator()); - unsigned expected_num_results = op->getNumResults(); - if (yield.getNumOperands() != expected_num_results) - return op->emitOpError() - << region_name + " should have same number (" << expected_num_results - << ") of results as " << op_name << " but has " - << yield.getNumOperands() << " results"; +// A type range with description (in singular form) attached to it. +using TypeRangeWithDesc = std::pair; - for (int idx : llvm::seq(0, expected_num_results)) { - auto op_result_type = op->getResult(idx).getType().cast(); - auto region_result_type = - yield.getOperand(idx).getType().cast(); - if (!AreCastCompatible({region_result_type, op_result_type})) - return op->emitError(llvm::formatv( - "{0} result type {1} is incompatible with {2} " - "result type {3} at index {4}", - region_name, region_result_type, op_name, op_result_type, idx)); +LogicalResult VerifyTypeRangesAreCompatible(Operation *op, + TypeRangeWithDesc range0, + TypeRangeWithDesc range1) { + if (range0.first.size() != range1.first.size()) { + return op->emitOpError() + << range0.second << "s (size = " << range0.first.size() << ")" + << " should have the same number of values as " << range1.second + << "s (size = " << range1.first.size() << ")"; + } + + for (auto it : llvm::enumerate(llvm::zip(range0.first, range1.first))) { + int index = it.index(); + Type type0 = std::get<0>(it.value()); + Type type1 = std::get<1>(it.value()); + if (!AreCastCompatible({type0, type1})) + return op->emitOpError(llvm::formatv( + "{0} type {1} is incompatible with {2} type {3} at index {4}", + range0.second, type0, range1.second, type1, index)); } return success(); } @@ -587,3 +587,31 @@ struct DropAttributes : public OpRewritePattern { } }; +//===----------------------------------------------------------------------===// +// TF op helper functions for handling resource handles and ids. +//===----------------------------------------------------------------------===// + +// Returns device of op if present. If op has no device set, an empty string ref +// is returned instead. +llvm::StringRef GetDeviceOrEmpty(Operation *op) { + if (auto device_attr = op->getAttrOfType("device")) + return device_attr.getValue(); + return llvm::StringRef(); +} + +// Returns resource handle value and id for resource op based on attributes. If +// a resource handle is anonymous, a new id is always returned. +ResourceHandleValueAndId GetResourceHandleValueAndIdBase( + llvm::StringRef container, llvm::StringRef shared_name, + llvm::StringRef device, Value resource, + llvm::SmallDenseMap &resource_handle_id_map, + int64_t &next_id) { + // Always create a new ID for anonymous handle. + if (IsResourceHandleAnonymous(shared_name)) return {resource, next_id++}; + + ResourceHandle handle(container, shared_name, device, /*op=*/nullptr); + auto emplace_res = resource_handle_id_map.try_emplace(handle, next_id); + // New ID created, increment next_id. + if (emplace_res.second) ++next_id; + return {resource, emplace_res.first->second}; +} diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc index cbac03f80f8..b99c99029ed 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc @@ -27,6 +27,8 @@ limitations under the License. #include "llvm/ADT/APFloat.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/None.h" #include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Sequence.h" @@ -34,6 +36,7 @@ limitations under the License. #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSwitch.h" +#include "llvm/ADT/Twine.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Support/Casting.h" #include "llvm/Support/FormatVariadic.h" @@ -109,7 +112,7 @@ void NotEqualOp::build(OpBuilder &builder, OperationState &result, Value x, //===----------------------------------------------------------------------===// static LogicalResult Verify(OneHotOp op) { - int64_t axis = op.axis().getSExtValue(); + int64_t axis = op.axis(); auto indices_ty = op.indices().getType().dyn_cast(); if (indices_ty && @@ -207,7 +210,7 @@ static LogicalResult Verify(PackOp op) { // the axis value range is [-(R+1), R+1). int64_t range_begin = -inputs_rank - 1; // Inclusive int64_t range_end = inputs_rank + 1; // Exclusive - int64_t axis = op.axis().getSExtValue(); + int64_t axis = op.axis(); if (axis < range_begin || axis >= range_end) { return op.emitError() << "attribute 'axis' should be within range [" << range_begin << ", " << range_end @@ -232,7 +235,7 @@ OpFoldResult PackOp::fold(ArrayRef operands) { if (values().size() < 2) return {}; // Dimensions packed along axis = 0 (pack scalars into vector). - if (axis().getSExtValue() != 0) return {}; + if (axis() != 0) return {}; // First packed value is defined by a strided slice operation. auto slice_op = dyn_cast_or_null(values()[0].getDefiningOp()); @@ -247,11 +250,9 @@ OpFoldResult PackOp::fold(ArrayRef operands) { // All masks are `0` except `shrink_axis_mask` which is equal to `1` (slicing // scalar value from input vector). - if (slice_op.begin_mask().getSExtValue() != 0 || - slice_op.ellipsis_mask().getSExtValue() != 0 || - slice_op.end_mask().getSExtValue() != 0 || - slice_op.new_axis_mask().getSExtValue() != 0 || - slice_op.shrink_axis_mask().getSExtValue() != 1) + if (slice_op.begin_mask() != 0 || slice_op.ellipsis_mask() != 0 || + slice_op.end_mask() != 0 || slice_op.new_axis_mask() != 0 || + slice_op.shrink_axis_mask() != 1) return {}; // Returns a value if the `value` is defined by a ConstOp with a single @@ -566,135 +567,158 @@ OpFoldResult RealDivOp::fold(ArrayRef operands) { return IdentityArithmeticOpFolder(*this, operands); } +//===----------------------------------------------------------------------===// +// ReluOp +//===----------------------------------------------------------------------===// + +Optional ReluOp::GetContractionFusion() { + // Only f32 is supported for fusion. + if (!T().isF32()) return None; + + return ContractionFusion("Relu", /*additional_arguments=*/{}); +} + //===----------------------------------------------------------------------===// // ReshapeOp //===----------------------------------------------------------------------===// -// TODO(b/128020684): Verify the output type. -static LogicalResult Verify(ReshapeOp op) { - auto shape_type = op.shape().getType().cast(); - if (!shape_type.hasRank()) return success(); - if (shape_type.getRank() != 1) - return op.emitOpError("shape must be 1D tensor"); - auto rank_by_shape = shape_type.getShape()[0]; - auto type_of_tensor = op.tensor().getType().cast(); - // No compile time verification for unknown sized shape. - if (rank_by_shape == -1 || !type_of_tensor.hasStaticShape()) return success(); - int64_t num_by_tensor = type_of_tensor.getNumElements(); +namespace { +using ReshapeErrorHandler = + llvm::function_ref; - auto out_ty = op.getType().dyn_cast(); - if (out_ty && out_ty.hasStaticShape()) { - int64_t num_output_elements = out_ty.getNumElements(); - if (num_by_tensor != num_output_elements) - return op.emitOpError() - << "number of output elements (" << num_output_elements - << ") does not match expected number of elements (" - << num_by_tensor << ")"; - } +LogicalResult GetReshapeOutputType(Value tensor, Value shape, + ReshapeErrorHandler error_handler, + TensorType &output_ty) { + auto tensor_ty = tensor.getType().cast(); + auto element_ty = tensor_ty.getElementType(); + output_ty = UnrankedTensorType::get(element_ty); - // Check values if constant shape. No compiling time verification for - // non-constant shape. - auto *shape_op = op.shape().getDefiningOp(); - if (!shape_op) return success(); - Attribute shape_cst; - if (!matchPattern(shape_op, m_Constant(&shape_cst))) return success(); - auto shape_cst_attr = shape_cst.dyn_cast(); - if (!shape_cst_attr) return op.emitOpError("shape must be a valid tensor"); + auto shape_ty = shape.getType().dyn_cast(); + if (!shape_ty) return success(); + if (shape_ty.getRank() != 1) + return error_handler(llvm::formatv( + "requires 'shape' to be rank 1, but got {0}", shape_ty.getRank())); - if (auto opaque_attr = shape_cst_attr.dyn_cast()) { - opaque_attr.decode(shape_cst_attr); - } - - // We know the shape is a 1-D Tensor, then let us get the number of - // elements it implies. - unsigned num_by_shape = 1; - unsigned unknown_dim_count = 0; - for (int i = 0, e = rank_by_shape; i != e; ++i) { - auto num = shape_cst_attr.getValue(i).getInt(); - // The dimension size value can be -1, and that the real size needs to - // be computed so that the total size remains constant. At most one - // component of shape can be -1. - if (num == -1) { - if (++unknown_dim_count > 1) { - return op.emitOpError("more than one component of shape are -1"); - } - } else { - num_by_shape *= num; + DenseIntElementsAttr shape_attr; + if (!matchPattern(shape, m_Constant(&shape_attr))) { + // If only shape of `shape` is known, return ranked but dynamic output + // shape. + if (shape_ty.hasStaticShape()) { + llvm::SmallVector dynamic_shape(shape_ty.getDimSize(0), + ShapedType::kDynamicSize); + output_ty = RankedTensorType::get(dynamic_shape, element_ty); } - } - // If there is one component of shape is -1, the dimension should be - // computed so that the total size remains constant. - if (unknown_dim_count == 1) { - if (num_by_tensor % num_by_shape != 0) - return op.emitOpError( - "one component of shape is -1 but couldn't infer the dimension"); return success(); } - // If the elements by the tensor and implies by the shape don't match, - // fail this static check. - if (num_by_tensor != num_by_shape) { - return op.emitOpError( - "mismatch in tensor elements and shape implied elements"); + + // Detect if reshape output shape is folded. + bool shape_ty_zero_dim = false; + int unknown_index = -1; + // The product of constant shape argument excluding unknown dimension. + int64_t shape_ty_size = 1; + llvm::SmallVector output_ty_shape; + output_ty_shape.reserve(shape_attr.getNumElements()); + for (const auto &dim : llvm::enumerate(shape_attr.getIntValues())) { + const int64_t size = dim.value().getSExtValue(); + if (size == ShapedType::kDynamicSize) { + if (unknown_index != -1) + return error_handler(llvm::formatv( + "requires 'shape' to have at most one dynamic dimension, but got " + "multiple dynamic dimensions at indices {0} and {1}", + unknown_index, dim.index())); + + unknown_index = dim.index(); + } else if (size == 0) { + shape_ty_zero_dim = true; + } else if (size > 0) { + shape_ty_size *= size; + } else { + return error_handler( + llvm::formatv("requires 'shape' to have dimensions greater than -1, " + "but got {0} at index {1}", + size, dim.index())); + } + output_ty_shape.push_back(size); } + + if (!tensor_ty.hasStaticShape()) { + output_ty = RankedTensorType::get(output_ty_shape, element_ty); + return success(); + } + + // Compute the value of the unknown dimension. + if (unknown_index != -1) { + // Compute number of elements in tensor shape. + int64_t tensor_ty_size = 1; + bool tensor_ty_zero_dim = false; + for (const auto &dim : tensor_ty.getShape()) { + if (dim > 0 || !shape_ty_zero_dim) { + tensor_ty_size *= dim; + } else { + tensor_ty_zero_dim = true; + } + } + + const int64_t missing_dim = tensor_ty_size / shape_ty_size; + if (!tensor_ty_zero_dim && shape_ty_size * missing_dim != tensor_ty_size) + return error_handler( + llvm::formatv("requires 'tensor' number of elements be a multiple of " + "{0}, but got {1}", + shape_ty_size, tensor_ty_size)); + + // Set the unknown dimension such that total number of elements remain + // constant. + output_ty_shape[unknown_index] = missing_dim; + } + + output_ty = RankedTensorType::get(output_ty_shape, element_ty); + + return success(); +} +} // namespace + +static LogicalResult Verify(ReshapeOp op) { + auto error_handler = [&op](const llvm::Twine &message) -> LogicalResult { + return op.emitOpError() << message; + }; + TensorType expected_ty; + if (failed(GetReshapeOutputType(op.tensor(), op.shape(), error_handler, + expected_ty))) + return failure(); + + auto output_ty = op.getType().dyn_cast(); + if (!output_ty) return success(); + auto tensor_ty = op.tensor().getType().cast(); + if (output_ty.hasStaticShape() && tensor_ty.hasStaticShape()) { + const int64_t output_ty_size = output_ty.getNumElements(); + const int64_t tensor_ty_size = tensor_ty.getNumElements(); + if (tensor_ty_size != output_ty_size) + return op.emitOpError() << "requires 'output' number of elements to " + "match 'tensor' number of elements, but got " + << output_ty_size << " and " << tensor_ty_size; + } + + if (!AreCastCompatible({output_ty, expected_ty})) + return op.emitOpError() + << "requires 'output' type " << output_ty + << " to be cast compatible with expected type " << expected_ty; + return success(); } +// Currently there are use cases that rely on partial evaluation of the `shape` +// operand, so InferTypeOpInterface is not used (along with generated builder of +// the same signature). void ReshapeOp::build(OpBuilder &builder, OperationState &result, Value tensor, Value shape) { - auto ttype = tensor.getType().cast(); - auto etype = ttype.getElementType(); - - auto unranked = [&builder, etype, &result, shape, tensor]() { - return ReshapeOp::build(builder, result, UnrankedTensorType::get(etype), - tensor, shape); + auto error_handler = [&result](const llvm::Twine &message) { + return mlir::emitError(result.location) << message; }; + TensorType output_ty; + if (failed(GetReshapeOutputType(tensor, shape, error_handler, output_ty))) + return; - // If tensor is unranked then we have no info about output of shape. - if (!ttype.hasRank()) return unranked(); - - DenseIntElementsAttr attr_shape; - if (matchPattern(shape, m_Constant(&attr_shape))) { - llvm::SmallVector const_shape; - const_shape.reserve(attr_shape.getNumElements()); - - // Detect if reshape output shape is folded. - bool flatten = false; - int unknown_index = -1; - // The product of constant shape argument excluding unknown dimension. - int64_t product_cshape = 1; - for (auto e : llvm::enumerate(attr_shape)) { - int64_t val = e.value().getSExtValue(); - if (IsUnknownDimOrRank(val)) { - if (flatten) { - mlir::emitError(result.location) - << "only one unknown dimension allowed"; - return; - } - flatten = true; - unknown_index = e.index(); - } else { - product_cshape *= val; - } - const_shape.push_back(val); - } - - // Compute the value of the unknown dimension. - if (flatten) { - // Compute number of elements in tensor shape. - auto tshape = ttype.getShape(); - int64_t product_tshape = std::accumulate(tshape.begin(), tshape.end(), 1, - std::multiplies()); - // Set the unknown dimension such that total number of elements remain - // constant. - // Note: The case where the ratio is not integral, and so the total size - // of reshape not constant, is checked in verify function. - const_shape[unknown_index] = product_tshape / product_cshape; - } - return ReshapeOp::build(builder, result, - RankedTensorType::get(const_shape, etype), tensor, - shape); - } - return unranked(); + return ReshapeOp::build(builder, result, output_ty, tensor, shape); } void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results, @@ -1023,6 +1047,7 @@ static LogicalResult Verify(SizeOp op) { OpFoldResult SizeOp::fold(ArrayRef operands) { ShapedType output_type = getType().cast(); + if (!output_type.hasRank()) return {}; ShapedType input_type = getOperand().getType().cast(); if (!input_type.hasStaticShape()) return {}; int size = input_type.getNumElements(); @@ -1042,8 +1067,11 @@ OpFoldResult SizeOp::fold(ArrayRef operands) { // of elements in operands begin and size. // - if begin are constants, that // 0 <= begin[i] <= begin[i] + size[i] <= input_ty.getShape()[i] +// and +// size[i] == output_ty.getShape()[i] // - if begins aren't constant but the input is a ranked tensor, that // size[i] <= input_ty.getShape()[i] +// - output rank is the same as input rank // static LogicalResult Verify(SliceOp op) { RankedTensorType begin_ty = GetRankedTensorTypeForOperand(op.begin()); @@ -1071,21 +1099,40 @@ static LogicalResult Verify(SliceOp op) { "are equal to input rank"; } + auto output_ty = op.output().getType().dyn_cast(); + if (output_ty && input_ty && output_ty.getRank() != input_ty.getRank()) { + return op.emitOpError() + << "requires output to have the same rank as input, but got input " + "rank " + << input_ty.getRank() << " and output rank " << output_ty.getRank(); + } + DenseIntElementsAttr begin_indices; if (matchPattern(op.begin(), m_Constant(&begin_indices))) { DenseIntElementsAttr slice_sizes; bool constant_slice_sizes = matchPattern(op.size(), m_Constant(&slice_sizes)); int dim = 0; + // TODO(jpienaar): Reformulate the shape verification below to not use magic + // constants. for (const APInt &raw_begin_index : begin_indices.getValues()) { int64_t begin_index = raw_begin_index.getSExtValue(); int64_t input_size = input_ty ? input_ty.getShape()[dim] : -1; int64_t slice_size = constant_slice_sizes ? slice_sizes.getValue(dim).getSExtValue() : 0; + int64_t output_size = output_ty ? output_ty.getShape()[dim] : -1; + if (slice_size == -1 && input_size != -1) { slice_size = input_size - begin_index; } + if (output_size != -1 && constant_slice_sizes && + output_size != slice_size) { + return op.emitOpError() + << "requires output size to have the same size of slice, got " + "slice size " + << slice_size << " and output size " << output_size; + } if (begin_index < 0 || (input_size != -1 && begin_index + slice_size > input_size)) { return op.emitOpError() @@ -1143,6 +1190,183 @@ static LogicalResult Verify(SoftmaxCrossEntropyWithLogitsOp op) { return success(); } +//===----------------------------------------------------------------------===// +// SpaceToBatchNDOp +//===----------------------------------------------------------------------===// + +int64_t SpaceToBatchNDBlockRank(const TensorType block_shape_type, + const TensorType paddings_type) { + if (block_shape_type.hasStaticShape()) { + return block_shape_type.getShape()[0]; + } else if (paddings_type.hasStaticShape()) { + return paddings_type.getShape()[0]; + } else { + return -1; + } +} + +static LogicalResult Verify(SpaceToBatchNDOp op) { + const auto input_type = op.input().getType().cast(); + const auto block_shape_type = op.block_shape().getType().cast(); + const auto paddings_type = op.paddings().getType().cast(); + + // Check that block_shape has rank 1. + if (!IsOfRankOrUnranked(op.block_shape(), 1)) { + return op.emitOpError() << "requires rank of block_shape = 1; got " + << block_shape_type.getRank(); + } + + // Check that paddings has rank 2. + if (!IsOfRankOrUnranked(op.paddings(), 2)) { + return op.emitOpError() + << "requires rank of paddings = 2; got " << paddings_type.getRank(); + } + + // Check that paddings.shape[1]=2. + if (paddings_type.hasStaticShape() && paddings_type.getShape()[1] != 2) { + return op.emitOpError() << "requires paddings.shape[1] to be 2; got " + << paddings_type.getShape()[1]; + } + + // Check that block_shape and paddings have consistent ranks. + if (block_shape_type.hasStaticShape() && paddings_type.hasStaticShape() && + block_shape_type.getShape()[0] != paddings_type.getShape()[0]) { + return op.emitOpError() + << "requires block_shape.shape[0] must equal paddings.shape[0]"; + } + + const int64_t block_rank = + SpaceToBatchNDBlockRank(block_shape_type, paddings_type); + + // Further checks require block_rank to be known. + if (block_rank == -1) { + return success(); + } + + // check that rank of input_type >= block_rank + 1 + if (input_type.hasRank() && input_type.getRank() < 1 + block_rank) { + return op.emitOpError() << "requires rank of input >= 1 + rank of block"; + } + + ElementsAttr block_shape_attr = nullptr; + ElementsAttr paddings_attr = nullptr; + + // Check that block_shape[*] >= 1. + if (matchPattern(op.block_shape(), m_Constant(&block_shape_attr))) { + uint64_t i = 0; + for (auto block_len : block_shape_attr.getValues()) { + if (block_len.getSExtValue() < 1) { + return op.emitOpError() + << "requires all values of block_shape to be >= 1; " + "failed for dimension " + << i; + } + ++i; + } + } + + // Check that paddings[*] >= 0. + if (matchPattern(op.paddings(), m_Constant(&paddings_attr))) { + for (uint64_t i = 0; i < block_rank; ++i) { + const int64_t pad_start = + paddings_attr.getValue({i, 0}).cast().getInt(); + const int64_t pad_end = + paddings_attr.getValue({i, 1}).cast().getInt(); + if (pad_start < 0 || pad_end < 0) { + return op.emitOpError() + << "requires all values of paddings to be >= 0; " + "failed for dimension " + << i; + } + } + } + + // Check that block_shape divides the padded input. + if (input_type.hasStaticShape() && block_shape_attr && paddings_attr) { + for (uint64_t i = 0; i < block_rank; ++i) { + const int64_t input_len = input_type.getShape()[1 + i]; + const int64_t pad_start = + paddings_attr.getValue({i, 0}).cast().getInt(); + const int64_t pad_end = + paddings_attr.getValue({i, 1}).cast().getInt(); + const int64_t block_len = + block_shape_attr.getValue({i}).cast().getInt(); + if ((input_len + pad_start + pad_end) % block_len != 0) { + return op.emitOpError() + << "requires block_shape[i] divides " + "input_shape[i + 1] + paddings[i, 0] + paddings[i, 1]; " + "failed for i=" + << i; + } + } + } + + return success(); +} + +// Infers returned rank if possible. Further, infers returned dimension sizes +// when possible. For all dimensions sizes to be inferred, the arguments +// block_shape and paddings must be constant. +LogicalResult SpaceToBatchNDOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + const Value input = operands[0]; + const Value block_shape_val = operands[1]; + const Value paddings_val = operands[2]; + const auto input_type = input.getType().cast(); + const auto block_shape_type = block_shape_val.getType().cast(); + const auto paddings_type = paddings_val.getType().cast(); + + // The return is unranked when the input is unranked. + if (!input_type.hasRank()) { + inferredReturnTypes.assign( + {UnrankedTensorType::get(input_type.getElementType())}); + return success(); + } + + const int64_t input_rank = input_type.getRank(); + const ArrayRef input_shape = input_type.getShape(); + const int64_t block_rank = + SpaceToBatchNDBlockRank(block_shape_type, paddings_type); + SmallVector return_shape(input_rank, ShapedType::kDynamicSize); + + // The return has all dimension sizes unknown when block_rank is unknown. + if (block_rank == -1) { + inferredReturnTypes.assign( + {RankedTensorType::get(return_shape, input_type.getElementType())}); + return success(); + } + + // The return preserves the remaining dimensions after blocked dimensions. + for (uint64_t i = 1 + block_rank; i < input_rank; ++i) { + return_shape[i] = input_shape[i]; + } + + // The rest of the dimension sizes can be calculated when block_shape and + // paddings arguments are constant. + ElementsAttr block_shape_attr; + ElementsAttr paddings_attr; + if (matchPattern(block_shape_val, m_Constant(&block_shape_attr)) && + matchPattern(paddings_val, m_Constant(&paddings_attr))) { + int64_t return_batch = input_shape[0]; + for (uint64_t i = 0; i < block_rank; ++i) { + int64_t paddings_sum = + paddings_attr.getValue({i, 0}).cast().getInt() + + paddings_attr.getValue({i, 1}).cast().getInt(); + int64_t block_shape_i = + block_shape_attr.getValue({i}).cast().getInt(); + return_batch *= block_shape_i; + return_shape[1 + i] = (paddings_sum + input_shape[i + 1]) / block_shape_i; + } + return_shape[0] = return_batch; + } + + inferredReturnTypes.assign( + {RankedTensorType::get(return_shape, input_type.getElementType())}); + return success(); +} + //===----------------------------------------------------------------------===// // SparseSoftmaxCrossEntropyWithLogitsOp //===----------------------------------------------------------------------===// @@ -1237,7 +1461,8 @@ static LogicalResult Verify(SplitVOp op) { if (!split_sizes_type) return success(); if (split_sizes_type.getRank() != 1 || - split_sizes_type.getDimSize(0) != op.getNumResults()) + (split_sizes_type.getDimSize(0) != ShapedType::kDynamicSize && + split_sizes_type.getDimSize(0) != op.getNumResults())) return op.emitOpError("split sizes should be a 1D tensor of ") << op.getNumResults() << " elements"; @@ -1389,7 +1614,7 @@ static LogicalResult VerifyStridedSliceBase(OpTy op) { // Use bit compares to ensure ellipsis_mask is 0 or a power of 2, i.e. there // exists only no more than one ellipsis. - uint32_t ellipsis_mask = op.ellipsis_mask().getZExtValue(); + uint32_t ellipsis_mask = op.ellipsis_mask(); if (ellipsis_mask != 0 && !llvm::isPowerOf2_32(ellipsis_mask)) return op.emitOpError("cannot have multiple ellipses"); @@ -1645,10 +1870,9 @@ bool StridedSliceOp::GetSlicedBoundRanges( sparse_strides.push_back(stride.getSExtValue()); CalculateSlicedShapeFromSparseIndices( - input_shape, sparse_begin, sparse_end, sparse_strides, - begin_mask().getZExtValue(), end_mask().getZExtValue(), - ellipsis_mask().getZExtValue(), new_axis_mask().getZExtValue(), - shrink_axis_mask().getZExtValue(), slice_begin, slice_end, slice_stride); + input_shape, sparse_begin, sparse_end, sparse_strides, begin_mask(), + end_mask(), ellipsis_mask(), new_axis_mask(), shrink_axis_mask(), + slice_begin, slice_end, slice_stride); return true; } @@ -1699,13 +1923,25 @@ bool StridedSliceGradOp::GetSlicedShapeAndBoundRanges( sparse_strides.push_back(stride.getSExtValue()); CalculateSlicedShapeFromSparseIndices( - *input_shape, sparse_begin, sparse_end, sparse_strides, - begin_mask().getZExtValue(), end_mask().getZExtValue(), - ellipsis_mask().getZExtValue(), new_axis_mask().getZExtValue(), - shrink_axis_mask().getZExtValue(), slice_begin, slice_end, slice_stride); + *input_shape, sparse_begin, sparse_end, sparse_strides, begin_mask(), + end_mask(), ellipsis_mask(), new_axis_mask(), shrink_axis_mask(), + slice_begin, slice_end, slice_stride); return true; } +//===----------------------------------------------------------------------===// +// SummaryWriterOp +//===----------------------------------------------------------------------===// + +ResourceHandleValueAndId SummaryWriterOp::GetResourceHandleValueAndId( + llvm::SmallDenseMap &resource_handle_id_map, + int64_t &next_id) { + llvm::StringRef device = GetDeviceOrEmpty(getOperation()); + return GetResourceHandleValueAndIdBase(container(), shared_name(), device, + writer(), resource_handle_id_map, + next_id); +} + //===----------------------------------------------------------------------===// // TensorListReserveOp //===----------------------------------------------------------------------===// @@ -1776,6 +2012,87 @@ static LogicalResult Verify(TensorScatterUpdateOp op) { return success(); } +//===----------------------------------------------------------------------===// +// TileOp +//===----------------------------------------------------------------------===// + +// Verifies that, +// +// - input has at least rank 1 +// - multiples is rank 1 +// - multiples.size() == input.rank() +// - input.rank() == output.rank() +// - Elements in multiples are non-negative +// - input.shape[i] * multiples[i] == output.shape[i] +// for i in [0, input.rank() - 1] + +static LogicalResult Verify(TileOp op) { + auto input_type = op.input().getType().dyn_cast(); + auto multiples_type = op.multiples().getType().dyn_cast(); + auto output_type = op.output().getType().dyn_cast(); + + if (multiples_type && multiples_type.getRank() != 1) { + return op.emitOpError() << "expected multiples to be rank 1, got rank = " + << multiples_type.getRank(); + } + + if (input_type && multiples_type && multiples_type.hasStaticShape() && + (input_type.getRank() != multiples_type.getNumElements() || + (input_type.getRank() == 0 && multiples_type.getNumElements() == 1))) { + return op.emitOpError() + << "expected size of multiples equal to rank of input" + << ", got multiples of size " << multiples_type.getNumElements() + << ", and input of rank " << input_type.getRank(); + } + + if (input_type && output_type) { + if (input_type.getRank() != output_type.getRank()) { + return op.emitOpError() + << "expected rank of input to equal to rank of output" + << ", got input of rank " << input_type.getRank() + << ", and output of rank " << output_type.getRank(); + } + + DenseIntElementsAttr multiples_attr; + if (matchPattern(op.multiples(), m_Constant(&multiples_attr))) { + for (int32_t i = 0, e = input_type.getRank(); i < e; ++i) { + const int64_t input_dim = input_type.getDimSize(i); + const int64_t output_dim = output_type.getDimSize(i); + const int64_t m = multiples_attr.getValue(i).getSExtValue(); + + if (m < 0) { + return op.emitOpError() + << "expected multiples to be non-negative, got " + << "multiples[" << i << "] = " << m; + } + + if (!ShapedType::isDynamic(input_dim) && + !ShapedType::isDynamic(output_dim) && output_dim != input_dim * m) { + return op.emitOpError() + << "requires input.shape[" << i << "] (" << input_dim << ")" + << " * " << m << " to be equal to " + << "output.shape[" << i << "] (" << output_dim << ")"; + } + } + } + } + + return success(); +} + +OpFoldResult TileOp::fold(ArrayRef operands) { + DenseIntElementsAttr multiples_attr; + if (matchPattern(multiples(), m_Constant(&multiples_attr))) { + // Return input directly when multiples are all ones, + // regardless what input is. + if (multiples_attr.isSplat() && + multiples_attr.getSplatValue().getSExtValue() == 1) { + return input(); + } + } + return {}; +} + //===----------------------------------------------------------------------===// // TopKV2Op //===----------------------------------------------------------------------===// @@ -1993,6 +2310,80 @@ void TruncateDivOp::getCanonicalizationPatterns( results.insert(context); } +//===----------------------------------------------------------------------===// +// NonMaxSuppressionV3Op +//===----------------------------------------------------------------------===// + +namespace { + +// Canonicalize NonMaxSuppressionV3Op to NonMaxSuppressionV4Op. +class NMSV3ToNMSV4Op : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(NonMaxSuppressionV3Op nms_op, + PatternRewriter &rewriter) const override { + if (nms_op.getNumOperands() != 5) { + return failure(); + } + SmallVector new_result_types; + new_result_types.push_back(nms_op.getType()); + auto input_ty = nms_op.getType().template cast(); + // corresponds to the second result type of nmsv4 + RankedTensorType valid_output_type = + RankedTensorType::get({}, input_ty.getElementType()); + new_result_types.push_back(valid_output_type); + + auto nmsv4 = rewriter.create( + nms_op.getLoc(), new_result_types, nms_op.boxes(), nms_op.scores(), + nms_op.max_output_size(), nms_op.iou_threshold(), + nms_op.score_threshold()); + // Cannot replace the NMSv3 Op with NMSv4 since the outputs between the + // two are different (v4 expects two output values vs v3 requires only one. + nms_op.replaceAllUsesWith(nmsv4.getResult(0)); + return success(); + } +}; +} // namespace. + +void NonMaxSuppressionV3Op::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + +//===----------------------------------------------------------------------===// +// FusedBatchNormOp +//===----------------------------------------------------------------------===// + +namespace { + +class ConvertFusedBatchNorm : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(TF::FusedBatchNormOp tf_fused_batch_norm_op, + PatternRewriter &rewriter) const override { + auto new_result_types = + llvm::to_vector<6>(tf_fused_batch_norm_op.getResultTypes()); + // reserve_space_3 + new_result_types.push_back( + UnrankedTensorType::get(FloatType::getF32(rewriter.getContext()))); + + OperationState new_state(tf_fused_batch_norm_op.getLoc(), + TF::FusedBatchNormV3Op::getOperationName(), + tf_fused_batch_norm_op.getOperands(), + new_result_types, + tf_fused_batch_norm_op.getAttrs()); + Operation *tf_fused_batch_norm_op_v3 = rewriter.createOperation(new_state); + + rewriter.replaceOp(tf_fused_batch_norm_op, + tf_fused_batch_norm_op_v3->getResults().drop_back()); + return success(); + } +}; +} // namespace. + +void FusedBatchNormOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // UnpackOp //===----------------------------------------------------------------------===// @@ -2002,7 +2393,7 @@ static LogicalResult Verify(UnpackOp op) { if (!value_type) return success(); int64_t value_rank = value_type.getRank(); - int64_t axis = op.axis().getSExtValue(); + int64_t axis = op.axis(); if (axis < -value_rank || axis >= value_rank) return op.emitOpError("axis attribute must be in the range of [-") << value_rank << ", " << value_rank << ')'; @@ -2060,6 +2451,19 @@ static LogicalResult VerifyUnsortedSegmentReduction(Op op) { return success(); } +//===----------------------------------------------------------------------===// +// VarHandleOp +//===----------------------------------------------------------------------===// + +ResourceHandleValueAndId VarHandleOp::GetResourceHandleValueAndId( + llvm::SmallDenseMap &resource_handle_id_map, + int64_t &next_id) { + llvm::StringRef device = GetDeviceOrEmpty(getOperation()); + return GetResourceHandleValueAndIdBase(container(), shared_name(), device, + resource(), resource_handle_id_map, + next_id); +} + //===----------------------------------------------------------------------===// // VarIsInitializedOp //===----------------------------------------------------------------------===// @@ -2122,38 +2526,19 @@ OpFoldResult VariableShapeOp::fold(ArrayRef operands) { // WhileOp //===----------------------------------------------------------------------===// -static LogicalResult Verify(WhileOp op) { - auto cond_fn = op.cond_func(); - auto body_fn = op.body_func(); - if (!cond_fn) { - return op.emitOpError("cond refers to an undefined function : ") - << op.cond(); - } - if (!body_fn) { - return op.emitOpError("body refers to an undefined function : ") - << op.body(); - } - - auto cond_fn_type = cond_fn.getType(); - auto body_fn_type = body_fn.getType(); - - // Verify that the cond function has exactly one result. - if (cond_fn_type.getNumResults() != 1) - return op.emitOpError("requires cond function to have exactly one result"); - - SmallVector operands(op.getOperandTypes()); - +static LogicalResult VerifyWhileTypes(Operation *op, TypeRange cond_input, + TypeRange body_input, + TypeRange body_result) { // Collect all the type lists for the op so that different pairs of type lists // can be compared for the compatibility. constexpr int kNumTypeLists = 5; - const std::array>, kNumTypeLists> - type_lists = {{ - {"operand", operands}, - {"body function result", body_fn_type.getResults()}, - {"result", op.getResultTypes()}, - {"cond function input", cond_fn_type.getInputs()}, - {"body function input", body_fn_type.getInputs()}, - }}; + const std::array type_lists = {{ + {op->getOperandTypes(), "input"}, + {body_result, "body result"}, + {op->getResultTypes(), "result"}, + {cond_input, "condition input"}, + {body_input, "body input"}, + }}; // A pair of type lists should be cast compatible with each other if one is // converted to the another for a function call or assignment or there is a @@ -2183,28 +2568,38 @@ static LogicalResult Verify(WhileOp op) { for (int j = std::max(2, i + 1); j < kNumTypeLists; ++j) { auto &a = type_lists[i]; auto &b = type_lists[j]; - - int a_size = a.second.size(); - if (a_size != b.second.size()) - return op.emitOpError( - llvm::formatv("requires the number of {0}s to be equal to the " - "number of {1}s. Found {2} and {3}, respectively", - a.first, b.first, a_size, b.second.size())); - - for (int idx = 0; idx < a_size; ++idx) { - auto a_type = a.second[idx]; - auto b_type = b.second[idx]; - - if (!AreCastCompatible({a_type, b_type})) - return op.emitError(llvm::formatv( - "{0} type {1} is incompatible with {2} type {3} at index {4}", - a.first, a_type, b.first, b_type, idx)); - } + if (failed(VerifyTypeRangesAreCompatible(op, a, b))) return failure(); } } return success(); } +static LogicalResult Verify(WhileOp op) { + auto cond_fn = op.cond_function(); + auto body_fn = op.body_function(); + if (!cond_fn) { + return op.emitOpError("cond refers to an undefined function : ") + << op.cond(); + } + if (!body_fn) { + return op.emitOpError("body refers to an undefined function : ") + << op.body(); + } + + auto cond_fn_type = cond_fn.getType(); + auto body_fn_type = body_fn.getType(); + + // Verify that the cond function has exactly one result. + if (cond_fn_type.getNumResults() != 1) + return op.emitOpError("requires cond function to have exactly one result"); + + if (failed(VerifyWhileTypes(op, /*cond_input=*/cond_fn_type.getInputs(), + /*body_input=*/body_fn_type.getInputs(), + /*body_result=*/body_fn_type.getResults()))) + return failure(); + return success(); +} + //===----------------------------------------------------------------------===// // WhileOp canonicalization. //===----------------------------------------------------------------------===// @@ -2218,50 +2613,23 @@ void WhileOp::getCanonicalizationPatterns(OwningRewritePatternList &results, //===----------------------------------------------------------------------===// static LogicalResult Verify(WhileRegionOp op) { // Verify that the condition generates a single tensor result. - YieldOp yield = cast(op.cond().front().getTerminator()); - if (yield.getNumOperands() != 1) + Operation *cond_yield = op.cond().front().getTerminator(); + if (cond_yield->getNumOperands() != 1) return op.emitOpError() << "condition should have a single tensor result"; - auto cond_type = yield.getOperand(0).getType().dyn_cast(); + auto cond_type = + cond_yield->getOperand(0).getType().dyn_cast(); if (!cond_type || !cond_type.getShape().equals({}) || !cond_type.getElementType().isInteger(/*width=*/1)) return op.emitOpError() << "condition should have a single tensor result"; - // The body result types should match while op result types. - if (failed(VerifyRegionResults(op, op.body(), "body"))) return failure(); - - // Both condition and body should have same number and type of operands as - // the WhileRegion inputs. - const int num_inputs = op.getNumOperands(); - auto block_inputs_match_op_inputs = [&](Region ®ion, - StringRef name) -> LogicalResult { - Block &block = region.front(); - if (block.getNumArguments() != num_inputs) - return op.emitOpError() - << name << " should have same number of inputs (" << num_inputs - << ") as " << WhileRegionOp::getOperationName() << " but has " - << block.getNumArguments() << " inputs"; - - for (auto types_idx : llvm::enumerate( - llvm::zip(op.getOperandTypes(), block.getArgumentTypes()))) { - auto op_input_type = std::get<0>(types_idx.value()); - auto block_input_type = std::get<1>(types_idx.value()); - if (!AreCastCompatible({block_input_type, op_input_type})) - return op.emitOpError(llvm::formatv( - "{0} input type {1} is incompatible with {2} " - "input type {3} at index {4}", - name, block_input_type, WhileRegionOp::getOperationName(), - op_input_type, types_idx.index())); - } - return success(); - }; - - if (failed(block_inputs_match_op_inputs(op.cond(), "condition")) || - failed(block_inputs_match_op_inputs(op.body(), "body"))) + Operation *body_yield = op.body().front().getTerminator(); + if (failed(VerifyWhileTypes(op, /*cond_input=*/op.cond().getArgumentTypes(), + /*body_input=*/op.body().getArgumentTypes(), + /*body_result=*/body_yield->getOperandTypes()))) return failure(); - return success(); } @@ -2373,7 +2741,8 @@ struct WhileRegionEliminatePassThrough auto &new_body_block = new_while_op.body().front(); auto &new_yield = *new_body_block.getTerminator(); - // Build a vector of new results. Also patch up the region bodies and yield. + // Build a vector of new results. Also patch up the region bodies and + // yield. SmallVector new_results; next_idx = 0; for (int op_idx : llvm::seq(0, old_num_operands)) { @@ -2408,12 +2777,12 @@ void XdivyOp::getCanonicalizationPatterns(OwningRewritePatternList &results, results.insert(context); } +} // namespace TF +} // namespace mlir + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// #define GET_OP_CLASSES #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc.inc" - -} // namespace TF -} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h index 761c06a475c..9b06d855b01 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h @@ -38,15 +38,9 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_verifiers.h" -namespace mlir { -namespace TF { - #define GET_OP_FWD_DEFINES #include "tensorflow/compiler/mlir/tensorflow/ir/tf_all_ops.h.inc" #define GET_OP_CLASSES #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h.inc" -} // namespace TF -} // namespace mlir - #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_N_Z_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.cc index e87cc494a4a..38f9175a500 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.cc @@ -70,11 +70,12 @@ limitations under the License. namespace mlir { namespace TF { - namespace { #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_helpers.inc" #include "tensorflow/compiler/mlir/tensorflow/transforms/generated_canonicalize.inc" } // namespace +} // namespace TF +} // namespace mlir //===----------------------------------------------------------------------===// // TableGen'd op method definitions @@ -82,6 +83,3 @@ namespace { #define GET_OP_CLASSES #include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.cc.inc" - -} // namespace TF -} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h index 8586515edee..589e0e91615 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h @@ -36,15 +36,9 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_verifiers.h" -namespace mlir { -namespace TF { - #define GET_OP_FWD_DEFINES #include "tensorflow/compiler/mlir/tensorflow/ir/tf_all_ops.h.inc" #define GET_OP_CLASSES #include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h.inc" -} // namespace TF -} // namespace mlir - #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_REMAINING_OPS_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc index 6883d0358ec..1eaf997ab69 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc @@ -105,9 +105,15 @@ static LogicalResult Verify(SessionInitializerOp session_initializer) { return success(); } +} // namespace tf_saved_model +} // namespace mlir + #define GET_OP_CLASSES #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc.inc" +namespace mlir { +namespace tf_saved_model { + //===----------------------------------------------------------------------===// // TensorFlowSavedModelDialect Dialect //===----------------------------------------------------------------------===// @@ -115,6 +121,11 @@ static LogicalResult Verify(SessionInitializerOp session_initializer) { TensorFlowSavedModelDialect::TensorFlowSavedModelDialect(MLIRContext *context) : Dialect(/*name=*/"tf_saved_model", context, TypeID::get()) { + // The TensorFlow Dialect is needed in the verifier and other routines + // associated to this dialect. It makes little sense anyway to use the + // SavedModel dialect without the TensorFlow Dialect. + context->loadDialect(); + addOperations< #define GET_OP_LIST #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc.inc" diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h index 02b7f0b75f4..c8518a9ca02 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h @@ -40,10 +40,16 @@ class TensorFlowSavedModelDialect : public Dialect { static StringRef getDialectNamespace() { return "tf_saved_model"; } }; +} // namespace tf_saved_model +} // namespace mlir + // Declares the operations for this dialect using the generated header. #define GET_OP_CLASSES #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h.inc" +namespace mlir { +namespace tf_saved_model { + // Returns the list of exported names for `op`. // An empty list means `op` is not exported. SmallVector GetExportedNames(Operation *op); diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model_ops.td index a22a684953b..753e2368d6e 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model_ops.td @@ -82,7 +82,7 @@ def TfSavedModel_Dialect : Dialect { with "get_global @some_global_tensor" in the function body. }]; - let cppNamespace = "tf_saved_model"; + let cppNamespace = "::mlir::tf_saved_model"; } diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h index 9be61b1db39..3c8ec1d38af 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h @@ -35,6 +35,28 @@ struct TensorArray : ::mlir::SideEffects::Resource::Base { StringRef getName() final { return "TensorArray"; } }; +struct Summary : ::mlir::SideEffects::Resource::Base { + StringRef getName() final { return "Summary"; } +}; + +struct LookupTable : ::mlir::SideEffects::Resource::Base { + StringRef getName() final { return "LookupTable"; } +}; + +struct DatasetSeedGenerator + : ::mlir::SideEffects::Resource::Base { + StringRef getName() final { return "DatasetSeedGenerator"; } +}; + +struct DatasetMemoryCache + : ::mlir::SideEffects::Resource::Base { + StringRef getName() final { return "DatasetMemoryCache"; } +}; + +struct DatasetIterator : ::mlir::SideEffects::Resource::Base { + StringRef getName() final { return "DatasetIterator"; } +}; + } // namespace ResourceEffects } // namespace TF } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_structs.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_structs.cc index 6c5485c16dd..9d8f25c6633 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_structs.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_structs.cc @@ -15,11 +15,9 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" -namespace mlir { - -// NOLINTNEXTLINE #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.cc.inc" +namespace mlir { namespace TF { void RuntimeDevices::AddDevice(const ParsedName& device) { diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h index b1f39ad1d28..b90bf2d47a8 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h @@ -26,10 +26,9 @@ limitations under the License. #include "mlir/IR/Types.h" // from @llvm-project #include "tensorflow/core/util/device_name_utils.h" -namespace mlir { - #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h.inc" +namespace mlir { namespace TF { // Tensorflow devices available at runtime with corresponding metadata if it is diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h index 412bf113a0f..aef3c538bc8 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h @@ -124,6 +124,10 @@ class CannotDuplicate : public TraitBase { } }; +// Trait to indicate an operation cannot be constant folded. +template +class NoConstantFold : public TraitBase {}; + // Coefficient-wise binary operation with implicit broadcasting support, for // example tf.Sub operation. template diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc index 2ec73824f6c..86369b993be 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc @@ -62,101 +62,6 @@ bool GetCastCompatibleShape(llvm::ArrayRef a_shape, return true; } -// Given two types `a` and `b`, returns a refined type which is cast compatible -// with both `a` and `b` and is equal to or more precise than both of them. It -// returns empty Type if the input types are not cast compatible. -// -// The two types are considered cast compatible if they have dynamically equal -// shapes and element type. For element types that do not have subtypes, they -// must be equal. However for TensorFlow types such as Resource and Variant, -// that also have subtypes, we recursively check for subtype compatibilty for -// Resource types and assume all variant types are cast compatible. If either -// one of `a` or `b` have empty subtypes, they are considered cast compatible. -// -// The returned type is same or more precise than the input types. For example, -// if `a` and `b` are cast compatible types tensor<2x?x?xf32> and -// tensor respectively, the returned type is tensor<2x4x?xf32>. -// -// Provides option to ignore ref types on 'a'. This is useful for TF ops that -// might allow operands to either be same as result type or be a ref type -// corresponding to it. -mlir::Type GetCastCompatibleType(mlir::Type a, mlir::Type b, - bool may_ignore_ref_type_a) { - // Fast path if everything is equal. - if (a == b) return b; - - auto a_tt = a.dyn_cast(); - auto b_tt = b.dyn_cast(); - - // If only one of a or b is a tensor type, they are incompatible. - if (static_cast(a_tt) ^ static_cast(b_tt)) return nullptr; - - // For non-tensor types, we do not need to worry about shape and can return - // early. - if (!a_tt && !b_tt) { - // Remove ref types. - if (may_ignore_ref_type_a) { - if (auto ref_type = a.dyn_cast()) { - a = ref_type.RemoveRef(); - if (a == b) return a; - } - } - if (a.getTypeID() != b.getTypeID()) return nullptr; - - // If either is not a type that contain subtypes then the types are not cast - // compatible. - auto a_wst = a.dyn_cast(); - auto b_wst = b.dyn_cast(); - if (!a_wst || !b_wst) return nullptr; - - // For Variant types we are more permissive right now and accept all pairs - // of Variant types. If we are more constrainted and check compatibility of - // subtypes, we might reject valid graphs. - // TODO(prakalps): Variant doesn't have a subtype, we assign it - // one, so we should only assign it one when we know the subtype. Then we - // can be more constrained and check subtypes for cast compatibility as - // well. - if (a.isa()) return a; - - // For Resource types, we recursively check the subtypes for cast - // compatibility, if possible. Otherwise treat them as compatible. - auto a_wst_st = a_wst.GetSubtypes(); - auto b_wst_st = b_wst.GetSubtypes(); - if (a_wst_st.empty() || b_wst_st.empty()) return a; - if (a_wst_st.size() != b_wst_st.size()) return nullptr; - llvm::SmallVector refined_subtypes; - for (auto subtypes : llvm::zip(a_wst_st, b_wst_st)) { - mlir::Type refined_st = - GetCastCompatibleType(std::get<0>(subtypes), std::get<1>(subtypes), - /*may_ignore_ref_type_a=*/false); - if (!refined_st) return nullptr; - refined_subtypes.push_back(refined_st.cast()); - } - - return mlir::TF::ResourceType::get(refined_subtypes, a.getContext()); - } - - // For tensor types, check compatibility of both element type and shape. - mlir::Type refined_element_ty = GetCastCompatibleType( - a_tt.getElementType(), b_tt.getElementType(), may_ignore_ref_type_a); - if (!refined_element_ty) return nullptr; - - if (!a_tt.hasRank() && !b_tt.hasRank()) { - return mlir::UnrankedTensorType::get(refined_element_ty); - } - if (!a_tt.hasRank()) { - return mlir::RankedTensorType::get(b_tt.getShape(), refined_element_ty); - } - if (!b_tt.hasRank()) { - return mlir::RankedTensorType::get(a_tt.getShape(), refined_element_ty); - } - - llvm::SmallVector refined_shape; - if (!GetCastCompatibleShape(a_tt.getShape(), b_tt.getShape(), &refined_shape)) - return nullptr; - - return mlir::RankedTensorType::get(refined_shape, refined_element_ty); -} } // namespace namespace mlir { @@ -343,6 +248,102 @@ bool BroadcastCompatible(ArrayRef lhs, ArrayRef rhs) { return true; } +// Given two types `a` and `b`, returns a refined type which is cast compatible +// with both `a` and `b` and is equal to or more precise than both of them. It +// returns empty Type if the input types are not cast compatible. +// +// The two types are considered cast compatible if they have dynamically equal +// shapes and element type. For element types that do not have subtypes, they +// must be equal. However for TensorFlow types such as Resource and Variant, +// that also have subtypes, we recursively check for subtype compatibilty for +// Resource types and assume all variant types are cast compatible. If either +// one of `a` or `b` have empty subtypes, they are considered cast compatible. +// +// The returned type is same or more precise than the input types. For example, +// if `a` and `b` are cast compatible types tensor<2x?x?xf32> and +// tensor respectively, the returned type is tensor<2x4x?xf32>. +// +// Provides option to ignore ref types on 'a'. This is useful for TF ops that +// might allow operands to either be same as result type or be a ref type +// corresponding to it. +mlir::Type GetCastCompatibleType(mlir::Type a, mlir::Type b, + bool may_ignore_ref_type_a) { + // Fast path if everything is equal. + if (a == b) return b; + + auto a_tt = a.dyn_cast(); + auto b_tt = b.dyn_cast(); + + // If only one of a or b is a tensor type, they are incompatible. + if (static_cast(a_tt) ^ static_cast(b_tt)) return nullptr; + + // For non-tensor types, we do not need to worry about shape and can return + // early. + if (!a_tt && !b_tt) { + // Remove ref types. + if (may_ignore_ref_type_a) { + if (auto ref_type = a.dyn_cast()) { + a = ref_type.RemoveRef(); + if (a == b) return a; + } + } + if (a.getTypeID() != b.getTypeID()) return nullptr; + + // If either is not a type that contain subtypes then the types are not cast + // compatible. + auto a_wst = a.dyn_cast(); + auto b_wst = b.dyn_cast(); + if (!a_wst || !b_wst) return nullptr; + + // For Variant types we are more permissive right now and accept all pairs + // of Variant types. If we are more constrainted and check compatibility of + // subtypes, we might reject valid graphs. + // TODO(prakalps): Variant doesn't have a subtype, we assign it + // one, so we should only assign it one when we know the subtype. Then we + // can be more constrained and check subtypes for cast compatibility as + // well. + if (a.isa()) return a; + + // For Resource types, we recursively check the subtypes for cast + // compatibility, if possible. Otherwise treat them as compatible. + auto a_wst_st = a_wst.GetSubtypes(); + auto b_wst_st = b_wst.GetSubtypes(); + if (a_wst_st.empty() || b_wst_st.empty()) return a; + if (a_wst_st.size() != b_wst_st.size()) return nullptr; + llvm::SmallVector refined_subtypes; + for (auto subtypes : llvm::zip(a_wst_st, b_wst_st)) { + mlir::Type refined_st = + GetCastCompatibleType(std::get<0>(subtypes), std::get<1>(subtypes), + /*may_ignore_ref_type_a=*/false); + if (!refined_st) return nullptr; + refined_subtypes.push_back(refined_st.cast()); + } + + return mlir::TF::ResourceType::get(refined_subtypes, a.getContext()); + } + + // For tensor types, check compatibility of both element type and shape. + mlir::Type refined_element_ty = GetCastCompatibleType( + a_tt.getElementType(), b_tt.getElementType(), may_ignore_ref_type_a); + if (!refined_element_ty) return nullptr; + + if (!a_tt.hasRank() && !b_tt.hasRank()) { + return mlir::UnrankedTensorType::get(refined_element_ty); + } + if (!a_tt.hasRank()) { + return mlir::RankedTensorType::get(b_tt.getShape(), refined_element_ty); + } + if (!b_tt.hasRank()) { + return mlir::RankedTensorType::get(a_tt.getShape(), refined_element_ty); + } + + llvm::SmallVector refined_shape; + if (!GetCastCompatibleShape(a_tt.getShape(), b_tt.getShape(), &refined_shape)) + return nullptr; + + return mlir::RankedTensorType::get(refined_shape, refined_element_ty); +} + bool HasCompatibleElementTypes(Type lhs, Type rhs, bool may_ignore_ref_type_lhs) { return GetCastCompatibleType(lhs, rhs, may_ignore_ref_type_lhs) != nullptr; @@ -359,6 +360,16 @@ bool AreCastCompatible(ArrayRef types) { return true; } +bool ArraysAreCastCompatible(ArrayRef lhs, ArrayRef rhs) { + if (lhs.size() != rhs.size()) return false; + for (auto pair : llvm::zip(lhs, rhs)) { + auto lhs_i = std::get<0>(pair); + auto rhs_i = std::get<1>(pair); + if (!AreCastCompatible({lhs_i, rhs_i})) return false; + } + return true; +} + // Assumes a function `GetDefaultTypeOf(ComposedType)` that returns the default // type for a composed type (such as a ref type or a type with subtypes). template diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h index f93f6b657da..1d3ca0c4a60 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h @@ -272,6 +272,15 @@ class VariantType : public detail::TypeWithSubtypeImpl { static std::string getTypeName() { return "VariantType"; } }; +// Given two types `a` and `b`, returns a refined type which is cast compatible +// with both `a` and `b` and is equal to or more precise than both of them. It +// returns empty Type if the input types are not cast compatible. +// Provides option to ignore ref types on 'a'. This is useful for TF ops that +// might allow operands to either be same as result type or be a ref type +// corresponding to it. +mlir::Type GetCastCompatibleType(mlir::Type a, mlir::Type b, + bool may_ignore_ref_type_a); + // Returns whether two arrays of Type are broadcast compatible. bool BroadcastCompatible(ArrayRef lhs, ArrayRef rhs); @@ -293,6 +302,10 @@ bool HasCompatibleElementTypes(Type lhs, Type rhs, // compatible. bool AreCastCompatible(ArrayRef types); +// Returns true if corresponding elements of lhs and rhs AreCastCompatible and +// lhs and rhs are the same length. +bool ArraysAreCastCompatible(ArrayRef lhs, ArrayRef rhs); + // If `ty` is a tensor type and its element type has subtypes, then returns a // new type of same shape but dropped subtypes for the element type. // Otherwise, if `ty` has subtypes, then returns corresponding type with dropped diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.cc new file mode 100644 index 00000000000..6a6a7574f29 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.cc @@ -0,0 +1,27 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.h" + +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project + +//===----------------------------------------------------------------------===// +// TableGen'd op method definitions +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.cc.inc" diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.h b/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.h new file mode 100644 index 00000000000..039f211533c --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.h @@ -0,0 +1,26 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TFRT_OPS_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TFRT_OPS_H_ + +#include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project +#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project + +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.h.inc" + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TFRT_OPS_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.td new file mode 100644 index 00000000000..fea9500b638 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.td @@ -0,0 +1,61 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This is the operation definition file for TensorFlow operations with +// implementation available only in TFRT. + +#ifndef TFRT_OPS +#define TFRT_OPS + +include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td" +include "mlir/IR/OpBase.td" + +def TF__JitFusedMatMulOp : TF_Op<"_JitFusedMatMul", [NoSideEffect, SameOperandsAndResultElementType]> { + let summary = [{ + MatMul operation with an output fusion compiled at runtime via MLIR codegen. + }]; + + let description = [{ +The inputs to the MatMul are specified by `a` and `b`. The series of operations +that follows is specified by the `fusion` attribute, which is a list of output +kernel names specified as strings (e.g. "BiasAdd"). They are performed in order, +where the (first) input to each op is the output of the preceding op. The first +input and the output of each fused_op must be of type T. + +Supported list of fusions is defined by ContractionOutputKernelBuilder +implementations. + +*WARN*: This is a TFRT only operations, and it does not exist in TF. This +operation is only added by the ContractionFusion pass. + }]; + + let arguments = (ins + TensorOf<[F32]>:$a, + TensorOf<[F32]>:$b, + Variadic>:$additional_args, + + DefaultValuedAttr:$transpose_a, + DefaultValuedAttr:$transpose_b, + DefaultValuedAttr:$fusion + ); + + let results = (outs + TensorOf<[F32]>:$product + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +#endif // TFRT_OPS diff --git a/tensorflow/compiler/mlir/tensorflow/ops/mlir_local_var_op.cc b/tensorflow/compiler/mlir/tensorflow/ops/mlir_local_var_op.cc index 211866900aa..d2c2cecdfdd 100644 --- a/tensorflow/compiler/mlir/tensorflow/ops/mlir_local_var_op.cc +++ b/tensorflow/compiler/mlir/tensorflow/ops/mlir_local_var_op.cc @@ -21,7 +21,7 @@ namespace tensorflow { REGISTER_OP("MlirLocalVarOp") .Output("resource: resource") .SetShapeFn(shape_inference::UnknownShape) - .Doc(R"(Creates a handle to a in-scope variable. + .Doc(R"(Creates a handle to an in-scope variable. Used by internal passes for temporary representation of local state, which will be eventually removed.)"); diff --git a/tensorflow/compiler/mlir/tensorflow/tests/BUILD b/tensorflow/compiler/mlir/tensorflow/tests/BUILD index daa583bed0e..63d01bf355e 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/tests/BUILD @@ -1,3 +1,4 @@ +load("//tensorflow:tensorflow.bzl", "filegroup") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") package(licenses = ["notice"]) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/breakup-islands.mlir b/tensorflow/compiler/mlir/tensorflow/tests/breakup-islands.mlir index 05d34eb0755..6654341ab42 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/breakup-islands.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/breakup-islands.mlir @@ -285,7 +285,7 @@ func @empty_island_multiple_data_results(%arg0: tensor<*xf32>, %arg1: tensor<*xi // and certain tf_executor ops are added correctly. // CHECK: %[[CONTROL:[^ ,]*]] = tf_executor.island wraps "tf.Print" -// CHECK: tf_executor.NextIteration.Sink [{{.*}}] {{.*}}, %[[CONTROL]] +// CHECK: tf_executor.NextIteration.Sink[{{.*}}] {{.*}}, %[[CONTROL]] func @next_iteration_sink_control_input() { tf_executor.graph { %source:3 = tf_executor.NextIteration.Source : tensor<*xi32> diff --git a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir index 50486909694..e77dd365abf 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir @@ -568,6 +568,14 @@ func @testSelectElseUnranked(%arg0: tensor<3xi1>, %arg1: tensor<3x2xf16>, %arg2: return %0: tensor<*xf16> } +// CHECK-LABEL: testTileMultiplesAllOnes +func @testTileMultiplesAllOnes(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { + %cst = constant dense <[1, 1]> : tensor<2xi32> + // CHECK: return %arg0 + %0 = "tf.Tile"(%arg0, %cst) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<2x3xf32> + return %0: tensor<2x3xf32> +} + // CHECK-LABEL: testLogicalNotOfEqual func @testLogicalNotOfEqual(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xi1> { %0 = "tf.Equal"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xi1> @@ -967,6 +975,65 @@ func @foldIfRegionMismatchedTypes(%arg0: tensor, %arg1: tensor, %a return %0 : tensor<1xf32> } +// CHECK-LABEL: func @eliminatePassThroughIfRegion( +// CHECK-SAME: %[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor +func @eliminatePassThroughIfRegion(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> (tensor) { + // CHECK: %[[PRED:.*]] = "tf._SomeOp"() : () -> tensor + %pred = "tf._SomeOp"() : () -> tensor + // CHECK: %[[IF_OUTPUT:.*]] = "tf.IfRegion"(%[[PRED]]) ( { + // CHECK: %[[MUL:.*]] = "tf.Mul"(%[[ARG0]], %[[ARG1]]) + // CHECK: "tf.Yield"(%[[MUL]]) : (tensor) + // CHECK: }, { + // CHECK: %[[SUB:.*]] = "tf.Sub"(%[[ARG0]], %[[ARG1]]) + // CHECK: "tf.Yield"(%[[SUB]]) : (tensor) + // CHECK: }) {is_stateless = true} : (tensor) -> tensor + %0:4 = "tf.IfRegion"(%pred) ({ + %true_value = "tf.Mul"(%arg0, %arg1) : (tensor, tensor) -> tensor + "tf.Yield"(%arg1, %arg2, %true_value, %arg2) : (tensor, tensor, tensor, tensor) -> () + }, { + %false_value = "tf.Sub"(%arg0, %arg1) : (tensor, tensor) -> tensor + "tf.Yield"(%arg1, %arg2, %false_value, %arg2) : (tensor, tensor, tensor, tensor) -> () + }) { is_stateless = true}: (tensor) -> (tensor, tensor, tensor, tensor) + // CHECK: "tf._SomeOp"(%[[ARG2]], %[[ARG1]]) : (tensor, tensor) -> () + "tf._SomeOp"(%0#1, %0#0) : (tensor, tensor) -> () + // CHECK: "tf._SomeOp"(%[[ARG2]], %[[IF_OUTPUT]]) : (tensor, tensor) -> () + "tf._SomeOp"(%0#3, %0#2) : (tensor, tensor) -> () + // CHECK: return %[[IF_OUTPUT]] : tensor + return %0#2 : tensor +} + +// CHECK-LABEL: func @eliminatePassThroughCaseRegion( +// CHECK-SAME: %[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor +func @eliminatePassThroughCaseRegion(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> (tensor) { + // CHECK: %[[INDEX:.*]] = "tf._SomeOp"() : () -> tensor + %index = "tf._SomeOp"() : () -> tensor + // CHECK: %[[CASE_OUTPUT:.*]] = "tf.CaseRegion"(%[[INDEX]]) ( { + // CHECK: %[[MUL:.*]] = "tf.Mul"(%[[ARG0]], %[[ARG1]]) + // CHECK: "tf.Yield"(%[[MUL]]) : (tensor) + // CHECK: }, { + // CHECK: %[[SUB:.*]] = "tf.Sub"(%[[ARG0]], %[[ARG1]]) + // CHECK: "tf.Yield"(%[[SUB]]) : (tensor) + // CHECK: }, { + // CHECK: %[[ADD:.*]] = "tf.AddV2"(%[[ARG0]], %[[ARG1]]) + // CHECK: "tf.Yield"(%[[ADD]]) : (tensor) + // CHECK: }) {is_stateless = true} : (tensor) -> tensor + %0:3 = "tf.CaseRegion"(%index) ({ + %mul = "tf.Mul"(%arg0, %arg1) : (tensor, tensor) -> tensor + "tf.Yield"(%arg1, %mul, %arg2) : (tensor, tensor, tensor) -> () + }, { + %sub = "tf.Sub"(%arg0, %arg1) : (tensor, tensor) -> tensor + "tf.Yield"(%arg1, %sub, %arg2) : (tensor, tensor, tensor) -> () + }, { + %add = "tf.AddV2"(%arg0, %arg1) : (tensor, tensor) -> tensor + "tf.Yield"(%arg1, %add, %arg2) : (tensor, tensor, tensor) -> () + }) { is_stateless = true}: (tensor) -> (tensor, tensor, tensor) + // CHECK: "tf._SomeOp"(%[[ARG2]], %[[ARG1]]) : (tensor, tensor) -> () + "tf._SomeOp"(%0#2, %0#0) : (tensor, tensor) -> () + // CHECK: return %[[CASE_OUTPUT]] : tensor + return %0#1 : tensor +} + + // CHECK-LABEL: foldCase func @foldCase(%arg0: tensor, %arg1: tensor) -> (tensor) { %2 = constant dense<1> : tensor @@ -1209,3 +1276,18 @@ func @testWhileDropOutputShapes(tensor<*xf32>) -> (tensor<*xf32>) { return %1 : tensor<*xf32> } + +// CHECK-LABEL: testNMSV3ToNMSV4 +func @testNMSV3ToNMSV4(%arg0: tensor<3x4xf32>, %arg1: tensor<3xf32>, %arg2: tensor, %arg3: tensor) -> tensor<2xi32> { + %max_size = constant dense<2> : tensor + // CHECK: "tf.NonMaxSuppressionV4" + %0 = "tf.NonMaxSuppressionV3"(%arg0, %arg1, %max_size, %arg2, %arg3): (tensor<3x4xf32>, tensor<3xf32>, tensor, tensor, tensor) -> (tensor<2xi32>) + return %0 : tensor<2xi32> +} + +// CHECK-LABEL: testFusedBatchNormToBatchNormV3 +func @testFusedBatchNormToBatchNormV3(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { + // CHECK: "tf.FusedBatchNormV3" + %0:5 = "tf.FusedBatchNorm"(%arg0, %arg1, %arg2, %arg3, %arg4): (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32> ) + return %0#0 : tensor<8x8x8x8xf32> +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/BUILD b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/BUILD new file mode 100644 index 00000000000..b8ab6ffeeb9 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/BUILD @@ -0,0 +1,26 @@ +load("//tensorflow:tensorflow.bzl", "filegroup") +load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") + +licenses(["notice"]) + +glob_lit_tests( + data = [ + ":test_utilities", + ], + driver = "@llvm-project//mlir:run_lit.sh", + test_file_exts = [ + "mlir", + "pbtxt", + ], +) + +# Bundle together all of the test utilities that are used by tests. +filegroup( + name = "test_utilities", + testonly = True, + data = [ + "//tensorflow/compiler/mlir:tf-mlir-translate", + "@llvm-project//llvm:FileCheck", + "@llvm-project//llvm:not", + ], +) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/add.mlir b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/add.mlir new file mode 100644 index 00000000000..84e3f528a5c --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/add.mlir @@ -0,0 +1,38 @@ +// RUN: tf-mlir-translate -mlir-tf-to-hlo-text %s -tf-input-shapes=: -emit-return-tuple | FileCheck %s +// RUN: tf-mlir-translate -mlir-tf-to-hlo-text %s -tf-input-shapes=: -emit-use-tuple-args -emit-return-tuple | FileCheck -check-prefix=TUPLE-ARGS %s + +module attributes {tf.versions = {producer = 179 : i32}} { + func @main(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "tf.AddV2"(%arg0, %arg1) : (tensor, tensor) -> tensor + return %0 : tensor + } +} + +// CHECK-LABEL: HloModule main +// CHECK: ENTRY %main.{{[0-9]+}} ([[ARG0:.*]]: f32[], [[ARG1:.*]]: f32[]) -> (f32[]) { +// CHECK-NEXT: %[[ARG0]] = f32[] parameter(0) +// CHECK-NEXT: %[[ARG1]] = f32[] parameter(1) +// CHECK-NEXT: [[ADD:%.*]] = f32[] add(f32[] %[[ARG0]], f32[] %[[ARG1]]) +// CHECK-NEXT: ROOT %tuple.{{[0-9]+}} = (f32[]) tuple(f32[] [[ADD]]) +// CHECK-NEXT: } + +// CHECK: // InputMapping {0, 1} +// CHECK-NEXT: // XlaInputShape f32[] +// CHECK-NEXT: // XlaInputShape f32[] +// CHECK-NEXT: // XlaOutputShape (f32[]) +// CHECK-NEXT: // XlaOutputDescription type=float shape=() + + +// TUPLE-ARGS-LABEL: HloModule main +// TUPLE-ARGS: ENTRY %main.{{[0-9]+}} ([[ARG_TUPLE:.*]]: (f32[], f32[])) -> (f32[]) { +// TUPLE-ARGS: %[[ARG_TUPLE]] = (f32[], f32[]) parameter(0) +// TUPLE-ARGS: [[ARG0:%.*]] = f32[] get-tuple-element((f32[], f32[]) %[[ARG_TUPLE]]), index=0 +// TUPLE-ARGS: [[ARG1:%.*]] = f32[] get-tuple-element((f32[], f32[]) %[[ARG_TUPLE]]), index=1 +// TUPLE-ARGS: [[ADD:%.*]] = f32[] add(f32[] [[ARG0]], f32[] [[ARG1]]) +// TUPLE-ARGS: ROOT %tuple.{{[0-9]+}} = (f32[]) tuple(f32[] [[ADD]]) +// TUPLE-ARGS: } + +// TUPLE-ARGS: // InputMapping {0, 1} +// TUPLE-ARGS-NEXT: // XlaInputShape (f32[], f32[]) +// TUPLE-ARGS-NEXT: // XlaOutputShape (f32[]) +// TUPLE-ARGS-NEXT: // XlaOutputDescription type=float shape=() diff --git a/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/argument-sharding-invalid.mlir b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/argument-sharding-invalid.mlir new file mode 100644 index 00000000000..5347037d7cf --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/argument-sharding-invalid.mlir @@ -0,0 +1,9 @@ +// RUN: not tf-mlir-translate -mlir-tf-to-hlo-text %s -tf-input-shapes=128,10 -emit-use-tuple-args -emit-return-tuple 2>&1 | FileCheck %s + +module attributes {tf.versions = {producer = 179 : i32}} { + func @main(%arg0: tensor<128x8xf32> {mhlo.sharding = "bad_sharding"}) { + return + } +} + +// CHECK: failed to parse argument sharding 0 'bad_sharding' diff --git a/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/argument-sharding.mlir b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/argument-sharding.mlir new file mode 100644 index 00000000000..7154919c3d1 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/argument-sharding.mlir @@ -0,0 +1,38 @@ +// RUN: tf-mlir-translate -mlir-tf-to-hlo-text %s -tf-input-shapes=128,10:10,1024:128,1024 -emit-use-tuple-args -emit-return-tuple | FileCheck %s + +module attributes {tf.versions = {producer = 179 : i32}} { + func @main(%arg0: tensor<128x10xf32> {mhlo.sharding = "\08\03\1A\02\01\02\22\02\00\01"}, %arg1: tensor<10x1024xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg2: tensor<128x1024xf32> {mhlo.sharding = ""}) { + return + } +} + +// The following xla::OpSharding protos are used: +// Serialized string: +// "\08\03\1A\02\01\02\22\02\00\01" +// Proto debug string: +// type: OTHER +// tile_assignment_dimensions: 1 +// tile_assignment_dimensions: 2 +// tile_assignment_devices: 0 +// tile_assignment_devices: 1 +// +// Serialized string: +// "\08\01\1A\01\01\22\01\00" +// Proto debug string: +// type: MAXIMAL +// tile_assignment_dimensions: 1 +// tile_assignment_devices: 0 +// +// Serialized string: +// "" +// Proto debug string (empty but would equivalent to): +// type: REPLICATED + +// CHECK-LABEL: HloModule main +// CHECK: ENTRY %main.{{[0-9]+}} ([[ARG_TUPLE:.*]]: (f32[128,10], f32[10,1024], f32[128,1024])) -> () { +// CHECK: %[[ARG_TUPLE]] = (f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) parameter(0) +// CHECK-SAME: sharding={ +// CHECK-SAME: {devices=[1,2]0,1} +// CHECK-SAME: {maximal device=0} +// CHECK-SAME: {replicated} +// CHECK-SAME: } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/constant-folding-hook.mlir b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/constant-folding-hook.mlir new file mode 100644 index 00000000000..c745fbc0744 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/constant-folding-hook.mlir @@ -0,0 +1,16 @@ +// RUN: tf-mlir-translate -mlir-tf-to-hlo-text %s -tf-input-shapes=: -emit-use-tuple-args -emit-return-tuple | FileCheck %s + +module attributes {tf.versions = {producer = 179 : i32}} { + func @main() -> (tensor<0xi32>, tensor<0xi32>) { + %0 = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32> + %r0, %r1 = "tf.BroadcastGradientArgs"(%0, %0) {T = i32} : (tensor<0xi32>, tensor<0xi32>) -> (tensor<0xi32>, tensor<0xi32>) + return %r0, %r1 : tensor<0xi32>, tensor<0xi32> + } +} + +// CHECK-LABEL: HloModule main +// CHECK: ENTRY %main.{{[0-9+]}} ([[ARG_TUPLE:.*]]: ()) -> (s32[0], s32[0]) { +// CHECK: %[[ARG_TUPLE]] = () parameter(0) +// CHECK: [[CONSTANT:%.*]] = s32[0]{0} constant({}) +// CHECK: ROOT %tuple.{{[0-9]+}} = (s32[0]{0}, s32[0]{0}) tuple(s32[0]{0} [[CONSTANT]], s32[0]{0} [[CONSTANT]]) +// CHECK: } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/constant-folding.mlir b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/constant-folding.mlir new file mode 100644 index 00000000000..e54ff79e5e4 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/constant-folding.mlir @@ -0,0 +1,23 @@ +// RUN: tf-mlir-translate -mlir-tf-to-hlo-text %s -tf-input-shapes=10,19:19,10 -emit-use-tuple-args -emit-return-tuple | FileCheck %s + +module attributes {tf.versions = {producer = 179 : i32}} { + func @main(%arg0: tensor<10x19xf32>, %arg1: tensor<19x10xf32> {mhlo.is_same_data_across_replicas}) -> tensor<10x19xf32> { + %0 = "tf.Shape"(%arg0) : (tensor<10x19xf32>) -> tensor<2xi64> + %1 = "tf.Reshape"(%arg1, %0) : (tensor<19x10xf32>, tensor<2xi64>) -> tensor<10x19xf32> + return %1 : tensor<10x19xf32> + } +} + +// Tests that foldable ops are constant-folded to enable legalization of ops +// that require compile time constant operand. +// "tf.Shape" can only be folded away after shape inference. tf.Reshape can only +// be lowered when tf.Shape is folded into a constant. + +// CHECK-LABEL: HloModule main +// CHECK: ENTRY %main.{{[0-9]+}} ([[ARG_TUPLE:.*]]: (f32[10,19], f32[19,10])) -> (f32[10,19]) { +// CHECK: %[[ARG_TUPLE]] = (f32[10,19]{1,0}, f32[19,10]{1,0}) parameter(0), parameter_replication={false,true} +// CHECK: [[ARG0:%.*]] = f32[10,19]{1,0} get-tuple-element((f32[10,19]{1,0}, f32[19,10]{1,0}) %[[ARG_TUPLE]]), index=0 +// CHECK: [[ARG1:%.*]] = f32[19,10]{1,0} get-tuple-element((f32[10,19]{1,0}, f32[19,10]{1,0}) %[[ARG_TUPLE]]), index=1 +// CHECK: [[RESHAPE:%.*]] = f32[10,19]{1,0} reshape(f32[19,10]{1,0} [[ARG1]]) +// CHECK: ROOT %tuple.{{[0-9]+}} = (f32[10,19]{1,0}) tuple(f32[10,19]{1,0} [[RESHAPE]]) +// CHECK: } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/graph-resource.mlir b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/graph-resource.mlir new file mode 100644 index 00000000000..3d1a34b932d --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/graph-resource.mlir @@ -0,0 +1,27 @@ +// RUN: tf-mlir-translate -mlir-tf-graph-to-hlo-text %s -tf-input-shapes=2:2 -tf-input-data-types=DT_FLOAT,DT_FLOAT -tf-xla-input-types=parameter,resource -emit-return-tuple | FileCheck %s + +module attributes {tf.versions = {producer = 511 : i32}} { + func @main(%arg0: tensor<*xf32>, %arg1: tensor<*x!tf.resource>) { + tf_executor.graph { + %control = tf_executor.island wraps "tf.AssignVariableOp"(%arg1, %arg0) : (tensor<*x!tf.resource>, tensor<*xf32>) -> () + tf_executor.fetch %control : !tf_executor.control + } + return + } +} + +// Tests a conversion from Graph (tf_executor dialect MLIR) to MLIR with +// resource arguments. + +// CHECK-LABEL: HloModule main.{{[0-9]+}}, input_output_alias={ {0}: (1, {}, may-alias) } +// CHECK: ENTRY %main.{{[0-9]+}} ([[ARG0:.*]]: f32[2], [[ARG1:.*]]: f32[2]) -> (f32[2]) { +// CHECK-NEXT: %[[ARG1]] = f32[2]{0} parameter(1) +// CHECK-NEXT: %[[ARG0]] = f32[2]{0} parameter(0) +// CHECK-NEXT: ROOT %tuple.{{[0-9]+}} = (f32[2]{0}) tuple(f32[2]{0} %[[ARG0]]) +// CHECK-NEXT: } + +// CHECK: // InputMapping {0, 1} +// CHECK-NEXT: // XlaInputShape f32[2] +// CHECK-NEXT: // XlaInputShape f32[2] +// CHECK-NEXT: // XlaOutputShape (f32[2]) +// CHECK-NEXT: // ResourceUpdate input_index=1 type=float shape=(2) modified diff --git a/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/graph-resource.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/graph-resource.pbtxt new file mode 100644 index 00000000000..5fb90b1bce0 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/graph-resource.pbtxt @@ -0,0 +1,66 @@ +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-graph-as-function | tf-mlir-translate -mlir-tf-graph-to-hlo-text -tf-input-shapes=2:2 -tf-input-data-types=DT_FLOAT,DT_FLOAT -tf-xla-input-types=parameter,resource -emit-return-tuple | FileCheck %s + +node { + name: "arg0" + op: "_Arg" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "index" + value { + i: 0 + } + } +} +node { + name: "arg1" + op: "_Arg" + attr { + key: "T" + value { + type: DT_RESOURCE + } + } + attr { + key: "index" + value { + i: 1 + } + } +} +node { + name: "assign_variable" + op: "AssignVariableOp" + input: "arg1" + input: "arg0" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } +} +library { +} +versions { + producer: 511 +} + +# Tests a conversion from Graph to MLIR with resource arguments. + +# CHECK-LABEL: HloModule main.{{[0-9]+}}, input_output_alias={ {0}: (1, {}, may-alias) } +# CHECK: ENTRY %main.{{[0-9]+}} ([[ARG0:.*]]: f32[2], [[ARG1:.*]]: f32[2]) -> (f32[2]) { +# CHECK-NEXT: %[[ARG1]] = f32[2]{0} parameter(1) +# CHECK-NEXT: %[[ARG0]] = f32[2]{0} parameter(0) +# CHECK-NEXT: ROOT %tuple.{{[0-9]+}} = (f32[2]{0}) tuple(f32[2]{0} %[[ARG0]]) +# CHECK-NEXT: } + +# CHECK: // InputMapping {0, 1} +# CHECK-NEXT: // XlaInputShape f32[2] +# CHECK-NEXT: // XlaInputShape f32[2] +# CHECK-NEXT: // XlaOutputShape (f32[2]) +# CHECK-NEXT: // ResourceUpdate input_index=1 type=float shape=(2) modified diff --git a/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/graph.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/graph.pbtxt new file mode 100644 index 00000000000..f1f7c6434eb --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/graph.pbtxt @@ -0,0 +1,47 @@ +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-graph-as-function | tf-mlir-translate -mlir-tf-graph-to-hlo-text -tf-input-shapes='' -tf-input-data-types=DT_FLOAT -emit-return-tuple | FileCheck %s + +node { + name: "arg" + op: "_Arg" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "index" + value { + i: 0 + } + } +} +node { + name: "retval" + op: "_Retval" + input: "arg" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "index" + value { + i: 0 + } + } +} +versions { + producer: 511 +} + +# Verify that conversion from Graph to MLIR and empty shape representation +# function is successful. + +# CHECK-LABEL: HloModule main +# CHECK: ENTRY %main.{{[0-9]+}} ([[ARG0:.*]]: f32[]) -> (f32[]) { +# CHECK-NEXT: %[[ARG0]] = f32[] parameter(0) +# CHECK-NEXT: ROOT %tuple.{{[0-9]+}} = (f32[]) tuple(f32[] %[[ARG0]]) +# CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/mlir-module-serialized-str-attr.mlir b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/mlir-module-serialized-str-attr.mlir new file mode 100644 index 00000000000..b68f177b183 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/mlir-module-serialized-str-attr.mlir @@ -0,0 +1,10 @@ +// RUN: tf-mlir-translate -mlir-tf-mlir-to-str-attr %s | FileCheck %s + +module attributes {tf.versions = {producer = 888 : i32}} { + func @main(%arg0: tensor) -> tensor { + %0 = "tf.Identity"(%arg0) : (tensor) -> tensor loc(unknown) + return %0 : tensor loc(unknown) + } loc(unknown) +} loc(unknown) + +// CHECK: "\0A\0Amodule attributes {tf.versions = {producer = 888 : i32}} {\0A func @main(%arg0: tensor) -> tensor {\0A %0 = \22tf.Identity\22(%arg0) : (tensor) -> tensor loc(unknown)\0A return %0 : tensor loc(unknown)\0A } loc(unknown)\0A} loc(unknown)" diff --git a/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/result-sharding.mlir b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/result-sharding.mlir new file mode 100644 index 00000000000..c9c02ba2588 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/result-sharding.mlir @@ -0,0 +1,39 @@ +// RUN: tf-mlir-translate -mlir-tf-to-hlo-text %s -tf-input-shapes=128,10:10,1024:128,1024 -emit-use-tuple-args -emit-return-tuple | FileCheck %s + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 351 : i32}} { + func @main(%arg0: tensor<128x10xf32>, %arg1: tensor<10x1024xf32>, %arg2: tensor<128x1024xf32>) -> (tensor<128x10xf32> {mhlo.sharding = "\08\03\1A\02\01\02\22\02\00\01"}, tensor<10x1024xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<128x1024xf32> {mhlo.sharding = ""}) { + return %arg0, %arg1, %arg2 : tensor<128x10xf32>, tensor<10x1024xf32>, tensor<128x1024xf32> + } +} + +// The following xla::OpSharding protos are used: +// Serialized string: +// "\08\03\1A\02\01\02\22\02\00\01" +// Proto debug string: +// type: OTHER +// tile_assignment_dimensions: 1 +// tile_assignment_dimensions: 2 +// tile_assignment_devices: 0 +// tile_assignment_devices: 1 +// +// Serialized string: +// "\08\01\1A\01\01\22\01\00" +// Proto debug string: +// type: MAXIMAL +// tile_assignment_dimensions: 1 +// tile_assignment_devices: 0 +// +// Serialized string: +// "" +// Proto debug string (empty but would equivalent to): +// type: REPLICATED + +// CHECK-LABEL: HloModule main +// CHECK: ENTRY %main.{{[0-9]+}} +// CHECK-SAME: (arg_tuple.{{[0-9]+}}: (f32[128,10], f32[10,1024], f32[128,1024])) -> (f32[128,10], f32[10,1024], f32[128,1024]) { +// CHECK: ROOT %tuple.{{[0-9]+}} +// CHECK-SAME: sharding={ +// CHECK-SAME: {devices=[1,2]0,1} +// CHECK-SAME: {maximal device=0} +// CHECK-SAME: {replicated} +// CHECK-SAME: } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/serialized-mlir-module-str-attr-invalid.mlir b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/serialized-mlir-module-str-attr-invalid.mlir new file mode 100644 index 00000000000..ced11f3a083 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/serialized-mlir-module-str-attr-invalid.mlir @@ -0,0 +1,5 @@ +// RUN: not tf-mlir-translate -mlir-tf-str-attr-to-mlir %s 2>&1 | FileCheck %s + +"totally @invalid MLIR module {here} <-" + +// CHECK: Invalid argument: could not parse MLIR module-:1:1: error: custom op 'totally' is unknown diff --git a/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/serialized-mlir-module-str-attr.mlir b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/serialized-mlir-module-str-attr.mlir new file mode 100644 index 00000000000..9a0e1dc38c8 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/serialized-mlir-module-str-attr.mlir @@ -0,0 +1,15 @@ +// RUN: tf-mlir-translate -mlir-tf-str-attr-to-mlir %s -mlir-print-debuginfo | FileCheck %s + +"\0A\0Amodule attributes {tf.versions = {producer = 888 : i32}} {\0A func @main(%arg0: tensor) -> tensor {\0A %0 = \22tf.Identity\22(%arg0) : (tensor) -> tensor loc(unknown)\0A return %0 : tensor loc(unknown)\0A } loc(unknown)\0A} loc(unknown)" + +// Test simple serialized computation consisting of a function named `main` +// with a tf.Identity op forwarding the function single argument to the function +// single result. + +// CHECK-LABEL: module +// CHECK-SAME: attributes {tf.versions = {producer = 888 : i32}} { +// CHECK-NEXT: func @main([[ARG0:%.+]]: tensor) -> tensor { +// CHECK-NEXT: [[IDENTITY:%.+]] = "tf.Identity"([[ARG0]]) : (tensor) -> tensor loc(unknown) +// CHECK-NEXT: return [[IDENTITY]] : tensor loc(unknown) +// CHECK-NEXT: } loc(unknown) +// CHECK-NEXT: } loc(unknown) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/shape-inference-after-legalization.mlir b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/shape-inference-after-legalization.mlir new file mode 100644 index 00000000000..55bdea5dd36 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/shape-inference-after-legalization.mlir @@ -0,0 +1,11 @@ +// RUN: tf-mlir-translate -mlir-tf-to-hlo-text %s -tf-input-shapes=8,16,16,64:64 -emit-use-tuple-args -emit-return-tuple | FileCheck %s + +module attributes {tf.versions = {producer = 179 : i32}} { + func @main(%arg0: tensor<8x16x16x64xbf16>, %arg1: tensor<64xf32>) -> (tensor<8x16x16x64xbf16>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<*xf32>) { + %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg1, %arg1, %arg1) {data_format = "NHWC", device = "", epsilon = 9.99999974E-5 : f32, exponential_avg_factor = 1.000000e+00 : f32, is_training = false} : (tensor<8x16x16x64xbf16>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) -> (tensor<8x16x16x64xbf16>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<*xf32>) + return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5 : tensor<8x16x16x64xbf16>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<*xf32> + } +} + +// CHECK-LABEL: HloModule main +// CHECK: -> (bf16[8,16,16,64], f32[64], f32[64], f32[64], f32[64], f32[0]) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/shape-inference.mlir b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/shape-inference.mlir new file mode 100644 index 00000000000..f9eca514da3 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/shape-inference.mlir @@ -0,0 +1,11 @@ +// RUN: tf-mlir-translate -mlir-tf-to-hlo-text %s -tf-input-shapes=10,17:17,19 -emit-use-tuple-args -emit-return-tuple | FileCheck %s + +module attributes {tf.versions = {producer = 179 : i32}} { + func @main(%arg0: tensor<*xf32>, %arg1: tensor) -> tensor { + %0 = "tf.MatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", transpose_a = false, transpose_b = false} : (tensor<*xf32>, tensor) -> tensor + return %0 : tensor + } +} + +// CHECK-LABEL: HloModule main +// CHECK: (arg_tuple.{{[0-9]+}}: (f32[10,17], f32[17,19])) -> (f32[10,19]) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir index f114d1724f2..779065b94d5 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir @@ -492,3 +492,22 @@ func @DontFoldTile() -> (tensor<8x10000xi32>) { return %3 : tensor<8x10000xi32> } // LINT.ThenChange(../transforms/constant_fold.cc:folding-policy) + +func @fold_conv() -> tensor<1x520x520x1xf32> { + %0 = "tf.Const"() {value = dense<0.111111112> : tensor<3x3x1x1xf32>} : () -> tensor<3x3x1x1xf32> + %1 = "tf.Const"() {value = dense<1.000000e+00> : tensor<1x520x520x1xf32>} : () -> tensor<1x520x520x1xf32> + %2 = "tf.DepthwiseConv2dNative"(%1, %0) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x520x520x1xf32>, tensor<3x3x1x1xf32>) -> tensor<1x520x520x1xf32> + return %2 : tensor<1x520x520x1xf32> + + // CHECK: tf.Const + // CHECK-NOT: tf.DepthwiseConv2dNative +} + +// CHECK-LABEL: DontFoldNoConstantFold +func @DontFoldNoConstantFold() -> tensor<8xf32> { + %0 = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> tensor<1xi32> + %1 = "tf.Const"() {value = dense<[2, 1]> : tensor<2xi32>} : () -> tensor<2xi32> + // CHECK: tf.StatelessRandomUniform + %2 = "tf.StatelessRandomUniform"(%0, %1) : (tensor<1xi32>, tensor<2xi32>) -> tensor<8xf32> + return %2 : tensor<8xf32> +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/contraction_fusion.mlir b/tensorflow/compiler/mlir/tensorflow/tests/contraction_fusion.mlir new file mode 100644 index 00000000000..b12f50ad525 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/contraction_fusion.mlir @@ -0,0 +1,37 @@ +// RUN: tf-opt %s -tf-contraction-fusion | FileCheck %s + +// CHECK-LABEL: matmulBiasAdd +func @matmulBiasAdd(%arg0: tensor<64xf32>, %arg1: tensor<8x32xf32>, %arg2: tensor<32x64xf32>) -> tensor<8x64xf32> { + // CHECK: %[[FUSED:.*]] = "tf._JitFusedMatMul"(%arg1, %arg2, %arg0) + // CHECK-SAME: fusion = ["BiasAdd"] + // CHECK-SAME: transpose_a = false, transpose_b = false + %3 = "tf.MatMul"(%arg1, %arg2) {transpose_a = false, transpose_b = false} : (tensor<8x32xf32>, tensor<32x64xf32>) -> tensor<8x64xf32> + %4 = "tf.BiasAdd"(%3, %arg0) {data_format = "NHWC"} : (tensor<8x64xf32>, tensor<64xf32>) -> tensor<8x64xf32> + // CHECK: return %[[FUSED]] + return %4 : tensor<8x64xf32> +} + +// CHECK-LABEL: matmulBiasAddRelu +func @matmulBiasAddRelu(%arg0: tensor<64xf32>, %arg1: tensor<8x32xf32>, %arg2: tensor<32x64xf32>) -> tensor<8x64xf32> { + // CHECK: %[[FUSED:.*]] = "tf._JitFusedMatMul"(%arg1, %arg2, %arg0) + // CHECK-SAME: fusion = ["BiasAdd", "Relu"] + // CHECK-SAME: transpose_a = false, transpose_b = false + %3 = "tf.MatMul"(%arg1, %arg2) {transpose_a = false, transpose_b = false} : (tensor<8x32xf32>, tensor<32x64xf32>) -> tensor<8x64xf32> + %4 = "tf.BiasAdd"(%3, %arg0) {data_format = "NHWC"} : (tensor<8x64xf32>, tensor<64xf32>) -> tensor<8x64xf32> + %5 = "tf.Relu"(%4) : (tensor<8x64xf32>) -> tensor<8x64xf32> + // CHECK: return %[[FUSED]] + return %5 : tensor<8x64xf32> +} + +// CHECK-LABEL: matmulBiasAddLeakyRelu +func @matmulBiasAddLeakyRelu(%arg0: tensor<64xf32>, %arg1: tensor<8x32xf32>, %arg2: tensor<32x64xf32>) -> tensor<8x64xf32> { + // CHECK: %[[FUSED:.*]] = "tf._JitFusedMatMul"(%arg1, %arg2, %arg0) + // CHECK-SAME: alpha = 2.000000e-01 : f32 + // CHECK-SAME: fusion = ["BiasAdd", "LeakyRelu"] + // CHECK-SAME: transpose_a = false, transpose_b = false + %3 = "tf.MatMul"(%arg1, %arg2) {transpose_a = false, transpose_b = false} : (tensor<8x32xf32>, tensor<32x64xf32>) -> tensor<8x64xf32> + %4 = "tf.BiasAdd"(%3, %arg0) {data_format = "NHWC"} : (tensor<8x64xf32>, tensor<64xf32>) -> tensor<8x64xf32> + %5 = "tf.LeakyRelu"(%4) { alpha = 0.2 : f32 } : (tensor<8x64xf32>) -> tensor<8x64xf32> + // CHECK: return %[[FUSED]] + return %5 : tensor<8x64xf32> +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/decompose_resource_ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/decompose_resource_ops.mlir index ff4dbf41221..e6a92a520f0 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/decompose_resource_ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/decompose_resource_ops.mlir @@ -101,7 +101,7 @@ func @decompose_resource_apply_momentum_non_nesterov(%arg0: tensor, %arg1: // CHECK: [[ACCUM:%.*]] = "tf.ReadVariableOp"([[ACCUM_HANDLE]]) // CHECK: [[ACCUM_MOMENTUM:%.*]] = "tf.Mul"([[ACCUM]], [[MOMENTUM]]) - // CHECK: [[ACCUM_NEW:%.*]] = "tf.Add"([[ACCUM_MOMENTUM]], [[GRAD]]) + // CHECK: [[ACCUM_NEW:%.*]] = "tf.AddV2"([[ACCUM_MOMENTUM]], [[GRAD]]) // CHECK: "tf.AssignVariableOp"([[ACCUM_HANDLE]], [[ACCUM_NEW]]) // CHECK: [[ACCUM_NEW_LR:%.*]] = "tf.Mul"([[ACCUM_NEW]], [[LR]]) // CHECK: [[VAR:%.*]] = "tf.ReadVariableOp"([[VAR_HANDLE]]) @@ -127,12 +127,12 @@ func @decompose_resource_apply_momentum_nesterov(%arg0: tensor, %arg1: tens // CHECK: [[ACCUM:%.*]] = "tf.ReadVariableOp"([[ACCUM_HANDLE]]) // CHECK: [[ACCUM_MOMENTUM:%.*]] = "tf.Mul"([[ACCUM]], [[MOMENTUM]]) - // CHECK: [[ACCUM_NEW:%.*]] = "tf.Add"([[ACCUM_MOMENTUM]], [[GRAD]]) + // CHECK: [[ACCUM_NEW:%.*]] = "tf.AddV2"([[ACCUM_MOMENTUM]], [[GRAD]]) // CHECK: "tf.AssignVariableOp"([[ACCUM_HANDLE]], [[ACCUM_NEW]]) // CHECK: [[GRAD_LR:%.*]] = "tf.Mul"([[GRAD]], [[LR]]) // CHECK: [[MOMENTUM_LR:%.*]] = "tf.Mul"([[MOMENTUM]], [[LR]]) // CHECK: [[ACCUM_NEW_MOMENTUM_LR:%.*]] = "tf.Mul"([[ACCUM_NEW]], [[MOMENTUM_LR]]) - // CHECK: [[DELTA:%.*]] = "tf.Add"([[GRAD_LR]], [[ACCUM_NEW_MOMENTUM_LR]]) + // CHECK: [[DELTA:%.*]] = "tf.AddV2"([[GRAD_LR]], [[ACCUM_NEW_MOMENTUM_LR]]) // CHECK: [[VAR:%.*]] = "tf.ReadVariableOp"([[VAR_HANDLE]]) // CHECK: [[VAR_NEW:%.*]] = "tf.Sub"([[VAR]], [[DELTA]]) // CHECK: "tf.AssignVariableOp"([[VAR_HANDLE]], [[VAR_NEW]]) @@ -231,6 +231,31 @@ func @decompose_resource_apply_adagradv2(%arg0: tensor, %arg1: tensor, return } +// ----- +// CHECK-LABEL: func @decompose_resource_apply_adagrad +// CHECK-SAME: (%[[LR:.*]]: tensor, %[[GRAD:.*]]: tensor) +func @decompose_resource_apply_adagrad(%arg0: tensor, %arg1: tensor) -> () { + + // CHECK: %[[VAR_HANDLE:.*]] = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource> + // CHECK: %[[ACCUM_HANDLE:.*]] = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource> + // CHECK: %[[GRAD_SQUARE:.*]] = "tf.Mul"(%[[GRAD]], %[[GRAD]]) : (tensor, tensor) -> tensor + // CHECK: %[[ACCUM:.*]] = "tf.ReadVariableOp"(%[[ACCUM_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32> + // CHECK: %[[ACCUM_NEW:.*]] = "tf.AddV2"(%[[ACCUM]], %[[GRAD_SQUARE]]) : (tensor<*xf32>, tensor) -> tensor<*xf32> + // CHECK: %[[LR_MULTIPLY:.*]] = "tf.Mul"(%[[LR]], %[[GRAD]]) : (tensor, tensor) -> tensor + // CHECK: %[[SQRT:.*]] = "tf.Sqrt"(%[[ACCUM_NEW]]) : (tensor<*xf32>) -> tensor<*xf32> + // CHECK: %[[DIV:.*]] = "tf.Div"(%[[LR_MULTIPLY]], %[[SQRT]]) : (tensor, tensor<*xf32>) -> tensor<*xf32> + // CHECK: %[[VAR:.*]] = "tf.ReadVariableOp"(%[[VAR_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32> + // CHECK: %[[VAR_NEW:.*]] = "tf.Sub"(%[[VAR]], %[[DIV]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + // CHECK: "tf.AssignVariableOp"(%[[VAR_HANDLE]], %[[VAR_NEW]]) : (tensor<*x!tf.resource>, tensor<*xf32>) -> () + // CHECK: "tf.AssignVariableOp"(%[[ACCUM_HANDLE]], %[[ACCUM_NEW]]) : (tensor<*x!tf.resource>, tensor<*xf32>) -> () + %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource> + %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource> + + "tf.ResourceApplyAdagrad"(%0, %1, %arg0, %arg1) {update_slots = true, use_locking = true} : (tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor, tensor) -> () + + return +} + // ----- // Tests that composite tf.ResourceApplyAdam (non-Nesterov) operation is @@ -388,14 +413,14 @@ func @decompose_resource_apply_centered_RMS_prop(%arg0: tensor, %arg1: tens // CHECK: [[GRAD_SUB:%.*]] = "tf.Mul"([[GRADSQ]], [[SB]]) // CHECK: [[MS:%.*]] = "tf.ReadVariableOp"([[MS_HANDLE]]) // CHECK: [[MS_RHO:%.*]] = "tf.Mul"([[MS]], [[RHO]]) - // CHECK: [[MS_NEW:%.*]] = "tf.Add"([[GRAD_SUB]], [[MS_RHO]]) + // CHECK: [[MS_NEW:%.*]] = "tf.AddV2"([[GRAD_SUB]], [[MS_RHO]]) // CHECK: "tf.AssignVariableOp"([[MS_HANDLE]], [[MS_NEW]]) // CHECK: [[SUB_RHO:%.*]] = "tf.Sub"([[ONE]], [[RHO]]) // CHECK: [[SUB_GRAD:%.*]] = "tf.Mul"([[GRAD]], [[SUB_RHO]]) // CHECK: [[MG:%.*]] = "tf.ReadVariableOp"([[MG_HANDLE]]) // CHECK: [[MG_RHO:%.*]] = "tf.Mul"([[MG]], [[RHO]]) - // CHECK: [[MG_NEW:%.*]] = "tf.Add"([[SUB_GRAD]], [[MG_RHO]]) + // CHECK: [[MG_NEW:%.*]] = "tf.AddV2"([[SUB_GRAD]], [[MG_RHO]]) // CHECK: "tf.AssignVariableOp"([[MG_HANDLE]], [[MG_NEW]]) // CHECK: [[MOM:%.*]] = "tf.ReadVariableOp"([[MOM_HANDLE]]) @@ -403,11 +428,11 @@ func @decompose_resource_apply_centered_RMS_prop(%arg0: tensor, %arg1: tens // CHECK: [[LR_GRAD:%.*]] = "tf.Mul"([[LR]], [[GRAD]]) // CHECK: [[MG_MG:%.*]] = "tf.Mul"([[MG_NEW]], [[MG_NEW]]) - // CHECK: [[MG_NEW:%.*]] = "tf.Add"([[MG_MG]], [[EPSILON]]) + // CHECK: [[MG_NEW:%.*]] = "tf.AddV2"([[MG_MG]], [[EPSILON]]) // CHECK: [[MG_SUB:%.*]] = "tf.Sub"([[MS_NEW]], [[MG_NEW]]) // CHECK: [[MG_SQRT:%.*]] = "tf.Sqrt"([[MG_SUB]]) // CHECK: [[MOM_DIV:%.*]] = "tf.Div"([[LR_GRAD]], [[MG_SQRT]]) - // CHECK: [[MOM_NEW:%.*]] = "tf.Add"([[MOM_MOM]], [[MOM_DIV]]) + // CHECK: [[MOM_NEW:%.*]] = "tf.AddV2"([[MOM_MOM]], [[MOM_DIV]]) // CHECK: [[VAR:%.*]] = "tf.ReadVariableOp"([[VAR_HANDLE]]) // CHECK: [[VAR_NEW:%.*]] = "tf.Sub"([[VAR]], [[MOM_NEW]]) @@ -416,6 +441,33 @@ func @decompose_resource_apply_centered_RMS_prop(%arg0: tensor, %arg1: tens "tf.ResourceApplyCenteredRMSProp"(%0, %1, %2, %3, %arg4, %arg5, %arg6, %arg7, %arg8) {use_locking = false} : (tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor, tensor, tensor, tensor, tensor) -> () return } +// ----- +// CHECK-LABEL: func @decompose_resource_apply_RMS_prop +// CHECK-SAME: (%[[VAR_HANDLE:.*]]: tensor<*x!tf.resource>, %[[MS_HANDLE:.*]]: tensor<*x!tf.resource>, %[[MOM_HANDLE:.*]]: tensor<*x!tf.resource>, +// CHECK-SAME: %[[LR:.*]]: tensor, %[[RHO:.*]]: tensor, %[[MOMENTUM:.*]]: tensor, %[[EPSILON:.*]]: tensor, %[[GRAD:.*]]: tensor) +func @decompose_resource_apply_RMS_prop(%arg0: tensor<*x!tf.resource>, %arg1: tensor<*x!tf.resource>, %arg2: tensor<*x!tf.resource>, %arg3: tensor, %arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor) -> () { +// CHECK: %[[ONE:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor} : () -> tensor +// CHECK: %[[MS:.*]] = "tf.ReadVariableOp"(%[[MS_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32> +// CHECK: %[[MS_RHO:.*]] = "tf.Mul"(%[[MS]], %[[RHO]]) : (tensor<*xf32>, tensor) -> tensor<*xf32> +// CHECK: %[[GRAD_SQUARE:.*]] = "tf.Square"(%[[GRAD]]) : (tensor) -> tensor +// CHECK: %[[ONE_RHO:.*]] = "tf.Sub"(%[[ONE]], %[[RHO]]) : (tensor, tensor) -> tensor +// CHECK: %[[MUL:.*]] = "tf.Mul"(%[[GRAD_SQUARE]], %[[ONE_RHO]]) : (tensor, tensor) -> tensor +// CHECK: %[[MS_NEW:.*]] = "tf.AddV2"(%[[MS_RHO]], %[[MUL]]) : (tensor<*xf32>, tensor) -> tensor<*xf32> +// CHECK: "tf.AssignVariableOp"(%[[MS_HANDLE]], %[[MS_NEW]]) : (tensor<*x!tf.resource>, tensor<*xf32>) -> () +// CHECK: %[[MOM:.*]] = "tf.ReadVariableOp"(%[[MOM_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32> +// CHECK: %[[MOMENTUM_MOM:.*]] = "tf.Mul"(%[[MOMENTUM]], %[[MOM]]) : (tensor, tensor<*xf32>) -> tensor<*xf32> +// CHECK: %[[LR_GRAD:.*]] = "tf.Mul"(%[[LR]], %[[GRAD]]) : (tensor, tensor) -> tensor +// CHECK: %[[ADD:.*]] = "tf.AddV2"(%[[MS_NEW]], %[[EPSILON]]) : (tensor<*xf32>, tensor) -> tensor<*xf32> +// CHECK: %[[SQRT:.*]] = "tf.Sqrt"(%[[ADD]]) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: %[[DIV:.*]] = "tf.Div"(%[[LR_GRAD]], %[[SQRT]]) : (tensor, tensor<*xf32>) -> tensor<*xf32> +// CHECK: %[[MOM_NEW:.*]] = "tf.AddV2"(%[[MOMENTUM_MOM]], %[[DIV]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> +// CHECK: "tf.AssignVariableOp"(%[[MOM_HANDLE]], %[[MOM_NEW]]) : (tensor<*x!tf.resource>, tensor<*xf32>) -> () +// CHECK: %[[VAR:.*]] = "tf.ReadVariableOp"(%[[VAR_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32> +// CHECK: %[[VAR_NEW:.*]] = "tf.Sub"(%[[VAR]], %[[MOM_NEW]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> +// CHECK: "tf.AssignVariableOp"(%[[VAR_HANDLE]], %[[VAR_NEW]]) : (tensor<*x!tf.resource>, tensor<*xf32>) -> () + "tf.ResourceApplyRMSProp"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) {use_locking = false} : (tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor, tensor, tensor, tensor, tensor) -> () + return +} // ----- diff --git a/tensorflow/compiler/mlir/tensorflow/tests/einsum.mlir b/tensorflow/compiler/mlir/tensorflow/tests/einsum.mlir index e7430993755..c963147b855 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/einsum.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/einsum.mlir @@ -7,6 +7,14 @@ func @einsum_basic(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x6xf32>) -> tensor // CHECK: "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<3x4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32> } +func @einsum_matmul(%arg0: tensor<7x9xf32>, %arg1: tensor<9x5xf32>) -> tensor<7x5xf32> { + %0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "ae,ed->ad"}: (tensor<7x9xf32>, tensor<9x5xf32>) -> tensor<7x5xf32> + return %0 : tensor<7x5xf32> + // CHECK-LABEL: einsum_matmul + // CHECK: %[[v0:.*]] = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<7x9xf32>, tensor<9x5xf32>) -> tensor<7x5xf32> + // CHECK: return %[[v0]] : tensor<7x5xf32> +} + func @einsum_broadcast(%arg0: tensor<3x4x5xf32>, %arg1: tensor<5x6xf32>) -> tensor<3x4x6xf32> { %0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "ijk,km->ijm"}: (tensor<3x4x5xf32>, tensor<5x6xf32>) -> tensor<3x4x6xf32> return %0 : tensor<3x4x6xf32> @@ -14,18 +22,27 @@ func @einsum_broadcast(%arg0: tensor<3x4x5xf32>, %arg1: tensor<5x6xf32>) -> tens // CHECK: "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<3x4x5xf32>, tensor<5x6xf32>) -> tensor<3x4x6xf32> } +func @einsum_broadcast4(%arg0: tensor<3x4x5x6x7xf32>, %arg1: tensor<7x8xf32>) -> tensor<3x4x5x6x8xf32> { + %0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "abcdh,hg->abcdg"}: (tensor<3x4x5x6x7xf32>, tensor<7x8xf32>) -> tensor<3x4x5x6x8xf32> + return %0 : tensor<3x4x5x6x8xf32> + // CHECK-LABEL: einsum_broadcast4 + // CHECK: "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<3x4x5x6x7xf32>, tensor<7x8xf32>) -> tensor<3x4x5x6x8xf32> +} + func @einsum_reducesum(%arg0: tensor<2x5x7xf32>, %arg1: tensor<5x2xf32>) -> tensor<5x7xf32> { %0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "lbh,bl->bh"}: (tensor<2x5x7xf32>, tensor<5x2xf32>) -> tensor<5x7xf32> return %0 : tensor<5x7xf32> // CHECK-LABEL: einsum_reducesum // CHECK: %[[cst:.*]] = constant dense<[1, 2, 0]> : tensor<3xi32> - // CHECK: %[[cst_1:.*]] = constant dense<[5, 1, 2]> : tensor<3xi64> - // CHECK: %[[cst_2:.*]] = constant dense<2> : tensor<1xi32> + // CHECK: %[[cst_1:.*]] = constant dense<[5, 2, 1]> : tensor<3xi64> + // CHECK: %[[cst_2:.*]] = constant dense<[5, 7]> : tensor<2xi64> // CHECK: %[[v0:.*]] = "tf.Transpose"(%arg0, %[[cst]]) : (tensor<2x5x7xf32>, tensor<3xi32>) -> tensor<5x7x2xf32> - // CHECK: %[[v1:.*]] = "tf.Reshape"(%arg1, %[[cst_1]]) : (tensor<5x2xf32>, tensor<3xi64>) -> tensor<5x1x2xf32> - // CHECK: %[[v2:.*]] = "tf.Mul"(%[[v0]], %[[v1]]) : (tensor<5x7x2xf32>, tensor<5x1x2xf32>) -> tensor<5x7x2xf32> - // CHECK: "tf.Sum"(%[[v2]], %[[cst_2]]) {keep_dims = false} : (tensor<5x7x2xf32>, tensor<1xi32>) -> tensor<5x7xf32> + // CHECK: %[[v1:.*]] = "tf.Reshape"(%arg1, %[[cst_1]]) : (tensor<5x2xf32>, tensor<3xi64>) -> tensor<5x2x1xf32> + // CHECK: %[[v2:.*]] = "tf.BatchMatMulV2"(%[[v0]], %[[v1]]) {adj_x = false, adj_y = false} : (tensor<5x7x2xf32>, tensor<5x2x1xf32>) -> tensor<5x7x1xf32> + // CHECK: %[[v3:.*]] = "tf.Reshape"(%[[v2]], %[[cst_2]]) : (tensor<5x7x1xf32>, tensor<2xi64>) -> tensor<5x7xf32> + // CHECK: return %[[v3:.*]] : tensor<5x7xf32> } + func @einsum_transpose_matmul(%arg0: tensor<2x5x7xf32>, %arg1: tensor<5x3x2xf32>) -> tensor<5x3x7xf32> { %0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "lbh,bkl->bkh"}: (tensor<2x5x7xf32>, tensor<5x3x2xf32>) -> tensor<5x3x7xf32> return %0 : tensor<5x3x7xf32> @@ -88,12 +105,12 @@ func @einsum_transposereduceddim(%arg0: tensor<2x5x7xf32>, %arg1: tensor<2x5x3x7 %0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "bij,binj->bin"}: (tensor<2x5x7xf32>, tensor<2x5x3x7xf32>) -> tensor<2x5x3xf32> return %0 : tensor<2x5x3xf32> // CHECK-LABEL: einsum_transposereduceddim - // CHECK: %[[cst:.*]] = constant dense<[2, 5, 1, 7]> : tensor<4xi64> - // CHECK: %[[cst_1:.*]] = constant dense<[0, 1, 3, 2]> : tensor<4xi32> + // CHECK: %[[cst:.*]] = constant dense<[0, 1, 3, 2]> : tensor<4xi32> + // CHECK: %[[cst_1:.*]] = constant dense<[2, 5, 1, 7]> : tensor<4xi64> // CHECK: %[[cst_2:.*]] = constant dense<[2, 5, 3]> : tensor<3xi64> - // CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<2x5x7xf32>, tensor<4xi64>) -> tensor<2x5x1x7xf32> - // CHECK: %[[v1:.*]] = "tf.Transpose"(%arg1, %[[cst_1]]) : (tensor<2x5x3x7xf32>, tensor<4xi32>) -> tensor<2x5x7x3xf32> - // CHECK: %[[v2:.*]] = "tf.BatchMatMulV2"(%[[v0]], %[[v1]]) {adj_x = false, adj_y = false} : (tensor<2x5x1x7xf32>, tensor<2x5x7x3xf32>) -> tensor<2x5x1x3xf32> + // CHECK: %[[v0:.*]] = "tf.Transpose"(%arg1, %[[cst]]) : (tensor<2x5x3x7xf32>, tensor<4xi32>) -> tensor<2x5x7x3xf32> + // CHECK: %[[v1:.*]] = "tf.Reshape"(%arg0, %[[cst_1]]) : (tensor<2x5x7xf32>, tensor<4xi64>) -> tensor<2x5x1x7xf32> + // CHECK: %[[v2:.*]] = "tf.BatchMatMulV2"(%[[v1]], %[[v0]]) {adj_x = false, adj_y = false} : (tensor<2x5x1x7xf32>, tensor<2x5x7x3xf32>) -> tensor<2x5x1x3xf32> // CHECK: %[[v3:.*]] = "tf.Reshape"(%[[v2]], %[[cst_2]]) : (tensor<2x5x1x3xf32>, tensor<3xi64>) -> tensor<2x5x3xf32> // CHECK: return %[[v3]] : tensor<2x5x3xf32> } @@ -123,13 +140,26 @@ func @einsum_fourdtransposeall(%arg0: tensor<2x5x7x3xf32>, %arg1: tensor<2x11x7x // CHECK: return %[[v3]] : tensor<2x7x11x5xf32> } -func @einsum_no_match(%arg0: tensor<4x5xf32>, %arg1: tensor<5xf32>) -> tensor<4xf32> { - %0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "ij,j->i"}: (tensor<4x5xf32>, tensor<5xf32>) -> tensor<4xf32> +func @einsum_4d_1(%arg0: tensor<3x4x5x6xf32>, %arg1: tensor<3x7x5x6xf32>) -> tensor<3x5x4x7xf32> { + %0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "jbki,jfki->jkbf"}: (tensor<3x4x5x6xf32>, tensor<3x7x5x6xf32>) -> tensor<3x5x4x7xf32> + return %0 : tensor<3x5x4x7xf32> + // CHECK-LABEL: einsum_4d_1 + // CHECK: %[[cst:.*]] = constant dense<[0, 2, 1, 3]> : tensor<4xi32> + // CHECK: %[[cst_1:.*]] = constant dense<[0, 2, 3, 1]> : tensor<4xi32> + // CHECK: %[[v0:.*]] = "tf.Transpose"(%arg0, %[[cst:.*]]) : (tensor<3x4x5x6xf32>, tensor<4xi32>) -> tensor<3x5x4x6xf32> + // CHECK: %[[v1:.*]] = "tf.Transpose"(%arg1, %[[cst_1]]) : (tensor<3x7x5x6xf32>, tensor<4xi32>) -> tensor<3x5x6x7xf32> + // CHECK: %[[v2:.*]] = "tf.BatchMatMulV2"(%[[v0]], %[[v1]]) {adj_x = false, adj_y = false} : (tensor<3x5x4x6xf32>, tensor<3x5x6x7xf32>) -> tensor<3x5x4x7xf32> + // CHECK: return %[[v2]] : tensor<3x5x4x7xf32> +} + +func @einsum_no_match(%arg0: tensor<4x5x6xf32>, %arg1: tensor<5xf32>) -> tensor<4xf32> { + %0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "ijk,j->i"}: (tensor<4x5x6xf32>, tensor<5xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> // CHECK-LABEL: einsum_no_match -// CHECK: %[[v0:.*]] = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "ij,j->i"} : (tensor<4x5xf32>, tensor<5xf32>) -> tensor<4xf32> +// CHECK: %[[v0:.*]] = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "ijk,j->i"} : (tensor<4x5x6xf32>, tensor<5xf32>) -> tensor<4xf32> // CHECK: return %[[v0]] } + func @einsum_illegal_no_match(%arg0: tensor<4x5xf32>, %arg1: tensor<5xf32>) -> tensor<4xf32> { %0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "ij,?zw->kq->i"}: (tensor<4x5xf32>, tensor<5xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> @@ -137,10 +167,15 @@ func @einsum_illegal_no_match(%arg0: tensor<4x5xf32>, %arg1: tensor<5xf32>) -> t // CHECK: %[[v0:.*]] = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "ij,?zw->kq->i"} : (tensor<4x5xf32>, tensor<5xf32>) -> tensor<4xf32> // CHECK: return %[[v0]] } -func @einsum_no_match5D(%arg0: tensor<4x5xf32>, %arg1: tensor<2x4x7x3x5xf32>) -> tensor<4xf32> { - %0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "ij,j->i"}: (tensor<4x5xf32>, tensor<2x4x7x3x5xf32>) -> tensor<4xf32> - return %0 : tensor<4xf32> -// CHECK-LABEL: einsum_no_match5D -// CHECK: %[[v0:.*]] = "tf.Einsum" -// CHECK: return %[[v0]] + +func @batch_multilhs_einsum(%arg0: tensor<2x1x1x11xf32>, %arg1: tensor<2x11x2xf32>) -> tensor<2x1x1x2xf32> { + %0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "BiNj,BjS->BiNS"} : (tensor<2x1x1x11xf32>, tensor<2x11x2xf32>) -> tensor<2x1x1x2xf32> + return %0 : tensor<2x1x1x2xf32> +// CHECK-LABEL: batch_multilhs_einsum +// CHECK: %[[cst:.*]] = constant dense<[2, 1, 11]> : tensor<3xi64> +// CHECK: %[[cst_1:.*]] = constant dense<[2, 1, 1, 2]> : tensor<4xi64> +// CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<2x1x1x11xf32>, tensor<3xi64>) -> tensor<2x1x11xf32> +// CHECK: %[[v1:.*]] = "tf.BatchMatMulV2"(%[[v0]], %arg1) {adj_x = false, adj_y = false} : (tensor<2x1x11xf32>, tensor<2x11x2xf32>) -> tensor<2x1x2xf32> +// CHECK: %[[v2:.*]] = "tf.Reshape"(%[[v1]], %[[cst_1]]) : (tensor<2x1x2xf32>, tensor<4xi64>) -> tensor<2x1x1x2xf32> +// CHECK: return %[[v2]] : tensor<2x1x1x2xf32> } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/executor_island_coarsening.mlir b/tensorflow/compiler/mlir/tensorflow/tests/executor_island_coarsening.mlir index bec48181b3b..726495f1fbc 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/executor_island_coarsening.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/executor_island_coarsening.mlir @@ -220,7 +220,7 @@ func @merge_islands_only() { %11:2 = tf_executor.island(%10#1) wraps "tf.opF"() : () -> tensor %12:2 = tf_executor.island wraps "tf.opG"(%10#0, %11#0) : (tensor<*xi32>, tensor) -> tensor<*xi32> %13 = tf_executor.ControlTrigger %2, %12#1, %9#1 - tf_executor.NextIteration.Sink [%3#1] %12#0, %13 : tensor<*xi32> + tf_executor.NextIteration.Sink[%3#1] %12#0, %13 : tensor<*xi32> tf_executor.fetch } return @@ -244,7 +244,7 @@ func @merge_islands_only() { // CHECK-NEXT: %[[OP_G:[0-9]*]] = "tf.opG"(%[[OP_E]], %[[OP_F]]) // CHECK-NEXT: tf_executor.yield %[[OP_G]] : tensor<*xi32> // CHECK: %[[CT:.*]] = tf_executor.ControlTrigger %[[ISLAND_1]], %[[ISLAND_3_control]], %[[EXIT_control]] -// CHECK-NEXT: tf_executor.NextIteration.Sink [%[[NEXTIT_SRC_token]]] %[[ISLAND_3]], %[[CT]] +// CHECK-NEXT: tf_executor.NextIteration.Sink[%[[NEXTIT_SRC_token]]] %[[ISLAND_3]], %[[CT]] // Test no merging took place as cycle would be formed otherwise. diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/BUILD b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/BUILD index 1544d27009f..81cb0ed7c73 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/BUILD @@ -1,3 +1,4 @@ +load("//tensorflow:tensorflow.bzl", "filegroup") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") licenses(["notice"]) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/force_shared_name_for_resource_ops.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/force_shared_name_for_resource_ops.pbtxt new file mode 100644 index 00000000000..05302ed430c --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/force_shared_name_for_resource_ops.pbtxt @@ -0,0 +1,95 @@ +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-upgrade-legacy %s -tf-output-arrays=hash_table_node -o - | FileCheck %s + +node: { + name: "hash_table_node" + op: "HashTableV2" + attr: { + key: "key_dtype" + value: { + type: DT_INT32 + } + } + attr: { + key: "shared_name" + value: { + s: "" + } + } + attr: { + key: "value_dtype" + value: { + type: DT_FLOAT + } + } +} +node { + name: "Call" + op: "PartitionedCall" + attr { + key: "Tin" + value { + list { + } + } + } + attr { + key: "Tout" + value { + list { + type: DT_RESOURCE + } + } + } + attr { + key: "f" + value { + func { + name: "create_resource" + } + } + } +} +library { + function { + signature { + name: "create_resource" + output_arg { + name: "handle" + type: DT_RESOURCE + } + } + node_def: { + name: "hash_table_node" + op: "HashTableV2" + attr: { + key: "key_dtype" + value: { + type: DT_INT32 + } + } + attr: { + key: "shared_name" + value: { + s: "" + } + } + attr: { + key: "value_dtype" + value: { + type: DT_FLOAT + } + } + } + ret { + key: "handle" + value: "hash_table_node:table_handle:0" + } + } +} + +# CHECK: tf.HashTableV2 +# CHECK-SAME: shared_name = "hash_table_node" + +# CHECK: func @create_resource +# CHECK: tf.HashTableV2 +# CHECK-SAME: shared_name = "hash_table_node@create_resource" diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-while-loop.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-while-loop.pbtxt index e21fd901a9e..a6b1979ee26 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-while-loop.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-while-loop.pbtxt @@ -7,7 +7,7 @@ # CHECK: %[[NEXTITERATION:[a-z0-9]+]], %[[NEXTITERATION_token:[a-z0-9]+]], {{.*}} = tf_executor.NextIteration.Source # CHECK: tf_executor.Merge {{.*}} %[[NEXTITERATION]] -# CHECK: tf_executor.NextIteration.Sink [%[[NEXTITERATION_token]]] +# CHECK: tf_executor.NextIteration.Sink[%[[NEXTITERATION_token]]] node { name: "Const" diff --git a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nchw.mlir b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nchw.mlir index 30599b2e437..9bb05a75877 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nchw.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nchw.mlir @@ -7,7 +7,7 @@ // CHECK-LABEL: func @transposeConv2D func @transposeConv2D(%input: tensor<1x32x32x3xf32>, %filter: tensor<1x1x3x8xf32>) -> tensor<1x32x32x8xf32> { - // CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} + // CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} // CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]]) // CHECK: %[[CONV2D:[0-9]*]] = "tf.Conv2D"(%[[ARG_TRANSPOSE]], %arg1) @@ -18,7 +18,7 @@ func @transposeConv2D(%input: tensor<1x32x32x3xf32>, %filter: tensor<1x1x3x8xf32 // CHECK-SAME: strides = [5, 8, 6, 7] // CHECK-SAME: (tensor<1x3x32x32xf32>, tensor<1x1x3x8xf32>) -> tensor<1x8x32x32xf32> - // CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} + // CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>} // CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[CONV2D]], %[[RES_PERM]]) // CHECK: return %[[RES_TRANSPOSE]] @@ -38,7 +38,7 @@ func @transposeConv2D(%input: tensor<1x32x32x3xf32>, %filter: tensor<1x1x3x8xf32 func @transposeConv2DWithDefaultAttr(%input: tensor<1x32x32x3xf32>, %filter: tensor<1x1x3x8xf32>) -> tensor<*xf32> { - // CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} + // CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} // CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]]) // CHECK: %[[CONV2D:[0-9]*]] = "tf.Conv2D"(%[[ARG_TRANSPOSE]], %arg1) @@ -49,7 +49,7 @@ func @transposeConv2DWithDefaultAttr(%input: tensor<1x32x32x3xf32>, %filter: ten // CHECK-SAME: strides = [5, 8, 6, 7] // CHECK-SAME: (tensor<1x3x32x32xf32>, tensor<1x1x3x8xf32>) -> tensor<*xf32> - // CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} + // CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>} // CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[CONV2D]], %[[RES_PERM]]) // CHECK: return %[[RES_TRANSPOSE]] @@ -77,7 +77,7 @@ func @transposeConv2DBackpropFilter( // CHECK-SAME: dst_format = "NCHW" // CHECK-SAME: src_format = "NHWC" - // CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} + // CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} // CHECK: %[[IN_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]]) // CHECK: %[[OUT_BP_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg2, %[[ARG_PERM]]) @@ -117,7 +117,7 @@ func @transposeConv2DBackpropInput( // CHECK-SAME: dst_format = "NCHW" // CHECK-SAME: src_format = "NHWC" - // CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} + // CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} // CHECK: %[[OUT_BP_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg2, %[[ARG_PERM]]) // CHECK: %[[CONV2D_BACKPROP:[0-9]*]] = "tf.Conv2DBackpropInput" @@ -130,7 +130,7 @@ func @transposeConv2DBackpropInput( // CHECK-SAME: (tensor<4xi32>, tensor<1x1x3x8xf32>, tensor<1x8x32x32xf32>) // CHECK-SAME: -> tensor<1x3x32x32xf32> - // CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} + // CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>} // CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[CONV2D_BACKPROP]], %[[RES_PERM]]) // CHECK: return %[[RES_TRANSPOSE]] @@ -154,7 +154,7 @@ func @transposeFusedBatchNormV3( ) -> tensor<1x28x28x64xf32> { // CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() - // CHECK-SAME: {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} + // CHECK-SAME: {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} // CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]]) // CHECK: "tf.FusedBatchNormV3" @@ -164,7 +164,7 @@ func @transposeFusedBatchNormV3( // CHECK-SAME: -> (tensor<1x64x28x28xf32>, tensor<64xf32>, // CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() - // CHECK-SAME: {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} + // CHECK-SAME: {value = dense<[0, 2, 3, 1]> : tensor<4xi64>} // CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%y, %[[RES_PERM]]) // CHECK: return %[[RES_TRANSPOSE]] @@ -192,7 +192,7 @@ func @transposeFusedBatchNormGradV3( ) -> tensor<1x28x28x64xf32> { // CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() - // CHECK-SAME: {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} + // CHECK-SAME: {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} // CHECK: %[[ARG0_TPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]]) // CHECK: %[[ARG1_TPOSE:[0-9]*]] = "tf.Transpose"(%arg1, %[[ARG_PERM]]) @@ -204,7 +204,7 @@ func @transposeFusedBatchNormGradV3( // CHECK-SAME: -> (tensor<1x64x28x28xf32>, // CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() - // CHECK-SAME: {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} + // CHECK-SAME: {value = dense<[0, 2, 3, 1]> : tensor<4xi64>} // CHECK: %[[RES_TPOSE:[0-9]*]] = "tf.Transpose" // CHECK-SAME: (%x_backprop, %[[RES_PERM]]) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nhwc.mlir b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nhwc.mlir index e6b3bf08394..c71d8ef2850 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nhwc.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nhwc.mlir @@ -7,7 +7,7 @@ // CHECK-LABEL: func @transposeConv2D func @transposeConv2D(%input: tensor<1x3x32x32xf32>, %filter: tensor<1x1x3x8xf32>) -> tensor<1x8x32x32xf32> { - // CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} + // CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>} // CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]]) // CHECK: %[[CONV2D:[0-9]*]] = "tf.Conv2D"(%[[ARG_TRANSPOSE]], %arg1) @@ -18,7 +18,7 @@ func @transposeConv2D(%input: tensor<1x3x32x32xf32>, %filter: tensor<1x1x3x8xf32 // CHECK-SAME: strides = [5, 7, 8, 6] // CHECK-SAME: (tensor<1x32x32x3xf32>, tensor<1x1x3x8xf32>) -> tensor<1x32x32x8xf32> - // CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} + // CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} // CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[CONV2D]], %[[RES_PERM]]) // CHECK: return %[[RES_TRANSPOSE]] @@ -41,7 +41,7 @@ func @transposeFusedBatchNormV3( ) -> tensor<1x64x28x28xf32> { // CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() - // CHECK-SAME: {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} + // CHECK-SAME: {value = dense<[0, 2, 3, 1]> : tensor<4xi64>} // CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]]) // CHECK: "tf.FusedBatchNormV3" @@ -51,7 +51,7 @@ func @transposeFusedBatchNormV3( // CHECK-SAME: -> (tensor<1x28x28x64xf32>, tensor<64xf32>, // CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() - // CHECK-SAME: {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} + // CHECK-SAME: {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} // CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%y, %[[RES_PERM]]) // CHECK: return %[[RES_TRANSPOSE]] diff --git a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_move_transposes_begin.mlir b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_move_transposes_begin.mlir index 0b1e27733eb..bacfeea2dc9 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_move_transposes_begin.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_move_transposes_begin.mlir @@ -65,3 +65,40 @@ func @move_with_multiple_uses(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> return %3 : tensor<1x8x4x4xf32> } + +// CHECK-LABEL: move_transpose_handle_broadcast +func @move_transpose_handle_broadcast(%arg0:tensor<8x64xf32>, %arg1:tensor<8x64x64xf32>) -> tensor<512x64xf32> { + %cst = "tf.Const"() {value = dense<3> : tensor} : () -> tensor + %cst_1 = "tf.Const"() {value = dense<[2, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32> + %cst_2 = "tf.Const"() {value = dense<[512, 64]> : tensor<2xi32>} : () -> tensor<2xi32> + %0 = "tf.ExpandDims"(%arg0, %cst) {device = ""} : (tensor<8x64xf32>, tensor) -> tensor<8x64x1xf32> + %1 = "tf.AddV2"(%0, %arg1) {device = ""} : (tensor<8x64x1xf32>, tensor<8x64x64xf32>) -> tensor<8x64x64xf32> + %2 = "tf.Transpose"(%1, %cst_1) {device = ""} : (tensor<8x64x64xf32>, tensor<3xi32>) -> tensor<64x8x64xf32> + %3 = "tf.Reshape"(%2, %cst_2) {device = ""} : (tensor<64x8x64xf32>, tensor<2xi32>) -> tensor<512x64xf32> + + return %3 : tensor<512x64xf32> + + // CHECK: %[[CST_0:.*]] = "tf.Const"() {value = dense<[2, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32> + // CHECK: %[[CST_1:.*]] = "tf.Const"() {value = dense<3> : tensor} : () -> tensor + // CHECK: %[[CST_2:.*]] = "tf.Const"() {value = dense<[512, 64]> : tensor<2xi32>} : () -> tensor<2xi32> + // CHECK: %[[EXPAND_DIMS:.*]] = "tf.ExpandDims"(%arg0, %[[CST_1]]) {device = ""} : (tensor<8x64xf32>, tensor) -> tensor<8x64x1xf32> + // CHECK: %[[TRANSPOSE_1:.*]] = "tf.Transpose"(%[[EXPAND_DIMS]], %[[CST_0]]) : (tensor<8x64x1xf32>, tensor<3xi32>) -> tensor<1x8x64xf32> + // CHECK: %[[TRANSPOSE_2:.*]] = "tf.Transpose"(%arg1, %[[CST_0]]) : (tensor<8x64x64xf32>, tensor<3xi32>) -> tensor<64x8x64xf32> + // CHECK: %[[ADD:.*]] = "tf.AddV2"(%[[TRANSPOSE_1]], %[[TRANSPOSE_2]]) {device = ""} : (tensor<1x8x64xf32>, tensor<64x8x64xf32>) -> tensor<64x8x64xf32> + // CHECK: %[[RESHAPE:.*]] = "tf.Reshape"(%[[ADD]], %[[CST_2]]) {device = ""} : (tensor<64x8x64xf32>, tensor<2xi32>) -> tensor<512x64xf32> + // CHECK: return %[[RESHAPE]] : tensor<512x64xf32> +} + +// CHECK-LABEL: dont_move_transpose_different_ranks +func @dont_move_transpose_different_ranks(%arg0:tensor<1x1x2x3xf32>, %arg1:tensor<2x3xf32>) -> tensor<1x2x1x3xf32> { + %cst = "tf.Const"() {value = dense<[0, 2, 1, 3]> : tensor<4xi32>} : () -> tensor<4xi32> + %0 = "tf.AddV2"(%arg0, %arg1) {device = ""} : (tensor<1x1x2x3xf32>, tensor<2x3xf32>) -> tensor<1x1x2x3xf32> + %1 = "tf.Transpose"(%0, %cst) {device = ""} : (tensor<1x1x2x3xf32>, tensor<4xi32>) -> tensor<1x2x1x3xf32> + + return %1 : tensor<1x2x1x3xf32> + + // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[0, 2, 1, 3]> : tensor<4xi32>} : () -> tensor<4xi32> + // CHECK: %[[ADD:.*]] = "tf.AddV2"(%arg0, %arg1) {device = ""} : (tensor<1x1x2x3xf32>, tensor<2x3xf32>) -> tensor<1x1x2x3xf32> + // CHECK: %[[TRANSPOSE:.*]] = "tf.Transpose"(%[[ADD]], %[[CST]]) {device = ""} : (tensor<1x1x2x3xf32>, tensor<4xi32>) -> tensor<1x2x1x3xf32> + // CHECK: return %[[TRANSPOSE]] : tensor<1x2x1x3xf32> +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir index 4f044cd5eff..cc923070077 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir @@ -1,177 +1,396 @@ +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py // RUN: tf-opt -tf-legalize-hlo %s | FileCheck %s +// CHECK-LABEL: func @biasAdd_NHWC( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x32x10x32xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<32xi32>) -> tensor<1x32x10x32xi32> { +// CHECK: %[[VAL_2:.*]] = "tf.AddV2"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> +// CHECK: return %[[VAL_2]] : tensor<1x32x10x32xi32> +// CHECK: } func @biasAdd_NHWC(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> return %0 : tensor<1x32x10x32xi32> } +// CHECK-LABEL: func @biasAdd_NCHW( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x32x10x32xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<32xi32>) -> tensor<1x32x10x32xi32> { +// CHECK: %[[VAL_2:.*]] = "tf.AddV2"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> +// CHECK: return %[[VAL_2]] : tensor<1x32x10x32xi32> +// CHECK: } func @biasAdd_NCHW(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> return %0 : tensor<1x32x10x32xi32> } +// CHECK-LABEL: func @biasAdd_dynamic( +// CHECK-SAME: %[[VAL_0:.*]]: tensor, +// CHECK-SAME: %[[VAL_1:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_2:.*]] = "tf.AddV2"(%[[VAL_0]], %[[VAL_1]]) : (tensor, tensor) -> tensor +// CHECK: return %[[VAL_2]] : tensor +// CHECK: } func @biasAdd_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor, tensor) -> tensor return %0 : tensor } +// CHECK-LABEL: func @add( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi32>) -> tensor<2xi32> { +// CHECK: %[[VAL_1:.*]] = "tf.AddV2"(%[[VAL_0]], %[[VAL_0]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> +// CHECK: %[[VAL_2:.*]] = "tf.AddV2"(%[[VAL_1]], %[[VAL_0]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> +// CHECK: return %[[VAL_2]] : tensor<2xi32> +// CHECK: } func @add(%arg0: tensor<2xi32>) -> tensor<2xi32> { %0 = mhlo.add %arg0, %arg0 : tensor<2xi32> %1 = mhlo.add %0, %arg0 : tensor<2xi32> return %1 : tensor<2xi32> } +// CHECK-LABEL: func @broadcast_add( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2xi32>) -> tensor<1x2xi32> { +// CHECK: %[[VAL_2:.*]] = "tf.AddV2"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> +// CHECK: return %[[VAL_2]] : tensor<1x2xi32> +// CHECK: } func @broadcast_add(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> return %0 : tensor<1x2xi32> } +// CHECK-LABEL: func @broadcast_multi_dim_add( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x1x1xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> { +// CHECK: %[[VAL_2:.*]] = "tf.AddV2"(%[[VAL_0]], %[[VAL_1]]) : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> +// CHECK: return %[[VAL_2]] : tensor<4x4x4x4xi32> +// CHECK: } func @broadcast_multi_dim_add(%arg0: tensor<4x1x1xi32>, %arg1: tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> { %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>} : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> return %0 : tensor<4x4x4x4xi32> } +// CHECK-LABEL: func @div( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi32>) -> tensor<2xi32> { +// CHECK: %[[VAL_1:.*]] = "tf.Div"(%[[VAL_0]], %[[VAL_0]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> +// CHECK: return %[[VAL_1]] : tensor<2xi32> +// CHECK: } func @div(%arg0: tensor<2xi32>) -> tensor<2xi32> { %0 = mhlo.divide %arg0, %arg0 : tensor<2xi32> return %0 : tensor<2xi32> } +// CHECK-LABEL: func @broadcast_div( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2xi32>) -> tensor<1x2xi32> { +// CHECK: %[[VAL_2:.*]] = "tf.Div"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> +// CHECK: return %[[VAL_2]] : tensor<1x2xi32> +// CHECK: } func @broadcast_div(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { %0 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> return %0 : tensor<1x2xi32> } +// CHECK-LABEL: func @shift_left( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<4xi32>) -> tensor<4xi32> { +// CHECK: %[[VAL_2:.*]] = "tf.LeftShift"(%[[VAL_0]], %[[VAL_1]]) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> +// CHECK: return %[[VAL_2]] : tensor<4xi32> +// CHECK: } func @shift_left(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { %0 = mhlo.shift_left %arg0, %arg1 : tensor<4xi32> return %0 : tensor<4xi32> } +// CHECK-LABEL: func @div_dynamic( +// CHECK-SAME: %[[VAL_0:.*]]: tensor, +// CHECK-SAME: %[[VAL_1:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_2:.*]] = "tf.Div"(%[[VAL_0]], %[[VAL_1]]) : (tensor, tensor) -> tensor +// CHECK: return %[[VAL_2]] : tensor +// CHECK: } func @div_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { %0 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor) -> tensor return %0 : tensor } +// CHECK-LABEL: func @maximum( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<4xf32>) -> tensor<4xf32> { +// CHECK: %[[VAL_2:.*]] = "tf.Maximum"(%[[VAL_0]], %[[VAL_1]]) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> +// CHECK: return %[[VAL_2]] : tensor<4xf32> +// CHECK: } func @maximum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { %0 = mhlo.maximum %arg0, %arg1 : tensor<4xf32> return %0 : tensor<4xf32> } +// CHECK-LABEL: func @minimum( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<4xf32>) -> tensor<4xf32> { +// CHECK: %[[VAL_2:.*]] = "tf.Minimum"(%[[VAL_0]], %[[VAL_1]]) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> +// CHECK: return %[[VAL_2]] : tensor<4xf32> +// CHECK: } func @minimum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { %0 = mhlo.minimum %arg0, %arg1 : tensor<4xf32> return %0 : tensor<4xf32> } +// CHECK-LABEL: func @mul( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi32>) -> tensor<2xi32> { +// CHECK: %[[VAL_1:.*]] = "tf.Mul"(%[[VAL_0]], %[[VAL_0]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> +// CHECK: return %[[VAL_1]] : tensor<2xi32> +// CHECK: } func @mul(%arg0: tensor<2xi32>) -> tensor<2xi32> { %0 = mhlo.multiply %arg0, %arg0 : tensor<2xi32> return %0 : tensor<2xi32> } +// CHECK-LABEL: func @broadcast_mul( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2xi32>) -> tensor<1x2xi32> { +// CHECK: %[[VAL_2:.*]] = "tf.Mul"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> +// CHECK: return %[[VAL_2]] : tensor<1x2xi32> +// CHECK: } func @broadcast_mul(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { %0 = "chlo.broadcast_multiply"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> return %0 : tensor<1x2xi32> } +// CHECK-LABEL: func @real_div( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi32>) -> tensor<2xi32> { +// CHECK: %[[VAL_1:.*]] = "tf.Div"(%[[VAL_0]], %[[VAL_0]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> +// CHECK: return %[[VAL_1]] : tensor<2xi32> +// CHECK: } func @real_div(%arg0: tensor<2xi32>) -> tensor<2xi32> { %0 = mhlo.divide %arg0, %arg0 : tensor<2xi32> return %0 : tensor<2xi32> } +// CHECK-LABEL: func @broadcast_real_div( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2xi32>) -> tensor<1x2xi32> { +// CHECK: %[[VAL_2:.*]] = "tf.Div"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> +// CHECK: return %[[VAL_2]] : tensor<1x2xi32> +// CHECK: } func @broadcast_real_div(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { %0 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> return %0 : tensor<1x2xi32> } +// CHECK-LABEL: func @sub( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi32>) -> tensor<2xi32> { +// CHECK: %[[VAL_1:.*]] = "tf.Sub"(%[[VAL_0]], %[[VAL_0]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> +// CHECK: return %[[VAL_1]] : tensor<2xi32> +// CHECK: } func @sub(%arg0: tensor<2xi32>) -> tensor<2xi32> { %0 = mhlo.subtract %arg0, %arg0 : tensor<2xi32> return %0 : tensor<2xi32> } +// CHECK-LABEL: func @broadcast_sub( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2xi32>) -> tensor<1x2xi32> { +// CHECK: %[[VAL_2:.*]] = "tf.Sub"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> +// CHECK: return %[[VAL_2]] : tensor<1x2xi32> +// CHECK: } func @broadcast_sub(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { %0 = "chlo.broadcast_subtract"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> return %0 : tensor<1x2xi32> } +// CHECK-LABEL: func @shift_right( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<4xi32>) -> tensor<4xi32> { +// CHECK: %[[VAL_2:.*]] = "tf.RightShift"(%[[VAL_0]], %[[VAL_1]]) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> +// CHECK: return %[[VAL_2]] : tensor<4xi32> +// CHECK: } func @shift_right(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { %0 = mhlo.shift_right_arithmetic %arg0, %arg1 : tensor<4xi32> return %0 : tensor<4xi32> } +// CHECK-LABEL: func @broadcast_shift_right( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<2x4xi32>) -> tensor<2x4xi32> { +// CHECK: %[[VAL_2:.*]] = "tf.RightShift"(%[[VAL_0]], %[[VAL_1]]) : (tensor<4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32> +// CHECK: return %[[VAL_2]] : tensor<2x4xi32> +// CHECK: } func @broadcast_shift_right(%arg0: tensor<4xi32>, %arg1: tensor<2x4xi32>) -> tensor<2x4xi32> { %0 = "chlo.broadcast_shift_right_arithmetic"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32> return %0 : tensor<2x4xi32> } -func @and(%arg0: tensor<2xi1>) -> tensor<2xi1> { - %0 = mhlo.and %arg0, %arg0 : tensor<2xi1> +// CHECK-LABEL: func @and( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi1>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<2xi1>) -> tensor<2xi1> { +// CHECK: %[[VAL_2:.*]] = "tf.LogicalAnd"(%[[VAL_0]], %[[VAL_1]]) : (tensor<2xi1>, tensor<2xi1>) -> tensor<2xi1> +// CHECK: return %[[VAL_2]] : tensor<2xi1> +// CHECK: } +func @and(%arg0: tensor<2xi1>, %arg1: tensor<2xi1>) -> tensor<2xi1> { + %0 = mhlo.and %arg0, %arg1 : tensor<2xi1> return %0 : tensor<2xi1> } +// CHECK-LABEL: func @and_broadcast( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1xi1>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2xi1>) -> tensor<1x2xi1> { +// CHECK: %[[VAL_2:.*]] = "tf.LogicalAnd"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1> +// CHECK: return %[[VAL_2]] : tensor<1x2xi1> +// CHECK: } func @and_broadcast(%arg0: tensor<1xi1>, %arg1: tensor<1x2xi1>) -> tensor<1x2xi1> { %0 = "chlo.broadcast_and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } +// CHECK-LABEL: func @and_dynamic( +// CHECK-SAME: %[[VAL_0:.*]]: tensor, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1xi1>) -> tensor { +// CHECK: %[[VAL_2:.*]] = "tf.LogicalAnd"(%[[VAL_0]], %[[VAL_1]]) : (tensor, tensor<1xi1>) -> tensor +// CHECK: return %[[VAL_2]] : tensor +// CHECK: } func @and_dynamic(%arg0: tensor, %arg1: tensor<1xi1>) -> tensor { %0 = "chlo.broadcast_and"(%arg0, %arg1) : (tensor, tensor<1xi1>) -> tensor return %0 : tensor } -func @or(%arg0: tensor<2xi1>) -> tensor<2xi1> { - %0 = mhlo.or %arg0, %arg0 : tensor<2xi1> +// CHECK-LABEL: func @or( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi1>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<2xi1>) -> tensor<2xi1> { +// CHECK: %[[VAL_2:.*]] = "tf.LogicalOr"(%[[VAL_0]], %[[VAL_1]]) : (tensor<2xi1>, tensor<2xi1>) -> tensor<2xi1> +// CHECK: return %[[VAL_2]] : tensor<2xi1> +// CHECK: } +func @or(%arg0: tensor<2xi1>, %arg1: tensor<2xi1>) -> tensor<2xi1> { + %0 = mhlo.or %arg0, %arg1 : tensor<2xi1> return %0 : tensor<2xi1> } +// CHECK-LABEL: func @or_broadcast( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1xi1>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2xi1>) -> tensor<1x2xi1> { +// CHECK: %[[VAL_2:.*]] = "tf.LogicalOr"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1> +// CHECK: return %[[VAL_2]] : tensor<1x2xi1> +// CHECK: } func @or_broadcast(%arg0: tensor<1xi1>, %arg1: tensor<1x2xi1>) -> tensor<1x2xi1> { %0 = "chlo.broadcast_or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } +// CHECK-LABEL: func @or_dynamic( +// CHECK-SAME: %[[VAL_0:.*]]: tensor, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1xi1>) -> tensor { +// CHECK: %[[VAL_2:.*]] = "tf.LogicalOr"(%[[VAL_0]], %[[VAL_1]]) : (tensor, tensor<1xi1>) -> tensor +// CHECK: return %[[VAL_2]] : tensor +// CHECK: } func @or_dynamic(%arg0: tensor, %arg1: tensor<1xi1>) -> tensor { %0 = "chlo.broadcast_or"(%arg0, %arg1) : (tensor, tensor<1xi1>) -> tensor return %0 : tensor } +// CHECK-LABEL: func @bitwise_or( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<4xi32>) -> tensor<4xi32> { +// CHECK: %[[VAL_2:.*]] = "tf.BitwiseOr"(%[[VAL_0]], %[[VAL_1]]) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> +// CHECK: return %[[VAL_2]] : tensor<4xi32> +// CHECK: } func @bitwise_or(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { %0 = mhlo.or %arg0, %arg1 : tensor<4xi32> return %0 : tensor<4xi32> } +// CHECK-LABEL: func @bitwise_or_broadcast( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1xi8>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x4xi8>) -> tensor<1x4xi8> { +// CHECK: %[[VAL_2:.*]] = "tf.BitwiseOr"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8> +// CHECK: return %[[VAL_2]] : tensor<1x4xi8> +// CHECK: } func @bitwise_or_broadcast(%arg0: tensor<1xi8>, %arg1: tensor<1x4xi8>) -> tensor<1x4xi8> { %0 = "chlo.broadcast_or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8> return %0 : tensor<1x4xi8> } +// CHECK-LABEL: func @bitwise_or_dynamic( +// CHECK-SAME: %[[VAL_0:.*]]: tensor, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1xi32>) -> tensor { +// CHECK: %[[VAL_2:.*]] = "tf.BitwiseOr"(%[[VAL_0]], %[[VAL_1]]) : (tensor, tensor<1xi32>) -> tensor +// CHECK: return %[[VAL_2]] : tensor +// CHECK: } func @bitwise_or_dynamic(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { %0 = "chlo.broadcast_or"(%arg0, %arg1) : (tensor, tensor<1xi32>) -> tensor return %0 : tensor } +// CHECK-LABEL: func @bitwise_and( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<4xi32>) -> tensor<4xi32> { +// CHECK: %[[VAL_2:.*]] = "tf.BitwiseAnd"(%[[VAL_0]], %[[VAL_1]]) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> +// CHECK: return %[[VAL_2]] : tensor<4xi32> +// CHECK: } func @bitwise_and(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { %0 = mhlo.and %arg0, %arg1 : tensor<4xi32> return %0 : tensor<4xi32> } +// CHECK-LABEL: func @bitwise_and_broadcast( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1xi8>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x4xi8>) -> tensor<1x4xi8> { +// CHECK: %[[VAL_2:.*]] = "tf.BitwiseAnd"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8> +// CHECK: return %[[VAL_2]] : tensor<1x4xi8> +// CHECK: } func @bitwise_and_broadcast(%arg0: tensor<1xi8>, %arg1: tensor<1x4xi8>) -> tensor<1x4xi8> { %0 = "chlo.broadcast_and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8> return %0 : tensor<1x4xi8> } +// CHECK-LABEL: func @bitwise_and_dynamic( +// CHECK-SAME: %[[VAL_0:.*]]: tensor, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1xi32>) -> tensor { +// CHECK: %[[VAL_2:.*]] = "tf.BitwiseAnd"(%[[VAL_0]], %[[VAL_1]]) : (tensor, tensor<1xi32>) -> tensor +// CHECK: return %[[VAL_2]] : tensor +// CHECK: } func @bitwise_and_dynamic(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { %0 = "chlo.broadcast_and"(%arg0, %arg1) : (tensor, tensor<1xi32>) -> tensor return %0 : tensor } +// CHECK-LABEL: func @pow( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Pow"(%[[VAL_0]], %[[VAL_0]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> +// CHECK: return %[[VAL_1]] : tensor<2xf32> +// CHECK: } func @pow(%arg0: tensor<2xf32>) -> tensor<2xf32> { %0 = mhlo.power %arg0, %arg0 : tensor<2xf32> return %0 : tensor<2xf32> } +// CHECK-LABEL: func @pow_dynamic( +// CHECK-SAME: %[[VAL_0:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_1:.*]] = "tf.Pow"(%[[VAL_0]], %[[VAL_0]]) : (tensor, tensor) -> tensor +// CHECK: return %[[VAL_1]] : tensor +// CHECK: } func @pow_dynamic(%arg0: tensor) -> tensor { %0 = mhlo.power %arg0, %arg0 : tensor return %0 : tensor } +// CHECK-LABEL: func @floordiv_broadcast_i32( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2x3xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<3xi32>) -> tensor<2x3xi32> { +// CHECK: %[[VAL_2:.*]] = "tf.Const"() {value = dense<0> : tensor<2x3xi32>} : () -> tensor<2x3xi32> +// CHECK: %[[VAL_3:.*]] = "tf.Less"(%[[VAL_0]], %[[VAL_2]]) : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1> +// CHECK: %[[VAL_4:.*]] = "tf.Const"() {value = dense<0> : tensor<3xi32>} : () -> tensor<3xi32> +// CHECK: %[[VAL_5:.*]] = "tf.Less"(%[[VAL_1]], %[[VAL_4]]) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> +// CHECK: %[[VAL_6:.*]] = "tf.Equal"(%[[VAL_3]], %[[VAL_5]]) {incompatible_shape_error = true} : (tensor<2x3xi1>, tensor<3xi1>) -> tensor<2x3xi1> +// CHECK: %[[VAL_7:.*]] = "tf.Div"(%[[VAL_0]], %[[VAL_1]]) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> +// CHECK: %[[VAL_8:.*]] = "tf.Abs"(%[[VAL_0]]) : (tensor<2x3xi32>) -> tensor<2x3xi32> +// CHECK: %[[VAL_9:.*]] = "tf.Abs"(%[[VAL_1]]) : (tensor<3xi32>) -> tensor<3xi32> +// CHECK: %[[VAL_10:.*]] = "tf.Const"() {value = dense<1> : tensor<3xi32>} : () -> tensor<3xi32> +// CHECK: %[[VAL_11:.*]] = "tf.Sub"(%[[VAL_9]], %[[VAL_10]]) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32> +// CHECK: %[[VAL_12:.*]] = "tf.AddV2"(%[[VAL_8]], %[[VAL_11]]) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> +// CHECK: %[[VAL_13:.*]] = "tf.Neg"(%[[VAL_12]]) : (tensor<2x3xi32>) -> tensor<2x3xi32> +// CHECK: %[[VAL_14:.*]] = "tf.Abs"(%[[VAL_1]]) : (tensor<3xi32>) -> tensor<3xi32> +// CHECK: %[[VAL_15:.*]] = "tf.Div"(%[[VAL_13]], %[[VAL_14]]) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> +// CHECK: %[[VAL_16:.*]] = "tf.Select"(%[[VAL_6]], %[[VAL_7]], %[[VAL_15]]) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> +// CHECK: return %[[VAL_16]] : tensor<2x3xi32> +// CHECK: } func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> { %0 = mhlo.constant dense<0> : tensor<2x3xi32> %1 = "chlo.broadcast_compare"(%arg0, %0) {comparison_direction = "LT"} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1> @@ -191,6 +410,26 @@ func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> te return %14 : tensor<2x3xi32> } +// CHECK-LABEL: func @floordiv_reverse_broadcast_i32( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<3xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<2x3xi32>) -> tensor<2x3xi32> { +// CHECK: %[[VAL_2:.*]] = "tf.Const"() {value = dense<0> : tensor<3xi32>} : () -> tensor<3xi32> +// CHECK: %[[VAL_3:.*]] = "tf.Less"(%[[VAL_0]], %[[VAL_2]]) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> +// CHECK: %[[VAL_4:.*]] = "tf.Const"() {value = dense<0> : tensor<2x3xi32>} : () -> tensor<2x3xi32> +// CHECK: %[[VAL_5:.*]] = "tf.Less"(%[[VAL_1]], %[[VAL_4]]) : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1> +// CHECK: %[[VAL_6:.*]] = "tf.Equal"(%[[VAL_3]], %[[VAL_5]]) {incompatible_shape_error = true} : (tensor<3xi1>, tensor<2x3xi1>) -> tensor<2x3xi1> +// CHECK: %[[VAL_7:.*]] = "tf.Div"(%[[VAL_0]], %[[VAL_1]]) : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> +// CHECK: %[[VAL_8:.*]] = "tf.Abs"(%[[VAL_0]]) : (tensor<3xi32>) -> tensor<3xi32> +// CHECK: %[[VAL_9:.*]] = "tf.Abs"(%[[VAL_1]]) : (tensor<2x3xi32>) -> tensor<2x3xi32> +// CHECK: %[[VAL_10:.*]] = "tf.Const"() {value = dense<1> : tensor<2x3xi32>} : () -> tensor<2x3xi32> +// CHECK: %[[VAL_11:.*]] = "tf.Sub"(%[[VAL_9]], %[[VAL_10]]) : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> +// CHECK: %[[VAL_12:.*]] = "tf.AddV2"(%[[VAL_8]], %[[VAL_11]]) : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> +// CHECK: %[[VAL_13:.*]] = "tf.Neg"(%[[VAL_12]]) : (tensor<2x3xi32>) -> tensor<2x3xi32> +// CHECK: %[[VAL_14:.*]] = "tf.Abs"(%[[VAL_1]]) : (tensor<2x3xi32>) -> tensor<2x3xi32> +// CHECK: %[[VAL_15:.*]] = "tf.Div"(%[[VAL_13]], %[[VAL_14]]) : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> +// CHECK: %[[VAL_16:.*]] = "tf.Select"(%[[VAL_6]], %[[VAL_7]], %[[VAL_15]]) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> +// CHECK: return %[[VAL_16]] : tensor<2x3xi32> +// CHECK: } func @floordiv_reverse_broadcast_i32(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> { %0 = mhlo.constant dense<0> : tensor<3xi32> %1 = "mhlo.compare"(%arg0, %0) {comparison_direction = "LT"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> @@ -210,6 +449,13 @@ func @floordiv_reverse_broadcast_i32(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32 return %14 : tensor<2x3xi32> } +// CHECK-LABEL: func @floordiv_f32( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Div"(%[[VAL_0]], %[[VAL_0]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> +// CHECK: %[[VAL_2:.*]] = "tf.Div"(%[[VAL_0]], %[[VAL_0]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> +// CHECK: %[[VAL_3:.*]] = "tf.FloorDiv"(%[[VAL_0]], %[[VAL_0]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> +// CHECK: return %[[VAL_3]] : tensor<2xf32> +// CHECK: } func @floordiv_f32(%arg0: tensor<2xf32>) -> tensor<2xf32> { %0 = mhlo.divide %arg0, %arg0 : tensor<2xf32> %1 = mhlo.divide %arg0, %arg0 : tensor<2xf32> @@ -217,6 +463,14 @@ func @floordiv_f32(%arg0: tensor<2xf32>) -> tensor<2xf32> { return %2 : tensor<2xf32> } +// CHECK-LABEL: func @floordiv_f16_broadcast( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2x3xf16>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<3xf16>) -> tensor<2x3xf16> { +// CHECK: %[[VAL_2:.*]] = "tf.Div"(%[[VAL_0]], %[[VAL_1]]) : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> +// CHECK: %[[VAL_3:.*]] = "tf.Div"(%[[VAL_0]], %[[VAL_1]]) : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> +// CHECK: %[[VAL_4:.*]] = "tf.FloorDiv"(%[[VAL_0]], %[[VAL_1]]) : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> +// CHECK: return %[[VAL_4]] : tensor<2x3xf16> +// CHECK: } func @floordiv_f16_broadcast(%arg0: tensor<2x3xf16>, %arg1: tensor<3xf16>) -> tensor<2x3xf16> { %0 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> %1 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> @@ -224,118 +478,258 @@ func @floordiv_f16_broadcast(%arg0: tensor<2x3xf16>, %arg1: tensor<3xf16>) -> te return %2 : tensor<2x3xf16> } -func @equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { - %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> +// CHECK-LABEL: func @equal( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<2xi32>) -> tensor<2xi1> { +// CHECK: %[[VAL_2:.*]] = "tf.Equal"(%[[VAL_0]], %[[VAL_1]]) {incompatible_shape_error = true} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> +// CHECK: return %[[VAL_2]] : tensor<2xi1> +// CHECK: } +func @equal(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> { + %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> return %0 : tensor<2xi1> } +// CHECK-LABEL: func @equal_dynamic( +// CHECK-SAME: %[[VAL_0:.*]]: tensor, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1xi32>) -> tensor { +// CHECK: %[[VAL_2:.*]] = "tf.Equal"(%[[VAL_0]], %[[VAL_1]]) {incompatible_shape_error = true} : (tensor, tensor<1xi32>) -> tensor +// CHECK: return %[[VAL_2]] : tensor +// CHECK: } func @equal_dynamic(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { %0 = "chlo.broadcast_compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor, tensor<1xi32>) -> tensor return %0 : tensor } +// CHECK-LABEL: func @equal_broadcast( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2xi32>) -> tensor<1x2xi1> { +// CHECK: %[[VAL_2:.*]] = "tf.Equal"(%[[VAL_0]], %[[VAL_1]]) {incompatible_shape_error = true} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> +// CHECK: return %[[VAL_2]] : tensor<1x2xi1> +// CHECK: } func @equal_broadcast(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } +// CHECK-LABEL: func @equal_broadcast_no_incompatible_shapes_error( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2xi32>) -> tensor<1x2xi1> { +// CHECK: %[[VAL_2:.*]] = "tf.Equal"(%[[VAL_0]], %[[VAL_1]]) {incompatible_shape_error = true} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> +// CHECK: return %[[VAL_2]] : tensor<1x2xi1> +// CHECK: } func @equal_broadcast_no_incompatible_shapes_error(%arg0: tensor<2xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } +// CHECK-LABEL: func @equal_incompatible_shape_broadcastable( +// CHECK-SAME: %[[VAL_0:.*]]: tensor, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1xi32>) -> tensor { +// CHECK: %[[VAL_2:.*]] = "tf.Equal"(%[[VAL_0]], %[[VAL_1]]) {incompatible_shape_error = true} : (tensor, tensor<1xi32>) -> tensor +// CHECK: return %[[VAL_2]] : tensor +// CHECK: } func @equal_incompatible_shape_broadcastable(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { %0 = "chlo.broadcast_compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor, tensor<1xi32>) -> tensor return %0 : tensor } -func @notequal(%arg0: tensor<2xi32>) -> tensor<2xi1> { - %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> +// CHECK-LABEL: func @notequal( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<2xi32>) -> tensor<2xi1> { +// CHECK: %[[VAL_2:.*]] = "tf.NotEqual"(%[[VAL_0]], %[[VAL_1]]) {incompatible_shape_error = true} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> +// CHECK: return %[[VAL_2]] : tensor<2xi1> +// CHECK: } +func @notequal(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> { + %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "NE"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> return %0 : tensor<2xi1> } +// CHECK-LABEL: func @notequal_broadcast( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2xi32>) -> tensor<1x2xi1> { +// CHECK: %[[VAL_2:.*]] = "tf.NotEqual"(%[[VAL_0]], %[[VAL_1]]) {incompatible_shape_error = true} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> +// CHECK: return %[[VAL_2]] : tensor<1x2xi1> +// CHECK: } func @notequal_broadcast(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } +// CHECK-LABEL: func @notequal_broadcast_no_incompatible_shapes_error( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2xi32>) -> tensor<1x2xi1> { +// CHECK: %[[VAL_2:.*]] = "tf.NotEqual"(%[[VAL_0]], %[[VAL_1]]) {incompatible_shape_error = true} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> +// CHECK: return %[[VAL_2]] : tensor<1x2xi1> +// CHECK: } func @notequal_broadcast_no_incompatible_shapes_error(%arg0: tensor<2xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } +// CHECK-LABEL: func @notequal_incompatible_shape_broadcastable( +// CHECK-SAME: %[[VAL_0:.*]]: tensor, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1xi32>) -> tensor { +// CHECK: %[[VAL_2:.*]] = "tf.NotEqual"(%[[VAL_0]], %[[VAL_1]]) {incompatible_shape_error = true} : (tensor, tensor<1xi32>) -> tensor +// CHECK: return %[[VAL_2]] : tensor +// CHECK: } func @notequal_incompatible_shape_broadcastable(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { %0 = "chlo.broadcast_compare"(%arg0, %arg1) {comparison_direction = "NE"} : (tensor, tensor<1xi32>) -> tensor return %0 : tensor } -func @greater(%arg0: tensor<2xi32>) -> tensor<2xi1> { - %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> +// CHECK-LABEL: func @greater( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<2xi32>) -> tensor<2xi1> { +// CHECK: %[[VAL_2:.*]] = "tf.Greater"(%[[VAL_0]], %[[VAL_1]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> +// CHECK: return %[[VAL_2]] : tensor<2xi1> +// CHECK: } +func @greater(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> { + %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> return %0 : tensor<2xi1> } +// CHECK-LABEL: func @broadcast_greater( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2xi32>) -> tensor<1x2xi1> { +// CHECK: %[[VAL_2:.*]] = "tf.Greater"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> +// CHECK: return %[[VAL_2]] : tensor<1x2xi1> +// CHECK: } func @broadcast_greater(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "GT"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } -func @greater_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { - %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> +// CHECK-LABEL: func @greater_equal( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<2xi32>) -> tensor<2xi1> { +// CHECK: %[[VAL_2:.*]] = "tf.GreaterEqual"(%[[VAL_0]], %[[VAL_1]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> +// CHECK: return %[[VAL_2]] : tensor<2xi1> +// CHECK: } +func @greater_equal(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> { + %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GE"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> return %0 : tensor<2xi1> } +// CHECK-LABEL: func @broadcast_greater_equal( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2xi32>) -> tensor<1x2xi1> { +// CHECK: %[[VAL_2:.*]] = "tf.GreaterEqual"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> +// CHECK: return %[[VAL_2]] : tensor<1x2xi1> +// CHECK: } func @broadcast_greater_equal(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "GE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } -func @less(%arg0: tensor<2xi32>) -> tensor<2xi1> { - %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> +// CHECK-LABEL: func @less( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<2xi32>) -> tensor<2xi1> { +// CHECK: %[[VAL_2:.*]] = "tf.Less"(%[[VAL_0]], %[[VAL_1]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> +// CHECK: return %[[VAL_2]] : tensor<2xi1> +// CHECK: } +func @less(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> { + %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "LT"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> return %0 : tensor<2xi1> } +// CHECK-LABEL: func @broadcast_less( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2xi32>) -> tensor<1x2xi1> { +// CHECK: %[[VAL_2:.*]] = "tf.Less"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> +// CHECK: return %[[VAL_2]] : tensor<1x2xi1> +// CHECK: } func @broadcast_less(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LT"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } -func @less_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { - %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> +// CHECK-LABEL: func @less_equal( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<2xi32>) -> tensor<2xi1> { +// CHECK: %[[VAL_2:.*]] = "tf.LessEqual"(%[[VAL_0]], %[[VAL_1]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> +// CHECK: return %[[VAL_2]] : tensor<2xi1> +// CHECK: } +func @less_equal(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> { + %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "LE"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> return %0 : tensor<2xi1> } +// CHECK-LABEL: func @broadcast_less_equal( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2xi32>) -> tensor<1x2xi1> { +// CHECK: %[[VAL_2:.*]] = "tf.LessEqual"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> +// CHECK: return %[[VAL_2]] : tensor<1x2xi1> +// CHECK: } func @broadcast_less_equal(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } +// CHECK-LABEL: func @concat_v2( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<3x3xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<3x3xf32>) -> tensor<6x3xf32> { +// CHECK: %[[VAL_2:.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor +// CHECK: %[[VAL_3:.*]] = "tf.ConcatV2"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) : (tensor<3x3xf32>, tensor<3x3xf32>, tensor) -> tensor<6x3xf32> +// CHECK: return %[[VAL_3]] : tensor<6x3xf32> +// CHECK: } func @concat_v2(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<6x3xf32> { %2 = "mhlo.concatenate"(%arg0, %arg1) {dimension = 0 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<6x3xf32> return %2 : tensor<6x3xf32> } +// CHECK-LABEL: func @concat_v2_1d_axis( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<3x3xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<3x3xf32>) -> tensor<3x6xf32> { +// CHECK: %[[VAL_2:.*]] = "tf.Const"() {value = dense<1> : tensor} : () -> tensor +// CHECK: %[[VAL_3:.*]] = "tf.ConcatV2"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) : (tensor<3x3xf32>, tensor<3x3xf32>, tensor) -> tensor<3x6xf32> +// CHECK: return %[[VAL_3]] : tensor<3x6xf32> +// CHECK: } func @concat_v2_1d_axis(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<3x6xf32> { %2 = "mhlo.concatenate"(%arg0, %arg1) {dimension = 1 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x6xf32> return %2 : tensor<3x6xf32> } +// CHECK-LABEL: func @const() -> tensor<2xi32> { +// CHECK: %[[VAL_0:.*]] = "tf.Const"() {value = dense<0> : tensor<2xi32>} : () -> tensor<2xi32> +// CHECK: return %[[VAL_0]] : tensor<2xi32> +// CHECK: } func @const() -> tensor<2xi32> { %0 = mhlo.constant dense<0> : tensor<2xi32> return %0 : tensor<2xi32> } +// CHECK-LABEL: func @relu( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1xi32>) -> tensor<1xi32> { +// CHECK: %[[VAL_1:.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor +// CHECK: %[[VAL_2:.*]] = "tf.Maximum"(%[[VAL_1]], %[[VAL_0]]) : (tensor, tensor<1xi32>) -> tensor<1xi32> +// CHECK: return %[[VAL_2]] : tensor<1xi32> +// CHECK: } func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> { %0 = mhlo.constant dense<0> : tensor %1 = "chlo.broadcast_maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor<1xi32>) -> tensor<1xi32> return %1 : tensor<1xi32> } +// CHECK-LABEL: func @relu_unranked( +// CHECK-SAME: %[[VAL_0:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_1:.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor +// CHECK: %[[VAL_2:.*]] = "tf.Maximum"(%[[VAL_1]], %[[VAL_0]]) : (tensor, tensor) -> tensor +// CHECK: return %[[VAL_2]] : tensor +// CHECK: } func @relu_unranked(%arg0: tensor) -> tensor { %0 = mhlo.constant dense<0> : tensor %1 = "chlo.broadcast_maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor) -> tensor return %1 : tensor } +// CHECK-LABEL: func @relu6( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1xi32>) -> tensor<1xi32> { +// CHECK: %[[VAL_1:.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor +// CHECK: %[[VAL_2:.*]] = "tf.Const"() {value = dense<6> : tensor} : () -> tensor +// CHECK: %[[VAL_3:.*]] = "tf.Minimum"(%[[VAL_0]], %[[VAL_2]]) : (tensor<1xi32>, tensor) -> tensor<1xi32> +// CHECK: %[[VAL_4:.*]] = "tf.Maximum"(%[[VAL_3]], %[[VAL_1]]) : (tensor<1xi32>, tensor) -> tensor<1xi32> +// CHECK: return %[[VAL_4]] : tensor<1xi32> +// CHECK: } func @relu6(%arg0: tensor<1xi32>) -> tensor<1xi32> { %0 = mhlo.constant dense<0> : tensor %1 = mhlo.constant dense<6> : tensor @@ -344,6 +738,14 @@ func @relu6(%arg0: tensor<1xi32>) -> tensor<1xi32> { return %3 : tensor<1xi32> } +// CHECK-LABEL: func @relu6_unranked( +// CHECK-SAME: %[[VAL_0:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_1:.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor +// CHECK: %[[VAL_2:.*]] = "tf.Const"() {value = dense<6> : tensor} : () -> tensor +// CHECK: %[[VAL_3:.*]] = "tf.Minimum"(%[[VAL_0]], %[[VAL_2]]) : (tensor, tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = "tf.Maximum"(%[[VAL_3]], %[[VAL_1]]) : (tensor, tensor) -> tensor +// CHECK: return %[[VAL_4]] : tensor +// CHECK: } func @relu6_unranked(%arg0: tensor) -> tensor { %0 = mhlo.constant dense<0> : tensor %1 = mhlo.constant dense<6> : tensor @@ -352,6 +754,15 @@ func @relu6_unranked(%arg0: tensor) -> tensor { return %3 : tensor } +// CHECK-LABEL: func @relu_grad( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x8xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor) -> tensor<4x8xf32> { +// CHECK: %[[VAL_2:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor} : () -> tensor +// CHECK: %[[VAL_3:.*]] = "tf.Greater"(%[[VAL_1]], %[[VAL_2]]) : (tensor, tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<4x8xf32>} : () -> tensor<4x8xf32> +// CHECK: %[[VAL_5:.*]] = "tf.Select"(%[[VAL_3]], %[[VAL_0]], %[[VAL_4]]) : (tensor, tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32> +// CHECK: return %[[VAL_5]] : tensor<4x8xf32> +// CHECK: } func @relu_grad(%arg0: tensor<4x8xf32>, %arg1: tensor) -> tensor<4x8xf32> { %0 = mhlo.constant dense<0.000000e+00> : tensor %1 = "chlo.broadcast_compare"(%arg1, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "GT"} : (tensor, tensor) -> tensor @@ -360,31 +771,74 @@ func @relu_grad(%arg0: tensor<4x8xf32>, %arg1: tensor) -> tensor<4x8xf3 return %3 : tensor<4x8xf32> } +// CHECK-LABEL: func @select( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi1>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<2xi32>, +// CHECK-SAME: %[[VAL_2:.*]]: tensor<2xi32>) -> tensor<2xi32> { +// CHECK: %[[VAL_3:.*]] = "tf.Select"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> +// CHECK: return %[[VAL_3]] : tensor<2xi32> +// CHECK: } func @select(%arg0: tensor<2xi1>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> { %0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> return %0 : tensor<2xi32> } +// CHECK-LABEL: func @select_float( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi1>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<2xf32>, +// CHECK-SAME: %[[VAL_2:.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: %[[VAL_3:.*]] = "tf.Select"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) : (tensor<2xi1>, tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> +// CHECK: return %[[VAL_3]] : tensor<2xf32> +// CHECK: } func @select_float(%arg0: tensor<2xi1>, %arg1: tensor<2xf32>, %arg2: tensor<2xf32>) -> tensor<2xf32> { %0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } +// CHECK-LABEL: func @select_multidimensional( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<3x2xi1>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<3x2xi32>, +// CHECK-SAME: %[[VAL_2:.*]]: tensor<3x2xi32>) -> tensor<3x2xi32> { +// CHECK: %[[VAL_3:.*]] = "tf.Select"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) : (tensor<3x2xi1>, tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32> +// CHECK: return %[[VAL_3]] : tensor<3x2xi32> +// CHECK: } func @select_multidimensional(%arg0: tensor<3x2xi1>, %arg1: tensor<3x2xi32>, %arg2: tensor<3x2xi32>) -> tensor<3x2xi32> { %0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<3x2xi1>, tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32> return %0 : tensor<3x2xi32> } +// CHECK-LABEL: func @selectv2( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi1>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<2xi32>, +// CHECK-SAME: %[[VAL_2:.*]]: tensor<2xi32>) -> tensor<2xi32> { +// CHECK: %[[VAL_3:.*]] = "tf.Select"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> +// CHECK: return %[[VAL_3]] : tensor<2xi32> +// CHECK: } func @selectv2(%arg0: tensor<2xi1>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> { %0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> return %0 : tensor<2xi32> } +// CHECK-LABEL: func @selectv2_pred_scalar( +// CHECK-SAME: %[[VAL_0:.*]]: tensor, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<2xi32>, +// CHECK-SAME: %[[VAL_2:.*]]: tensor<2xi32>) -> tensor<2xi32> { +// CHECK: %[[VAL_3:.*]] = "tf.Select"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) : (tensor, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> +// CHECK: return %[[VAL_3]] : tensor<2xi32> +// CHECK: } func @selectv2_pred_scalar(%arg0: tensor, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> { %0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> return %0 : tensor<2xi32> } +// CHECK-LABEL: func @transpose_2d( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2x3xf32>) -> tensor<3x2xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> +// CHECK: %[[VAL_2:.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> +// CHECK: %[[VAL_3:.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> +// CHECK: %[[VAL_4:.*]] = "tf.Transpose"(%[[VAL_0]], %[[VAL_3]]) : (tensor<2x3xf32>, tensor<2xi64>) -> tensor<3x2xf32> +// CHECK: return %[[VAL_4]] : tensor<3x2xf32> +// CHECK: } func @transpose_2d(%arg0: tensor<2x3xf32>) -> tensor<3x2xf32> { %0 = mhlo.constant dense<[1, 0]> : tensor<2xi64> %1 = mhlo.constant dense<[1, 0]> : tensor<2xi64> @@ -392,6 +846,14 @@ func @transpose_2d(%arg0: tensor<2x3xf32>) -> tensor<3x2xf32> { return %2 : tensor<3x2xf32> } +// CHECK-LABEL: func @transpose_3d_int32( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Const"() {value = dense<[2, 1, 0]> : tensor<3xi32>} : () -> tensor<3xi32> +// CHECK: %[[VAL_2:.*]] = "tf.Const"() {value = dense<[2, 1, 0]> : tensor<3xi64>} : () -> tensor<3xi64> +// CHECK: %[[VAL_3:.*]] = "tf.Const"() {value = dense<[2, 1, 0]> : tensor<3xi64>} : () -> tensor<3xi64> +// CHECK: %[[VAL_4:.*]] = "tf.Transpose"(%[[VAL_0]], %[[VAL_3]]) : (tensor<1x2x3xf32>, tensor<3xi64>) -> tensor<3x2x1xf32> +// CHECK: return %[[VAL_4]] : tensor<3x2x1xf32> +// CHECK: } func @transpose_3d_int32(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> { %0 = mhlo.constant dense<[2, 1, 0]> : tensor<3xi32> %1 = mhlo.constant dense<[2, 1, 0]> : tensor<3xi64> @@ -399,6 +861,14 @@ func @transpose_3d_int32(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> { return %2 : tensor<3x2x1xf32> } +// CHECK-LABEL: func @transpose_3d( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Const"() {value = dense<[2, 1, 0]> : tensor<3xi64>} : () -> tensor<3xi64> +// CHECK: %[[VAL_2:.*]] = "tf.Const"() {value = dense<[2, 1, 0]> : tensor<3xi64>} : () -> tensor<3xi64> +// CHECK: %[[VAL_3:.*]] = "tf.Const"() {value = dense<[2, 1, 0]> : tensor<3xi64>} : () -> tensor<3xi64> +// CHECK: %[[VAL_4:.*]] = "tf.Transpose"(%[[VAL_0]], %[[VAL_3]]) : (tensor<1x2x3xf32>, tensor<3xi64>) -> tensor<3x2x1xf32> +// CHECK: return %[[VAL_4]] : tensor<3x2x1xf32> +// CHECK: } func @transpose_3d(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> { %0 = mhlo.constant dense<[2, 1, 0]> : tensor<3xi64> %1 = mhlo.constant dense<[2, 1, 0]> : tensor<3xi64> @@ -406,6 +876,14 @@ func @transpose_3d(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> { return %2 : tensor<3x2x1xf32> } +// CHECK-LABEL: func @transpose_dynamic_2d( +// CHECK-SAME: %[[VAL_0:.*]]: tensor) -> tensor<4x?xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> +// CHECK: %[[VAL_2:.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> +// CHECK: %[[VAL_3:.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> +// CHECK: %[[VAL_4:.*]] = "tf.Transpose"(%[[VAL_0]], %[[VAL_3]]) : (tensor, tensor<2xi64>) -> tensor<4x?xf32> +// CHECK: return %[[VAL_4]] : tensor<4x?xf32> +// CHECK: } func @transpose_dynamic_2d(%arg0: tensor) -> tensor<4x?xf32> { %0 = mhlo.constant dense<[1, 0]> : tensor<2xi64> %1 = mhlo.constant dense<[1, 0]> : tensor<2xi64> @@ -413,6 +891,14 @@ func @transpose_dynamic_2d(%arg0: tensor) -> tensor<4x?xf32> { return %2 : tensor<4x?xf32> } +// CHECK-LABEL: func @transpose_unranked_2d( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> +// CHECK: %[[VAL_2:.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> +// CHECK: %[[VAL_3:.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> +// CHECK: %[[VAL_4:.*]] = "tf.Transpose"(%[[VAL_0]], %[[VAL_3]]) : (tensor<*xf32>, tensor<2xi64>) -> tensor<*xf32> +// CHECK: return %[[VAL_4]] : tensor<*xf32> +// CHECK: } func @transpose_unranked_2d(%arg0: tensor<*xf32>) -> tensor<*xf32> { %0 = mhlo.constant dense<[1, 0]> : tensor<2xi64> %1 = mhlo.constant dense<[1, 0]> : tensor<2xi64> @@ -420,146 +906,297 @@ func @transpose_unranked_2d(%arg0: tensor<*xf32>) -> tensor<*xf32> { return %2 : tensor<*xf32> } +// CHECK-LABEL: func @abs( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Abs"(%[[VAL_0]]) : (tensor<2xf32>) -> tensor<2xf32> +// CHECK: return %[[VAL_1]] : tensor<2xf32> +// CHECK: } func @abs(%arg0: tensor<2xf32>) -> tensor<2xf32> { %0 = "mhlo.abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } +// CHECK-LABEL: func @abs_dynamic( +// CHECK-SAME: %[[VAL_0:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_1:.*]] = "tf.Abs"(%[[VAL_0]]) : (tensor) -> tensor +// CHECK: return %[[VAL_1]] : tensor +// CHECK: } func @abs_dynamic(%arg0: tensor) -> tensor { %0 = "mhlo.abs"(%arg0) : (tensor) -> tensor return %0 : tensor } +// CHECK-LABEL: func @abs_unranked( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Abs"(%[[VAL_0]]) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return %[[VAL_1]] : tensor<*xf32> +// CHECK: } func @abs_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { %0 = "mhlo.abs"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } +// CHECK-LABEL: func @ceil( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Ceil"(%[[VAL_0]]) : (tensor<2xf32>) -> tensor<2xf32> +// CHECK: return %[[VAL_1]] : tensor<2xf32> +// CHECK: } func @ceil(%arg0: tensor<2xf32>) -> tensor<2xf32> { %0 = "mhlo.ceil"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } +// CHECK-LABEL: func @ceil_dynamic( +// CHECK-SAME: %[[VAL_0:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_1:.*]] = "tf.Ceil"(%[[VAL_0]]) : (tensor) -> tensor +// CHECK: return %[[VAL_1]] : tensor +// CHECK: } func @ceil_dynamic(%arg0: tensor) -> tensor { %0 = "mhlo.ceil"(%arg0) : (tensor) -> tensor return %0 : tensor } +// CHECK-LABEL: func @ceil_unranked( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Ceil"(%[[VAL_0]]) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return %[[VAL_1]] : tensor<*xf32> +// CHECK: } func @ceil_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { %0 = "mhlo.ceil"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } +// CHECK-LABEL: func @complex_abs( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xcomplex>) -> tensor<2xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.ComplexAbs"(%[[VAL_0]]) : (tensor<2xcomplex>) -> tensor<2xf32> +// CHECK: return %[[VAL_1]] : tensor<2xf32> +// CHECK: } func @complex_abs(%arg0: tensor<2xcomplex>) -> tensor<2xf32> { %0 = "mhlo.abs"(%arg0) : (tensor<2xcomplex>) -> tensor<2xf32> return %0 : tensor<2xf32> } +// CHECK-LABEL: func @cos( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Cos"(%[[VAL_0]]) : (tensor<2xf32>) -> tensor<2xf32> +// CHECK: return %[[VAL_1]] : tensor<2xf32> +// CHECK: } func @cos(%arg0: tensor<2xf32>) -> tensor<2xf32> { %0 = "mhlo.cosine"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } +// CHECK-LABEL: func @cos_dynamic( +// CHECK-SAME: %[[VAL_0:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_1:.*]] = "tf.Cos"(%[[VAL_0]]) : (tensor) -> tensor +// CHECK: return %[[VAL_1]] : tensor +// CHECK: } func @cos_dynamic(%arg0: tensor) -> tensor { %0 = "mhlo.cosine"(%arg0) : (tensor) -> tensor return %0 : tensor } +// CHECK-LABEL: func @cos_unranked( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Cos"(%[[VAL_0]]) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return %[[VAL_1]] : tensor<*xf32> +// CHECK: } func @cos_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { %0 = "mhlo.cosine"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } +// CHECK-LABEL: func @exp( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Exp"(%[[VAL_0]]) : (tensor<2xf32>) -> tensor<2xf32> +// CHECK: return %[[VAL_1]] : tensor<2xf32> +// CHECK: } func @exp(%arg0: tensor<2xf32>) -> tensor<2xf32> { %0 = "mhlo.exponential"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } +// CHECK-LABEL: func @exp_dynamic( +// CHECK-SAME: %[[VAL_0:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_1:.*]] = "tf.Exp"(%[[VAL_0]]) : (tensor) -> tensor +// CHECK: return %[[VAL_1]] : tensor +// CHECK: } func @exp_dynamic(%arg0: tensor) -> tensor { %0 = "mhlo.exponential"(%arg0) : (tensor) -> tensor return %0 : tensor } +// CHECK-LABEL: func @exp_unranked( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Exp"(%[[VAL_0]]) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return %[[VAL_1]] : tensor<*xf32> +// CHECK: } func @exp_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { %0 = "mhlo.exponential"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } +// CHECK-LABEL: func @floor( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Floor"(%[[VAL_0]]) : (tensor<2xf32>) -> tensor<2xf32> +// CHECK: return %[[VAL_1]] : tensor<2xf32> +// CHECK: } func @floor(%arg0: tensor<2xf32>) -> tensor<2xf32> { %0 = "mhlo.floor"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } +// CHECK-LABEL: func @floor_dynamic( +// CHECK-SAME: %[[VAL_0:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_1:.*]] = "tf.Floor"(%[[VAL_0]]) : (tensor) -> tensor +// CHECK: return %[[VAL_1]] : tensor +// CHECK: } func @floor_dynamic(%arg0: tensor) -> tensor { %0 = "mhlo.floor"(%arg0) : (tensor) -> tensor return %0 : tensor } +// CHECK-LABEL: func @floor_unranked( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Floor"(%[[VAL_0]]) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return %[[VAL_1]] : tensor<*xf32> +// CHECK: } func @floor_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { %0 = "mhlo.floor"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } +// CHECK-LABEL: func @is_finite( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xf32>) -> tensor<2xi1> { +// CHECK: %[[VAL_1:.*]] = "tf.IsFinite"(%[[VAL_0]]) : (tensor<2xf32>) -> tensor<2xi1> +// CHECK: return %[[VAL_1]] : tensor<2xi1> +// CHECK: } func @is_finite(%arg0: tensor<2xf32>) -> tensor<2xi1> { %0 = "mhlo.is_finite"(%arg0) : (tensor<2xf32>) -> tensor<2xi1> return %0 : tensor<2xi1> } +// CHECK-LABEL: func @is_finite_dynamic( +// CHECK-SAME: %[[VAL_0:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_1:.*]] = "tf.IsFinite"(%[[VAL_0]]) : (tensor) -> tensor +// CHECK: return %[[VAL_1]] : tensor +// CHECK: } func @is_finite_dynamic(%arg0: tensor) -> tensor { %0 = "mhlo.is_finite"(%arg0) : (tensor) -> tensor return %0 : tensor } +// CHECK-LABEL: func @is_finite_unranked( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<*xf32>) -> tensor<*xi1> { +// CHECK: %[[VAL_1:.*]] = "tf.IsFinite"(%[[VAL_0]]) : (tensor<*xf32>) -> tensor<*xi1> +// CHECK: return %[[VAL_1]] : tensor<*xi1> +// CHECK: } func @is_finite_unranked(%arg0: tensor<*xf32>) -> tensor<*xi1> { %0 = "mhlo.is_finite"(%arg0) : (tensor<*xf32>) -> tensor<*xi1> return %0 : tensor<*xi1> } +// CHECK-LABEL: func @log( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Log"(%[[VAL_0]]) : (tensor<2xf32>) -> tensor<2xf32> +// CHECK: return %[[VAL_1]] : tensor<2xf32> +// CHECK: } func @log(%arg0: tensor<2xf32>) -> tensor<2xf32> { %0 = "mhlo.log"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } +// CHECK-LABEL: func @log_dynamic( +// CHECK-SAME: %[[VAL_0:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_1:.*]] = "tf.Log"(%[[VAL_0]]) : (tensor) -> tensor +// CHECK: return %[[VAL_1]] : tensor +// CHECK: } func @log_dynamic(%arg0: tensor) -> tensor { %0 = "mhlo.log"(%arg0) : (tensor) -> tensor return %0 : tensor } +// CHECK-LABEL: func @log_unranked( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Log"(%[[VAL_0]]) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return %[[VAL_1]] : tensor<*xf32> +// CHECK: } func @log_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { %0 = "mhlo.log"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } +// CHECK-LABEL: func @log1p( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Log1p"(%[[VAL_0]]) : (tensor<2xf32>) -> tensor<2xf32> +// CHECK: return %[[VAL_1]] : tensor<2xf32> +// CHECK: } func @log1p(%arg0: tensor<2xf32>) -> tensor<2xf32> { %0 = "mhlo.log_plus_one"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } +// CHECK-LABEL: func @log1p_dynamic( +// CHECK-SAME: %[[VAL_0:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_1:.*]] = "tf.Log1p"(%[[VAL_0]]) : (tensor) -> tensor +// CHECK: return %[[VAL_1]] : tensor +// CHECK: } func @log1p_dynamic(%arg0: tensor) -> tensor { %0 = "mhlo.log_plus_one"(%arg0) : (tensor) -> tensor return %0 : tensor } +// CHECK-LABEL: func @log1p_unranked( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Log1p"(%[[VAL_0]]) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return %[[VAL_1]] : tensor<*xf32> +// CHECK: } func @log1p_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { %0 = "mhlo.log_plus_one"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } +// CHECK-LABEL: func @neg( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Neg"(%[[VAL_0]]) : (tensor<2xf32>) -> tensor<2xf32> +// CHECK: return %[[VAL_1]] : tensor<2xf32> +// CHECK: } func @neg(%arg0: tensor<2xf32>) -> tensor<2xf32> { %0 = "mhlo.negate"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } +// CHECK-LABEL: func @neg_dynamic( +// CHECK-SAME: %[[VAL_0:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_1:.*]] = "tf.Neg"(%[[VAL_0]]) : (tensor) -> tensor +// CHECK: return %[[VAL_1]] : tensor +// CHECK: } func @neg_dynamic(%arg0: tensor) -> tensor { %0 = "mhlo.negate"(%arg0) : (tensor) -> tensor return %0 : tensor } +// CHECK-LABEL: func @neg_unranked( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Neg"(%[[VAL_0]]) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return %[[VAL_1]] : tensor<*xf32> +// CHECK: } func @neg_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { %0 = "mhlo.negate"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } +// CHECK-LABEL: func @sigmoid( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Const"() {value = dense<5.000000e-01> : tensor} : () -> tensor +// CHECK: %[[VAL_2:.*]] = "tf.Const"() {value = dense<2> : tensor<1xi64>} : () -> tensor<1xi64> +// CHECK: %[[VAL_3:.*]] = "tf.Const"() {value = dense<5.000000e-01> : tensor<2xf32>} : () -> tensor<2xf32> +// CHECK: %[[VAL_4:.*]] = "tf.Mul"(%[[VAL_0]], %[[VAL_3]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> +// CHECK: %[[VAL_5:.*]] = "tf.Tanh"(%[[VAL_4]]) : (tensor<2xf32>) -> tensor<2xf32> +// CHECK: %[[VAL_6:.*]] = "tf.Mul"(%[[VAL_5]], %[[VAL_3]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> +// CHECK: %[[VAL_7:.*]] = "tf.AddV2"(%[[VAL_6]], %[[VAL_3]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> +// CHECK: return %[[VAL_7]] : tensor<2xf32> +// CHECK: } func @sigmoid(%arg0: tensor<2xf32>) -> tensor<2xf32> { %0 = mhlo.constant dense<5.000000e-01> : tensor %1 = mhlo.constant dense<2> : tensor<1xi64> @@ -571,90 +1208,182 @@ func @sigmoid(%arg0: tensor<2xf32>) -> tensor<2xf32> { return %6 : tensor<2xf32> } +// CHECK-LABEL: func @sin( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Sin"(%[[VAL_0]]) : (tensor<2xf32>) -> tensor<2xf32> +// CHECK: return %[[VAL_1]] : tensor<2xf32> +// CHECK: } func @sin(%arg0: tensor<2xf32>) -> tensor<2xf32> { %0 = "mhlo.sine"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } +// CHECK-LABEL: func @sin_dynamic( +// CHECK-SAME: %[[VAL_0:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_1:.*]] = "tf.Sin"(%[[VAL_0]]) : (tensor) -> tensor +// CHECK: return %[[VAL_1]] : tensor +// CHECK: } func @sin_dynamic(%arg0: tensor) -> tensor { %0 = "mhlo.sine"(%arg0) : (tensor) -> tensor return %0 : tensor } +// CHECK-LABEL: func @sin_unranked( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Sin"(%[[VAL_0]]) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return %[[VAL_1]] : tensor<*xf32> +// CHECK: } func @sin_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { %0 = "mhlo.sine"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } +// CHECK-LABEL: func @rsqrt( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Rsqrt"(%[[VAL_0]]) : (tensor<2xf32>) -> tensor<2xf32> +// CHECK: return %[[VAL_1]] : tensor<2xf32> +// CHECK: } func @rsqrt(%arg0: tensor<2xf32>) -> tensor<2xf32> { %0 = "mhlo.rsqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } +// CHECK-LABEL: func @rsqrt_dynamic( +// CHECK-SAME: %[[VAL_0:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_1:.*]] = "tf.Rsqrt"(%[[VAL_0]]) : (tensor) -> tensor +// CHECK: return %[[VAL_1]] : tensor +// CHECK: } func @rsqrt_dynamic(%arg0: tensor) -> tensor { %0 = "mhlo.rsqrt"(%arg0) : (tensor) -> tensor return %0 : tensor } +// CHECK-LABEL: func @rsqrt_unranked( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Rsqrt"(%[[VAL_0]]) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return %[[VAL_1]] : tensor<*xf32> +// CHECK: } func @rsqrt_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { %0 = "mhlo.rsqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } +// CHECK-LABEL: func @sqrt( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Sqrt"(%[[VAL_0]]) : (tensor<2xf32>) -> tensor<2xf32> +// CHECK: return %[[VAL_1]] : tensor<2xf32> +// CHECK: } func @sqrt(%arg0: tensor<2xf32>) -> tensor<2xf32> { %0 = "mhlo.sqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } +// CHECK-LABEL: func @sqrt_dynamic( +// CHECK-SAME: %[[VAL_0:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_1:.*]] = "tf.Sqrt"(%[[VAL_0]]) : (tensor) -> tensor +// CHECK: return %[[VAL_1]] : tensor +// CHECK: } func @sqrt_dynamic(%arg0: tensor) -> tensor { %0 = "mhlo.sqrt"(%arg0) : (tensor) -> tensor return %0 : tensor } +// CHECK-LABEL: func @sqrt_unranked( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Sqrt"(%[[VAL_0]]) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return %[[VAL_1]] : tensor<*xf32> +// CHECK: } func @sqrt_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { %0 = "mhlo.sqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } +// CHECK-LABEL: func @tanh( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Tanh"(%[[VAL_0]]) : (tensor<2xf32>) -> tensor<2xf32> +// CHECK: return %[[VAL_1]] : tensor<2xf32> +// CHECK: } func @tanh(%arg0: tensor<2xf32>) -> tensor<2xf32> { %0 = "mhlo.tanh"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } +// CHECK-LABEL: func @tanh_dynamic( +// CHECK-SAME: %[[VAL_0:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_1:.*]] = "tf.Tanh"(%[[VAL_0]]) : (tensor) -> tensor +// CHECK: return %[[VAL_1]] : tensor +// CHECK: } func @tanh_dynamic(%arg0: tensor) -> tensor { %0 = "mhlo.tanh"(%arg0) : (tensor) -> tensor return %0 : tensor } +// CHECK-LABEL: func @tanh_unranked( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Tanh"(%[[VAL_0]]) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return %[[VAL_1]] : tensor<*xf32> +// CHECK: } func @tanh_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { %0 = "mhlo.tanh"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } +// CHECK-LABEL: func @bitcast( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Bitcast"(%[[VAL_0]]) : (tensor<2xf32>) -> tensor<2xf32> +// CHECK: return %[[VAL_1]] : tensor<2xf32> +// CHECK: } func @bitcast(%arg0: tensor<2xf32>) -> tensor<2xf32> { %0 = "mhlo.bitcast_convert"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } +// CHECK-LABEL: func @bitcast_dynamic( +// CHECK-SAME: %[[VAL_0:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_1:.*]] = "tf.Bitcast"(%[[VAL_0]]) : (tensor) -> tensor +// CHECK: return %[[VAL_1]] : tensor +// CHECK: } func @bitcast_dynamic(%arg0: tensor) -> tensor { %0 = "mhlo.bitcast_convert"(%arg0) : (tensor) -> tensor return %0 : tensor } +// CHECK-LABEL: func @bitcast_unranked( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Bitcast"(%[[VAL_0]]) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return %[[VAL_1]] : tensor<*xf32> +// CHECK: } func @bitcast_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { %0 = "mhlo.bitcast_convert"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } +// CHECK-LABEL: func @bitcast_same_widths( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xf32>) -> tensor<2xi32> { +// CHECK: %[[VAL_1:.*]] = "tf.Bitcast"(%[[VAL_0]]) : (tensor<2xf32>) -> tensor<2xi32> +// CHECK: return %[[VAL_1]] : tensor<2xi32> +// CHECK: } func @bitcast_same_widths(%arg0: tensor<2xf32>) -> tensor<2xi32> { %0 = "mhlo.bitcast_convert"(%arg0) : (tensor<2xf32>) -> tensor<2xi32> return %0 : tensor<2xi32> } -func @sign(%arg0: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> { - %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xi1> +// CHECK-LABEL: func @sign( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x2x3x4xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> { +// CHECK: %[[VAL_2:.*]] = "tf.NotEqual"(%[[VAL_0]], %[[VAL_1]]) {incompatible_shape_error = true} : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xi1> +// CHECK: %[[VAL_3:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1x2x3x4xf32>} : () -> tensor<1x2x3x4xf32> +// CHECK: %[[VAL_4:.*]] = "tf.NotEqual"(%[[VAL_0]], %[[VAL_1]]) {incompatible_shape_error = true} : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xi1> +// CHECK: %[[VAL_5:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1x2x3x4xf32>} : () -> tensor<1x2x3x4xf32> +// CHECK: %[[VAL_6:.*]] = "tf.Sign"(%[[VAL_0]]) : (tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> +// CHECK: %[[VAL_7:.*]] = "tf.Select"(%[[VAL_4]], %[[VAL_5]], %[[VAL_6]]) : (tensor<1x2x3x4xi1>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> +// CHECK: %[[VAL_8:.*]] = "tf.Select"(%[[VAL_2]], %[[VAL_3]], %[[VAL_7]]) : (tensor<1x2x3x4xi1>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> +// CHECK: return %[[VAL_8]] : tensor<1x2x3x4xf32> +// CHECK: } +func @sign(%arg0: tensor<1x2x3x4xf32>, %arg1: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> { + %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "NE"} : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xi1> %1 = mhlo.constant dense<0.000000e+00> : tensor<1x2x3x4xf32> - %2 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xi1> + %2 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "NE"} : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xi1> %3 = mhlo.constant dense<0.000000e+00> : tensor<1x2x3x4xf32> %4 = "mhlo.sign"(%arg0) : (tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> %5 = "mhlo.select"(%2, %3, %4) : (tensor<1x2x3x4xi1>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> @@ -662,72 +1391,180 @@ func @sign(%arg0: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> { return %6 : tensor<1x2x3x4xf32> } +// CHECK-LABEL: func @size_rank_one_i32( +// CHECK-SAME: %[[VAL_0:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_1:.*]] = "tf.Const"() {value = dense<1> : tensor} : () -> tensor +// CHECK: return %[[VAL_1]] : tensor +// CHECK: } func @size_rank_one_i32(%arg0: tensor) -> tensor { %0 = mhlo.constant dense<1> : tensor return %0 : tensor } +// CHECK-LABEL: func @size_rank_one_i64( +// CHECK-SAME: %[[VAL_0:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_1:.*]] = "tf.Const"() {value = dense<1> : tensor} : () -> tensor +// CHECK: return %[[VAL_1]] : tensor +// CHECK: } func @size_rank_one_i64(%arg0: tensor) -> tensor { %0 = mhlo.constant dense<1> : tensor return %0 : tensor } +// CHECK-LABEL: func @complex( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<3xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<3xf32>) -> tensor<3xcomplex> { +// CHECK: %[[VAL_2:.*]] = "tf.Complex"(%[[VAL_0]], %[[VAL_1]]) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xcomplex> +// CHECK: return %[[VAL_2]] : tensor<3xcomplex> +// CHECK: } func @complex(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xcomplex> { %0 = "mhlo.complex"(%arg0, %arg1) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xcomplex> return %0 : tensor<3xcomplex> } +// CHECK-LABEL: func @convert_i32_f32( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi32>) -> tensor<2xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Cast"(%[[VAL_0]]) {Truncate = false} : (tensor<2xi32>) -> tensor<2xf32> +// CHECK: return %[[VAL_1]] : tensor<2xf32> +// CHECK: } func @convert_i32_f32(%arg0: tensor<2xi32>) -> tensor<2xf32> { %0 = "mhlo.convert"(%arg0) : (tensor<2xi32>) -> tensor<2xf32> return %0 : tensor<2xf32> } +// CHECK-LABEL: func @convert_slice( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x4672xf32>) -> tensor<1x519xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Const"() {value = dense<[0, 4153]> : tensor<2xi64>} : () -> tensor<2xi64> +// CHECK: %[[VAL_2:.*]] = "tf.Const"() {value = dense<[1, 519]> : tensor<2xi64>} : () -> tensor<2xi64> +// CHECK: %[[VAL_3:.*]] = "tf.Slice"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) : (tensor<1x4672xf32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x519xf32> +// CHECK: return %[[VAL_3]] : tensor<1x519xf32> +// CHECK: } func @convert_slice(%arg0: tensor<1x4672xf32>) -> tensor<1x519xf32> { %0 = "mhlo.slice"(%arg0) {limit_indices = dense<[1, 4672]> : tensor<2xi64>, start_indices = dense<[0, 4153]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<1x4672xf32>) -> tensor<1x519xf32> return %0 : tensor<1x519xf32> } +// CHECK-LABEL: func @reshape( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x6xf32>) -> tensor<2x2x6xf32> { +// CHECK: %[[VAL_1:.*]] = constant dense<[2, 2, 6]> : tensor<3xi64> +// CHECK: %[[VAL_2:.*]] = "tf.Reshape"(%[[VAL_0]], %[[VAL_1]]) : (tensor<4x6xf32>, tensor<3xi64>) -> tensor<2x2x6xf32> +// CHECK: return %[[VAL_2]] : tensor<2x2x6xf32> +// CHECK: } func @reshape(%arg0: tensor<4x6xf32>) -> tensor<2x2x6xf32> { %0 = "mhlo.reshape"(%arg0) : (tensor<4x6xf32>) -> tensor<2x2x6xf32> return %0 : tensor<2x2x6xf32> } +// CHECK-LABEL: func @convert_dot_1d_2d( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<256xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<256x1xf32>) -> tensor<1xf32> { +// CHECK: %[[VAL_2:.*]] = constant dense<[1, 256]> : tensor<2xi64> +// CHECK: %[[VAL_3:.*]] = "tf.Reshape"(%[[VAL_0]], %[[VAL_2]]) : (tensor<256xf32>, tensor<2xi64>) -> tensor<1x256xf32> +// CHECK: %[[VAL_4:.*]] = "tf.MatMul"(%[[VAL_3]], %[[VAL_1]]) {transpose_a = false, transpose_b = false} : (tensor<1x256xf32>, tensor<256x1xf32>) -> tensor<1x1xf32> +// CHECK: %[[VAL_5:.*]] = constant dense<1> : tensor<1xi64> +// CHECK: %[[VAL_6:.*]] = "tf.Reshape"(%[[VAL_4]], %[[VAL_5]]) : (tensor<1x1xf32>, tensor<1xi64>) -> tensor<1xf32> +// CHECK: return %[[VAL_6]] : tensor<1xf32> +// CHECK: } func @convert_dot_1d_2d(%arg0: tensor<256xf32>, %arg1: tensor<256x1xf32>) -> tensor<1xf32> { %0 = "mhlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<256xf32>, tensor<256x1xf32>) -> tensor<1xf32> return %0 : tensor<1xf32> } +// CHECK-LABEL: func @convert_dot_2d_1d( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x256xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<256xf32>) -> tensor<1xf32> { +// CHECK: %[[VAL_2:.*]] = constant dense<[1, 256]> : tensor<2xi64> +// CHECK: %[[VAL_3:.*]] = "tf.Reshape"(%[[VAL_1]], %[[VAL_2]]) : (tensor<256xf32>, tensor<2xi64>) -> tensor<1x256xf32> +// CHECK: %[[VAL_4:.*]] = "tf.MatMul"(%[[VAL_0]], %[[VAL_3]]) {transpose_a = false, transpose_b = true} : (tensor<1x256xf32>, tensor<1x256xf32>) -> tensor<1x1xf32> +// CHECK: %[[VAL_5:.*]] = constant dense<1> : tensor<1xi64> +// CHECK: %[[VAL_6:.*]] = "tf.Reshape"(%[[VAL_4]], %[[VAL_5]]) : (tensor<1x1xf32>, tensor<1xi64>) -> tensor<1xf32> +// CHECK: return %[[VAL_6]] : tensor<1xf32> +// CHECK: } func @convert_dot_2d_1d(%arg0: tensor<1x256xf32>, %arg1: tensor<256xf32>) -> tensor<1xf32> { %0 = "mhlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x256xf32>, tensor<256xf32>) -> tensor<1xf32> return %0 : tensor<1xf32> } +// CHECK-LABEL: func @convert_dot_1d_1d( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<256xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<256xf32>) -> tensor { +// CHECK: %[[VAL_2:.*]] = constant dense<[1, 256]> : tensor<2xi64> +// CHECK: %[[VAL_3:.*]] = "tf.Reshape"(%[[VAL_0]], %[[VAL_2]]) : (tensor<256xf32>, tensor<2xi64>) -> tensor<1x256xf32> +// CHECK: %[[VAL_4:.*]] = constant dense<[1, 256]> : tensor<2xi64> +// CHECK: %[[VAL_5:.*]] = "tf.Reshape"(%[[VAL_1]], %[[VAL_4]]) : (tensor<256xf32>, tensor<2xi64>) -> tensor<1x256xf32> +// CHECK: %[[VAL_6:.*]] = "tf.MatMul"(%[[VAL_3]], %[[VAL_5]]) {transpose_a = false, transpose_b = true} : (tensor<1x256xf32>, tensor<1x256xf32>) -> tensor<1x1xf32> +// CHECK: %[[VAL_7:.*]] = constant dense<> : tensor<0xi64> +// CHECK: %[[VAL_8:.*]] = "tf.Reshape"(%[[VAL_6]], %[[VAL_7]]) : (tensor<1x1xf32>, tensor<0xi64>) -> tensor +// CHECK: return %[[VAL_8]] : tensor +// CHECK: } func @convert_dot_1d_1d(%arg0: tensor<256xf32>, %arg1: tensor<256xf32>) -> tensor { %0 = "mhlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<256xf32>, tensor<256xf32>) -> tensor return %0 : tensor } +// CHECK-LABEL: func @convert_dot_2d_2d( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x256xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<256x1xf32>) -> tensor<1x1xf32> { +// CHECK: %[[VAL_2:.*]] = "tf.MatMul"(%[[VAL_0]], %[[VAL_1]]) {transpose_a = false, transpose_b = false} : (tensor<1x256xf32>, tensor<256x1xf32>) -> tensor<1x1xf32> +// CHECK: return %[[VAL_2]] : tensor<1x1xf32> +// CHECK: } func @convert_dot_2d_2d(%arg0: tensor<1x256xf32>, %arg1: tensor<256x1xf32>) -> tensor<1x1xf32> { %0 = "mhlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x256xf32>, tensor<256x1xf32>) -> tensor<1x1xf32> return %0 : tensor<1x1xf32> } +// CHECK-LABEL: func @broadcast_in_dim_tf_style( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<8x1x16xf32>) -> tensor<3x8x8x16xf32> { +// CHECK: %[[VAL_1:.*]] = constant dense<[3, 8, 8, 16]> : tensor<4xi64> +// CHECK: %[[VAL_2:.*]] = "tf.BroadcastTo"(%[[VAL_0]], %[[VAL_1]]) : (tensor<8x1x16xf32>, tensor<4xi64>) -> tensor<3x8x8x16xf32> +// CHECK: return %[[VAL_2]] : tensor<3x8x8x16xf32> +// CHECK: } func @broadcast_in_dim_tf_style(%arg0: tensor<8x1x16xf32>) -> tensor<3x8x8x16xf32> { %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>, name = "broadcast.0"} : (tensor<8x1x16xf32>) -> tensor<3x8x8x16xf32> return %0 : tensor<3x8x8x16xf32> } +// CHECK-LABEL: func @broadcast_in_dim_general_case( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<3x1x16xf32>) -> tensor<3x8x8x16xf32> { +// CHECK: %[[VAL_1:.*]] = constant dense<[3, 1, 1, 16]> : tensor<4xi64> +// CHECK: %[[VAL_2:.*]] = "tf.Reshape"(%[[VAL_0]], %[[VAL_1]]) : (tensor<3x1x16xf32>, tensor<4xi64>) -> tensor<3x1x1x16xf32> +// CHECK: %[[VAL_3:.*]] = constant dense<[3, 8, 8, 16]> : tensor<4xi64> +// CHECK: %[[VAL_4:.*]] = "tf.BroadcastTo"(%[[VAL_2]], %[[VAL_3]]) : (tensor<3x1x1x16xf32>, tensor<4xi64>) -> tensor<3x8x8x16xf32> +// CHECK: return %[[VAL_4]] : tensor<3x8x8x16xf32> +// CHECK: } func @broadcast_in_dim_general_case(%arg0: tensor<3x1x16xf32>) -> tensor<3x8x8x16xf32> { %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 2, 3]> : tensor<3xi64>, name = "broadcast.0"} : (tensor<3x1x16xf32>) -> tensor<3x8x8x16xf32> return %0 : tensor<3x8x8x16xf32> } +// CHECK-LABEL: func @convert_dot_general( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<3x2x6x5x1xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<3x2x4x6xf32>) -> tensor<3x5x1x4xf32> { +// CHECK: %[[VAL_2:.*]] = "tf.Const"() {value = dense<[0, 3, 4, 1, 2]> : tensor<5xi64>} : () -> tensor<5xi64> +// CHECK: %[[VAL_3:.*]] = "tf.Transpose"(%[[VAL_0]], %[[VAL_2]]) : (tensor<3x2x6x5x1xf32>, tensor<5xi64>) -> tensor<3x5x1x2x6xf32> +// CHECK: %[[VAL_4:.*]] = "tf.Const"() {value = dense<[0, 1, 3, 2]> : tensor<4xi64>} : () -> tensor<4xi64> +// CHECK: %[[VAL_5:.*]] = "tf.Transpose"(%[[VAL_1]], %[[VAL_4]]) : (tensor<3x2x4x6xf32>, tensor<4xi64>) -> tensor<3x2x6x4xf32> +// CHECK: %[[VAL_6:.*]] = constant dense<[3, 5, 12]> : tensor<3xi64> +// CHECK: %[[VAL_7:.*]] = "tf.Reshape"(%[[VAL_3]], %[[VAL_6]]) : (tensor<3x5x1x2x6xf32>, tensor<3xi64>) -> tensor<3x5x12xf32> +// CHECK: %[[VAL_8:.*]] = constant dense<[3, 12, 4]> : tensor<3xi64> +// CHECK: %[[VAL_9:.*]] = "tf.Reshape"(%[[VAL_5]], %[[VAL_8]]) : (tensor<3x2x6x4xf32>, tensor<3xi64>) -> tensor<3x12x4xf32> +// CHECK: %[[VAL_10:.*]] = "tf.BatchMatMulV2"(%[[VAL_7]], %[[VAL_9]]) {adj_x = false, adj_y = false} : (tensor<3x5x12xf32>, tensor<3x12x4xf32>) -> tensor<3x5x4xf32> +// CHECK: %[[VAL_11:.*]] = constant dense<[3, 5, 1, 4]> : tensor<4xi64> +// CHECK: %[[VAL_12:.*]] = "tf.Reshape"(%[[VAL_10]], %[[VAL_11]]) : (tensor<3x5x4xf32>, tensor<4xi64>) -> tensor<3x5x1x4xf32> +// CHECK: return %[[VAL_12]] : tensor<3x5x1x4xf32> +// CHECK: } func @convert_dot_general(%arg0: tensor<3x2x6x5x1xf32>, %arg1: tensor<3x2x4x6xf32>) -> tensor<3x5x1x4xf32> { %0 = "mhlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<[1, 2]> : tensor<2xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, rhs_contracting_dimensions = dense<[1, 3]> : tensor<2xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<3x2x6x5x1xf32>, tensor<3x2x4x6xf32>) -> tensor<3x5x1x4xf32> return %0 : tensor<3x5x1x4xf32> } +// CHECK-LABEL: func @convert_conv2d( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x8x8x207xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { +// CHECK: %[[VAL_2:.*]] = "tf.Conv2D"(%[[VAL_0]], %[[VAL_1]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> +// CHECK: return %[[VAL_2]] : tensor<1x8x8x16xf32> +// CHECK: } func @convert_conv2d(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { %0 = "mhlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, @@ -736,6 +1573,12 @@ func @convert_conv2d(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32> return %0 : tensor<1x8x8x16xf32> } +// CHECK-LABEL: func @convert_depthwise_conv2d( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x8x8x207xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { +// CHECK: %[[VAL_2:.*]] = "tf.DepthwiseConv2dNative"(%[[VAL_0]], %[[VAL_1]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> +// CHECK: return %[[VAL_2]] : tensor<1x8x8x16xf32> +// CHECK: } func @convert_depthwise_conv2d(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { %0 = "mhlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, @@ -744,6 +1587,12 @@ func @convert_depthwise_conv2d(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x2 return %0 : tensor<1x8x8x16xf32> } +// CHECK-LABEL: func @convert_conv2d_valid_padding( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x8x8x207xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { +// CHECK: %[[VAL_2:.*]] = "tf.Conv2D"(%[[VAL_0]], %[[VAL_1]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> +// CHECK: return %[[VAL_2]] : tensor<1x8x8x16xf32> +// CHECK: } func @convert_conv2d_valid_padding(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { %0 = "mhlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, @@ -752,6 +1601,13 @@ func @convert_conv2d_valid_padding(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3 return %0 : tensor<1x8x8x16xf32> } +// CHECK-LABEL: func @convert_reduce_to_sum( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x256xf32>) -> tensor<1xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor} : () -> tensor +// CHECK: %[[VAL_2:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64> +// CHECK: %[[VAL_3:.*]] = "tf.Sum"(%[[VAL_0]], %[[VAL_2]]) {keep_dims = false} : (tensor<1x256xf32>, tensor<1xi64>) -> tensor<1xf32> +// CHECK: return %[[VAL_3]] : tensor<1xf32> +// CHECK: } func @convert_reduce_to_sum(%arg0: tensor<1x256xf32>) -> tensor<1xf32> { %0 = mhlo.constant dense<0.000000e+00> : tensor %1 = "mhlo.reduce"(%arg0, %0) ( { @@ -762,6 +1618,13 @@ func @convert_reduce_to_sum(%arg0: tensor<1x256xf32>) -> tensor<1xf32> { return %1 : tensor<1xf32> } +// CHECK-LABEL: func @convert_reduce_to_max( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x256xf32>) -> tensor<1xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Const"() {value = dense<0xFF800000> : tensor} : () -> tensor +// CHECK: %[[VAL_2:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64> +// CHECK: %[[VAL_3:.*]] = "tf.Max"(%[[VAL_0]], %[[VAL_2]]) {keep_dims = false} : (tensor<1x256xf32>, tensor<1xi64>) -> tensor<1xf32> +// CHECK: return %[[VAL_3]] : tensor<1xf32> +// CHECK: } func @convert_reduce_to_max(%arg0: tensor<1x256xf32>) -> tensor<1xf32> { // "0xFF800000" represents -INF for f32. %0 = mhlo.constant dense<0xFF800000> : tensor @@ -773,7 +1636,13 @@ func @convert_reduce_to_max(%arg0: tensor<1x256xf32>) -> tensor<1xf32> { return %1 : tensor<1xf32> } - +// CHECK-LABEL: func @convert_reduce_to_min( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x256xf32>) -> tensor<1xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Const"() {value = dense<0x7F800000> : tensor} : () -> tensor +// CHECK: %[[VAL_2:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64> +// CHECK: %[[VAL_3:.*]] = "tf.Min"(%[[VAL_0]], %[[VAL_2]]) {keep_dims = false} : (tensor<1x256xf32>, tensor<1xi64>) -> tensor<1xf32> +// CHECK: return %[[VAL_3]] : tensor<1xf32> +// CHECK: } func @convert_reduce_to_min(%arg0: tensor<1x256xf32>) -> tensor<1xf32> { // "0x7F800000" represents INF for f32. %0 = mhlo.constant dense<0x7F800000> : tensor @@ -785,928 +1654,31 @@ func @convert_reduce_to_min(%arg0: tensor<1x256xf32>) -> tensor<1xf32> { return %1 : tensor<1xf32> } +// CHECK-LABEL: func @convert_iota_1d() -> tensor<123xf32> { +// CHECK: %[[VAL_0:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor} : () -> tensor +// CHECK: %[[VAL_1:.*]] = "tf.Const"() {value = dense<1.230000e+02> : tensor} : () -> tensor +// CHECK: %[[VAL_2:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor} : () -> tensor +// CHECK: %[[VAL_3:.*]] = "tf.Range"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) : (tensor, tensor, tensor) -> tensor<123xf32> +// CHECK: return %[[VAL_3]] : tensor<123xf32> +// CHECK: } +func @convert_iota_1d() -> tensor<123xf32> { + %0 = "mhlo.iota"() { iota_dimension = 0 : i64 } : () -> tensor<123xf32> + return %0 : tensor<123xf32> +} + +// CHECK-LABEL: func @convert_iota_3d() -> tensor<5x7x9xi32> { +// CHECK: %[[VAL_0:.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor +// CHECK: %[[VAL_1:.*]] = "tf.Const"() {value = dense<7> : tensor} : () -> tensor +// CHECK: %[[VAL_2:.*]] = "tf.Const"() {value = dense<1> : tensor} : () -> tensor +// CHECK: %[[VAL_3:.*]] = "tf.Range"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) : (tensor, tensor, tensor) -> tensor<7xi32> +// CHECK: %[[VAL_4:.*]] = "tf.Const"() {value = dense<[1, 7, 1]> : tensor<3xi64>} : () -> tensor<3xi64> +// CHECK: %[[VAL_5:.*]] = "tf.Reshape"(%[[VAL_3]], %[[VAL_4]]) : (tensor<7xi32>, tensor<3xi64>) -> tensor<1x7x1xi32> +// CHECK: %[[VAL_6:.*]] = "tf.Const"() {value = dense<[5, 7, 9]> : tensor<3xi64>} : () -> tensor<3xi64> +// CHECK: %[[VAL_7:.*]] = "tf.BroadcastTo"(%[[VAL_5]], %[[VAL_6]]) : (tensor<1x7x1xi32>, tensor<3xi64>) -> tensor<5x7x9xi32> +// CHECK: return %[[VAL_7]] : tensor<5x7x9xi32> +// CHECK: } +func @convert_iota_3d() -> tensor<5x7x9xi32> { + %0 = "mhlo.iota"() { iota_dimension = 1 : i64 } : () -> tensor<5x7x9xi32> + return %0 : tensor<5x7x9xi32> +} - -// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py - -// CHECK-LABEL: func @biasAdd_NHWC( -// CHECK-SAME: [[VAL_0:%.*]]: tensor<1x32x10x32xi32>, [[VAL_1:%.*]]: tensor<32xi32>) -> tensor<1x32x10x32xi32> { -// CHECK: [[VAL_2:%.*]] = "tf.AddV2"([[VAL_0]], [[VAL_1]]) : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> -// CHECK: return [[VAL_2]] : tensor<1x32x10x32xi32> -// CHECK: } - -// CHECK-LABEL: func @biasAdd_NCHW( -// CHECK-SAME: [[VAL_3:%.*]]: tensor<1x32x10x32xi32>, [[VAL_4:%.*]]: tensor<32xi32>) -> tensor<1x32x10x32xi32> { -// CHECK: [[VAL_5:%.*]] = "tf.AddV2"([[VAL_3]], [[VAL_4]]) : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> -// CHECK: return [[VAL_5]] : tensor<1x32x10x32xi32> -// CHECK: } - -// CHECK-LABEL: func @biasAdd_dynamic( -// CHECK-SAME: [[VAL_6:%.*]]: tensor, [[VAL_7:%.*]]: tensor) -> tensor { -// CHECK: [[VAL_8:%.*]] = "tf.AddV2"([[VAL_6]], [[VAL_7]]) : (tensor, tensor) -> tensor -// CHECK: return [[VAL_8]] : tensor -// CHECK: } - -// CHECK-LABEL: func @add( -// CHECK-SAME: [[VAL_9:%.*]]: tensor<2xi32>) -> tensor<2xi32> { -// CHECK: [[VAL_10:%.*]] = "tf.AddV2"([[VAL_9]], [[VAL_9]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> -// CHECK: [[VAL_11:%.*]] = "tf.AddV2"([[VAL_10]], [[VAL_9]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> -// CHECK: return [[VAL_11]] : tensor<2xi32> -// CHECK: } - -// CHECK-LABEL: func @broadcast_add( -// CHECK-SAME: [[VAL_12:%.*]]: tensor<1xi32>, [[VAL_13:%.*]]: tensor<1x2xi32>) -> tensor<1x2xi32> { -// CHECK: [[VAL_14:%.*]] = "tf.AddV2"([[VAL_12]], [[VAL_13]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> -// CHECK: return [[VAL_14]] : tensor<1x2xi32> -// CHECK: } - -// CHECK-LABEL: func @broadcast_multi_dim_add( -// CHECK-SAME: [[VAL_15:%.*]]: tensor<4x1x1xi32>, [[VAL_16:%.*]]: tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> { -// CHECK: [[VAL_17:%.*]] = "tf.AddV2"([[VAL_15]], [[VAL_16]]) : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> -// CHECK: return [[VAL_17]] : tensor<4x4x4x4xi32> -// CHECK: } - -// CHECK-LABEL: func @div( -// CHECK-SAME: [[VAL_18:%.*]]: tensor<2xi32>) -> tensor<2xi32> { -// CHECK: [[VAL_19:%.*]] = "tf.Div"([[VAL_18]], [[VAL_18]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> -// CHECK: return [[VAL_19]] : tensor<2xi32> -// CHECK: } - -// CHECK-LABEL: func @broadcast_div( -// CHECK-SAME: [[VAL_20:%.*]]: tensor<1xi32>, [[VAL_21:%.*]]: tensor<1x2xi32>) -> tensor<1x2xi32> { -// CHECK: [[VAL_22:%.*]] = "tf.Div"([[VAL_20]], [[VAL_21]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> -// CHECK: return [[VAL_22]] : tensor<1x2xi32> -// CHECK: } - -// CHECK-LABEL: func @shift_left( -// CHECK-SAME: [[VAL_23:%.*]]: tensor<4xi32>, [[VAL_24:%.*]]: tensor<4xi32>) -> tensor<4xi32> { -// CHECK: [[VAL_25:%.*]] = "tf.LeftShift"([[VAL_23]], [[VAL_24]]) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> -// CHECK: return [[VAL_25]] : tensor<4xi32> -// CHECK: } - -// CHECK-LABEL: func @div_dynamic( -// CHECK-SAME: [[VAL_26:%.*]]: tensor, [[VAL_27:%.*]]: tensor) -> tensor { -// CHECK: [[VAL_28:%.*]] = "tf.Div"([[VAL_26]], [[VAL_27]]) : (tensor, tensor) -> tensor -// CHECK: return [[VAL_28]] : tensor -// CHECK: } - -// CHECK-LABEL: func @maximum( -// CHECK-SAME: [[VAL_29:%.*]]: tensor<4xf32>, [[VAL_30:%.*]]: tensor<4xf32>) -> tensor<4xf32> { -// CHECK: [[VAL_31:%.*]] = "tf.Maximum"([[VAL_29]], [[VAL_30]]) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> -// CHECK: return [[VAL_31]] : tensor<4xf32> -// CHECK: } - -// CHECK-LABEL: func @minimum( -// CHECK-SAME: [[VAL_32:%.*]]: tensor<4xf32>, [[VAL_33:%.*]]: tensor<4xf32>) -> tensor<4xf32> { -// CHECK: [[VAL_34:%.*]] = "tf.Minimum"([[VAL_32]], [[VAL_33]]) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> -// CHECK: return [[VAL_34]] : tensor<4xf32> -// CHECK: } - -// CHECK-LABEL: func @mul( -// CHECK-SAME: [[VAL_35:%.*]]: tensor<2xi32>) -> tensor<2xi32> { -// CHECK: [[VAL_36:%.*]] = "tf.Mul"([[VAL_35]], [[VAL_35]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> -// CHECK: return [[VAL_36]] : tensor<2xi32> -// CHECK: } - -// CHECK-LABEL: func @broadcast_mul( -// CHECK-SAME: [[VAL_37:%.*]]: tensor<1xi32>, [[VAL_38:%.*]]: tensor<1x2xi32>) -> tensor<1x2xi32> { -// CHECK: [[VAL_39:%.*]] = "tf.Mul"([[VAL_37]], [[VAL_38]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> -// CHECK: return [[VAL_39]] : tensor<1x2xi32> -// CHECK: } - -// CHECK-LABEL: func @real_div( -// CHECK-SAME: [[VAL_40:%.*]]: tensor<2xi32>) -> tensor<2xi32> { -// CHECK: [[VAL_41:%.*]] = "tf.Div"([[VAL_40]], [[VAL_40]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> -// CHECK: return [[VAL_41]] : tensor<2xi32> -// CHECK: } - -// CHECK-LABEL: func @broadcast_real_div( -// CHECK-SAME: [[VAL_42:%.*]]: tensor<1xi32>, [[VAL_43:%.*]]: tensor<1x2xi32>) -> tensor<1x2xi32> { -// CHECK: [[VAL_44:%.*]] = "tf.Div"([[VAL_42]], [[VAL_43]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> -// CHECK: return [[VAL_44]] : tensor<1x2xi32> -// CHECK: } - -// CHECK-LABEL: func @sub( -// CHECK-SAME: [[VAL_45:%.*]]: tensor<2xi32>) -> tensor<2xi32> { -// CHECK: [[VAL_46:%.*]] = "tf.Sub"([[VAL_45]], [[VAL_45]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> -// CHECK: return [[VAL_46]] : tensor<2xi32> -// CHECK: } - -// CHECK-LABEL: func @broadcast_sub( -// CHECK-SAME: [[VAL_47:%.*]]: tensor<1xi32>, [[VAL_48:%.*]]: tensor<1x2xi32>) -> tensor<1x2xi32> { -// CHECK: [[VAL_49:%.*]] = "tf.Sub"([[VAL_47]], [[VAL_48]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> -// CHECK: return [[VAL_49]] : tensor<1x2xi32> -// CHECK: } - -// CHECK-LABEL: func @shift_right( -// CHECK-SAME: [[VAL_50:%.*]]: tensor<4xi32>, [[VAL_51:%.*]]: tensor<4xi32>) -> tensor<4xi32> { -// CHECK: [[VAL_52:%.*]] = "tf.RightShift"([[VAL_50]], [[VAL_51]]) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> -// CHECK: return [[VAL_52]] : tensor<4xi32> -// CHECK: } - -// CHECK-LABEL: func @broadcast_shift_right( -// CHECK-SAME: [[VAL_53:%.*]]: tensor<4xi32>, [[VAL_54:%.*]]: tensor<2x4xi32>) -> tensor<2x4xi32> { -// CHECK: [[VAL_55:%.*]] = "tf.RightShift"([[VAL_53]], [[VAL_54]]) : (tensor<4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32> -// CHECK: return [[VAL_55]] : tensor<2x4xi32> -// CHECK: } - -// CHECK-LABEL: func @and( -// CHECK-SAME: [[VAL_56:%.*]]: tensor<2xi1>) -> tensor<2xi1> { -// CHECK: [[VAL_57:%.*]] = "tf.LogicalAnd"([[VAL_56]], [[VAL_56]]) : (tensor<2xi1>, tensor<2xi1>) -> tensor<2xi1> -// CHECK: return [[VAL_57]] : tensor<2xi1> -// CHECK: } - -// CHECK-LABEL: func @and_broadcast( -// CHECK-SAME: [[VAL_58:%.*]]: tensor<1xi1>, [[VAL_59:%.*]]: tensor<1x2xi1>) -> tensor<1x2xi1> { -// CHECK: [[VAL_60:%.*]] = "tf.LogicalAnd"([[VAL_58]], [[VAL_59]]) : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1> -// CHECK: return [[VAL_60]] : tensor<1x2xi1> -// CHECK: } - -// CHECK-LABEL: func @and_dynamic( -// CHECK-SAME: [[VAL_61:%.*]]: tensor, [[VAL_62:%.*]]: tensor<1xi1>) -> tensor { -// CHECK: [[VAL_63:%.*]] = "tf.LogicalAnd"([[VAL_61]], [[VAL_62]]) : (tensor, tensor<1xi1>) -> tensor -// CHECK: return [[VAL_63]] : tensor -// CHECK: } - -// CHECK-LABEL: func @or( -// CHECK-SAME: [[VAL_64:%.*]]: tensor<2xi1>) -> tensor<2xi1> { -// CHECK: [[VAL_65:%.*]] = "tf.LogicalOr"([[VAL_64]], [[VAL_64]]) : (tensor<2xi1>, tensor<2xi1>) -> tensor<2xi1> -// CHECK: return [[VAL_65]] : tensor<2xi1> -// CHECK: } - -// CHECK-LABEL: func @or_broadcast( -// CHECK-SAME: [[VAL_66:%.*]]: tensor<1xi1>, [[VAL_67:%.*]]: tensor<1x2xi1>) -> tensor<1x2xi1> { -// CHECK: [[VAL_68:%.*]] = "tf.LogicalOr"([[VAL_66]], [[VAL_67]]) : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1> -// CHECK: return [[VAL_68]] : tensor<1x2xi1> -// CHECK: } - -// CHECK-LABEL: func @or_dynamic( -// CHECK-SAME: [[VAL_69:%.*]]: tensor, [[VAL_70:%.*]]: tensor<1xi1>) -> tensor { -// CHECK: [[VAL_71:%.*]] = "tf.LogicalOr"([[VAL_69]], [[VAL_70]]) : (tensor, tensor<1xi1>) -> tensor -// CHECK: return [[VAL_71]] : tensor -// CHECK: } - -// CHECK-LABEL: func @bitwise_or( -// CHECK-SAME: [[VAL_72:%.*]]: tensor<4xi32>, [[VAL_73:%.*]]: tensor<4xi32>) -> tensor<4xi32> { -// CHECK: [[VAL_74:%.*]] = "tf.BitwiseOr"([[VAL_72]], [[VAL_73]]) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> -// CHECK: return [[VAL_74]] : tensor<4xi32> -// CHECK: } - -// CHECK-LABEL: func @bitwise_or_broadcast( -// CHECK-SAME: [[VAL_75:%.*]]: tensor<1xi8>, [[VAL_76:%.*]]: tensor<1x4xi8>) -> tensor<1x4xi8> { -// CHECK: [[VAL_77:%.*]] = "tf.BitwiseOr"([[VAL_75]], [[VAL_76]]) : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8> -// CHECK: return [[VAL_77]] : tensor<1x4xi8> -// CHECK: } - -// CHECK-LABEL: func @bitwise_or_dynamic( -// CHECK-SAME: [[VAL_78:%.*]]: tensor, [[VAL_79:%.*]]: tensor<1xi32>) -> tensor { -// CHECK: [[VAL_80:%.*]] = "tf.BitwiseOr"([[VAL_78]], [[VAL_79]]) : (tensor, tensor<1xi32>) -> tensor -// CHECK: return [[VAL_80]] : tensor -// CHECK: } - -// CHECK-LABEL: func @bitwise_and( -// CHECK-SAME: [[VAL_81:%.*]]: tensor<4xi32>, [[VAL_82:%.*]]: tensor<4xi32>) -> tensor<4xi32> { -// CHECK: [[VAL_83:%.*]] = "tf.BitwiseAnd"([[VAL_81]], [[VAL_82]]) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> -// CHECK: return [[VAL_83]] : tensor<4xi32> -// CHECK: } - -// CHECK-LABEL: func @bitwise_and_broadcast( -// CHECK-SAME: [[VAL_84:%.*]]: tensor<1xi8>, [[VAL_85:%.*]]: tensor<1x4xi8>) -> tensor<1x4xi8> { -// CHECK: [[VAL_86:%.*]] = "tf.BitwiseAnd"([[VAL_84]], [[VAL_85]]) : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8> -// CHECK: return [[VAL_86]] : tensor<1x4xi8> -// CHECK: } - -// CHECK-LABEL: func @bitwise_and_dynamic( -// CHECK-SAME: [[VAL_87:%.*]]: tensor, [[VAL_88:%.*]]: tensor<1xi32>) -> tensor { -// CHECK: [[VAL_89:%.*]] = "tf.BitwiseAnd"([[VAL_87]], [[VAL_88]]) : (tensor, tensor<1xi32>) -> tensor -// CHECK: return [[VAL_89]] : tensor -// CHECK: } - -// CHECK-LABEL: func @pow( -// CHECK-SAME: [[VAL_90:%.*]]: tensor<2xf32>) -> tensor<2xf32> { -// CHECK: [[VAL_91:%.*]] = "tf.Pow"([[VAL_90]], [[VAL_90]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> -// CHECK: return [[VAL_91]] : tensor<2xf32> -// CHECK: } - -// CHECK-LABEL: func @pow_dynamic( -// CHECK-SAME: [[VAL_92:%.*]]: tensor) -> tensor { -// CHECK: [[VAL_93:%.*]] = "tf.Pow"([[VAL_92]], [[VAL_92]]) : (tensor, tensor) -> tensor -// CHECK: return [[VAL_93]] : tensor -// CHECK: } - -// CHECK-LABEL: func @floordiv_broadcast_i32( -// CHECK-SAME: [[VAL_94:%.*]]: tensor<2x3xi32>, [[VAL_95:%.*]]: tensor<3xi32>) -> tensor<2x3xi32> { -// CHECK: [[VAL_96:%.*]] = "tf.Const"() {value = dense<0> : tensor<2x3xi32>} : () -> tensor<2x3xi32> -// CHECK: [[VAL_97:%.*]] = "tf.Less"([[VAL_94]], [[VAL_96]]) : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1> -// CHECK: [[VAL_98:%.*]] = "tf.Const"() {value = dense<0> : tensor<3xi32>} : () -> tensor<3xi32> -// CHECK: [[VAL_99:%.*]] = "tf.Less"([[VAL_95]], [[VAL_98]]) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> -// CHECK: [[VAL_100:%.*]] = "tf.Equal"([[VAL_97]], [[VAL_99]]) {incompatible_shape_error = true} : (tensor<2x3xi1>, tensor<3xi1>) -> tensor<2x3xi1> -// CHECK: [[VAL_101:%.*]] = "tf.Div"([[VAL_94]], [[VAL_95]]) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> -// CHECK: [[VAL_102:%.*]] = "tf.Abs"([[VAL_94]]) : (tensor<2x3xi32>) -> tensor<2x3xi32> -// CHECK: [[VAL_103:%.*]] = "tf.Abs"([[VAL_95]]) : (tensor<3xi32>) -> tensor<3xi32> -// CHECK: [[VAL_104:%.*]] = "tf.Const"() {value = dense<1> : tensor<3xi32>} : () -> tensor<3xi32> -// CHECK: [[VAL_105:%.*]] = "tf.Sub"([[VAL_103]], [[VAL_104]]) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32> -// CHECK: [[VAL_106:%.*]] = "tf.AddV2"([[VAL_102]], [[VAL_105]]) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> -// CHECK: [[VAL_107:%.*]] = "tf.Neg"([[VAL_106]]) : (tensor<2x3xi32>) -> tensor<2x3xi32> -// CHECK: [[VAL_108:%.*]] = "tf.Abs"([[VAL_95]]) : (tensor<3xi32>) -> tensor<3xi32> -// CHECK: [[VAL_109:%.*]] = "tf.Div"([[VAL_107]], [[VAL_108]]) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> -// CHECK: [[VAL_110:%.*]] = "tf.Select"([[VAL_100]], [[VAL_101]], [[VAL_109]]) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> -// CHECK: return [[VAL_110]] : tensor<2x3xi32> -// CHECK: } - -// CHECK-LABEL: func @floordiv_reverse_broadcast_i32( -// CHECK-SAME: [[VAL_111:%.*]]: tensor<3xi32>, [[VAL_112:%.*]]: tensor<2x3xi32>) -> tensor<2x3xi32> { -// CHECK: [[VAL_113:%.*]] = "tf.Const"() {value = dense<0> : tensor<3xi32>} : () -> tensor<3xi32> -// CHECK: [[VAL_114:%.*]] = "tf.Less"([[VAL_111]], [[VAL_113]]) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> -// CHECK: [[VAL_115:%.*]] = "tf.Const"() {value = dense<0> : tensor<2x3xi32>} : () -> tensor<2x3xi32> -// CHECK: [[VAL_116:%.*]] = "tf.Less"([[VAL_112]], [[VAL_115]]) : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1> -// CHECK: [[VAL_117:%.*]] = "tf.Equal"([[VAL_114]], [[VAL_116]]) {incompatible_shape_error = true} : (tensor<3xi1>, tensor<2x3xi1>) -> tensor<2x3xi1> -// CHECK: [[VAL_118:%.*]] = "tf.Div"([[VAL_111]], [[VAL_112]]) : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> -// CHECK: [[VAL_119:%.*]] = "tf.Abs"([[VAL_111]]) : (tensor<3xi32>) -> tensor<3xi32> -// CHECK: [[VAL_120:%.*]] = "tf.Abs"([[VAL_112]]) : (tensor<2x3xi32>) -> tensor<2x3xi32> -// CHECK: [[VAL_121:%.*]] = "tf.Const"() {value = dense<1> : tensor<2x3xi32>} : () -> tensor<2x3xi32> -// CHECK: [[VAL_122:%.*]] = "tf.Sub"([[VAL_120]], [[VAL_121]]) : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> -// CHECK: [[VAL_123:%.*]] = "tf.AddV2"([[VAL_119]], [[VAL_122]]) : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> -// CHECK: [[VAL_124:%.*]] = "tf.Neg"([[VAL_123]]) : (tensor<2x3xi32>) -> tensor<2x3xi32> -// CHECK: [[VAL_125:%.*]] = "tf.Abs"([[VAL_112]]) : (tensor<2x3xi32>) -> tensor<2x3xi32> -// CHECK: [[VAL_126:%.*]] = "tf.Div"([[VAL_124]], [[VAL_125]]) : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> -// CHECK: [[VAL_127:%.*]] = "tf.Select"([[VAL_117]], [[VAL_118]], [[VAL_126]]) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> -// CHECK: return [[VAL_127]] : tensor<2x3xi32> -// CHECK: } - -// CHECK-LABEL: func @floordiv_f32( -// CHECK-SAME: [[VAL_128:%.*]]: tensor<2xf32>) -> tensor<2xf32> { -// CHECK: [[VAL_129:%.*]] = "tf.Div"([[VAL_128]], [[VAL_128]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> -// CHECK: [[VAL_130:%.*]] = "tf.Div"([[VAL_128]], [[VAL_128]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> -// CHECK: [[VAL_131:%.*]] = "tf.FloorDiv"([[VAL_128]], [[VAL_128]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> -// CHECK: return [[VAL_131]] : tensor<2xf32> -// CHECK: } - -// CHECK-LABEL: func @floordiv_f16_broadcast( -// CHECK-SAME: [[VAL_132:%.*]]: tensor<2x3xf16>, [[VAL_133:%.*]]: tensor<3xf16>) -> tensor<2x3xf16> { -// CHECK: [[VAL_134:%.*]] = "tf.Div"([[VAL_132]], [[VAL_133]]) : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> -// CHECK: [[VAL_135:%.*]] = "tf.Div"([[VAL_132]], [[VAL_133]]) : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> -// CHECK: [[VAL_136:%.*]] = "tf.FloorDiv"([[VAL_132]], [[VAL_133]]) : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> -// CHECK: return [[VAL_136]] : tensor<2x3xf16> -// CHECK: } - -// CHECK-LABEL: func @equal( -// CHECK-SAME: [[VAL_137:%.*]]: tensor<2xi32>) -> tensor<2xi1> { -// CHECK: [[VAL_138:%.*]] = "tf.Equal"([[VAL_137]], [[VAL_137]]) {incompatible_shape_error = true} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> -// CHECK: return [[VAL_138]] : tensor<2xi1> -// CHECK: } - -// CHECK-LABEL: func @equal_dynamic( -// CHECK-SAME: [[VAL_139:%.*]]: tensor, [[VAL_140:%.*]]: tensor<1xi32>) -> tensor { -// CHECK: [[VAL_141:%.*]] = "tf.Equal"([[VAL_139]], [[VAL_140]]) {incompatible_shape_error = true} : (tensor, tensor<1xi32>) -> tensor -// CHECK: return [[VAL_141]] : tensor -// CHECK: } - -// CHECK-LABEL: func @equal_broadcast( -// CHECK-SAME: [[VAL_142:%.*]]: tensor<1xi32>, [[VAL_143:%.*]]: tensor<1x2xi32>) -> tensor<1x2xi1> { -// CHECK: [[VAL_144:%.*]] = "tf.Equal"([[VAL_142]], [[VAL_143]]) {incompatible_shape_error = true} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> -// CHECK: return [[VAL_144]] : tensor<1x2xi1> -// CHECK: } - -// CHECK-LABEL: func @equal_broadcast_no_incompatible_shapes_error( -// CHECK-SAME: [[VAL_145:%.*]]: tensor<2xi32>, [[VAL_146:%.*]]: tensor<1x2xi32>) -> tensor<1x2xi1> { -// CHECK: [[VAL_147:%.*]] = "tf.Equal"([[VAL_145]], [[VAL_146]]) {incompatible_shape_error = true} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> -// CHECK: return [[VAL_147]] : tensor<1x2xi1> -// CHECK: } - -// CHECK-LABEL: func @equal_incompatible_shape_broadcastable( -// CHECK-SAME: [[VAL_148:%.*]]: tensor, [[VAL_149:%.*]]: tensor<1xi32>) -> tensor { -// CHECK: [[VAL_150:%.*]] = "tf.Equal"([[VAL_148]], [[VAL_149]]) {incompatible_shape_error = true} : (tensor, tensor<1xi32>) -> tensor -// CHECK: return [[VAL_150]] : tensor -// CHECK: } - -// CHECK-LABEL: func @notequal( -// CHECK-SAME: [[VAL_151:%.*]]: tensor<2xi32>) -> tensor<2xi1> { -// CHECK: [[VAL_152:%.*]] = "tf.NotEqual"([[VAL_151]], [[VAL_151]]) {incompatible_shape_error = true} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> -// CHECK: return [[VAL_152]] : tensor<2xi1> -// CHECK: } - -// CHECK-LABEL: func @notequal_broadcast( -// CHECK-SAME: [[VAL_153:%.*]]: tensor<1xi32>, [[VAL_154:%.*]]: tensor<1x2xi32>) -> tensor<1x2xi1> { -// CHECK: [[VAL_155:%.*]] = "tf.NotEqual"([[VAL_153]], [[VAL_154]]) {incompatible_shape_error = true} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> -// CHECK: return [[VAL_155]] : tensor<1x2xi1> -// CHECK: } - -// CHECK-LABEL: func @notequal_broadcast_no_incompatible_shapes_error( -// CHECK-SAME: [[VAL_156:%.*]]: tensor<2xi32>, [[VAL_157:%.*]]: tensor<1x2xi32>) -> tensor<1x2xi1> { -// CHECK: [[VAL_158:%.*]] = "tf.NotEqual"([[VAL_156]], [[VAL_157]]) {incompatible_shape_error = true} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> -// CHECK: return [[VAL_158]] : tensor<1x2xi1> -// CHECK: } - -// CHECK-LABEL: func @notequal_incompatible_shape_broadcastable( -// CHECK-SAME: [[VAL_159:%.*]]: tensor, [[VAL_160:%.*]]: tensor<1xi32>) -> tensor { -// CHECK: [[VAL_161:%.*]] = "tf.NotEqual"([[VAL_159]], [[VAL_160]]) {incompatible_shape_error = true} : (tensor, tensor<1xi32>) -> tensor -// CHECK: return [[VAL_161]] : tensor -// CHECK: } - -// CHECK-LABEL: func @greater( -// CHECK-SAME: [[VAL_162:%.*]]: tensor<2xi32>) -> tensor<2xi1> { -// CHECK: [[VAL_163:%.*]] = "tf.Greater"([[VAL_162]], [[VAL_162]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> -// CHECK: return [[VAL_163]] : tensor<2xi1> -// CHECK: } - -// CHECK-LABEL: func @broadcast_greater( -// CHECK-SAME: [[VAL_164:%.*]]: tensor<1xi32>, [[VAL_165:%.*]]: tensor<1x2xi32>) -> tensor<1x2xi1> { -// CHECK: [[VAL_166:%.*]] = "tf.Greater"([[VAL_164]], [[VAL_165]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> -// CHECK: return [[VAL_166]] : tensor<1x2xi1> -// CHECK: } - -// CHECK-LABEL: func @greater_equal( -// CHECK-SAME: [[VAL_167:%.*]]: tensor<2xi32>) -> tensor<2xi1> { -// CHECK: [[VAL_168:%.*]] = "tf.GreaterEqual"([[VAL_167]], [[VAL_167]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> -// CHECK: return [[VAL_168]] : tensor<2xi1> -// CHECK: } - -// CHECK-LABEL: func @broadcast_greater_equal( -// CHECK-SAME: [[VAL_169:%.*]]: tensor<1xi32>, [[VAL_170:%.*]]: tensor<1x2xi32>) -> tensor<1x2xi1> { -// CHECK: [[VAL_171:%.*]] = "tf.GreaterEqual"([[VAL_169]], [[VAL_170]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> -// CHECK: return [[VAL_171]] : tensor<1x2xi1> -// CHECK: } - -// CHECK-LABEL: func @less( -// CHECK-SAME: [[VAL_172:%.*]]: tensor<2xi32>) -> tensor<2xi1> { -// CHECK: [[VAL_173:%.*]] = "tf.Less"([[VAL_172]], [[VAL_172]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> -// CHECK: return [[VAL_173]] : tensor<2xi1> -// CHECK: } - -// CHECK-LABEL: func @broadcast_less( -// CHECK-SAME: [[VAL_174:%.*]]: tensor<1xi32>, [[VAL_175:%.*]]: tensor<1x2xi32>) -> tensor<1x2xi1> { -// CHECK: [[VAL_176:%.*]] = "tf.Less"([[VAL_174]], [[VAL_175]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> -// CHECK: return [[VAL_176]] : tensor<1x2xi1> -// CHECK: } - -// CHECK-LABEL: func @less_equal( -// CHECK-SAME: [[VAL_177:%.*]]: tensor<2xi32>) -> tensor<2xi1> { -// CHECK: [[VAL_178:%.*]] = "tf.LessEqual"([[VAL_177]], [[VAL_177]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> -// CHECK: return [[VAL_178]] : tensor<2xi1> -// CHECK: } - -// CHECK-LABEL: func @broadcast_less_equal( -// CHECK-SAME: [[VAL_179:%.*]]: tensor<1xi32>, [[VAL_180:%.*]]: tensor<1x2xi32>) -> tensor<1x2xi1> { -// CHECK: [[VAL_181:%.*]] = "tf.LessEqual"([[VAL_179]], [[VAL_180]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> -// CHECK: return [[VAL_181]] : tensor<1x2xi1> -// CHECK: } - -// CHECK-LABEL: func @concat_v2( -// CHECK-SAME: [[VAL_182:%.*]]: tensor<3x3xf32>, [[VAL_183:%.*]]: tensor<3x3xf32>) -> tensor<6x3xf32> { -// CHECK: [[VAL_184:%.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor -// CHECK: [[VAL_185:%.*]] = "tf.ConcatV2"([[VAL_182]], [[VAL_183]], [[VAL_184]]) : (tensor<3x3xf32>, tensor<3x3xf32>, tensor) -> tensor<6x3xf32> -// CHECK: return [[VAL_185]] : tensor<6x3xf32> -// CHECK: } - -// CHECK-LABEL: func @concat_v2_1d_axis( -// CHECK-SAME: [[VAL_186:%.*]]: tensor<3x3xf32>, [[VAL_187:%.*]]: tensor<3x3xf32>) -> tensor<3x6xf32> { -// CHECK: [[VAL_188:%.*]] = "tf.Const"() {value = dense<1> : tensor} : () -> tensor -// CHECK: [[VAL_189:%.*]] = "tf.ConcatV2"([[VAL_186]], [[VAL_187]], [[VAL_188]]) : (tensor<3x3xf32>, tensor<3x3xf32>, tensor) -> tensor<3x6xf32> -// CHECK: return [[VAL_189]] : tensor<3x6xf32> -// CHECK: } - -// CHECK-LABEL: func @const() -> tensor<2xi32> { -// CHECK: [[VAL_190:%.*]] = "tf.Const"() {value = dense<0> : tensor<2xi32>} : () -> tensor<2xi32> -// CHECK: return [[VAL_190]] : tensor<2xi32> -// CHECK: } - -// CHECK-LABEL: func @relu( -// CHECK-SAME: [[VAL_192:%.*]]: tensor<1xi32>) -> tensor<1xi32> { -// CHECK: [[VAL_193:%.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor -// CHECK: [[VAL_194:%.*]] = "tf.Maximum"([[VAL_193]], [[VAL_192]]) : (tensor, tensor<1xi32>) -> tensor<1xi32> -// CHECK: return [[VAL_194]] : tensor<1xi32> -// CHECK: } - -// CHECK-LABEL: func @relu_unranked( -// CHECK-SAME: [[VAL_195:%.*]]: tensor) -> tensor { -// CHECK: [[VAL_196:%.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor -// CHECK: [[VAL_197:%.*]] = "tf.Maximum"([[VAL_196]], [[VAL_195]]) : (tensor, tensor) -> tensor -// CHECK: return [[VAL_197]] : tensor -// CHECK: } - -// CHECK-LABEL: func @relu6( -// CHECK-SAME: [[VAL_198:%.*]]: tensor<1xi32>) -> tensor<1xi32> { -// CHECK: [[VAL_199:%.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor -// CHECK: [[VAL_200:%.*]] = "tf.Const"() {value = dense<6> : tensor} : () -> tensor -// CHECK: [[VAL_201:%.*]] = "tf.Minimum"([[VAL_198]], [[VAL_200]]) : (tensor<1xi32>, tensor) -> tensor<1xi32> -// CHECK: [[VAL_202:%.*]] = "tf.Maximum"([[VAL_201]], [[VAL_199]]) : (tensor<1xi32>, tensor) -> tensor<1xi32> -// CHECK: return [[VAL_202]] : tensor<1xi32> -// CHECK: } - -// CHECK-LABEL: func @relu6_unranked( -// CHECK-SAME: [[VAL_203:%.*]]: tensor) -> tensor { -// CHECK: [[VAL_204:%.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor -// CHECK: [[VAL_205:%.*]] = "tf.Const"() {value = dense<6> : tensor} : () -> tensor -// CHECK: [[VAL_206:%.*]] = "tf.Minimum"([[VAL_203]], [[VAL_205]]) : (tensor, tensor) -> tensor -// CHECK: [[VAL_207:%.*]] = "tf.Maximum"([[VAL_206]], [[VAL_204]]) : (tensor, tensor) -> tensor -// CHECK: return [[VAL_207]] : tensor -// CHECK: } - -// CHECK-LABEL: func @relu_grad( -// CHECK-SAME: [[VAL_208:%.*]]: tensor<4x8xf32>, [[VAL_209:%.*]]: tensor) -> tensor<4x8xf32> { -// CHECK: [[VAL_210:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor} : () -> tensor -// CHECK: [[VAL_211:%.*]] = "tf.Greater"([[VAL_209]], [[VAL_210]]) : (tensor, tensor) -> tensor -// CHECK: [[VAL_212:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<4x8xf32>} : () -> tensor<4x8xf32> -// CHECK: [[VAL_213:%.*]] = "tf.Select"([[VAL_211]], [[VAL_208]], [[VAL_212]]) : (tensor, tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32> -// CHECK: return [[VAL_213]] : tensor<4x8xf32> -// CHECK: } - -// CHECK-LABEL: func @select( -// CHECK-SAME: [[VAL_214:%.*]]: tensor<2xi1>, [[VAL_215:%.*]]: tensor<2xi32>, [[VAL_216:%.*]]: tensor<2xi32>) -> tensor<2xi32> { -// CHECK: [[VAL_217:%.*]] = "tf.Select"([[VAL_214]], [[VAL_215]], [[VAL_216]]) : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> -// CHECK: return [[VAL_217]] : tensor<2xi32> -// CHECK: } - -// CHECK-LABEL: func @select_float( -// CHECK-SAME: [[VAL_218:%.*]]: tensor<2xi1>, [[VAL_219:%.*]]: tensor<2xf32>, [[VAL_220:%.*]]: tensor<2xf32>) -> tensor<2xf32> { -// CHECK: [[VAL_221:%.*]] = "tf.Select"([[VAL_218]], [[VAL_219]], [[VAL_220]]) : (tensor<2xi1>, tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> -// CHECK: return [[VAL_221]] : tensor<2xf32> -// CHECK: } - -// CHECK-LABEL: func @select_multidimensional( -// CHECK-SAME: [[VAL_222:%.*]]: tensor<3x2xi1>, [[VAL_223:%.*]]: tensor<3x2xi32>, [[VAL_224:%.*]]: tensor<3x2xi32>) -> tensor<3x2xi32> { -// CHECK: [[VAL_225:%.*]] = "tf.Select"([[VAL_222]], [[VAL_223]], [[VAL_224]]) : (tensor<3x2xi1>, tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32> -// CHECK: return [[VAL_225]] : tensor<3x2xi32> -// CHECK: } - -// CHECK-LABEL: func @selectv2( -// CHECK-SAME: [[VAL_226:%.*]]: tensor<2xi1>, [[VAL_227:%.*]]: tensor<2xi32>, [[VAL_228:%.*]]: tensor<2xi32>) -> tensor<2xi32> { -// CHECK: [[VAL_229:%.*]] = "tf.Select"([[VAL_226]], [[VAL_227]], [[VAL_228]]) : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> -// CHECK: return [[VAL_229]] : tensor<2xi32> -// CHECK: } - -// CHECK-LABEL: func @selectv2_pred_scalar( -// CHECK-SAME: [[VAL_230:%.*]]: tensor, [[VAL_231:%.*]]: tensor<2xi32>, [[VAL_232:%.*]]: tensor<2xi32>) -> tensor<2xi32> { -// CHECK: [[VAL_233:%.*]] = "tf.Select"([[VAL_230]], [[VAL_231]], [[VAL_232]]) : (tensor, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> -// CHECK: return [[VAL_233]] : tensor<2xi32> -// CHECK: } - -// CHECK-LABEL: func @transpose_2d( -// CHECK-SAME: [[VAL_234:%.*]]: tensor<2x3xf32>) -> tensor<3x2xf32> { -// CHECK: [[VAL_235:%.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> -// CHECK: [[VAL_236:%.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> -// CHECK: [[VAL_237:%.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> -// CHECK: [[VAL_238:%.*]] = "tf.Transpose"([[VAL_234]], [[VAL_237]]) : (tensor<2x3xf32>, tensor<2xi64>) -> tensor<3x2xf32> -// CHECK: return [[VAL_238]] : tensor<3x2xf32> -// CHECK: } - -// CHECK-LABEL: func @transpose_3d_int32( -// CHECK-SAME: [[VAL_239:%.*]]: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> { -// CHECK: [[VAL_240:%.*]] = "tf.Const"() {value = dense<[2, 1, 0]> : tensor<3xi32>} : () -> tensor<3xi32> -// CHECK: [[VAL_241:%.*]] = "tf.Const"() {value = dense<[2, 1, 0]> : tensor<3xi64>} : () -> tensor<3xi64> -// CHECK: [[VAL_242:%.*]] = "tf.Const"() {value = dense<[2, 1, 0]> : tensor<3xi64>} : () -> tensor<3xi64> -// CHECK: [[VAL_243:%.*]] = "tf.Transpose"([[VAL_239]], [[VAL_242]]) : (tensor<1x2x3xf32>, tensor<3xi64>) -> tensor<3x2x1xf32> -// CHECK: return [[VAL_243]] : tensor<3x2x1xf32> -// CHECK: } - -// CHECK-LABEL: func @transpose_3d( -// CHECK-SAME: [[VAL_244:%.*]]: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> { -// CHECK: [[VAL_245:%.*]] = "tf.Const"() {value = dense<[2, 1, 0]> : tensor<3xi64>} : () -> tensor<3xi64> -// CHECK: [[VAL_246:%.*]] = "tf.Const"() {value = dense<[2, 1, 0]> : tensor<3xi64>} : () -> tensor<3xi64> -// CHECK: [[VAL_247:%.*]] = "tf.Const"() {value = dense<[2, 1, 0]> : tensor<3xi64>} : () -> tensor<3xi64> -// CHECK: [[VAL_248:%.*]] = "tf.Transpose"([[VAL_244]], [[VAL_247]]) : (tensor<1x2x3xf32>, tensor<3xi64>) -> tensor<3x2x1xf32> -// CHECK: return [[VAL_248]] : tensor<3x2x1xf32> -// CHECK: } - -// CHECK-LABEL: func @transpose_dynamic_2d( -// CHECK-SAME: [[VAL_249:%.*]]: tensor) -> tensor<4x?xf32> { -// CHECK: [[VAL_250:%.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> -// CHECK: [[VAL_251:%.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> -// CHECK: [[VAL_252:%.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> -// CHECK: [[VAL_253:%.*]] = "tf.Transpose"([[VAL_249]], [[VAL_252]]) : (tensor, tensor<2xi64>) -> tensor<4x?xf32> -// CHECK: return [[VAL_253]] : tensor<4x?xf32> -// CHECK: } - -// CHECK-LABEL: func @transpose_unranked_2d( -// CHECK-SAME: [[VAL_254:%.*]]: tensor<*xf32>) -> tensor<*xf32> { -// CHECK: [[VAL_255:%.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> -// CHECK: [[VAL_256:%.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> -// CHECK: [[VAL_257:%.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> -// CHECK: [[VAL_258:%.*]] = "tf.Transpose"([[VAL_254]], [[VAL_257]]) : (tensor<*xf32>, tensor<2xi64>) -> tensor<*xf32> -// CHECK: return [[VAL_258]] : tensor<*xf32> -// CHECK: } - -// CHECK-LABEL: func @abs( -// CHECK-SAME: [[VAL_259:%.*]]: tensor<2xf32>) -> tensor<2xf32> { -// CHECK: [[VAL_260:%.*]] = "tf.Abs"([[VAL_259]]) : (tensor<2xf32>) -> tensor<2xf32> -// CHECK: return [[VAL_260]] : tensor<2xf32> -// CHECK: } - -// CHECK-LABEL: func @abs_dynamic( -// CHECK-SAME: [[VAL_261:%.*]]: tensor) -> tensor { -// CHECK: [[VAL_262:%.*]] = "tf.Abs"([[VAL_261]]) : (tensor) -> tensor -// CHECK: return [[VAL_262]] : tensor -// CHECK: } - -// CHECK-LABEL: func @abs_unranked( -// CHECK-SAME: [[VAL_263:%.*]]: tensor<*xf32>) -> tensor<*xf32> { -// CHECK: [[VAL_264:%.*]] = "tf.Abs"([[VAL_263]]) : (tensor<*xf32>) -> tensor<*xf32> -// CHECK: return [[VAL_264]] : tensor<*xf32> -// CHECK: } - -// CHECK-LABEL: func @ceil( -// CHECK-SAME: [[VAL_265:%.*]]: tensor<2xf32>) -> tensor<2xf32> { -// CHECK: [[VAL_266:%.*]] = "tf.Ceil"([[VAL_265]]) : (tensor<2xf32>) -> tensor<2xf32> -// CHECK: return [[VAL_266]] : tensor<2xf32> -// CHECK: } - -// CHECK-LABEL: func @ceil_dynamic( -// CHECK-SAME: [[VAL_267:%.*]]: tensor) -> tensor { -// CHECK: [[VAL_268:%.*]] = "tf.Ceil"([[VAL_267]]) : (tensor) -> tensor -// CHECK: return [[VAL_268]] : tensor -// CHECK: } - -// CHECK-LABEL: func @ceil_unranked( -// CHECK-SAME: [[VAL_269:%.*]]: tensor<*xf32>) -> tensor<*xf32> { -// CHECK: [[VAL_270:%.*]] = "tf.Ceil"([[VAL_269]]) : (tensor<*xf32>) -> tensor<*xf32> -// CHECK: return [[VAL_270]] : tensor<*xf32> -// CHECK: } - -// CHECK-LABEL: func @complex_abs( -// CHECK-SAME: [[VAL_271:%.*]]: tensor<2xcomplex>) -> tensor<2xf32> { -// CHECK: [[VAL_272:%.*]] = "tf.ComplexAbs"([[VAL_271]]) : (tensor<2xcomplex>) -> tensor<2xf32> -// CHECK: return [[VAL_272]] : tensor<2xf32> -// CHECK: } - -// CHECK-LABEL: func @cos( -// CHECK-SAME: [[VAL_273:%.*]]: tensor<2xf32>) -> tensor<2xf32> { -// CHECK: [[VAL_274:%.*]] = "tf.Cos"([[VAL_273]]) : (tensor<2xf32>) -> tensor<2xf32> -// CHECK: return [[VAL_274]] : tensor<2xf32> -// CHECK: } - -// CHECK-LABEL: func @cos_dynamic( -// CHECK-SAME: [[VAL_275:%.*]]: tensor) -> tensor { -// CHECK: [[VAL_276:%.*]] = "tf.Cos"([[VAL_275]]) : (tensor) -> tensor -// CHECK: return [[VAL_276]] : tensor -// CHECK: } - -// CHECK-LABEL: func @cos_unranked( -// CHECK-SAME: [[VAL_277:%.*]]: tensor<*xf32>) -> tensor<*xf32> { -// CHECK: [[VAL_278:%.*]] = "tf.Cos"([[VAL_277]]) : (tensor<*xf32>) -> tensor<*xf32> -// CHECK: return [[VAL_278]] : tensor<*xf32> -// CHECK: } - -// CHECK-LABEL: func @exp( -// CHECK-SAME: [[VAL_279:%.*]]: tensor<2xf32>) -> tensor<2xf32> { -// CHECK: [[VAL_280:%.*]] = "tf.Exp"([[VAL_279]]) : (tensor<2xf32>) -> tensor<2xf32> -// CHECK: return [[VAL_280]] : tensor<2xf32> -// CHECK: } - -// CHECK-LABEL: func @exp_dynamic( -// CHECK-SAME: [[VAL_281:%.*]]: tensor) -> tensor { -// CHECK: [[VAL_282:%.*]] = "tf.Exp"([[VAL_281]]) : (tensor) -> tensor -// CHECK: return [[VAL_282]] : tensor -// CHECK: } - -// CHECK-LABEL: func @exp_unranked( -// CHECK-SAME: [[VAL_283:%.*]]: tensor<*xf32>) -> tensor<*xf32> { -// CHECK: [[VAL_284:%.*]] = "tf.Exp"([[VAL_283]]) : (tensor<*xf32>) -> tensor<*xf32> -// CHECK: return [[VAL_284]] : tensor<*xf32> -// CHECK: } - -// CHECK-LABEL: func @floor( -// CHECK-SAME: [[VAL_285:%.*]]: tensor<2xf32>) -> tensor<2xf32> { -// CHECK: [[VAL_286:%.*]] = "tf.Floor"([[VAL_285]]) : (tensor<2xf32>) -> tensor<2xf32> -// CHECK: return [[VAL_286]] : tensor<2xf32> -// CHECK: } - -// CHECK-LABEL: func @floor_dynamic( -// CHECK-SAME: [[VAL_287:%.*]]: tensor) -> tensor { -// CHECK: [[VAL_288:%.*]] = "tf.Floor"([[VAL_287]]) : (tensor) -> tensor -// CHECK: return [[VAL_288]] : tensor -// CHECK: } - -// CHECK-LABEL: func @floor_unranked( -// CHECK-SAME: [[VAL_289:%.*]]: tensor<*xf32>) -> tensor<*xf32> { -// CHECK: [[VAL_290:%.*]] = "tf.Floor"([[VAL_289]]) : (tensor<*xf32>) -> tensor<*xf32> -// CHECK: return [[VAL_290]] : tensor<*xf32> -// CHECK: } - -// CHECK-LABEL: func @is_finite( -// CHECK-SAME: [[VAL_291:%.*]]: tensor<2xf32>) -> tensor<2xi1> { -// CHECK: [[VAL_292:%.*]] = "tf.IsFinite"([[VAL_291]]) : (tensor<2xf32>) -> tensor<2xi1> -// CHECK: return [[VAL_292]] : tensor<2xi1> -// CHECK: } - -// CHECK-LABEL: func @is_finite_dynamic( -// CHECK-SAME: [[VAL_293:%.*]]: tensor) -> tensor { -// CHECK: [[VAL_294:%.*]] = "tf.IsFinite"([[VAL_293]]) : (tensor) -> tensor -// CHECK: return [[VAL_294]] : tensor -// CHECK: } - -// CHECK-LABEL: func @is_finite_unranked( -// CHECK-SAME: [[VAL_295:%.*]]: tensor<*xf32>) -> tensor<*xi1> { -// CHECK: [[VAL_296:%.*]] = "tf.IsFinite"([[VAL_295]]) : (tensor<*xf32>) -> tensor<*xi1> -// CHECK: return [[VAL_296]] : tensor<*xi1> -// CHECK: } - -// CHECK-LABEL: func @log( -// CHECK-SAME: [[VAL_297:%.*]]: tensor<2xf32>) -> tensor<2xf32> { -// CHECK: [[VAL_298:%.*]] = "tf.Log"([[VAL_297]]) : (tensor<2xf32>) -> tensor<2xf32> -// CHECK: return [[VAL_298]] : tensor<2xf32> -// CHECK: } - -// CHECK-LABEL: func @log_dynamic( -// CHECK-SAME: [[VAL_299:%.*]]: tensor) -> tensor { -// CHECK: [[VAL_300:%.*]] = "tf.Log"([[VAL_299]]) : (tensor) -> tensor -// CHECK: return [[VAL_300]] : tensor -// CHECK: } - -// CHECK-LABEL: func @log_unranked( -// CHECK-SAME: [[VAL_301:%.*]]: tensor<*xf32>) -> tensor<*xf32> { -// CHECK: [[VAL_302:%.*]] = "tf.Log"([[VAL_301]]) : (tensor<*xf32>) -> tensor<*xf32> -// CHECK: return [[VAL_302]] : tensor<*xf32> -// CHECK: } - -// CHECK-LABEL: func @log1p( -// CHECK-SAME: [[VAL_303:%.*]]: tensor<2xf32>) -> tensor<2xf32> { -// CHECK: [[VAL_304:%.*]] = "tf.Log1p"([[VAL_303]]) : (tensor<2xf32>) -> tensor<2xf32> -// CHECK: return [[VAL_304]] : tensor<2xf32> -// CHECK: } - -// CHECK-LABEL: func @log1p_dynamic( -// CHECK-SAME: [[VAL_305:%.*]]: tensor) -> tensor { -// CHECK: [[VAL_306:%.*]] = "tf.Log1p"([[VAL_305]]) : (tensor) -> tensor -// CHECK: return [[VAL_306]] : tensor -// CHECK: } - -// CHECK-LABEL: func @log1p_unranked( -// CHECK-SAME: [[VAL_307:%.*]]: tensor<*xf32>) -> tensor<*xf32> { -// CHECK: [[VAL_308:%.*]] = "tf.Log1p"([[VAL_307]]) : (tensor<*xf32>) -> tensor<*xf32> -// CHECK: return [[VAL_308]] : tensor<*xf32> -// CHECK: } - -// CHECK-LABEL: func @neg( -// CHECK-SAME: [[VAL_309:%.*]]: tensor<2xf32>) -> tensor<2xf32> { -// CHECK: [[VAL_310:%.*]] = "tf.Neg"([[VAL_309]]) : (tensor<2xf32>) -> tensor<2xf32> -// CHECK: return [[VAL_310]] : tensor<2xf32> -// CHECK: } - -// CHECK-LABEL: func @neg_dynamic( -// CHECK-SAME: [[VAL_311:%.*]]: tensor) -> tensor { -// CHECK: [[VAL_312:%.*]] = "tf.Neg"([[VAL_311]]) : (tensor) -> tensor -// CHECK: return [[VAL_312]] : tensor -// CHECK: } - -// CHECK-LABEL: func @neg_unranked( -// CHECK-SAME: [[VAL_313:%.*]]: tensor<*xf32>) -> tensor<*xf32> { -// CHECK: [[VAL_314:%.*]] = "tf.Neg"([[VAL_313]]) : (tensor<*xf32>) -> tensor<*xf32> -// CHECK: return [[VAL_314]] : tensor<*xf32> -// CHECK: } - -// CHECK-LABEL: func @sigmoid( -// CHECK-SAME: [[VAL_315:%.*]]: tensor<2xf32>) -> tensor<2xf32> { -// CHECK: [[VAL_316:%.*]] = "tf.Const"() {value = dense<5.000000e-01> : tensor} : () -> tensor -// CHECK: [[VAL_317:%.*]] = "tf.Const"() {value = dense<2> : tensor<1xi64>} : () -> tensor<1xi64> -// CHECK: [[VAL_318:%.*]] = "tf.Const"() {value = dense<5.000000e-01> : tensor<2xf32>} : () -> tensor<2xf32> -// CHECK: [[VAL_319:%.*]] = "tf.Mul"([[VAL_315]], [[VAL_318]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> -// CHECK: [[VAL_320:%.*]] = "tf.Tanh"([[VAL_319]]) : (tensor<2xf32>) -> tensor<2xf32> -// CHECK: [[VAL_321:%.*]] = "tf.Mul"([[VAL_320]], [[VAL_318]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> -// CHECK: [[VAL_322:%.*]] = "tf.AddV2"([[VAL_321]], [[VAL_318]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> -// CHECK: return [[VAL_322]] : tensor<2xf32> -// CHECK: } - -// CHECK-LABEL: func @sin( -// CHECK-SAME: [[VAL_323:%.*]]: tensor<2xf32>) -> tensor<2xf32> { -// CHECK: [[VAL_324:%.*]] = "tf.Sin"([[VAL_323]]) : (tensor<2xf32>) -> tensor<2xf32> -// CHECK: return [[VAL_324]] : tensor<2xf32> -// CHECK: } - -// CHECK-LABEL: func @sin_dynamic( -// CHECK-SAME: [[VAL_325:%.*]]: tensor) -> tensor { -// CHECK: [[VAL_326:%.*]] = "tf.Sin"([[VAL_325]]) : (tensor) -> tensor -// CHECK: return [[VAL_326]] : tensor -// CHECK: } - -// CHECK-LABEL: func @sin_unranked( -// CHECK-SAME: [[VAL_327:%.*]]: tensor<*xf32>) -> tensor<*xf32> { -// CHECK: [[VAL_328:%.*]] = "tf.Sin"([[VAL_327]]) : (tensor<*xf32>) -> tensor<*xf32> -// CHECK: return [[VAL_328]] : tensor<*xf32> -// CHECK: } - -// CHECK-LABEL: func @rsqrt( -// CHECK-SAME: [[VAL_329:%.*]]: tensor<2xf32>) -> tensor<2xf32> { -// CHECK: [[VAL_330:%.*]] = "tf.Rsqrt"([[VAL_329]]) : (tensor<2xf32>) -> tensor<2xf32> -// CHECK: return [[VAL_330]] : tensor<2xf32> -// CHECK: } - -// CHECK-LABEL: func @rsqrt_dynamic( -// CHECK-SAME: [[VAL_331:%.*]]: tensor) -> tensor { -// CHECK: [[VAL_332:%.*]] = "tf.Rsqrt"([[VAL_331]]) : (tensor) -> tensor -// CHECK: return [[VAL_332]] : tensor -// CHECK: } - -// CHECK-LABEL: func @rsqrt_unranked( -// CHECK-SAME: [[VAL_333:%.*]]: tensor<*xf32>) -> tensor<*xf32> { -// CHECK: [[VAL_334:%.*]] = "tf.Rsqrt"([[VAL_333]]) : (tensor<*xf32>) -> tensor<*xf32> -// CHECK: return [[VAL_334]] : tensor<*xf32> -// CHECK: } - -// CHECK-LABEL: func @sqrt( -// CHECK-SAME: [[VAL_335:%.*]]: tensor<2xf32>) -> tensor<2xf32> { -// CHECK: [[VAL_336:%.*]] = "tf.Sqrt"([[VAL_335]]) : (tensor<2xf32>) -> tensor<2xf32> -// CHECK: return [[VAL_336]] : tensor<2xf32> -// CHECK: } - -// CHECK-LABEL: func @sqrt_dynamic( -// CHECK-SAME: [[VAL_337:%.*]]: tensor) -> tensor { -// CHECK: [[VAL_338:%.*]] = "tf.Sqrt"([[VAL_337]]) : (tensor) -> tensor -// CHECK: return [[VAL_338]] : tensor -// CHECK: } - -// CHECK-LABEL: func @sqrt_unranked( -// CHECK-SAME: [[VAL_339:%.*]]: tensor<*xf32>) -> tensor<*xf32> { -// CHECK: [[VAL_340:%.*]] = "tf.Sqrt"([[VAL_339]]) : (tensor<*xf32>) -> tensor<*xf32> -// CHECK: return [[VAL_340]] : tensor<*xf32> -// CHECK: } - -// CHECK-LABEL: func @tanh( -// CHECK-SAME: [[VAL_341:%.*]]: tensor<2xf32>) -> tensor<2xf32> { -// CHECK: [[VAL_342:%.*]] = "tf.Tanh"([[VAL_341]]) : (tensor<2xf32>) -> tensor<2xf32> -// CHECK: return [[VAL_342]] : tensor<2xf32> -// CHECK: } - -// CHECK-LABEL: func @tanh_dynamic( -// CHECK-SAME: [[VAL_343:%.*]]: tensor) -> tensor { -// CHECK: [[VAL_344:%.*]] = "tf.Tanh"([[VAL_343]]) : (tensor) -> tensor -// CHECK: return [[VAL_344]] : tensor -// CHECK: } - -// CHECK-LABEL: func @tanh_unranked( -// CHECK-SAME: [[VAL_345:%.*]]: tensor<*xf32>) -> tensor<*xf32> { -// CHECK: [[VAL_346:%.*]] = "tf.Tanh"([[VAL_345]]) : (tensor<*xf32>) -> tensor<*xf32> -// CHECK: return [[VAL_346]] : tensor<*xf32> -// CHECK: } - -// CHECK-LABEL: func @bitcast( -// CHECK-SAME: [[VAL_347:%.*]]: tensor<2xf32>) -> tensor<2xf32> { -// CHECK: [[VAL_348:%.*]] = "tf.Bitcast"([[VAL_347]]) : (tensor<2xf32>) -> tensor<2xf32> -// CHECK: return [[VAL_348]] : tensor<2xf32> -// CHECK: } - -// CHECK-LABEL: func @bitcast_dynamic( -// CHECK-SAME: [[VAL_349:%.*]]: tensor) -> tensor { -// CHECK: [[VAL_350:%.*]] = "tf.Bitcast"([[VAL_349]]) : (tensor) -> tensor -// CHECK: return [[VAL_350]] : tensor -// CHECK: } - -// CHECK-LABEL: func @bitcast_unranked( -// CHECK-SAME: [[VAL_351:%.*]]: tensor<*xf32>) -> tensor<*xf32> { -// CHECK: [[VAL_352:%.*]] = "tf.Bitcast"([[VAL_351]]) : (tensor<*xf32>) -> tensor<*xf32> -// CHECK: return [[VAL_352]] : tensor<*xf32> -// CHECK: } - -// CHECK-LABEL: func @bitcast_same_widths( -// CHECK-SAME: [[VAL_353:%.*]]: tensor<2xf32>) -> tensor<2xi32> { -// CHECK: [[VAL_354:%.*]] = "tf.Bitcast"([[VAL_353]]) : (tensor<2xf32>) -> tensor<2xi32> -// CHECK: return [[VAL_354]] : tensor<2xi32> -// CHECK: } - -// CHECK-LABEL: func @sign( -// CHECK-SAME: [[VAL_355:%.*]]: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> { -// CHECK: [[VAL_356:%.*]] = "tf.NotEqual"([[VAL_355]], [[VAL_355]]) {incompatible_shape_error = true} : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xi1> -// CHECK: [[VAL_357:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1x2x3x4xf32>} : () -> tensor<1x2x3x4xf32> -// CHECK: [[VAL_358:%.*]] = "tf.NotEqual"([[VAL_355]], [[VAL_355]]) {incompatible_shape_error = true} : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xi1> -// CHECK: [[VAL_359:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1x2x3x4xf32>} : () -> tensor<1x2x3x4xf32> -// CHECK: [[VAL_360:%.*]] = "tf.Sign"([[VAL_355]]) : (tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> -// CHECK: [[VAL_361:%.*]] = "tf.Select"([[VAL_358]], [[VAL_359]], [[VAL_360]]) : (tensor<1x2x3x4xi1>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> -// CHECK: [[VAL_362:%.*]] = "tf.Select"([[VAL_356]], [[VAL_357]], [[VAL_361]]) : (tensor<1x2x3x4xi1>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> -// CHECK: return [[VAL_362]] : tensor<1x2x3x4xf32> -// CHECK: } - -// CHECK-LABEL: func @size_rank_one_i32( -// CHECK-SAME: [[VAL_363:%.*]]: tensor) -> tensor { -// CHECK: [[VAL_364:%.*]] = "tf.Const"() {value = dense<1> : tensor} : () -> tensor -// CHECK: return [[VAL_364]] : tensor -// CHECK: } - -// CHECK-LABEL: func @size_rank_one_i64( -// CHECK-SAME: [[VAL_365:%.*]]: tensor) -> tensor { -// CHECK: [[VAL_366:%.*]] = "tf.Const"() {value = dense<1> : tensor} : () -> tensor -// CHECK: return [[VAL_366]] : tensor -// CHECK: } - -// CHECK-LABEL: func @complex( -// CHECK-SAME: [[VAL_367:%.*]]: tensor<3xf32>, [[VAL_368:%.*]]: tensor<3xf32>) -> tensor<3xcomplex> { -// CHECK: [[VAL_369:%.*]] = "tf.Complex"([[VAL_367]], [[VAL_368]]) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xcomplex> -// CHECK: return [[VAL_369]] : tensor<3xcomplex> -// CHECK: } - -// CHECK-LABEL: func @convert_i32_f32( -// CHECK-SAME: [[VAL_370:%.*]]: tensor<2xi32>) -> tensor<2xf32> { -// CHECK: [[VAL_371:%.*]] = "tf.Cast"([[VAL_370]]) {Truncate = false} : (tensor<2xi32>) -> tensor<2xf32> -// CHECK: return [[VAL_371]] : tensor<2xf32> -// CHECK: } - -// CHECK-LABEL: func @convert_slice( -// CHECK-SAME: [[VAL_372:%.*]]: tensor<1x4672xf32>) -> tensor<1x519xf32> { -// CHECK: [[VAL_373:%.*]] = "tf.Const"() {value = dense<[0, 4153]> : tensor<2xi64>} : () -> tensor<2xi64> -// CHECK: [[VAL_374:%.*]] = "tf.Const"() {value = dense<[1, 519]> : tensor<2xi64>} : () -> tensor<2xi64> -// CHECK: [[VAL_375:%.*]] = "tf.Slice"([[VAL_372]], [[VAL_373]], [[VAL_374]]) : (tensor<1x4672xf32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x519xf32> -// CHECK: return [[VAL_375]] : tensor<1x519xf32> -// CHECK: } - -// CHECK-LABEL: func @reshape( -// CHECK-SAME: [[VAL_372:%.*]]: tensor<4x6xf32>) -> tensor<2x2x6xf32> { -// CHECK: [[VAL_373:%.*]] = constant dense<[2, 2, 6]> : tensor<3xi64> -// CHECK: [[VAL_374:%.*]] = "tf.Reshape"([[VAL_372]], [[VAL_373]]) : (tensor<4x6xf32>, tensor<3xi64>) -> tensor<2x2x6xf32> -// CHECK: return [[VAL_374]] : tensor<2x2x6xf32> -// CHECK: } - -// CHECK-LABEL: func @convert_dot_1d_2d( -// CHECK-SAME: [[VAL_376:%.*]]: tensor<256xf32>, [[VAL_377:%.*]]: tensor<256x1xf32>) -> tensor<1xf32> { -// CHECK: [[VAL_378:%.*]] = "tf.Reshape"([[VAL_376]], {{.*}}) : (tensor<256xf32>, tensor<2xi64>) -> tensor<1x256xf32> -// CHECK: [[VAL_379:%.*]] = "tf.MatMul"([[VAL_378]], [[VAL_377]]) {transpose_a = false, transpose_b = false} : (tensor<1x256xf32>, tensor<256x1xf32>) -> tensor<1x1xf32> -// CHECK: [[VAL_380:%.*]] = "tf.Reshape"([[VAL_379]], {{.*}}) : (tensor<1x1xf32>, tensor<1xi64>) -> tensor<1xf32> -// CHECK: return [[VAL_380]] : tensor<1xf32> -// CHECK: } - -// CHECK-LABEL: func @convert_dot_2d_1d( -// CHECK-SAME: [[VAL_381:%.*]]: tensor<1x256xf32>, [[VAL_382:%.*]]: tensor<256xf32>) -> tensor<1xf32> { -// CHECK: [[VAL_383:%.*]] = "tf.Reshape"([[VAL_382]], {{.*}}) : (tensor<256xf32>, tensor<2xi64>) -> tensor<1x256xf32> -// CHECK: [[VAL_384:%.*]] = "tf.MatMul"([[VAL_381]], [[VAL_383]]) {transpose_a = false, transpose_b = true} : (tensor<1x256xf32>, tensor<1x256xf32>) -> tensor<1x1xf32> -// CHECK: [[VAL_385:%.*]] = "tf.Reshape"([[VAL_384]], {{.*}}) : (tensor<1x1xf32>, tensor<1xi64>) -> tensor<1xf32> -// CHECK: return [[VAL_385]] : tensor<1xf32> -// CHECK: } - -// CHECK-LABEL: func @convert_dot_1d_1d( -// CHECK-SAME: [[VAL_386:%.*]]: tensor<256xf32>, [[VAL_387:%.*]]: tensor<256xf32>) -> tensor { -// CHECK-DAG: [[VAL_388:%.*]] = "tf.Reshape"([[VAL_386]], {{.*}}) : (tensor<256xf32>, tensor<2xi64>) -> tensor<1x256xf32> -// CHECK-DAG: [[VAL_389:%.*]] = "tf.Reshape"([[VAL_387]], {{.*}}) : (tensor<256xf32>, tensor<2xi64>) -> tensor<1x256xf32> -// CHECK: [[VAL_390:%.*]] = "tf.MatMul"([[VAL_388]], [[VAL_389]]) {transpose_a = false, transpose_b = true} : (tensor<1x256xf32>, tensor<1x256xf32>) -> tensor<1x1xf32> -// CHECK: [[VAL_391:%.*]] = "tf.Reshape"([[VAL_390]], {{.*}}) : (tensor<1x1xf32>, tensor<0xi64>) -> tensor -// CHECK: return [[VAL_391]] : tensor -// CHECK: } - -// CHECK-LABEL: func @convert_dot_2d_2d( -// CHECK-SAME: [[VAL_392:%.*]]: tensor<1x256xf32>, [[VAL_393:%.*]]: tensor<256x1xf32>) -> tensor<1x1xf32> { -// CHECK: [[VAL_394:%.*]] = "tf.MatMul"([[VAL_392]], [[VAL_393]]) {transpose_a = false, transpose_b = false} : (tensor<1x256xf32>, tensor<256x1xf32>) -> tensor<1x1xf32> -// CHECK: return [[VAL_394]] : tensor<1x1xf32> -// CHECK: } - -// CHECK-LABEL: func @broadcast_in_dim_tf_style( -// CHECK-SAME: [[VAL_395:%.*]]: tensor<8x1x16xf32>) -> tensor<3x8x8x16xf32> { -// CHECK: [[VAL_396:%.*]] = constant dense<[3, 8, 8, 16]> : tensor<4xi64> -// CHECK: [[VAL_397:%.*]] = "tf.BroadcastTo"([[VAL_395]], [[VAL_396]]) : (tensor<8x1x16xf32>, tensor<4xi64>) -> tensor<3x8x8x16xf32> -// CHECK: return [[VAL_397]] : tensor<3x8x8x16xf32> -// CHECK: } - -// CHECK-LABEL: func @broadcast_in_dim_general_case( -// CHECK-SAME: [[VAL_398:%.*]]: tensor<3x1x16xf32>) -> tensor<3x8x8x16xf32> { -// CHECK: [[VAL_399:%.*]] = constant dense<[3, 1, 1, 16]> : tensor<4xi64> -// CHECK: [[VAL_400:%.*]] = "tf.Reshape"([[VAL_398]], [[VAL_399]]) : (tensor<3x1x16xf32>, tensor<4xi64>) -> tensor<3x1x1x16xf32> -// CHECK: [[VAL_401:%.*]] = constant dense<[3, 8, 8, 16]> : tensor<4xi64> -// CHECK: [[VAL_402:%.*]] = "tf.BroadcastTo"([[VAL_400]], [[VAL_401]]) : (tensor<3x1x1x16xf32>, tensor<4xi64>) -> tensor<3x8x8x16xf32> -// CHECK: return [[VAL_402]] : tensor<3x8x8x16xf32> -// CHECK: } - -// CHECK-LABEL: func @convert_dot_general( -// CHECK-SAME: [[VAL_396:%.*]]: tensor<3x2x6x5x1xf32>, [[VAL_397:%.*]]: tensor<3x2x4x6xf32>) -> tensor<3x5x1x4xf32> { -// CHECK: [[VAL_398:%.*]] = "tf.Transpose"([[VAL_396]], {{.*}}) : (tensor<3x2x6x5x1xf32>, tensor<5xi64>) -> tensor<3x5x1x2x6xf32> -// CHECK: [[VAL_399:%.*]] = "tf.Transpose"([[VAL_397]], {{.*}}) : (tensor<3x2x4x6xf32>, tensor<4xi64>) -> tensor<3x2x6x4xf32> -// CHECK: [[VAL_400:%.*]] = "tf.Reshape"([[VAL_398]], {{.*}}) : (tensor<3x5x1x2x6xf32>, tensor<3xi64>) -> tensor<3x5x12xf32> -// CHECK: [[VAL_401:%.*]] = "tf.Reshape"([[VAL_399]], {{.*}}) : (tensor<3x2x6x4xf32>, tensor<3xi64>) -> tensor<3x12x4xf32> -// CHECK: [[VAL_402:%.*]] = "tf.BatchMatMulV2"([[VAL_400]], [[VAL_401]]) {adj_x = false, adj_y = false} : (tensor<3x5x12xf32>, tensor<3x12x4xf32>) -> tensor<3x5x4xf32> -// CHECK: [[VAL_403:%.*]] = "tf.Reshape"([[VAL_402]], {{.*}}) : (tensor<3x5x4xf32>, tensor<4xi64>) -> tensor<3x5x1x4xf32> -// CHECK: return [[VAL_403]] : tensor<3x5x1x4xf32> -// CHECK: } - -// CHECK-LABEL: func @convert_conv2d( -// CHECK-SAME: [[VAL_404:%.*]]: tensor<1x8x8x207xf32>, [[VAL_405:%.*]]: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { -// CHECK: [[VAL_406:%.*]] = "tf.Conv2D"([[VAL_404]], [[VAL_405]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> -// CHECK: return [[VAL_406]] : tensor<1x8x8x16xf32> -// CHECK: } - -// CHECK-LABEL: func @convert_depthwise_conv2d( -// CHECK-SAME: [[VAL_407:%.*]]: tensor<1x8x8x207xf32>, [[VAL_408:%.*]]: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { -// CHECK: [[VAL_409:%.*]] = "tf.DepthwiseConv2dNative"([[VAL_407]], [[VAL_408]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> -// CHECK: return [[VAL_409]] : tensor<1x8x8x16xf32> -// CHECK: } - -// CHECK-LABEL: func @convert_conv2d_valid_padding( -// CHECK-SAME: [[VAL_410:%.*]]: tensor<1x8x8x207xf32>, [[VAL_411:%.*]]: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { -// CHECK: [[VAL_412:%.*]] = "tf.Conv2D"([[VAL_410]], [[VAL_411]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> -// CHECK: return [[VAL_412]] : tensor<1x8x8x16xf32> -// CHECK: } - -// CHECK-LABEL: func @convert_reduce_to_sum( -// CHECK-SAME: [[VAL_413:%.*]]: tensor<1x256xf32>) -> tensor<1xf32> { -// CHECK: [[VAL_414:%.*]] = "tf.Const"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64> -// CHECK: [[VAL_415:%.*]] = "tf.Sum"([[VAL_413:%.*]], [[VAL_414:%.*]]) {keep_dims = false} : (tensor<1x256xf32>, tensor<1xi64>) -> tensor<1xf32> -// CHECK: return [[VAL_415]] : tensor<1xf32> -// CHECK: } - -// CHECK-LABEL: func @convert_reduce_to_max( -// CHECK-SAME: [[VAL_416:%.*]]: tensor<1x256xf32>) -> tensor<1xf32> { -// CHECK: [[VAL_417:%.*]] = "tf.Const"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64> -// CHECK: [[VAL_418:%.*]] = "tf.Max"([[VAL_416:%.*]], [[VAL_417:%.*]]) {keep_dims = false} : (tensor<1x256xf32>, tensor<1xi64>) -> tensor<1xf32> -// CHECK: return [[VAL_418]] : tensor<1xf32> -// CHECK: } - -// CHECK-LABEL: func @convert_reduce_to_min( -// CHECK-SAME: [[VAL_419:%.*]]: tensor<1x256xf32>) -> tensor<1xf32> { -// CHECK: [[VAL_420:%.*]] = "tf.Const"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64> -// CHECK: [[VAL_421:%.*]] = "tf.Min"([[VAL_419:%.*]], [[VAL_420:%.*]]) {keep_dims = false} : (tensor<1x256xf32>, tensor<1xi64>) -> tensor<1xf32> -// CHECK: return [[VAL_421]] : tensor<1xf32> -// CHECK: } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir b/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir index ea55e50db30..b1787546d67 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir @@ -86,10 +86,19 @@ func @mul_no_nan(%arg0: tensor<2x3xf32>, %arg1: tensor<3xf32>) -> tensor<2x3xf32 return %0 : tensor<2x3xf32> } +// CHECK-LABEL: @is_inf +func @is_inf(%arg0: tensor<3x4xf32>) -> tensor<3x4xi1> { + // CHECK: %[[INF:.*]] = "tf.Const"() {value = dense<0x7F800000> : tensor} : () -> tensor + // CHECK: %[[ABS:.*]] = "tf.Abs"(%arg0) : (tensor<3x4xf32>) -> tensor<3x4xf32> + // CHECK: %[[RESULT:.*]] = "tf.Equal"(%[[ABS]], %[[INF]]) {incompatible_shape_error = true} : (tensor<3x4xf32>, tensor) -> tensor<3x4xi1> + %0 = "tf.IsInf"(%arg0) : (tensor<3x4xf32>) -> tensor<3x4xi1> + // CHECK: return %[[RESULT]] + return %0 : tensor<3x4xi1> +} + // CHECK-LABEL: @is_nan func @is_nan(%arg0: tensor<3x4xf32>) -> tensor<3x4xi1> { - // CHECK: %[[NAN:.*]] = "tf.Const"() {value = dense<0x7FC00000> : tensor} : () -> tensor - // CHECK: %[[RESULT:.*]] = "tf.Equal"(%arg0, %[[NAN]]) {incompatible_shape_error = true} : (tensor<3x4xf32>, tensor) -> tensor<3x4xi1> + // CHECK: %[[RESULT:.*]] = "tf.NotEqual"(%arg0, %arg0) {incompatible_shape_error = true} : (tensor<3x4xf32>, tensor<3x4xf32>) -> tensor<3x4xi1> %0 = "tf.IsNan"(%arg0) : (tensor<3x4xf32>) -> tensor<3x4xi1> // CHECK: return %[[RESULT]] return %0 : tensor<3x4xi1> @@ -215,6 +224,112 @@ func @rsqrt_grad_unranked(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor< return %0 : tensor<*xf32> } +// %input has 1 batch dimension then 2 block dimensions then 1 remainder +// dimension. +// CHECK-LABEL: fourdim_SpaceToBatchND +func @fourdim_SpaceToBatchND(%input: tensor<3x5x7x10xf32>, %block_shape: tensor<2xi64>, %paddings: tensor<2x2xi64>) -> tensor { + // CHECK-DAG: [[PAD00:%.+]] = "tf.Const"() {value = dense<0> : tensor<1x2xi64>} + // CHECK-DAG: [[ZERO_I32:%.+]] = "tf.Const"() {value = dense<0> : tensor} + // CHECK-DAG: [[ZERO_I64:%.+]] = "tf.Const"() {value = dense<0> : tensor} + // CHECK-DAG: [[ONE_I64:%.+]] = "tf.Const"() {value = dense<1> : tensor} + // CHECK-DAG: [[FULL_PADDINGS:%.+]] = "tf.ConcatV2"([[PAD00]], %arg2, [[PAD00]], [[ZERO_I64]]) + // CHECK-DAG: [[PAD_DEFAULT:%.+]] = "tf.Const"() {value = dense<0.000000e+00> : tensor} + // CHECK-DAG: [[PADDED:%.+]] = "tf.PadV2"(%arg0, [[FULL_PADDINGS]], [[PAD_DEFAULT]]) + // CHECK-DAG: [[PADDINGS_SUM:%.+]] = "tf.Sum"([[FULL_PADDINGS]], [[ONE_I64]]) + // CHECK-DAG: [[INPUT_SHAPE:%.+]] = "tf.Const"() {value = dense<[3, 5, 7, 10]> : tensor<4xi64>} + // CHECK-DAG: [[PADDED_SHAPE:%.+]] = "tf.Add"([[PADDINGS_SUM]], [[INPUT_SHAPE]]) + // CHECK-DAG: [[PADDED_SHAPE_SPLITS:%.+]]:4 = "tf.Split"([[ZERO_I32]], [[PADDED_SHAPE]]) + // CHECK-DAG: [[BLOCK_SHAPE_SPLITS:%.+]]:2 = "tf.Split"([[ZERO_I32]], %arg1) + // CHECK-DAG: [[OUTER_SHAPE_0:%.+]] = "tf.Div"([[PADDED_SHAPE_SPLITS]]#1, [[BLOCK_SHAPE_SPLITS]]#0) + // CHECK-DAG: [[OUTER_SHAPE_1:%.+]] = "tf.Div"([[PADDED_SHAPE_SPLITS]]#2, [[BLOCK_SHAPE_SPLITS]]#1) + // CHECK-DAG: [[RESHAPED_SHAPE:%.+]] = "tf.ConcatV2"([[PADDED_SHAPE_SPLITS]]#0, [[OUTER_SHAPE_0]], [[BLOCK_SHAPE_SPLITS]]#0, [[OUTER_SHAPE_1]], [[BLOCK_SHAPE_SPLITS]]#1, [[PADDED_SHAPE_SPLITS]]#3, [[ZERO_I64]]) + // CHECK-DAG: [[PERMUTATION:%.+]] = "tf.Const"() {value = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi64>} + // CHECK-DAG: [[OUTPUT_BATCH_PART:%.+]] = "tf.Mul"([[PADDED_SHAPE_SPLITS]]#0, [[BLOCK_SHAPE_SPLITS]]#0) + // CHECK-DAG: [[OUTPUT_BATCH:%.+]] = "tf.Mul"([[OUTPUT_BATCH_PART]], [[BLOCK_SHAPE_SPLITS]]#1) + // CHECK-DAG: [[OUTPUT_SHAPE:%.+]] = "tf.ConcatV2"([[OUTPUT_BATCH]], [[OUTER_SHAPE_0]], [[OUTER_SHAPE_1]], [[PADDED_SHAPE_SPLITS]]#3, [[ZERO_I64]]) + // CHECK-DAG: [[RESHAPED:%.+]] = "tf.Reshape"([[PADDED]], [[RESHAPED_SHAPE]]) + // CHECK-DAG: [[PERMUTED:%.+]] = "tf.Transpose"([[RESHAPED]], [[PERMUTATION]]) + // CHECK-DAG: [[RESULT:%.+]] = "tf.Reshape"([[PERMUTED]], [[OUTPUT_SHAPE]]) + // CHECK-DAG: return [[RESULT]] + %0 = "tf.SpaceToBatchND"(%input, %block_shape, %paddings) : (tensor<3x5x7x10xf32>, tensor<2xi64>, tensor<2x2xi64>) -> tensor + return %0 : tensor +} + +// %input has 1 batch dimension then 3 block dimensions then 2 remainder +// dimensions. This checks only ops that are specific to the case with 3 block +// dimension and 2 remainder dimensions. +// CHECK-LABEL: sixdim_SpaceToBatchND +func @sixdim_SpaceToBatchND(%input: tensor<3x5x7x9x10x11xf32>, %block_shape: tensor<3xi64>, %paddings: tensor<3x2xi64>) -> tensor { + // CHECK-DAG: [[PAD00:%.+]] = "tf.Const"() + // CHECK-DAG: [[FULL_PADDINGS:%.+]] = "tf.ConcatV2"([[PAD00]], %arg2, [[PAD00]], [[PAD00]], {{.+}}) + // CHECK-DAG: [[INPUT_SHAPE:%.+]] = "tf.Const"() {value = dense<[3, 5, 7, 9, 10, 11]> : tensor<6xi64>} + // CHECK-DAG: [[PADDED_SHAPE_SPLITS:%.+]]:6 = "tf.Split" + // CHECK-DAG: [[BLOCK_SHAPE_SPLITS:%.+]]:3 = "tf.Split" + // CHECK-DAG: [[OUTER_SHAPE_0:%.+]] = "tf.Div"([[PADDED_SHAPE_SPLITS]]#1, [[BLOCK_SHAPE_SPLITS]]#0) + // CHECK-DAG: [[OUTER_SHAPE_1:%.+]] = "tf.Div"([[PADDED_SHAPE_SPLITS]]#2, [[BLOCK_SHAPE_SPLITS]]#1) + // CHECK-DAG: [[OUTER_SHAPE_2:%.+]] = "tf.Div"([[PADDED_SHAPE_SPLITS]]#3, [[BLOCK_SHAPE_SPLITS]]#2) + // CHECK-DAG: [[RESHAPED_SHAPE:%.+]] = "tf.ConcatV2"([[PADDED_SHAPE_SPLITS]]#0, [[OUTER_SHAPE_0]], [[BLOCK_SHAPE_SPLITS]]#0, [[OUTER_SHAPE_1]], [[BLOCK_SHAPE_SPLITS]]#1, [[OUTER_SHAPE_2]], [[BLOCK_SHAPE_SPLITS]]#2, [[PADDED_SHAPE_SPLITS]]#4, [[PADDED_SHAPE_SPLITS]]#5, {{.+}}) + // CHECK-DAG: [[PERMUTATION:%.+]] = "tf.Const"() {value = dense<[2, 4, 6, 0, 1, 3, 5, 7, 8]> : tensor<9xi64>} + // CHECK-DAG: [[OUTPUT_BATCH_PART1:%.+]] = "tf.Mul"([[PADDED_SHAPE_SPLITS]]#0, [[BLOCK_SHAPE_SPLITS]]#0) + // CHECK-DAG: [[OUTPUT_BATCH_PART2:%.+]] = "tf.Mul"([[OUTPUT_BATCH_PART1]], [[BLOCK_SHAPE_SPLITS]]#1) + // CHECK-DAG: [[OUTPUT_BATCH:%.+]] = "tf.Mul"([[OUTPUT_BATCH_PART2]], [[BLOCK_SHAPE_SPLITS]]#2) + // CHECK-DAG: [[OUTPUT_SHAPE:%.+]] = "tf.ConcatV2"([[OUTPUT_BATCH]], [[OUTER_SHAPE_0]], [[OUTER_SHAPE_1]], [[OUTER_SHAPE_2]], [[PADDED_SHAPE_SPLITS]]#4, [[PADDED_SHAPE_SPLITS]]#5, {{.+}}) + %0 = "tf.SpaceToBatchND"(%input, %block_shape, %paddings) : (tensor<3x5x7x9x10x11xf32>, tensor<3xi64>, tensor<3x2xi64>) -> tensor + return %0 : tensor +} + +func @fake_quant_with_min_max_args(%arg0 : tensor) -> tensor { + // CHECK-DAG: [[VAL0:%.+]] = "tf.Const"() {value = dense<1.275000e+02> : tensor} + // CHECK-DAG: [[VAL1:%.+]] = "tf.Const"() {value = dense<1.00392163> : tensor} + // CHECK-DAG: [[VAL2:%.+]] = "tf.Const"() {value = dense<-0.996078491> : tensor} + // CHECK-DAG: [[VAL3:%.+]] = "tf.Const"() {value = dense<0.00784313772> : tensor} + // CHECK-DAG: [[VAL4:%.+]] = "tf.Const"() {value = dense<5.000000e-01> : tensor} + // CHECK-DAG: [[VAL5:%.+]] = "tf.ClipByValue"(%arg0, [[VAL2]], [[VAL1]]) + // CHECK-DAG: [[VAL6:%.+]] = "tf.Sub"([[VAL5]], [[VAL2]]) + // CHECK-DAG: [[VAL7:%.+]] = "tf.Mul"([[VAL6]], [[VAL0]]) + // CHECK-DAG: [[VAL8:%.+]] = "tf.Add"([[VAL7]], [[VAL4]]) + // CHECK-DAG: [[VAL9:%.+]] = "tf.Floor"([[VAL8]]) + // CHECK-DAG: [[VAL10:%.+]] = "tf.Mul"([[VAL9]], [[VAL3]]) + // CHECK-DAG: [[VAL11:%.+]] = "tf.Add"([[VAL10]], [[VAL2]]) + %0 = "tf.FakeQuantWithMinMaxArgs"(%arg0) {max = 1.0 : f32, min = -1.0 : f32, narrow_range = false, num_bits = 8 : i64} : (tensor) -> tensor + + // CHECK: return [[VAL11]] + return %0 : tensor +} + +func @fake_quant_with_min_max_vars(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> tensor { + // CHECK-DAG: [[VAL0:%.+]] = "tf.Const"() {value = dense<0.000000e+00> + // CHECK-DAG: [[VAL1:%.+]] = "tf.Const"() {value = dense<2.550000e+02> + // CHECK-DAG: [[VAL2:%.+]] = "tf.Const"() {value = dense<1.000000e+00> + // CHECK-DAG: [[VAL3:%.+]] = "tf.Const"() {value = dense<5.000000e-01> + // CHECK-DAG: [[VAL4:%.+]] = "tf.Sub"(%arg2, %arg1) + // CHECK-DAG: [[VAL5:%.+]] = "tf.Div"([[VAL4]], [[VAL1]]) + // CHECK-DAG: [[VAL6:%.+]] = "tf.Div"([[VAL1]], [[VAL4]]) + // CHECK-DAG: [[VAL7:%.+]] = "tf.Div"(%arg1, [[VAL5]]) + // CHECK-DAG: [[VAL8:%.+]] = "tf.Sub"([[VAL0]], [[VAL7]]) + // CHECK-DAG: [[VAL9:%.+]] = "tf.Floor"([[VAL8]]) + // CHECK-DAG: [[VAL10:%.+]] = "tf.Sub"([[VAL8]], [[VAL9]]) + // CHECK-DAG: [[VAL11:%.+]] = "tf.Less"([[VAL10]], [[VAL3]]) + // CHECK-DAG: [[VAL12:%.+]] = "tf.Add"([[VAL2]], [[VAL9]]) + // CHECK-DAG: [[VAL13:%.+]] = "tf.Select"([[VAL11]], [[VAL9]], [[VAL12]]) + // CHECK-DAG: [[VAL14:%.+]] = "tf.ClipByValue"([[VAL13]], [[VAL0]], [[VAL1]]) : + // CHECK-DAG: [[VAL15:%.+]] = "tf.Sub"([[VAL0]], [[VAL14]]) + // CHECK-DAG: [[VAL16:%.+]] = "tf.Sub"([[VAL1]], [[VAL14]]) + // CHECK-DAG: [[VAL17:%.+]] = "tf.Mul"([[VAL15]], [[VAL5]]) + // CHECK-DAG: [[VAL18:%.+]] = "tf.Mul"([[VAL16]], [[VAL5]]) + // CHECK-DAG: [[VAL19:%.+]] = "tf.ClipByValue"(%arg0, [[VAL17]], [[VAL18]]) + // CHECK-DAG: [[VAL20:%.+]] = "tf.Sub"([[VAL19]], [[VAL17]]) + // CHECK-DAG: [[VAL21:%.+]] = "tf.Mul"([[VAL20]], [[VAL6]]) + // CHECK-DAG: [[VAL22:%.+]] = "tf.Add"([[VAL21]], [[VAL3]]) + // CHECK-DAG: [[VAL23:%.+]] = "tf.Floor"([[VAL22]]) + // CHECK-DAG: [[VAL24:%.+]] = "tf.Mul"([[VAL23]], [[VAL5]]) + // CHECK-DAG: [[VAL25:%.+]] = "tf.Add"([[VAL24]], [[VAL17]]) + %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) {narrow_range = false, num_bits = 8 : i64} : (tensor, tensor, tensor) -> tensor + + // CHECK: return [[VAL25]] + return %0 : tensor +} + // CHECK-LABEL: SoftmaxCrossEntropyWithLogits // CHECK-SAME: %[[FEATURES:.*]]: tensor<2x3xf32>, %[[LABELS:.*]]: tensor<2x3xf32> func @SoftmaxCrossEntropyWithLogits(%features: tensor<2x3xf32>, %labels: tensor<2x3xf32>) -> (tensor<2xf32>, tensor<2x3xf32>) { @@ -533,3 +648,59 @@ func @_UnaryOpsComposition(%arg0: tensor<4xf32>) -> tensor<4xf32> { %0 = "tf._UnaryOpsComposition"(%arg0) {op_names = ["Asin", "Abs", "Log"]} : (tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } + + +// CHECK-LABEL: @round_int +func @round_int(%arg0: tensor<2xi32>) -> tensor<2xi32> { + // CHECK: [[IDENTITY:%.+]] = "tf.Identity"(%arg0) + %0 = "tf.Round"(%arg0) : (tensor<2xi32>) -> tensor<2xi32> + // CHECK: return [[IDENTITY]] + return %0 : tensor<2xi32> +} + +// CHECK-LABEL: @round +func @round(%arg0: tensor<2xf32>) -> tensor<2xf32> { + // CHECK-DAG: [[FLOOR:%.+]] = "tf.Floor"(%arg0) + // CHECK-DAG: [[SUB:%.+]] = "tf.Sub"(%arg0, [[FLOOR]]) + // CHECK-DAG: [[HALF:%.+]] = "tf.Const"() {value = dense<5.000000e-01> : tensor} + // CHECK-DAG: [[CMP:%.+]] = "tf.Less"([[SUB]], [[HALF]]) + // CHECK-DAG: [[ONE:%.+]] = "tf.Const"() {value = dense<1.000000e+00> : tensor} + // CHECK-DAG: [[ADD:%.+]] = "tf.Add"([[ONE]], [[FLOOR]]) + // CHECK-DAG: [[SELECT:%.+]] = "tf.Select"([[CMP]], [[FLOOR]], [[ADD]]) + %0 = "tf.Round"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + + // CHECK: return [[SELECT]] + return %0 : tensor<2xf32> +} + +// CHECK-LABEL: func @round_dynamic +func @round_dynamic(%arg0: tensor) -> tensor { + // CHECK-DAG: [[FLOOR:%.+]] = "tf.Floor"(%arg0) + // CHECK-DAG: [[SUB:%.+]] = "tf.Sub"(%arg0, [[FLOOR]]) + // CHECK-DAG: [[HALF:%.+]] = "tf.Const"() {value = dense<5.000000e-01> : tensor} + // CHECK-DAG: [[CMP:%.+]] = "tf.Less"([[SUB]], [[HALF]]) + // CHECK-DAG: [[ONE:%.+]] = "tf.Const"() {value = dense<1.000000e+00> : tensor} + // CHECK-DAG: [[ADD:%.+]] = "tf.Add"([[ONE]], [[FLOOR]]) + // CHECK-DAG: [[SELECT:%.+]] = "tf.Select"([[CMP]], [[FLOOR]], [[ADD]]) + %0 = "tf.Round"(%arg0) : (tensor) -> tensor + + // CHECK: return [[SELECT]] + return %0 : tensor +} + +// CHECK-LABEL: func @round_unranked +func @round_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { + %0 = "tf.Round"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// CHECK-LABEL: func @lgamma +func @lgamma(%arg0: tensor<4xf32>) -> tensor<4xf32> { + // The lowering for lgamma is complicated, which makes it awkward to write a + // complete test for it here. Instead we test that Lgamma is at least being + // lowered here and rely on UnaryOpsTest.testFloatOps and other TensorFlow + // tests to check it is lowered correctly and with sufficient precision. + // CHECK-NOT: tf.Lgamma + %0 = "tf.Lgamma"(%arg0) : (tensor<4xf32>) -> tensor<4xf32> + return %0 : tensor<4xf32> +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mark_ops_for_outside_compilation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mark_ops_for_outside_compilation.mlir index dc99d9d6343..c8a6d5489c3 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/mark_ops_for_outside_compilation.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/mark_ops_for_outside_compilation.mlir @@ -74,6 +74,17 @@ func @ignore_embedding_ops() -> () { return } +// CHECK-LABEL: func @ignore_stack_ops +func @ignore_stack_ops(%arg0: tensor) -> () { + "tf_device.cluster"() ( { + // CHECK: "tf.StackV2" + // CHECK-NOT: _xla_outside_compilation + %0 = "tf.StackV2"(%arg0) {elem_type = f32, stack_name = "s"} : (tensor) -> tensor + tf_device.return + }) {allow_soft_placement = true, num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> () + return +} + // CHECK-LABEL: func @op_string_result func @op_string_result() -> tensor { %0 = "tf_device.cluster"() ( { @@ -127,17 +138,17 @@ func @op_string_operand_string_result(%arg0: tensor) -> tensor return %0 : tensor } -// Test that a tf.IfRegion op with a captured string operand is marked for outside compilation. +// Test that operations inside tf.IfRegion op are corrected marked for outside +// compilation. -// CHECK-LABEL: func @if_region_captured_string -func @if_region_captured_string(%arg0: tensor, %arg1: tensor) -> tensor { +// CHECK-LABEL: func @ops_inside_tf_if_outside_compiled +func @ops_inside_tf_if_outside_compiled(%arg0: tensor, %arg1: tensor) -> tensor { %0 = "tf_device.cluster"() ( { - // CHECK: "tf.Const"() {value = dense<1> : tensor} - // CHECK-NOT: _xla_outside_compilation - // CHECK: "tf.IfRegion" - // CHECK: "tf.StringToNumber" - // CHECK-NOT: _xla_outside_compilation - // CHECK: _xla_outside_compilation = "auto1", is_stateless = true + // CHECK: "tf.Const"() {value = dense<1> : tensor} + // CHECK-NOT: _xla_outside_compilation + // CHECK: "tf.IfRegion" + // CHECK: "tf.StringToNumber" + // CHECK-SAME: _xla_outside_compilation %1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor %2 = "tf.IfRegion"(%arg0) ( { %3 = "tf.StringToNumber"(%arg1) {out_type = f32} : (tensor) -> tensor @@ -152,7 +163,8 @@ func @if_region_captured_string(%arg0: tensor, %arg1: tensor) -> return %0 : tensor } -// Test that ops with string results/operands inside a tf.IfRegion branch are marked for outside compilation. +// Test that ops with string results/operands inside a tf.IfRegion branch are +// marked for outside compilation. // CHECK-LABEL: func @if_region_string_op func @if_region_string_op(%arg0: tensor, %arg1: tensor) -> tensor { @@ -180,7 +192,8 @@ func @if_region_string_op(%arg0: tensor, %arg1: tensor) -> tensor } -// Test that ops with string results/operands inside a nested tf.IfRegion branch are marked for outside compilation. +// Test that ops with string results/operands inside a nested tf.IfRegion branch +// are marked for outside compilation. // CHECK-LABEL: func @nested_if_region_string_op func @nested_if_region_string_op(%arg0: tensor, %arg1: tensor) -> tensor { @@ -220,16 +233,17 @@ func @nested_if_region_string_op(%arg0: tensor, %arg1: tensor) -> ten return %0 : tensor } -// Test that a tf.WhileRegion op with a captured string operand is marked for outside compilation. +// Test that ops inside tf.WhileRegion op are correct marked for outside +// compilation. -// CHECK-LABEL: func @while_region_captured_string -func @while_region_captured_string(%arg0: tensor, %arg1: tensor) -> tensor { +// CHECK-LABEL: func @ops_inside_while_outside_compiled +func @ops_inside_while_outside_compiled(%arg0: tensor, %arg1: tensor) -> tensor { %0 = "tf_device.cluster"() ( { - // CHECK: "tf.Const"() {value = dense<1.000000e+00> : tensor} + // CHECK: "tf.Const"() {value = dense<1.000000e+00> : tensor} // CHECK-NOT: _xla_outside_compilation - // CHECK: "tf.WhileRegion" - // CHECK: "tf.StringToNumber" - // CHECK: _xla_outside_compilation = "auto1", is_stateless = true + // CHECK: "tf.WhileRegion" + // CHECK: "tf.StringToNumber" + // CHECK-SAME: _xla_outside_compilation %1 = "tf.Const"() {value = dense<1.0> : tensor} : () -> tensor %2:2 = "tf.WhileRegion"(%1, %arg0) ( { ^bb0(%carg0: tensor, %carg1: tensor): @@ -284,3 +298,31 @@ func @while_region_unsupported_op(%arg0: tensor, %arg1: tensor) }) {allow_soft_placement = true, num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor return %0 : tensor } + +// Checks that ops with inputs and outputs with string subtypes are marked +// for outside compilation. + +// CHECK-LABEL: func @check_op_with_variant_string_subtypes_outside_compiled +func @check_op_with_variant_string_subtypes_outside_compiled(%arg0: tensor, %arg1: tensor, %arg2: tensor<3xi32>) -> () { + "tf_device.cluster"() ( { + // CHECK: "tf.TensorListReserve" + // CHECK-SAME: _xla_outside_compilation + // CHECK: "tf.TensorListGetItem" + // CHECK-SAME: _xla_outside_compilation + %0 = "tf.TensorListReserve"(%arg0, %arg1) : (tensor, tensor) -> tensor>> + "tf.TensorListGetItem"(%0, %arg1, %arg2) : (tensor>>, tensor, tensor<3xi32>) -> tensor<24x24x64xui8> + tf_device.return + }) {allow_soft_placement = true, num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> () + return +} +// CHECK-LABEL: func @check_op_with_resource_string_subtypes_outside_compiled +func @check_op_with_resource_string_subtypes_outside_compiled(%arg0: tensor, %arg1: tensor, %arg2: tensor>>) -> () { + "tf_device.cluster"() ( { + // CHECK: "tf.VarHandleOp" + // CHECK-SAME: _xla_outside_compilation + "tf.VarHandleOp"() {allowed_devices = [], container = "", device = "", shared_name = ""} : () -> tensor>> + tf_device.return + }) {allow_soft_placement = true, num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> () + return +} + diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/BUILD b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/BUILD index cbdf5d96d0e..b98ed445e86 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/BUILD @@ -1,3 +1,4 @@ +load("//tensorflow:tensorflow.bzl", "filegroup") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") licenses(["notice"]) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/parallel_execute_to_islands.mlir b/tensorflow/compiler/mlir/tensorflow/tests/parallel_execute_to_islands.mlir index 52dc06cd393..03cac7dbd33 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/parallel_execute_to_islands.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/parallel_execute_to_islands.mlir @@ -1,13 +1,12 @@ // RUN: tf-opt %s -tf-parallel-execute-to-islands | FILECHECK_OPTS="" FileCheck %s -// CHECK-LABEL: func @check_regions_to_islands -func @check_regions_to_islands() { +// CHECK-LABEL: func @testEmptyRegions +func @testEmptyRegions() { tf_executor.graph { tf_executor.island() { "tf_device.parallel_execute"() ({ tf_device.return - }, - { + }, { tf_device.return }) {} : () -> () tf_executor.yield @@ -17,210 +16,133 @@ func @check_regions_to_islands() { return } -// CHECK: %[[ISLAND_1_CTL:[a-z_0-9]*]] = tf_executor.island { +// CHECK: [[ISLAND_0_CTRL:%.+]] = tf_executor.island { // CHECK: tf_executor.yield -// CHECK: %[[ISLAND_2_CTL:[a-z_0-9]*]] = tf_executor.island { +// CHECK: [[ISLAND_1_CTRL:%.+]] = tf_executor.island { // CHECK: tf_executor.yield -// CHECK: %{{.*}} = tf_executor.island(%[[ISLAND_1_CTL]], %[[ISLAND_2_CTL]]) { -// CHECK-NEXT: tf_executor.yield +// CHECK: tf_executor.fetch [[ISLAND_0_CTRL]], [[ISLAND_1_CTRL]] : -// CHECK-LABEL: func @check_regions_to_islands_with_inputs -// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor) -func @check_regions_to_islands_with_inputs(%arg0 : tensor) { - tf_executor.graph { +// CHECK-LABEL: func @testDataOperandsAndResults +// CHECK-SAME: ([[ARG_0:%.+]]: tensor) +func @testDataOperandsAndResults(%arg0 : tensor) { + %0:2 = tf_executor.graph { %1:2 = tf_executor.island { %2 = "tf.opA"(%arg0) : (tensor) -> tensor tf_executor.yield %2 : tensor } - tf_executor.island() { - "tf_device.parallel_execute"() ({ - %3 = "tf.opB"(%1#0) : (tensor) -> tensor - tf_device.return %3 : tensor - }, - { + %3:3 = tf_executor.island() { + %4:2 = "tf_device.parallel_execute"() ({ + %5 = "tf.opB"(%1#0) : (tensor) -> tensor + tf_device.return %5 : tensor + }, { %5 = "tf.opC"(%1#0) : (tensor) -> tensor tf_device.return %5 : tensor }) {} : () -> (tensor, tensor) - tf_executor.yield + tf_executor.yield %4#0, %4#1 : tensor, tensor } - tf_executor.fetch + tf_executor.fetch %3#0, %3#1 : tensor, tensor } return } -// CHECK: %[[INPUT_A:[a-z_0-9]*]], %{{.*}} = tf_executor.island { -// CHECK-NEXT: %[[OP_A_OUTPUT:[a-z_0-9]*]] = "tf.opA"(%[[ARG_0]]) : (tensor) -> tensor -// CHECK-NEXT: tf_executor.yield %[[OP_A_OUTPUT]] : tensor -// CHECK: %[[INPUT_0:[a-z_0-9]*]], %[[INPUT_CONTROL:[a-z_0-9]*]] = tf_executor.island { -// CHECK-NEXT: tf_executor.yield %[[INPUT_A]] : tensor -// CHECK: %[[ISLAND_1_OUTPUT:[a-z_0-9]*]], %[[ISLAND_1_CTL:[a-z_0-9]*]] = tf_executor.island { -// CHECK-NEXT: %[[OP_B_OUTPUT:[a-z_0-9]*]] = "tf.opB"(%[[INPUT_0]]) : (tensor) -> tensor -// CHECK: tf_executor.yield %[[OP_B_OUTPUT]] : tensor -// CHECK: %[[ISLAND_2_OUTPUT:[a-z_0-9]*]], %[[ISLAND_2_CTL:[a-z_0-9]*]] = tf_executor.island { -// CHECK-NEXT: %[[OP_C_OUTPUT:[a-z_0-9]*]] = "tf.opC"(%outputs_0) : (tensor) -> tensor -// CHECK: tf_executor.yield %[[OP_C_OUTPUT]] : tensor -// CHECK: %{{.*}} = tf_executor.island(%[[ISLAND_1_CTL]], %[[ISLAND_2_CTL]]) { -// CHECK-NEXT: tf_executor.yield +// CHECK: [[INPUT_A:%.+]], {{%.+}} = tf_executor.island { +// CHECK-NEXT: [[OP_A_OUTPUT:%.+]] = "tf.opA"([[ARG_0]]) +// CHECK-NEXT: tf_executor.yield [[OP_A_OUTPUT]] : +// CHECK: [[ISLAND_0_OUTPUT:%.+]], {{%.+}} = tf_executor.island { +// CHECK-NEXT: [[OP_B_OUTPUT:%.+]] = "tf.opB"([[INPUT_A]]) +// CHECK: tf_executor.yield [[OP_B_OUTPUT]] : +// CHECK: [[ISLAND_1_OUTPUT:%.+]], {{%.+}} = tf_executor.island { +// CHECK-NEXT: [[OP_C_OUTPUT:%.+]] = "tf.opC"([[INPUT_A]]) +// CHECK: tf_executor.yield [[OP_C_OUTPUT]] : +// CHECK: tf_executor.fetch [[ISLAND_0_OUTPUT]], [[ISLAND_1_OUTPUT]] : -// CHECK-LABEL: func @check_input_sink_island_forwards_control_inputs -// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor) -func @check_input_sink_island_forwards_control_inputs(%arg0 : tensor) { - tf_executor.graph { - %1:2 = tf_executor.island { - %2 = "tf.opA"(%arg0) : (tensor) -> tensor - tf_executor.yield %2 : tensor - } - %7 = tf_executor.ControlTrigger {} - %8 = tf_executor.ControlTrigger {} - tf_executor.island(%7, %8) { - "tf_device.parallel_execute"() ({ - %3 = "tf.opB"(%1#0) : (tensor) -> tensor - tf_device.return %3 : tensor - }, - { - %5 = "tf.opC"() : () -> tensor - tf_device.return %5 : tensor - }) {} : () -> (tensor, tensor) +// CHECK-LABEL: func @testControlOperands +func @testControlOperands() { + %0:2 = tf_executor.graph { + %1 = tf_executor.island { tf_executor.yield } - tf_executor.fetch - } - return -} - -// CHECK: %[[INPUT_A:[a-z_0-9]*]], %{{.*}} = tf_executor.island { -// CHECK-NEXT: %[[OP_A_OUTPUT:[a-z_0-9]*]] = "tf.opA"(%[[ARG_0]]) : (tensor) -> tensor -// CHECK-NEXT: tf_executor.yield %[[OP_A_OUTPUT]] : tensor -// CHECK: %[[CT_0:[0-9]*]] = tf_executor.ControlTrigger -// CHECK: %[[CT_1:[0-9]*]] = tf_executor.ControlTrigger -// CHECK: %[[INPUT_0:[a-z_0-9]*]], %[[INPUT_CONTROL:[a-z_0-9]*]] = tf_executor.island(%[[CT_0]], %[[CT_1]]) { -// CHECK-NEXT: tf_executor.yield %[[INPUT_A]] : tensor -// CHECK: %[[ISLAND_1_OUTPUT:[a-z_0-9]*]], %[[ISLAND_1_CTL:[a-z_0-9]*]] = tf_executor.island { -// CHECK-NEXT: %[[OP_B_OUTPUT:[a-z_0-9]*]] = "tf.opB"(%[[INPUT_0]]) : (tensor) -> tensor -// CHECK: tf_executor.yield %[[OP_B_OUTPUT]] : tensor -// CHECK: %[[ISLAND_2_OUTPUT:[a-z_0-9]*]], %[[ISLAND_2_CTL:[a-z_0-9]*]] = tf_executor.island(%[[INPUT_CONTROL]]) { -// CHECK-NEXT: %[[OP_C_OUTPUT:[a-z_0-9]*]] = "tf.opC"() : () -> tensor -// CHECK: tf_executor.yield %[[OP_C_OUTPUT]] : tensor -// CHECK: %{{.*}} = tf_executor.island(%[[ISLAND_1_CTL]], %[[ISLAND_2_CTL]]) { -// CHECK-NEXT: tf_executor.yield - - -// CHECK-LABEL: func @check_control_dep_added_when_region_does_not_have_inputs -// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor) -func @check_control_dep_added_when_region_does_not_have_inputs(%arg0 : tensor) { - tf_executor.graph { - %1:2 = tf_executor.island { - %2 = "tf.opA"(%arg0) : (tensor) -> tensor - tf_executor.yield %2 : tensor - } - %7:3 = tf_executor.island() { - %8:2 = "tf_device.parallel_execute"() ( - { - %3 = "tf.opB"() : () -> tensor - tf_device.return %3 : tensor - }, - { - %5 = "tf.opC"(%1#0) : (tensor) -> tensor - tf_device.return %5 : tensor - } - ) {} : () -> (tensor, tensor) - - tf_executor.yield %8#0, %8#1 : tensor, tensor - } - - tf_executor.island { - "tf.opD"(%7#0, %7#1) : (tensor, tensor) -> () - tf_executor.yield - } - tf_executor.fetch - } - return -} - -// CHECK: %[[INPUT_A:[a-z_0-9]*]], %{{.*}} = tf_executor.island { -// CHECK-NEXT: %[[OP_A_OUTPUT:[a-z_0-9]*]] = "tf.opA"(%[[ARG_0]]) : (tensor) -> tensor -// CHECK-NEXT: tf_executor.yield %[[OP_A_OUTPUT]] : tensor -// CHECK: %[[INPUT_0:[a-z_0-9]*]], %[[INPUT_CTL:[a-z_0-9]*]] = tf_executor.island { -// CHECK-NEXT: tf_executor.yield %[[INPUT_A]] : tensor -// CHECK: %[[ISLAND_1_OUTPUT:[a-z_0-9]*]], %{{.*}} = tf_executor.island(%[[INPUT_CTL]]) { -// CHECK-NEXT: %[[OP_B_OUTPUT:[a-z_0-9]*]] = "tf.opB"() : () -> tensor -// CHECK: tf_executor.yield %[[OP_B_OUTPUT]] : tensor -// CHECK: %[[ISLAND_2_OUTPUT:[a-z_0-9]*]], %{{.*}} = tf_executor.island { -// CHECK-NEXT: %[[OP_C_OUTPUT:[a-z_0-9]*]] = "tf.opC"(%outputs_0) : (tensor) -> tensor -// CHECK: tf_executor.yield %[[OP_C_OUTPUT]] : tensor -// CHECK: %{{.*}} = tf_executor.island { -// CHECK-NEXT: tf_executor.yield %[[ISLAND_1_OUTPUT]], %[[ISLAND_2_OUTPUT]] - - -// CHECK-LABEL: func @check_output_barrier_correctly_forwards_outputs -func @check_output_barrier_correctly_forwards_outputs(%arg0 : tensor) -> tensor { - %0 = tf_executor.graph { - %1:2 = tf_executor.island { - %2 = "tf.opA"(%arg0) : (tensor) -> tensor - tf_executor.yield %2 : tensor - } - %8:3 = tf_executor.island() { - %7:2 = "tf_device.parallel_execute"() ({ - %3 = "tf.opB"() : () -> tensor - tf_device.return %3 : tensor - }, - { - %5 = "tf.opC"(%1#0) : (tensor) -> tensor - tf_device.return %5 : tensor - }) {} : () -> (tensor, tensor) - tf_executor.yield %7#0, %7#1 : tensor, tensor - } - tf_executor.fetch %8#0 : tensor - } - return %0 : tensor -} - -// CHECK: %[[INPUT_A:[a-z_0-9]*]], %{{.*}} = tf_executor.island { -// CHECK-NEXT: %[[OP_A_OUTPUT:[a-z_0-9]*]] = "tf.opA"(%[[ARG_0]]) : (tensor) -> tensor -// CHECK-NEXT: tf_executor.yield %[[OP_A_OUTPUT]] : tensor -// CHECK: %[[INPUT_0:[a-z_0-9]*]], %[[INPUT_CTL:[a-z_0-9]*]] = tf_executor.island { -// CHECK-NEXT: tf_executor.yield %[[INPUT_A]] : tensor -// CHECK: %[[ISLAND_1_OUTPUT:[a-z_0-9]*]], %{{.*}} = tf_executor.island(%[[INPUT_CTL]]) { -// CHECK-NEXT: %[[OP_B_OUTPUT:[a-z_0-9]*]] = "tf.opB"() : () -> tensor -// CHECK: tf_executor.yield %[[OP_B_OUTPUT]] : tensor -// CHECK: %[[ISLAND_2_OUTPUT:[a-z_0-9]*]], %{{.*}} = tf_executor.island { -// CHECK-NEXT: %[[OP_C_OUTPUT:[a-z_0-9]*]] = "tf.opC"(%[[INPUT_0]]) : (tensor) -> tensor -// CHECK: tf_executor.yield %[[OP_C_OUTPUT]] : tensor -// CHECK: %[[OUTPUT_SINK_OUTPUT:[a-z_0-9]*]]:2, %[[OUTPUT_SINK_CTL:[a-z_0-9]*]] = tf_executor.island { -// CHECK-NEXT: tf_executor.yield %[[ISLAND_1_OUTPUT]], %[[ISLAND_2_OUTPUT]] : tensor, tensor - -// CHECK-LABEL: func @check_parallel_execute_using_args -// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor) -func @check_parallel_execute_using_args(%arg0 : tensor) { - tf_executor.graph { - %1:2 = tf_executor.island { - %2 = "tf.opA"(%arg0) : (tensor) -> tensor - tf_executor.yield %2 : tensor - } - %2:2 = tf_executor.island { - %3 = "tf.opB"(%arg0) : (tensor) -> tensor - tf_executor.yield %3 : tensor - } - tf_executor.island() { - "tf_device.parallel_execute"() ({ - %4 = "tf.opC"(%arg0, %1#0) : (tensor, tensor) -> tensor + %2:3 = tf_executor.island(%1) { + %3:2 = "tf_device.parallel_execute"() ({ + %4 = "tf.opA"() : () -> tensor tf_device.return %4 : tensor - }, - { - %5 = "tf.opD"(%arg0, %2#0) : (tensor, tensor) -> tensor - tf_device.return %5 : tensor + }, { + %4 = "tf.opB"() : () -> tensor + tf_device.return %4 : tensor }) {} : () -> (tensor, tensor) - tf_executor.yield + tf_executor.yield %3#0, %3#1 : tensor, tensor } - tf_executor.fetch + tf_executor.fetch %2#0, %2#1 : tensor, tensor } return } -// Verify that args are directly accessed in newly created island without alias -// through entry barrier. +// CHECK: [[INPUT_CTRL:%.+]] = tf_executor.island { +// CHECK: [[ISLAND_0_OUTPUT:%.+]], {{%.+}} = tf_executor.island([[INPUT_CTRL]]) { +// CHECK-NEXT: [[OP_A_OUTPUT:%.+]] = "tf.opA"() +// CHECK: tf_executor.yield [[OP_A_OUTPUT]] : +// CHECK: [[ISLAND_1_OUTPUT:%.+]], {{%.+}} = tf_executor.island([[INPUT_CTRL]]) { +// CHECK-NEXT: [[OP_B_OUTPUT:%.+]] = "tf.opB"() +// CHECK: tf_executor.yield [[OP_B_OUTPUT]] : +// CHECK: tf_executor.fetch [[ISLAND_0_OUTPUT]], [[ISLAND_1_OUTPUT]] : -// CHECK: "tf.opC"(%[[ARG_0]] -// CHECK: "tf.opD"(%[[ARG_0]] + +// CHECK-LABEL: func @testControlResults +func @testControlResults() { + tf_executor.graph { + %0:3 = tf_executor.island { + %1:2 = "tf_device.parallel_execute"() ({ + %2 = "tf.opA"() : () -> tensor + tf_device.return %2 : tensor + }, { + %2 = "tf.opB"() : () -> tensor + tf_device.return %2 : tensor + }) {} : () -> (tensor, tensor) + tf_executor.yield %1#0, %1#1 : tensor, tensor + } + %3 = tf_executor.island(%0#2) { + tf_executor.yield + } + tf_executor.fetch %3 : !tf_executor.control + } + return +} + +// CHECK: {{%.+}}, [[ISLAND_0_CTRL:%.+]] = tf_executor.island { +// CHECK-NEXT: [[OP_A_OUTPUT:%.+]] = "tf.opA"() +// CHECK: tf_executor.yield [[OP_A_OUTPUT]] : +// CHECK: {{%.+}}, [[ISLAND_1_CTRL:%.+]] = tf_executor.island { +// CHECK-NEXT: [[OP_B_OUTPUT:%.+]] = "tf.opB"() +// CHECK: tf_executor.yield [[OP_B_OUTPUT]] : +// CHECK: [[OUTPUT_CTRL:%.+]] = tf_executor.island([[ISLAND_0_CTRL]], [[ISLAND_1_CTRL]]) { +// CHECK: [[FETCH_ISLAND:%.+]] = tf_executor.island([[OUTPUT_CTRL]]) { +// CHECK: tf_executor.fetch [[FETCH_ISLAND]] : !tf_executor.control + + +// CHECK-LABEL: func @testSomeRegionNoUsers +func @testSomeRegionNoUsers() { + %0 = tf_executor.graph { + %1:3 = tf_executor.island { + %2:2 = "tf_device.parallel_execute"() ({ + %3 = "tf.opA"() : () -> tensor + tf_device.return %3 : tensor + }, { + %3 = "tf.opB"() : () -> tensor + tf_device.return %3 : tensor + }) {} : () -> (tensor, tensor) + tf_executor.yield %2#0, %2#1 : tensor, tensor + } + tf_executor.fetch %1#0 : tensor + } + return +} + +// CHECK: [[ISLAND_0_OUTPUT:%.+]], {{%.+}} = tf_executor.island { +// CHECK-NEXT: [[OP_A_OUTPUT:%.+]] = "tf.opA"() +// CHECK: tf_executor.yield [[OP_A_OUTPUT]] : +// CHECK: {{%.+}}, [[ISLAND_1_CTRL:%.+]] = tf_executor.island { +// CHECK-NEXT: [[OP_B_OUTPUT:%.+]] = "tf.opB"() +// CHECK: tf_executor.yield [[OP_B_OUTPUT]] : +// CHECK: tf_executor.fetch [[ISLAND_0_OUTPUT]], [[ISLAND_1_CTRL]] : diff --git a/tensorflow/compiler/mlir/tensorflow/tests/promote_resources_to_args.mlir b/tensorflow/compiler/mlir/tensorflow/tests/promote_resources_to_args.mlir index 3e6d4f37bac..0813ee8db90 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/promote_resources_to_args.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/promote_resources_to_args.mlir @@ -258,6 +258,19 @@ func @main(%arg0: tensor>>, %arg1: tensor) { // ----- +// Tests removal of dead local variables. + +// CHECK-LABEL: func @main +func @main(%arg0: tensor<2xf32>) { + // CHECK-NOT: tf.MlirLocalVarOp + // CHECK-NOT: tf.AssignVariableOp + %0 = "tf.MlirLocalVarOp"() : () -> tensor>> + "tf.AssignVariableOp"(%0, %arg0) : (tensor>>, tensor<2xf32>) -> () + return +} + +// ----- + // Tests first read of one resource is used as a value to write to another // resource. @@ -272,6 +285,26 @@ func @main(%arg0: tensor>>, %arg1: tensor) -> tensor<2xf32> { + // CHECK-NOT: tf.MlirLocalVarOp + // CHECK-NOT: tf.AssignVariableOp + %0 = "tf.MlirLocalVarOp"() : () -> tensor>> + %1 = "tf._SomeOp"() : () -> tensor<2xf32> + "tf.AssignVariableOp"(%0, %1) : (tensor>>, tensor<2xf32>) -> () + %2 = "tf.PartitionedCall"(%0) {config = "", config_proto = "", executor_type = "", f = @callee} : (tensor>>) -> tensor<2xf32> + return %2 : tensor<2xf32> +} +func @callee(%arg0: tensor>>) -> tensor<2xf32> attributes {sym_visibility = "private"} { + %0 = "tf.ReadVariableOp"(%arg0) : (tensor>>) -> tensor<2xf32> + return %0 : tensor<2xf32> +} + + // ----- // Tests main function with multiple blocks. diff --git a/tensorflow/compiler/mlir/tensorflow/tests/resource-alias-analysis-test.mlir b/tensorflow/compiler/mlir/tensorflow/tests/resource-alias-analysis-test.mlir index da0a2df9e6a..e857831e6be 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/resource-alias-analysis-test.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/resource-alias-analysis-test.mlir @@ -106,6 +106,52 @@ func @if_else(%arg0: !tf_res, %arg1: !tf_res) -> (!tf_res, !tf_res, !tf_res) { return %id0, %id0, %arg1 : !tf_res, !tf_res, !tf_res } +// ----- +// Test aliasing through CaseOp + +!tf_res = type tensor<*x!tf.resource>> + +// CHECK-LABEL: func @case_op_aliasing +// expected-remark@below {{Region #0, Arg #0, ID 4 : 1, 4}} +// expected-remark@below {{Region #0, Arg #1, ID 5 : 1, 2, 3, 5}} +func @case_op_aliasing(%arg0: !tf_res, %arg1: !tf_res) { + // expected-remark@below {{Result #0, ID 0 : 0}} + %vh0 = "tf.VarHandleOp"() {container = "c", shared_name = "v0"} : () -> !tf_res + %read0 = "tf.ReadVariableOp"(%vh0) : (!tf_res) -> tensor + // expected-remark@below {{Result #0, ID 1 : Unknown}} + // expected-remark@below {{Result #1, ID 2 : 1, 2, 3, 5}} + // expected-remark@below {{Result #2, ID 3 : 0, 1, 2, 3, 5}} + %if:3 = "tf.Case"(%read0, %arg1, %vh0) { + branches = [@case_branch0, @case_branch1, @case_branch2], + is_stateless = true + } : (tensor, !tf_res, !tf_res) -> (!tf_res, !tf_res, !tf_res) + return +} + +// expected-remark@below {{Region #0, Arg #0, ID 2 : 0, 1, 2}} +// expected-remark@below {{Region #0, Arg #1, ID 3 : 0, 3}} +func @case_branch0(%arg0: !tf_res, %arg1: !tf_res) -> (!tf_res, !tf_res, !tf_res) { + // expected-remark@below {{Result #0, ID 0 : Unknown}} + %u0 = "tf._UnknownSideEffectingOp_"() : () -> !tf_res + // expected-remark@below {{Result #0, ID 1 : 0, 1, 2}} + %id0 = "tf.Identity"(%arg0) : (!tf_res) -> !tf_res + return %u0, %id0, %id0 : !tf_res, !tf_res, !tf_res +} + +// expected-remark@below {{Region #0, Arg #0, ID 1 : 0, 1}} +// expected-remark@below {{Region #0, Arg #1, ID 2 : 2}} +func @case_branch1(%arg0: !tf_res, %arg1: !tf_res) -> (!tf_res, !tf_res, !tf_res) { + // expected-remark@below {{Result #0, ID 0 : 0, 1}} + %id0 = "tf.Identity"(%arg0) : (!tf_res) -> !tf_res + return %id0, %id0, %arg1 : !tf_res, !tf_res, !tf_res +} + +// expected-remark@below {{Region #0, Arg #0, ID 0 : 0}} +// expected-remark@below {{Region #0, Arg #1, ID 1 : 1}} +func @case_branch2(%arg0: !tf_res, %arg1: !tf_res) -> (!tf_res, !tf_res, !tf_res) { + return %arg0, %arg0, %arg1 : !tf_res, !tf_res, !tf_res +} + // ----- // Test aliasing through WhileOp !tf_res = type tensor<*x!tf.resource>> @@ -199,6 +245,37 @@ func @if_region_aliasing(%arg0: !tf_res, %arg1: !tf_res) { return } +// ----- +// Test aliasing through CaseRegion + +!tf_res = type tensor<*x!tf.resource>> + +// CHECK-LABEL: func @case_region_aliasing +// expected-remark@below {{Region #0, Arg #0, ID 7 : 1, 4, 6, 7}} +// expected-remark@below {{Region #0, Arg #1, ID 8 : 1, 2, 4, 5, 6, 8}} +func @case_region_aliasing(%arg0: !tf_res, %arg1: !tf_res) { + // expected-remark@below {{Result #0, ID 0 : 0, 1, 3, 4, 5}} + %vh0 = "tf.VarHandleOp"() {container = "c", shared_name = "v0"} : () -> !tf_res + %read0 = "tf.ReadVariableOp"(%vh0) : (!tf_res) -> tensor + // expected-remark@below {{Result #0, ID 4 : Unknown}} + // expected-remark@below {{Result #1, ID 5 : 0, 1, 2, 3, 4, 5, 6, 8}} + // expected-remark@below {{Result #2, ID 6 : 1, 2, 4, 5, 6, 7, 8}} + %if:3 = "tf.CaseRegion"(%read0) ({ + // expected-remark@below {{Result #0, ID 1 : Unknown}} + %u0 = "tf._UnknownSideEffectingOp_"() : () -> !tf_res + // expected-remark@below {{Result #0, ID 2 : 1, 2, 4, 5, 6, 8}} + %id0 = "tf.Identity"(%arg1) : (!tf_res) -> !tf_res + "tf.Yield"(%u0, %id0, %id0) : (!tf_res, !tf_res, !tf_res) -> () + }, { + // expected-remark@below {{Result #0, ID 3 : 0, 1, 3, 4, 5}} + %id0 = "tf.Identity"(%vh0) : (!tf_res) -> !tf_res + "tf.Yield"(%id0, %id0, %arg0) : (!tf_res, !tf_res, !tf_res) -> () + }, { + "tf.Yield"(%vh0, %arg1, %arg1) : (!tf_res, !tf_res, !tf_res) -> () + }) {is_stateless = true} : (tensor) -> (!tf_res, !tf_res, !tf_res) + return +} + // ----- // Test aliasing through WhileRegion !tf_res = type tensor<*x!tf.resource>> diff --git a/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir b/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir index 213ca402f56..79b90b67956 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir @@ -8,7 +8,7 @@ func @only_resource_load() -> tensor<*xi32> { // CHECK: %[[RES_HANDLE:[0-9]*]] = "tf.VarHandleOp" %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource> - // CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]]) {dtype = i32} + // CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]]) // CHECK: "tf_device.cluster" // CHECK: %[[COMPUTE_RES:[0-9]*]] = "tf.SomeComputation"(%[[RES_READ_VAL]]) // CHECK: tf_device.return %[[COMPUTE_RES]] @@ -39,7 +39,7 @@ func @only_resource_store() -> tensor<*xi32> { // CHECK: tf_device.return %[[COMPUTE_RES]], %[[COMPUTE_RES]] // CHECK: {cluster_attr = "cluster_attr"} // CHECK-SAME: () -> (tensor<*xi32>, tensor<*xi32>) - // CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[CLUSTER_RES]]#1) {dtype = i32} + // CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[CLUSTER_RES]]#1) %1 = "tf_device.cluster"() ( { %2 = "tf.SomeComputation"() : () -> (tensor<*xi32>) @@ -61,13 +61,13 @@ func @same_resource_load_and_store() -> tensor<*xi32> { // CHECK: %[[RES_HANDLE:[0-9]*]] = "tf.VarHandleOp" %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource> - // CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]]) {dtype = i32} + // CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]]) // CHECK: %[[CLUSTER_RES:[0-9]*]]:2 = "tf_device.cluster" // CHECK: %[[COMPUTE_RES:[0-9]*]] = "tf.SomeComputation"(%[[RES_READ_VAL]]) // CHECK: tf_device.return %[[COMPUTE_RES]], %[[COMPUTE_RES]] // CHECK: {cluster_attr = "cluster_attr"} // CHECK-SAME: () -> (tensor<*xi32>, tensor<*xi32>) - // CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[CLUSTER_RES]]#1) {dtype = i32} + // CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[CLUSTER_RES]]#1) %1 = "tf_device.cluster"() ( { %2 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource>) -> tensor<*xi32> @@ -308,6 +308,7 @@ func @while_cond1(%arg0: tensor<*x!tf.resource>>, %arg1: tensor<*x!t func @cluster_with_loop() -> () { %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> "tf_device.cluster"() ( { + // expected-error@+1 {{result #0 not tied to function argument for branch @while_body}} %1 = "tf.While"(%0) { body = @while_body, cond = @while_cond, device = "", is_stateless = false} : (tensor<*x!tf.resource>>) -> (tensor<*x!tf.resource>>) @@ -317,7 +318,6 @@ func @cluster_with_loop() -> () { } func @while_body(%arg0: tensor<*x!tf.resource>>) -> (tensor<*x!tf.resource>>) { %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v2"} : () -> tensor<*x!tf.resource>> - // expected-error @+1 {{resource used in while loop is only supported when the resource input and output alias each other in the loop body}} return %0 : tensor<*x!tf.resource>> } func @while_cond(%arg0: tensor<*x!tf.resource>>) -> tensor { @@ -332,6 +332,7 @@ func @while_cond(%arg0: tensor<*x!tf.resource>>) -> tensor { func @cluster_with_loop() -> () { %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> "tf_device.cluster"() ( { + // expected-error@+1 {{found resource write in loop condition.}} %1 = "tf.While"(%0) { body = @while_body, cond = @while_cond, device = "", is_stateless = false} : (tensor<*x!tf.resource>>) -> (tensor<*x!tf.resource>>) @@ -347,7 +348,6 @@ func @while_body(%arg0: tensor<*x!tf.resource>>) -> (tensor<*x!tf.re func @while_cond(%arg0: tensor<*x!tf.resource>>) -> tensor { %read = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource>>) -> tensor %constant = "tf.Const"() {value = dense<0.0> : tensor} : () -> tensor - // expected-error @+1 {{found resource write in loop condition.}} "tf.AssignVariableOp"(%arg0, %constant) : (tensor<*x!tf.resource>>, tensor) -> () return %read : tensor } @@ -527,7 +527,7 @@ func @cluster_with_if(%arg0: tensor) -> tensor<4xf32> { %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v2"} : () -> tensor<*x!tf.resource>> %2 = "tf_device.cluster"() ( { - // expected-error @+1 {{unsupported output: resource does not alias a single input}} + // expected-error @+1 {{result #0 is not tied to the same argument across all branches}} %3 = "tf.If"(%arg0, %0, %1) {then_branch = @if_then, else_branch = @if_else, is_stateless = false} : (tensor, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>) @@ -554,7 +554,7 @@ func @cluster_with_if(%arg0: tensor) -> tensor<4xf32> { %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v2"} : () -> tensor<*x!tf.resource>> %2 = "tf_device.cluster"() ( { - // expected-error @+1 {{unsupported output: resource does not alias input}} + // expected-error @+1 {{result #0 not tied to function argument for branch @if_then}} %3 = "tf.If"(%arg0, %0, %1) {then_branch = @if_then, else_branch = @if_else, is_stateless = false} : (tensor, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>) @@ -713,3 +713,507 @@ func @callee(%arg0: tensor<*x!tf.resource>>) -> tensor { // CHECK: func @callee_resource_lifted(%[[A0:.*]]: tensor) -> tensor // CHECK-NEXT: return %[[A0]] + +// ----- + +// Test that the pass can lift resources out of IfRegion +// CHECK: func @cluster_with_ifregion(%[[ARG0:.*]]: tensor) -> tensor<4xf32> +func @cluster_with_ifregion(%arg0: tensor) -> tensor<4xf32> { + // CHECK: %[[VH0:.*]] = "tf.VarHandleOp"() + // CHECK: %[[VH1:.*]] = "tf.VarHandleOp"() + // CHECK: %[[READ1:.*]] = "tf.ReadVariableOp"(%[[VH1]]) + // CHECK: %[[CLUSTER:.*]]:2 = "tf_device.cluster"() + // CHECK: %[[IF:.*]]:2 = "tf.IfRegion"(%[[ARG0]]) + %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> + %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v2"} : () -> tensor<*x!tf.resource>> + %2 = "tf_device.cluster"() ( { + %3:2 = "tf.IfRegion"(%arg0) ({ + // CHECK-NEXT: %[[CONST:.*]] = "tf.Const"() + // CHECK-NEXT: "tf.Yield"(%[[CONST]], %[[CONST]]) + %constant = "tf.Const"() {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> + "tf.AssignVariableOp"(%0, %constant) : (tensor<*x!tf.resource>>, tensor<4xf32>) -> () + "tf.Yield"(%0, %constant) : (tensor<*x!tf.resource>>, tensor<4xf32>) -> () + }, { + // CHECK: "tf.Yield"(%[[READ1]], %[[READ1]]) + %id = "tf.Identity"(%1) : (tensor<*x!tf.resource>>) -> tensor<*x!tf.resource>> + %read = "tf.ReadVariableOp"(%id) : (tensor<*x!tf.resource>>) -> tensor<4xf32> + "tf.AssignVariableOp"(%0, %read) : (tensor<*x!tf.resource>>, tensor<4xf32>) -> () + "tf.Yield"(%0, %read) : (tensor<*x!tf.resource>>, tensor<4xf32>) -> () + }) {is_stateless = false} : (tensor) -> (tensor<*x!tf.resource>>, tensor<4xf32>) + // CHECK: %[[ADD:.*]] = "tf.AddV2"(%[[IF]]#1, %[[IF]]#0) + // CHECK-NEXT: tf_device.return %[[ADD]], %[[IF]]#1 + %4 = "tf.ReadVariableOp"(%3#0) : (tensor<*x!tf.resource>>) -> tensor<4xf32> + %5 = "tf.AddV2"(%4, %3#1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + tf_device.return %5 : tensor<4xf32> + }) {cluster_attr = "cluster_attr"} : () -> tensor<4xf32> + // CHECK: {cluster_attr = "cluster_attr"} : () -> (tensor<4xf32>, tensor<4xf32>) + // CHECK: "tf.AssignVariableOp"(%[[VH0]], %[[CLUSTER]]#1) + // CHECK: return %[[CLUSTER]]#0 + return %2 : tensor<4xf32> +} + +// Test that the pass can lift resources out of CaseRegion +// CHECK: func @cluster_with_caseregion(%[[ARG0:.*]]: tensor) -> tensor<4xf32> +func @cluster_with_caseregion(%arg0: tensor) -> tensor<4xf32> { + // CHECK: %[[VH0:.*]] = "tf.VarHandleOp"() + // CHECK: %[[VH1:.*]] = "tf.VarHandleOp"() + // CHECK: %[[READ1:.*]] = "tf.ReadVariableOp"(%[[VH1]]) + // CHECK: %[[CLUSTER:.*]]:2 = "tf_device.cluster"() + // CHECK: %[[CASE:.*]]:2 = "tf.CaseRegion"(%[[ARG0]]) + %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> + %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v2"} : () -> tensor<*x!tf.resource>> + %2 = "tf_device.cluster"() ( { + %3:2 = "tf.CaseRegion"(%arg0) ({ + // CHECK-NEXT: %[[CONST:.*]] = "tf.Const"() + // CHECK-NEXT: "tf.Yield"(%[[CONST]], %[[CONST]]) + %constant = "tf.Const"() {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> + "tf.AssignVariableOp"(%0, %constant) : (tensor<*x!tf.resource>>, tensor<4xf32>) -> () + "tf.Yield"(%0, %constant) : (tensor<*x!tf.resource>>, tensor<4xf32>) -> () + }, { + // CHECK: "tf.Yield"(%[[READ1]], %[[READ1]]) + %id = "tf.Identity"(%1) : (tensor<*x!tf.resource>>) -> tensor<*x!tf.resource>> + %read = "tf.ReadVariableOp"(%id) : (tensor<*x!tf.resource>>) -> tensor<4xf32> + "tf.AssignVariableOp"(%0, %read) : (tensor<*x!tf.resource>>, tensor<4xf32>) -> () + "tf.Yield"(%0, %read) : (tensor<*x!tf.resource>>, tensor<4xf32>) -> () + }, { + // CHECK: %[[CONST1:.*]] = "tf.Const" + // CHECK: %[[SUB:.*]] = "tf.Sub"(%[[READ1]], %[[CONST1]]) + // CHECK: "tf.Yield"(%[[READ1]], %[[SUB]]) + %id = "tf.Identity"(%1) : (tensor<*x!tf.resource>>) -> tensor<*x!tf.resource>> + %read = "tf.ReadVariableOp"(%id) : (tensor<*x!tf.resource>>) -> tensor<4xf32> + %constant = "tf.Const"() {value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32> + %sub = "tf.Sub"(%read, %constant) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + "tf.AssignVariableOp"(%0, %sub) : (tensor<*x!tf.resource>>, tensor<4xf32>) -> () + "tf.Yield"(%0, %read) : (tensor<*x!tf.resource>>, tensor<4xf32>) -> () + }) {is_stateless = false} : (tensor) -> (tensor<*x!tf.resource>>, tensor<4xf32>) + // CHECK: %[[ADD:.*]] = "tf.AddV2"(%[[CASE]]#1, %[[CASE]]#0) + // CHECK-NEXT: tf_device.return %[[ADD]], %[[CASE]]#1 + %4 = "tf.ReadVariableOp"(%3#0) : (tensor<*x!tf.resource>>) -> tensor<4xf32> + %5 = "tf.AddV2"(%4, %3#1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + tf_device.return %5 : tensor<4xf32> + }) {cluster_attr = "cluster_attr"} : () -> tensor<4xf32> + // CHECK: {cluster_attr = "cluster_attr"} : () -> (tensor<4xf32>, tensor<4xf32>) + // CHECK: "tf.AssignVariableOp"(%[[VH0]], %[[CLUSTER]]#1) + // CHECK: return %[[CLUSTER]]#0 + return %2 : tensor<4xf32> +} + +// ----- + +// Test that the pass can lift resources out of WhileRegion +// CHECK-LABEL: func @cluster_with_whileregion +func @cluster_with_whileregion() -> () { + // CHECK: %[[COUNT:.*]] = "tf.Const"() {value = dense<10> : tensor} + // CHECK: %[[VH:.*]] = "tf.VarHandleOp"() + // CHECK: %[[READ:.*]] = "tf.ReadVariableOp"(%[[VH]]) + // CHECK: %[[CLUSTER:.*]] = "tf_device.cluster"() + // CHECK: %[[WHILE:.*]]:2 = "tf.WhileRegion"(%[[COUNT]], %[[READ]]) + %0 = "tf.Const"() {value = dense<10> : tensor} : () -> tensor + %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> + %unused = "tf.VarHandleOp"() {container = "c", shared_name = "v2"} : () -> tensor<*x!tf.resource>> + "tf_device.cluster"() ( { + %2:3 = "tf.WhileRegion"(%0, %1, %unused) ({ + // CHECK: (%[[CARG0:.+]]: tensor, %[[CARG1:.+]]: tensor): + // CHECK: %[[CAST:.+]] = "tf.Cast"(%[[CARG1]]) + // CHECK: "tf.Less"(%[[CARG0]], %[[CAST]]) + // CHECK: "tf.Yield" + ^bb0(%carg0: tensor, %carg1:tensor<*x!tf.resource>>, %carg2: tensor<*x!tf.resource>>): + %read0 = "tf.ReadVariableOp"(%carg1) : (tensor<*x!tf.resource>>) -> tensor + %cast = "tf.Cast"(%read0) : (tensor) -> tensor + %cond = "tf.Less"(%carg0, %cast) : (tensor, tensor) -> tensor + "tf.Yield"(%cond) : (tensor) -> () + }, { + // CHECK: (%[[BARG0:.+]]: tensor, %[[BARG1:.+]]: tensor): + // CHECK: %[[ADD0:.*]] = "tf.AddV2"(%[[BARG1]], %[[BARG1]]) + // CHECK-NEXT: %[[ADD1:.*]] = "tf.AddV2"(%[[ADD0]], %[[ADD0]]) + // CHECK-NEXT: %[[DELTA:.*]] = "tf.Const"() {value = dense<-1> : tensor} + // CHECK-NEXT: %[[ADD2:.*]] = "tf.AddV2"(%[[BARG0]], %[[DELTA]]) + // CHECK-NEXT: "tf.Yield"(%[[ADD2]], %[[ADD1]]) + ^bb1(%barg0: tensor, %barg1:tensor<*x!tf.resource>>, %barg2: tensor<*x!tf.resource>>): + %read0 = "tf.ReadVariableOp"(%barg1) : (tensor<*x!tf.resource>>) -> tensor + %add0 = "tf.AddV2"(%read0, %read0) : (tensor, tensor) -> tensor + "tf.AssignVariableOp"(%barg1, %add0) : (tensor<*x!tf.resource>>, tensor) -> () + %read1 = "tf.ReadVariableOp"(%barg1) : (tensor<*x!tf.resource>>) -> tensor + %add1 = "tf.AddV2"(%read1, %read1) : (tensor, tensor) -> tensor + "tf.AssignVariableOp"(%barg1, %add1) : (tensor<*x!tf.resource>>, tensor) -> () + %constant = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor + %add2 = "tf.AddV2"(%barg0, %constant) : (tensor, tensor) -> tensor + %id = "tf.Identity"(%barg2) : (tensor<*x!tf.resource>>) -> tensor<*x!tf.resource>> + "tf.Yield"(%add2, %barg1, %id) : (tensor, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>) -> () + }) {device = "", is_stateless = false} + : (tensor, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>) + -> (tensor, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>) + tf_device.return + }) {cluster_attr = "cluster_attr"} : () -> () + // CHECK: tf_device.return %[[WHILE]]#1 : tensor + // CHECK: {cluster_attr = "cluster_attr"} : () -> tensor + // CHECK: "tf.AssignVariableOp"(%[[VH]], %[[CLUSTER]]) + // CHECK: return + return +} + +// ----- + +// Test that the pass can lift out recursively (If with another if it its body) +// CHECK: func @cluster_with_if_within_if(%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) -> tensor<4xf32> +func @cluster_with_if_within_if(%arg0: tensor, %arg1: tensor) -> tensor<4xf32> { + // CHECK: %[[VH0:.*]] = "tf.VarHandleOp"() + // CHECK: %[[VH1:.*]] = "tf.VarHandleOp"() + // CHECK: %[[READ0:.*]] = "tf.ReadVariableOp"(%[[VH0]]) + // CHECK: %[[READ1:.*]] = "tf.ReadVariableOp"(%[[VH1]]) + // CHECK: %[[CLUSTER:.*]]:2 = "tf_device.cluster"() + // CHECK: %[[IF:.*]]:2 = "tf.IfRegion"(%[[ARG0]]) + %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> + %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v2"} : () -> tensor<*x!tf.resource>> + %2 = "tf_device.cluster"() ( { + %3:2 = "tf.IfRegion"(%arg0) ({ + // CHECK-NEXT: %[[CONST:.*]] = "tf.Const"() + // CHECK-NEXT: "tf.Yield"(%[[CONST]], %[[CONST]]) + %constant = "tf.Const"() {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> + "tf.AssignVariableOp"(%0, %constant) : (tensor<*x!tf.resource>>, tensor<4xf32>) -> () + "tf.Yield"(%0, %constant) : (tensor<*x!tf.resource>>, tensor<4xf32>) -> () + }, { + // CHECK: %[[IF1:.*]] = "tf.IfRegion" + // CHECK: "tf.Yield"(%[[READ1]]) + // CHECK: "tf.Yield"(%[[READ0]]) + // CHECK: "tf.Yield"(%[[IF1]], %[[IF1]]) + %id = "tf.Identity"(%1) : (tensor<*x!tf.resource>>) -> tensor<*x!tf.resource>> + %read = "tf.IfRegion"(%arg1) ({ + %read_then = "tf.ReadVariableOp"(%id) : (tensor<*x!tf.resource>>) -> tensor<4xf32> + "tf.Yield"(%read_then) : (tensor<4xf32>) -> () + }, { + %read_else = "tf.ReadVariableOp"(%0) : (tensor<*x!tf.resource>>) -> tensor<4xf32> + "tf.Yield"(%read_else) : (tensor<4xf32>) -> () + }) {is_stateless = false} : (tensor) -> tensor<4xf32> + "tf.AssignVariableOp"(%0, %read) : (tensor<*x!tf.resource>>, tensor<4xf32>) -> () + "tf.Yield"(%0, %read) : (tensor<*x!tf.resource>>, tensor<4xf32>) -> () + }) {is_stateless = false} : (tensor) -> (tensor<*x!tf.resource>>, tensor<4xf32>) + // CHECK: %[[ADD:.*]] = "tf.AddV2"(%[[IF]]#1, %[[IF]]#0) + // CHECK-NEXT: tf_device.return %[[ADD]], %[[IF]]#1 + %4 = "tf.ReadVariableOp"(%3#0) : (tensor<*x!tf.resource>>) -> tensor<4xf32> + %5 = "tf.AddV2"(%4, %3#1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + tf_device.return %5 : tensor<4xf32> + }) {cluster_attr = "cluster_attr"} : () -> tensor<4xf32> + // CHECK: {cluster_attr = "cluster_attr"} : () -> (tensor<4xf32>, tensor<4xf32>) + // CHECK: "tf.AssignVariableOp"(%[[VH0]], %[[CLUSTER]]#1) + // CHECK: return %[[CLUSTER]]#0 + return %2 : tensor<4xf32> +} + +// ----- + +// IfRegion with store in just one branch + +// CHECK: func @if_region_with_store_in_then(%[[ARG0:.*]]: tensor) +func @if_region_with_store_in_then(%arg0: tensor) { + // CHECK: %[[VH:.*]] = "tf.VarHandleOp"() + // CHECK: %[[READ:.*]] = "tf.ReadVariableOp"(%[[VH]]) + // CHECK: %[[CLUSTER:.*]] = "tf_device.cluster"() + // CHECK: %[[IF:.*]] = "tf.IfRegion"(%[[ARG0]]) + %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> + "tf_device.cluster"() ({ + "tf.IfRegion"(%arg0) ({ + // CHECK: %[[CONST:.*]] = "tf.Const"() {value = dense<0.000000e+00> + // CHECK: "tf.Yield"(%[[CONST]]) + %constant = "tf.Const"() {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> + "tf.AssignVariableOp"(%0, %constant) : (tensor<*x!tf.resource>>, tensor<4xf32>) -> () + "tf.Yield"() : () -> () + }, { + // CHECK: "tf.Yield"(%[[READ]]) + "tf.Yield"() : () -> () + }) { is_stateless = true} : (tensor) -> () + tf_device.return + }) { cluster_attr = "cluster_attr" } : () -> () + // CHECK: tf_device.return %[[IF]] + // CHECK: "tf.AssignVariableOp"(%[[VH0]], %[[CLUSTER]]) + return +} + +// ----- + +// IfRegion with store in both branches + +// CHECK: func @if_region_with_store_in_both(%[[ARG0:.*]]: tensor) +func @if_region_with_store_in_both(%arg0: tensor) { + // CHECK: %[[VH:.*]] = "tf.VarHandleOp"() + // CHECK: %[[CLUSTER:.*]] = "tf_device.cluster"() + // CHECK: %[[IF:.*]] = "tf.IfRegion"(%[[ARG0]]) + %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> + "tf_device.cluster"() ({ + "tf.IfRegion"(%arg0) ({ + // CHECK: %[[CONST:.*]] = "tf.Const"() {value = dense<0.000000e+00> + // CHECK: "tf.Yield"(%[[CONST]]) + %constant = "tf.Const"() {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> + "tf.AssignVariableOp"(%0, %constant) : (tensor<*x!tf.resource>>, tensor<4xf32>) -> () + "tf.Yield"() : () -> () + }, { + // CHECK: %[[CONST:.*]] = "tf.Const"() {value = dense<1.000000e+00> + // CHECK: "tf.Yield"(%[[CONST]]) + %constant = "tf.Const"() {value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32> + "tf.AssignVariableOp"(%0, %constant) : (tensor<*x!tf.resource>>, tensor<4xf32>) -> () + "tf.Yield"() : () -> () + }) { is_stateless = true} : (tensor) -> () + tf_device.return + }) { cluster_attr = "cluster_attr" } : () -> () + // CHECK: tf_device.return %[[IF]] + // CHECK: "tf.AssignVariableOp"(%[[VH0]], %[[CLUSTER]]) + return +} + + +// Make sure unsupported resources are handled correctly. If a resource is used +// in an unsupported op, resource op lifting should skip lifting that resource. +// So for the below test, the IR should stay unchanged. +// CHECK-LABEL: func @test_unsupported_resource_op +func @test_unsupported_resource_op() -> tensor<*xi32> { + // CHECK: "tf.VarHandleOp" + // CHECK: "tf_device.cluster"() ( { + // CHECK: "tf.ReadVariableOp" + // CHECK: "tf.SomeResourceOperation" + // CHECK: "tf.SomeComputation" + // CHECK: tf_device.return + // CHECK: {cluster_attr = "cluster_attr"} + // CHECK: return + %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource> + %1 = "tf_device.cluster"() ( { + %2 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource>) -> tensor<*xi32> + "tf.SomeResourceOperation"(%0) : (tensor<*x!tf.resource>) -> () + %3 = "tf.SomeComputation"(%2) : (tensor<*xi32>) -> (tensor<*xi32>) + tf_device.return %3 : tensor<*xi32> + }) {cluster_attr = "cluster_attr"} : () -> tensor<*xi32> + + return %1 : tensor<*xi32> +} + +// Test unsupported use of resource ops in functional control flow. In the test +// below, arg0 has an unsupported use whereas arg1 does not. So we expect arg0 +// to not be lifted and arg1 to be lifted. +// CHECK-LABEL: func @test_unsupported_resource_op_in_if +func @test_unsupported_resource_op_in_if(%arg0: tensor) -> tensor<*xi32> { + // CHECK: [[VH0:%.*]] = "tf.VarHandleOp"() {container = "c", shared_name = "v"} + // CHECK: [[VH1:%.*]] = "tf.VarHandleOp"() {container = "d", shared_name = "w"} + // CHECK-NOT: "tf.ReadVariableOp"([[VH0]]) + // CHECK: [[READ1:%.*]] = "tf.ReadVariableOp"([[VH1]]) + // CHECK-NOT: "tf.ReadVariableOp"([[VH0]]) + // CHECK: "tf_device.cluster"() ( { + // CHECK: "tf.If"({{%.*}}, [[VH0]], [[READ1]]) + // CHECK-SAME: else_branch = @else_fn, is_stateless = true, then_branch = @then_fn + // CHECK: tf_device.return + // CHECK: return + %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource> + %1 = "tf.VarHandleOp"() {container = "d", shared_name = "w"} : () -> tensor<*x!tf.resource> + %2 = "tf_device.cluster"() ( { + %3 = "tf.If"(%arg0, %0, %1) + { else_branch = @else_fn, then_branch = @then_fn, is_stateless = true} + : (tensor, tensor<*x!tf.resource>, tensor<*x!tf.resource>) -> tensor<*xi32> + tf_device.return %3 : tensor<*xi32> + }) {cluster_attr = "cluster_attr"} : () -> tensor<*xi32> + return %2 : tensor<*xi32> +} + +// CHECK-LABEL: func @else_fn +// CHECK-SAME: (%{{.*}}: tensor<*x!tf.resource>, %{{.*}}: tensor<*xi32>) +func @else_fn(%arg0: tensor<*x!tf.resource>, %arg1: tensor<*x!tf.resource>) -> tensor<*xi32> { + %0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource>) -> tensor<*xi32> + %1 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf.resource>) -> tensor<*xi32> + %2 = "tf.Add"(%0, %1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> + return %2 : tensor<*xi32> +} + +// CHECK-LABEL: func @then_fn +// CHECK-SAME: (%{{.*}}: tensor<*x!tf.resource>, %{{.*}}: tensor<*xi32>) +func @then_fn(%arg0: tensor<*x!tf.resource>, %arg1: tensor<*x!tf.resource>) -> tensor<*xi32> { + %0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource>) -> tensor<*xi32> + %1 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf.resource>) -> tensor<*xi32> + %2 = "tf.Add"(%0, %1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> + "tf.UnsupportedResourceOp"(%arg0) : (tensor<*x!tf.resource>) -> () + return %2 : tensor<*xi32> +} + +// Test type refinement. If the resource has a single subtype, check that that +// type gets used when hoisting the read. None of the result types will change. +// CHECK-LABEL: func @type_refinement_use_subtype +func @type_refinement_use_subtype() -> tensor<*xi32> { + + // CHECK: %[[RES_HANDLE:[0-9]*]] = "tf.VarHandleOp" + %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> + + // CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]]) + // CHECK-SAME: -> tensor<4xi32> + // CHECK: %[[CLUSTER_RES:[0-9]*]]:2 = "tf_device.cluster" + // CHECK: %[[COMPUTE_RES:[0-9]*]] = "tf.SomeComputation"(%[[RES_READ_VAL]]) : (tensor<4xi32>) -> tensor<*xi32> + // CHECK: tf_device.return %[[COMPUTE_RES]], %[[COMPUTE_RES]] + // CHECK-SAME: tensor<*xi32>, tensor<*xi32> + // CHECK: {cluster_attr = "cluster_attr"} + // CHECK-SAME: () -> (tensor<*xi32>, tensor<*xi32>) + // CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[CLUSTER_RES]]#1) + + %1 = "tf_device.cluster"() ( { + %2 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource>>) -> tensor<*xi32> + %3 = "tf.SomeComputation"(%2) : (tensor<*xi32>) -> (tensor<*xi32>) + "tf.AssignVariableOp"(%0, %3) {dtype = i32} : (tensor<*x!tf.resource>>, tensor<*xi32>) -> () + tf_device.return %3 : tensor<*xi32> + }) {cluster_attr = "cluster_attr"} : () -> tensor<*xi32> + + // CHECK: return %[[CLUSTER_RES]]#0 + // CHECK-SAME: tensor<*xi32> + return %1 : tensor<*xi32> +} + +// If multiple types are used across reads and writes, check that the read uses +// the most refined type. The first ReadVariable should refine the type from +// *xi32 to ?xi32 and the assign should refine it further to 4xi32. +// CHECK-LABEL: func @type_refinement_use_refined_type +func @type_refinement_use_refined_type() -> tensor<4xi32> { + + // CHECK: %[[RES_HANDLE:[0-9]*]] = "tf.VarHandleOp" + %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> + + // CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]]) + // CHECK-SAME: -> tensor<4xi32> + // CHECK: %[[CLUSTER_RES:[0-9]*]]:2 = "tf_device.cluster" + // CHECK: %[[COMPUTE_RES:[0-9]*]] = "tf.SomeComputation"(%[[RES_READ_VAL]]) : (tensor<4xi32>) -> tensor<4xi32> + // CHECK: tf_device.return %[[COMPUTE_RES]], %[[COMPUTE_RES]] + // CHECK-SAME: tensor<4xi32>, tensor<4xi32> + // CHECK: {cluster_attr = "cluster_attr"} + // CHECK-SAME: () -> (tensor<4xi32>, tensor<4xi32>) + // CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[CLUSTER_RES]]#1) + + %1 = "tf_device.cluster"() ( { + %2 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource>>) -> tensor + %3 = "tf.SomeComputation"(%2) : (tensor) -> (tensor<4xi32>) + "tf.AssignVariableOp"(%0, %3) {dtype = i32} : (tensor<*x!tf.resource>>, tensor<4xi32>) -> () + tf_device.return %3 : tensor<4xi32> + }) {cluster_attr = "cluster_attr"} : () -> tensor<4xi32> + + // CHECK: return %[[CLUSTER_RES]]#0 + // CHECK-SAME: tensor<4xi32> + return %1 : tensor<4xi32> +} + +// ----- + +!tf_res = type tensor<*x!tf.resource>> + +// Test all tf.VarIsInitializedOp's are set to true. +// CHECK-LABEL: func @tpu_computation +func @tpu_computation(%arg0: !tf_res, %arg1: tensor, %arg2: tensor) { + %0 = "tf_device.cluster"() ( { + %1 = "tf.Case"(%arg2, %arg0) {branches = [@case_branch], is_stateless = false} : (tensor, !tf_res) -> tensor + + // CHECK: "tf.CaseRegion" + %2 = "tf.CaseRegion"(%arg2) ( { + // CHECK-NEXT: [[CASE_REGION_BRANCH:%.+]] = "tf.Const" + // CHECK-SAME: value = dense : tensor + %3 = "tf.VarIsInitializedOp"(%arg0) : (!tf_res) -> tensor + // CHECK-NEXT: "tf.Yield"([[CASE_REGION_BRANCH]]) + "tf.Yield"(%3) : (tensor) -> () + }) {is_stateless = false} : (tensor) -> tensor + + %4 = "tf.If"(%arg1, %arg0) {then_branch = @if_then, else_branch = @if_else, is_stateless = false} : (tensor, !tf_res) -> tensor + + // CHECK: "tf.IfRegion" + %5 = "tf.IfRegion"(%arg1) ( { + // CHECK-NEXT: [[IF_REGION_THEN:%.+]] = "tf.Const" + // CHECK-SAME: value = dense : tensor + %6 = "tf.VarIsInitializedOp"(%arg0) : (!tf_res) -> tensor + // CHECK-NEXT: "tf.Yield"([[IF_REGION_THEN]]) + "tf.Yield"(%6) : (tensor) -> () + // CHECK-NEXT: }, { + }, { + // CHECK-NEXT: [[IF_REGION_ELSE:%.+]] = "tf.Const" + // CHECK-SAME: value = dense : tensor + %7 = "tf.VarIsInitializedOp"(%arg0) : (!tf_res) -> tensor + // CHECK-NEXT: "tf.Yield"([[IF_REGION_ELSE]]) + "tf.Yield"(%7) : (tensor) -> () + }) {is_stateless = false} : (tensor) -> tensor + + %8:2 = "tf.While"(%arg0, %arg1) {body = @while_body, cond = @while_cond, is_stateless = false} : (!tf_res, tensor) -> (!tf_res, tensor) + + // CHECK: "tf.WhileRegion" + %9 = "tf.WhileRegion"(%arg1) ( { + // CHECK-NEXT: ^{{.+}}({{.+}}: tensor): + ^cond(%carg0: tensor): + // CHECK-NEXT: [[WHILE_REGION_COND:%.+]] = "tf.Const" + // CHECK-SAME: value = dense : tensor + %10 = "tf.VarIsInitializedOp"(%arg0) : (!tf_res) -> tensor + // CHECK-NEXT: "tf.Yield"([[WHILE_REGION_COND]]) + "tf.Yield"(%10) : (tensor) -> () + // CHECK-NEXT: }, { + }, { + // CHECK-NEXT: ^{{.+}}({{.+}}: tensor): + ^body(%barg0: tensor): + // CHECK-NEXT: [[WHILE_REGION_BODY:%.+]] = "tf.Const" + // CHECK-SAME: value = dense : tensor + %11 = "tf.VarIsInitializedOp"(%arg0) : (!tf_res) -> tensor + // CHECK-NEXT: "tf.Yield"([[WHILE_REGION_BODY]]) + "tf.Yield"(%11) : (tensor) -> () + }) {is_stateless = false} : (tensor) -> tensor + + %12 = "tf.StatefulPartitionedCall"(%arg0) {f = @callee, config = "", config_proto = "", executor_type = ""} : (!tf_res) -> tensor + + // CHECK: [[TRUE:%.+]] = "tf.Const" + // CHECK-SAME: value = dense : tensor + %13 = "tf.VarIsInitializedOp"(%arg0) : (!tf_res) -> tensor + + // CHECK: tf_device.return [[TRUE]] : + tf_device.return %13 : tensor + }) : () -> tensor + return +} + +// CHECK-LABEL: func @case_branch +func @case_branch(%arg0: !tf_res) -> tensor { + // CHECK: [[TRUE:%.+]] = "tf.Const" + // CHECK-SAME: value = dense : tensor + %0 = "tf.VarIsInitializedOp"(%arg0) : (!tf_res) -> tensor + // CHECK-NEXT: return [[TRUE]] : + return %0 : tensor +} + +// CHECK-LABEL: func @if_then +func @if_then(%arg0: !tf_res) -> tensor { + // CHECK: [[TRUE:%.+]] = "tf.Const" + // CHECK-SAME: value = dense : tensor + %0 = "tf.VarIsInitializedOp"(%arg0) : (!tf_res) -> tensor + // CHECK-NEXT: return [[TRUE]] : + return %0 : tensor +} + +// CHECK-LABEL: func @if_else +func @if_else(%arg0: !tf_res) -> tensor { + // CHECK: [[TRUE:%.+]] = "tf.Const" + // CHECK-SAME: value = dense : tensor + %0 = "tf.VarIsInitializedOp"(%arg0) : (!tf_res) -> tensor + // CHECK-NEXT: return [[TRUE]] : + return %0 : tensor +} + +// CHECK-LABEL: func @while_cond +// CHECK-SAME: ({{.+}}: tensor) +func @while_cond(%arg0: !tf_res, %arg1: tensor) -> tensor { + // CHECK: [[TRUE:%.+]] = "tf.Const" + // CHECK-SAME: value = dense : tensor + %0 = "tf.VarIsInitializedOp"(%arg0) : (!tf_res) -> tensor + // CHECK-NEXT: return [[TRUE]] : + return %0 : tensor +} + +// CHECK-LABEL: func @while_body +// CHECK-SAME: ({{.+}}: tensor) +func @while_body(%arg0: !tf_res, %arg1: tensor) -> (!tf_res, tensor) { + // CHECK: [[TRUE:%.+]] = "tf.Const" + // CHECK-SAME: value = dense : tensor + %0 = "tf.VarIsInitializedOp"(%arg0) : (!tf_res) -> tensor + // CHECK-NEXT: return [[TRUE]] : + return %arg0, %0 : !tf_res, tensor +} + +// CHECK-LABEL: func @callee +func @callee(%arg0: !tf_res) -> tensor { + // CHECK: [[TRUE:%.+]] = "tf.Const" + // CHECK-SAME: value = dense : tensor + %0 = "tf.VarIsInitializedOp"(%arg0) : (!tf_res) -> tensor + // CHECK-NEXT: return [[TRUE]] : + return %0 : tensor +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir index 3e613573d42..428af91f155 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir @@ -530,6 +530,21 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr return %3#0, %3#1 : tensor<*xf32>, tensor<*xf32> } + // CHECK-LABEL: infer_device_cluster + func @infer_device_cluster(%arg0: tensor<1x8x2xi32>) -> (tensor<*xf32>, tensor<*xf32>) { + %0 = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor + %1 = "tf_device.cluster"() ({ + %2 = "tf.Cast"(%arg0) {Truncate = false} : (tensor<1x8x2xi32>) -> tensor<1x8x2xf32> + tf_device.return %2 : tensor<1x8x2xf32> + // CHECK: () -> tensor<1x8x2xf32> + }) : () -> tensor<*xf32> + // CHECK: "tf.Cast"(%{{.*}}) {Truncate = false} : (tensor<1x8x2xf32>) -> tensor<*xf32> + // CHECK: (tensor, tensor<1x8x2xf32>) -> (tensor<1x8x1xf32>, tensor<1x8x1xf32>) + %3:2 = "tf.Split"(%0, %1) {device = ""} : (tensor, tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) + %4 = addf %1, %1 : tensor<*xf32> + return %3#0, %3#1 : tensor<*xf32>, tensor<*xf32> + } + // CHECK-LABEL: func @tensor_cast(%arg0: tensor<1xi32>) -> tensor<1xi32> func @tensor_cast(%arg0: tensor<1xi32>) -> tensor<*xi32> { // CHECK: %[[RESULT:.*]] = tensor_cast @@ -560,4 +575,15 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr func @pcall_resource_result_func(%arg0: tensor<*x!tf.resource>>) -> tensor<*x!tf.resource>> { return %arg0 : tensor<*x!tf.resource>> } + + // Check that the fold for tf.Size does not crash with unranked output type. + // CHECK-LABEL: func @unranked_tf_size + func @unranked_tf_size() -> tensor<*xi32> { + %0 = "tf.Const"() {value = dense<[-1, 26]> : tensor<2xi32>} : () -> tensor<2xi32> + %add = "tf.AddV2"(%0, %0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<*xi32> + // CHECK: "tf.Size" + // CHECK-SAME: (tensor<2xi32>) -> tensor + %size = "tf.Size"(%add) {device = ""} : (tensor<*xi32>) -> tensor<*xi32> + return %size : tensor<*xi32> + } } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/stack_ops_decomposition.mlir b/tensorflow/compiler/mlir/tensorflow/tests/stack_ops_decomposition.mlir index e4fdad2eddb..17329050f3e 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/stack_ops_decomposition.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/stack_ops_decomposition.mlir @@ -122,6 +122,108 @@ func @while_cond(%arg0: tensor, %arg1: tensor) -> tensor // ----- +// Tests WhileRegion Op. + +// CHECK-LABEL: func @main() +func @main() -> () { + %max_size = "tf.Const"() {value = dense<10> : tensor} : () -> tensor + // CHECK-NOT: tf.Stack + // CHECK: %[[BUFFER:.*]] = "tf.MlirLocalVarOp"() : () -> tensor>> + // CHECK: %[[SIZE:.*]] = "tf.MlirLocalVarOp"() : () -> tensor>> + // CHECK: tf.AssignVariableOp + // CHECK: tf.AssignVariableOp + %stack = "tf.StackV2"(%max_size) {elem_type = f32, stack_name = "s"} : (tensor) -> tensor + // CHECK: tf.WhileRegion + %while = "tf.WhileRegion"(%max_size) ({ + // CHECK: ^bb0(%[[BARG0:.*]]: tensor + ^bb0(%barg0: tensor): + // CHECK: "tf._SomeOp"(%[[BARG0]]) + %pred = "tf._SomeOp"(%barg0) : (tensor) -> tensor + "tf.Yield"(%pred) : (tensor) -> () + }, { + // CHECK: ^bb0(%[[BARG0:.*]]: tensor + ^bb0(%barg0: tensor): + // CHECK: %[[CONST1:.*]] = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %const1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + // CHECK: %[[SUB:.*]] = "tf.Sub"(%[[BARG0]], %[[CONST1]]) + %sub = "tf.Sub"(%barg0, %const1) : (tensor, tensor) -> tensor + %elem = "tf._SomeOp"() : () -> tensor + // CHECK-NOT: "tf.StackPushV2" + // CHECK: %[[BUFFER_VAL:.*]] = "tf.ReadVariableOp"(%[[BUFFER]]) + // CHECK: %[[SIZE_VAL:.*]] = "tf.ReadVariableOp"(%[[SIZE]]) + // CHECK: %[[UPDATE:.*]] = "tf.XlaDynamicUpdateSlice"(%[[BUFFER_VAL]] + // CHECK: "tf.AssignVariableOp"(%[[BUFFER]], %[[UPDATE]]) + // CHECK: "tf.AssignVariableOp"(%[[SIZE]] + // CHECK-NOT: "tf.StackPushV2" + %push = "tf.StackPushV2"(%stack, %elem) {swap_memory = false} : (tensor, tensor) -> tensor + // CHECK: "tf.Yield"(%[[SUB]]) + "tf.Yield"(%sub) : (tensor) -> () + }) {is_stateless = false} + : (tensor) -> tensor + // CHECK-NOT: tf.StackPopV2 + // CHECK: %[[BUFFER_VAL:.*]] = "tf.ReadVariableOp"(%[[BUFFER]]) + // CHECK: %[[SIZE_VAL:.*]] = "tf.ReadVariableOp"(%[[SIZE]]) + // CHECK: %[[POP_VAL:.*]] = "tf.Slice"(%[[BUFFER_VAL]] + // CHECK: "tf.AssignVariableOp"(%[[SIZE]] + %pop = "tf.StackPopV2"(%stack) : (tensor) -> tensor + // CHECK-NOT: tf.StackCloseV2 + "tf.StackCloseV2"(%stack) : (tensor) -> () + return +} + +// ----- + +// Test CaseRegionOp + +// CHECK-LABEL: func @main +// CHECK-SAME: %[[BRANCH_INDEX:.*]]: tensor +func @main(%arg0: tensor) -> () { + %max_size = "tf.Const"() {value = dense<10> : tensor} : () -> tensor + // CHECK-NOT: tf.StackV2 + // CHECK: %[[BUFFER:.*]] = "tf.MlirLocalVarOp"() : () -> tensor>> + // CHECK: %[[SIZE:.*]] = "tf.MlirLocalVarOp"() : () -> tensor>> + // CHECK: tf.AssignVariableOp + // CHECK: tf.AssignVariableOp + %stack = "tf.StackV2"(%max_size) {elem_type = f32, stack_name = "s"} : (tensor) -> tensor + // CHECK: %[[CASE_OUTPUT:.*]] = "tf.CaseRegion"(%[[BRANCH_INDEX]]) ( { + %case_op = "tf.CaseRegion"(%arg0) ({ + %elem = "tf._SomeOp"() : () -> tensor + // CHECK-NOT: tf.StackPushV2 + // CHECK: %[[BUFFER_VAL:.*]] = "tf.ReadVariableOp"(%[[BUFFER]]) + // CHECK: %[[SIZE_VAL:.*]] = "tf.ReadVariableOp"(%[[SIZE]]) + // CHECK: %[[UPDATE:.*]] = "tf.XlaDynamicUpdateSlice"(%[[BUFFER_VAL]] + // CHECK: "tf.AssignVariableOp"(%[[BUFFER]], %[[UPDATE]]) + // CHECK: "tf.AssignVariableOp"(%[[SIZE]] + %push = "tf.StackPushV2"(%stack, %elem) {swap_memory = false} : (tensor, tensor) -> tensor + "tf.Yield"(%elem) : (tensor) -> () + }, { + %elem = "tf._SomeOtherOp"() : () -> tensor + // CHECK-NOT: tf.StackPushV2 + // CHECK: %[[BUFFER_VAL:.*]] = "tf.ReadVariableOp"(%[[BUFFER]]) + // CHECK: %[[SIZE_VAL:.*]] = "tf.ReadVariableOp"(%[[SIZE]]) + // CHECK: %[[UPDATE:.*]] = "tf.XlaDynamicUpdateSlice"(%[[BUFFER_VAL]] + // CHECK: "tf.AssignVariableOp"(%[[BUFFER]], %[[UPDATE]]) + // CHECK: "tf.AssignVariableOp"(%[[SIZE]] + %push = "tf.StackPushV2"(%stack, %elem) {swap_memory = false} : (tensor, tensor) -> tensor + "tf.Yield"(%elem) : (tensor) -> () + }, { + // CHECK-NOT: tf.StackPopV2 + // CHECK: %[[BUFFER_VAL:.*]] = "tf.ReadVariableOp"(%[[BUFFER]]) + // CHECK: %[[SIZE_VAL:.*]] = "tf.ReadVariableOp"(%[[SIZE]]) + // CHECK: %[[POP_VAL:.*]] = "tf.Slice"(%[[BUFFER_VAL]] + // CHECK: "tf.AssignVariableOp"(%[[SIZE]] + %pop = "tf.StackPopV2"(%stack) : (tensor) -> tensor + "tf.Yield"(%pop) : (tensor) -> () + }) {is_stateless = false} + : (tensor) -> tensor + // CHECK-NOT: tf.StackPopV2 + %pop = "tf.StackPopV2"(%stack) : (tensor) -> tensor + // CHECK-NOT: tf.StackCloseV2 + "tf.StackCloseV2"(%stack) : (tensor) -> () + return +} + +// ----- // Tests IfOp. // CHECK-LABEL: func @main @@ -308,3 +410,53 @@ func @if_else(%arg0: tensor, %arg1: tensor) -> tenso %push = "tf.StackPushV2"(%arg1, %elem) {swap_memory = false} : (tensor, tensor) -> tensor return %arg1 : tensor } + +// ----- + +// Tests that the pass returns meaningful error message when WhileRegion op has +// resource arguments. +func @main() -> () { + %max_size = "tf.Const"() {value = dense<10> : tensor} : () -> tensor + %stack = "tf.StackV2"(%max_size) {elem_type = f32, stack_name = "s"} : (tensor) -> tensor + %elem = "tf._SomeOp"() : () -> tensor + %push_0 = "tf.StackPushV2"(%stack, %elem) {swap_memory = false} : (tensor, tensor) -> tensor + // expected-error @+1 {{found unexpected type 'tensor>>' of operand #0, resource type operands are expected to have been canonicalized away for region based control flow ops}} + %1:2 = "tf.WhileRegion"(%stack, %max_size) ({ + ^bb0 (%carg0: tensor, %carg1: tensor): + %pred = "tf._SomeOp"(%carg1) : (tensor) -> tensor + "tf.Yield"(%pred) : (tensor) -> () + }, { + ^bb0 (%carg0: tensor, %carg1: tensor): + %const1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %sub = "tf.Sub"(%carg1, %const1) : (tensor, tensor) -> tensor + %push_1 = "tf.StackPushV2"(%carg0, %elem) {swap_memory = false} : (tensor, tensor) -> tensor + "tf.Yield"(%carg0, %sub) : (tensor, tensor) -> () + }) {is_stateless = false} + : (tensor, tensor) -> (tensor, tensor) + %pop = "tf.StackPopV2"(%1#0) : (tensor) -> tensor + "tf.StackCloseV2"(%stack) : (tensor) -> () + return +} + +// ----- + +// Tests that the pass returns meaningful error message when IfRegion op has +// resource returns. + +func @main(%arg0: tensor) -> () { + %max_size = "tf.Const"() {value = dense<10> : tensor} : () -> tensor + %stack = "tf.StackV2"(%max_size) {elem_type = f32, stack_name = "s"} : (tensor) -> tensor + // expected-error @+1 {{found unexpected type 'tensor' of result #0, resource type results are expected to have been canonicalized away for region based control flow ops}} + %if_op = "tf.IfRegion"(%arg0) ({ + %elem = "tf._SomeOp"() : () -> tensor + %push = "tf.StackPushV2"(%stack, %elem) {swap_memory = false} : (tensor, tensor) -> tensor + "tf.Yield"(%stack) : (tensor) -> () + }, { + %pop = "tf.StackPopV2"(%stack) : (tensor) -> tensor + "tf.Yield"(%stack) : (tensor) -> () + }) {is_stateless = false} + : (tensor) -> tensor + %pop = "tf.StackPopV2"(%if_op) : (tensor) -> tensor + "tf.StackCloseV2"(%stack) : (tensor) -> () + return +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tensor_array_ops_decomposition.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tensor_array_ops_decomposition.mlir index b65e88c589a..0c4dc77cf69 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tensor_array_ops_decomposition.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tensor_array_ops_decomposition.mlir @@ -54,6 +54,22 @@ func @main() -> tensor { // ----- +// Test inferring shape from the first scatter. + +// CHECK-LABEL: func @main +func @main() -> tensor { + %size = "tf.Const"() {value = dense<5> : tensor} : () -> tensor + // CHECK: %[[VAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor>> + %ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = #tf.shape<*>, dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor) -> (tensor, tensor) + %indices = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<2xi32> + %values = "tf.Const"() {value = dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %write = "tf.TensorArrayScatterV3"(%ta#0, %indices, %values, %ta#1) : (tensor, tensor<2xi32>, tensor<2x3xf32>, tensor) -> tensor + %size_out = "tf.TensorArraySizeV3"(%ta#0, %write) : (tensor, tensor) -> tensor + return %size_out : tensor +} + +// ----- + // Test tensor array concat and split. // CHECK-LABEL: func @main @@ -259,6 +275,13 @@ func @main() -> () { // CHECK: %[[READ:.*]] = "tf.ReadVariableOp"(%[[VAR]]) : (tensor>>) -> tensor<5x3xf32> // CHECK: "tf.Slice"(%[[READ]], %read = "tf.TensorArrayReadV3"(%1, %index, %ta#1) : (tensor, tensor, tensor) -> tensor<3xf32> + // CHECK: %[[READ_GVAR1:.*]] = "tf.ReadVariableOp"(%[[GVAR1]]) + // CHECK: %[[UPDATE:.*]] = "tf.XlaDynamicUpdateSlice"(%[[READ_GVAR1]], + // CHECK: "tf.AssignVariableOp"(%[[GVAR1]], %[[UPDATE]]) + %const = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %elem = "tf._SomeOp"() : () -> tensor<3xf32> + %grad:2 = "tf.TensorArrayGradV3"(%ta#0, %ta#1) {source = "a"} : (tensor, tensor) -> (tensor, tensor) + %gwrite = "tf.TensorArrayWriteV3"(%grad#0, %const, %elem, %grad#1) : (tensor, tensor, tensor<3xf32>, tensor) -> tensor return } // CHECK: func @then_branch(%[[TARG0:.*]]: tensor>>, %[[TARG1:.*]]: tensor>>, %[[TARG2:.*]]: tensor>>) @@ -412,6 +435,32 @@ func @callee() -> tensor attributes {sym_visibility = "public"} { // ----- +// CHECK-LABEL: func @main +func @main() -> () { + // CHECK: "tf.PartitionedCall"() {config = "", config_proto = "", executor_type = "", f = @callee} : () -> tensor<*xf32> + %call = "tf.PartitionedCall"() {config = "", config_proto = "", executor_type = "", f = @callee} : () -> (tensor<*xf32>) + return +} +func @callee() -> (tensor<*xf32>) attributes {sym_visibility = "private"} { + %size = "tf.Const"() {value = dense<5> : tensor} : () -> tensor + // CHECK: %[[LOCAL_VAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor>> + %ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = #tf.shape<*>, dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor) -> (tensor>>, tensor) + %index = "tf.Const"() {value = dense<2> : tensor} : () -> tensor + %value = "tf.Const"() {value = dense<[1.0, 2.0, 3.0]> : tensor<3xf32>} : () -> tensor<3xf32> + // CHECK: %[[UPDATE:.*]] = "tf.XlaDynamicUpdateSlice" + // CHECK: "tf.AssignVariableOp"(%[[LOCAL_VAR]], %[[UPDATE]]) : (tensor>>, tensor<5x3xf32>) -> () + %flow = "tf.TensorArrayWriteV3"(%ta#0, %index, %value, %ta#1) : (tensor>>, tensor, tensor<3xf32>, tensor) -> tensor + // CHECK: %[[SLICE:.*]] = "tf.Slice" + // CHECK: %[[ELEM_SHAPE:.*]] = "tf.Const"() {value = dense<3> : tensor<1xi32>} + // CHECK: %[[ELEM:.*]] = "tf.Reshape"(%[[SLICE]], %[[ELEM_SHAPE]]) + %val = "tf.TensorArrayReadV3"(%ta#0, %index, %ta#1) : (tensor>>, tensor, tensor) -> tensor<*xf32> + // CHECK: %[[CAST:.*]] = tensor_cast %[[ELEM]] : tensor<3xf32> to tensor<*xf32> + // CHECK: return %[[CAST]] : tensor<*xf32> + return %val : tensor<*xf32> +} + +// ----- + // Test the pass reports failure on unknown size. func @main(%arg0: tensor) -> () { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tensor_list_ops_decomposition.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tensor_list_ops_decomposition.mlir index 92cb0458bf9..09a2dcb6713 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tensor_list_ops_decomposition.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tensor_list_ops_decomposition.mlir @@ -34,7 +34,7 @@ func @main() -> (tensor, tensor) { // CHECK-NEXT: %[[SCALAR_SHAPE:.*]] = "tf.Const"() {value = dense<> : tensor<0xi32>} // CHECK-NEXT: %[[LENGTH:.*]] = "tf.Reshape"(%[[NEW_SIZE]], %[[SCALAR_SHAPE]]) %length = "tf.TensorListLength"(%push) : (tensor>>) -> tensor - // CHECK-NEXT: return %[[ELEM]], %[[LENGTH]] : tensor, tensor + // CHECK-NEXT: return %[[ELEM]], %[[LENGTH]] : tensor, tensor return %pop#1, %length: tensor, tensor } @@ -81,7 +81,7 @@ func @main(%arg0: tensor) -> (tensor, tensor<10xf32>, tensor) { %stack = "tf.TensorListStack"(%addn2, %elem_shape) : (tensor>>, tensor<0xi32>) -> tensor<10xf32> // CHECK-NEXT: %[[LEN:.*]] = "tf.Const"() {value = dense<10> : tensor} : () -> tensor %length = "tf.TensorListLength"(%addn2) : (tensor>>) -> tensor - // CHECK-NEXT: return %[[ELEM]], %[[ADDN2]], %[[LEN]] : tensor, tensor<10xf32>, tensor + // CHECK-NEXT: return %[[ELEM]], %[[ADDN2]], %[[LEN]] : tensor, tensor<10xf32>, tensor return %get, %stack, %length : tensor, tensor<10xf32>, tensor } @@ -104,7 +104,7 @@ func @main(%arg0: tensor, %arg1: tensor<10xf32>) -> tensor { // CHECK-NEXT: %[[ELEM_SHAPE:.*]] = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32> // CHECK-NEXT: %[[ELEM:.*]] = "tf.Reshape"(%[[SLICE]], %[[ELEM_SHAPE]]) : (tensor<1xf32>, tensor<0xi32>) -> tensor %get = "tf.TensorListGetItem"(%tl, %arg0, %elem_shape) : (tensor>>, tensor, tensor<0xi32>) -> tensor - // CHECK-NEXT: return %[[ELEM]] : tensor + // CHECK-NEXT: return %[[ELEM]] : tensor return %get: tensor } @@ -118,7 +118,7 @@ func @main(%arg0: tensor<10x8x9xf32>) -> tensor<2xi64> { %tl = "tf.TensorListFromTensor"(%arg0, %elem_shape) : (tensor<10x8x9xf32>, tensor<2xi32>) -> tensor>> // CHECK: %[[SHAPE:.*]] = "tf.Const"() {value = dense<[8, 9]> : tensor<2xi64>} : () -> tensor<2xi64> %shape = "tf.TensorListElementShape"(%tl) : (tensor>>) -> tensor<2xi64> - // CHECK-NEXT: return %[[SHAPE]] : tensor<2xi64> + // CHECK-NEXT: return %[[SHAPE]] : tensor<2xi64> return %shape: tensor<2xi64> } @@ -135,7 +135,7 @@ func @main(%arg0: tensor<10x8x9xf32>, %arg1: tensor<3xi32>) -> tensor<3x8x9xf32> // CHECK: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor // CHECK: %[[GATHER:.*]] = "tf.GatherV2"(%[[BUFFER]], %[[ARG1]], %[[AXIS]]) : (tensor<10x8x9xf32>, tensor<3xi32>, tensor) -> tensor<3x8x9xf32> %gather = "tf.TensorListGather"(%tl, %arg1, %elem_shape) : (tensor>>, tensor<3xi32>, tensor<2xi32>) -> tensor<3x8x9xf32> - // CHECK-NEXT: return %[[GATHER]] : tensor<3x8x9xf32> + // CHECK-NEXT: return %[[GATHER]] : tensor<3x8x9xf32> return %gather: tensor<3x8x9xf32> } @@ -173,7 +173,7 @@ func @main() -> () { : (tensor>>, tensor) -> (tensor>>, tensor) // CHECK: "tf.Slice" %pop:2 = "tf.TensorListPopBack"(%1#0, %elem_shape) : (tensor>>, tensor<0xi32>) -> (tensor>>, tensor) - // CHECK-NOT: tf.EmptyTensorList + // CHECK-NOT: tf.TensorListPopBack // CHECK: return return } @@ -242,7 +242,7 @@ func @if_else(%arg0: tensor>>) -> tensor, tensor<0xi32>) -> tensor // CHECK-NOT: "tf.TensorListPopBack" %pop:2 = "tf.TensorListPopBack"(%arg0, %elem_shape) : (tensor>>, tensor<0xi32>) -> (tensor>>, tensor) - // CHECK: return %[[COPY]], %[[SUB]] + // CHECK: return %[[COPY]], %[[SUB]] return %pop#0 : tensor>> } @@ -289,7 +289,7 @@ func @branch_1(%arg0: tensor>>) -> tensor, tensor<0xi32>) -> tensor // CHECK-NOT: "tf.TensorListPopBack" %pop:2 = "tf.TensorListPopBack"(%arg0, %elem_shape) : (tensor>>, tensor<0xi32>) -> (tensor>>, tensor) - // CHECK: return %[[COPY]], %[[SUB]] + // CHECK: return %[[COPY]], %[[SUB]] return %pop#0 : tensor>> } // CHECK: func @branch_2(%[[EARG0:.*]]: tensor<10xf32>, %[[EARG1:.*]]: tensor<1xi32>) -> (tensor<10xf32>, tensor<1xi32>) @@ -305,9 +305,145 @@ func @branch_2(%arg0: tensor>>) -> tensor, tensor<0xi32>) -> tensor // CHECK-NOT: "tf.TensorListPopBack" %pop:2 = "tf.TensorListPopBack"(%arg0, %elem_shape) : (tensor>>, tensor<0xi32>) -> (tensor>>, tensor) - // CHECK: return %[[COPY]], %[[SUB]] + // CHECK: return %[[COPY]], %[[SUB]] return %pop#0 : tensor>> } + +// ----- + +// CHECK-LABEL: func @main +func @main() -> tensor { + %elem_shape = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32> + %size = "tf.Const"() {value = dense<10> : tensor} : () -> tensor + // CHECK-NOT: tf.EmptyTensorList + %tl = "tf.EmptyTensorList"(%elem_shape, %size) : (tensor<0xi32>, tensor) -> tensor>> + %while_op:2 = "tf.WhileRegion"(%tl, %size) ( { + // CHECK: ^bb0(%[[CARG0:.*]]: tensor<10xf32>, %[[CARG1:.*]]: tensor, %[[CARG2:.*]]: tensor<1xi32>): + ^bb0(%arg0: tensor>>, %arg1: tensor): // no predecessors + // CHECK: %[[PRED:.*]] = "tf._SomeOp"() + // CHECK: "tf.Yield"(%[[PRED]]) + %pred = "tf._SomeOp"() : () -> tensor + "tf.Yield"(%pred) : (tensor) -> () + }, { + // CHECK: ^bb0(%[[CARG0:.*]]: tensor<10xf32>, %[[CARG1:.*]]: tensor, %[[CARG2:.*]]: tensor<1xi32>): + ^bb0(%arg0: tensor>>, %arg1: tensor): // no predecessors + // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + // CHECK: %[[SUB:.*]] = "tf.Sub"(%[[CARG1]], %[[CST]]) + // CHECK: %[[ELEM:.*]] = "tf._SomeOp"() : () -> tensor + %cst = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %sub = "tf.Sub"(%arg1, %cst) : (tensor, tensor) -> tensor + %elem = "tf._SomeOp"() : () -> tensor + // CHECK-NOT: "tf.TensorListPushBack" + // CHECK: %[[UPDATE:.*]] = "tf.XlaDynamicUpdateSlice"(%[[CARG0]] + // CHECK: %[[ONE:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} + // CHECK: %[[ADD:.*]] = "tf.AddV2"(%[[CARG2]], %[[ONE]]) + // CHECK-NOT: "tf.TensorListPushBack" + // CHECK: "tf.Yield"(%[[UPDATE]], %[[SUB]], %[[ADD]]) + // CHECK: }) {is_stateless = false} + %push = "tf.TensorListPushBack"(%arg0, %elem) : (tensor>>, tensor) -> tensor>> + "tf.Yield"(%push, %sub) : (tensor>>, tensor) -> () + }) {is_stateless = false} : (tensor>>, tensor) -> (tensor>>, tensor) + // CHECK: "tf.Slice" + // CHECK-NOT: tf.TensorListPopBack + %pop:2 = "tf.TensorListPopBack"(%while_op#0, %elem_shape) : (tensor>>, tensor<0xi32>) -> (tensor>>, tensor) + // CHECK: return + return %pop#1 : tensor +} +// ----- + +// CHECK-LABEL: func @main +func @main(%arg0: tensor) -> () { + %elem_shape = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32> + %max_size = "tf.Const"() {value = dense<10> : tensor} : () -> tensor + %tl = "tf.EmptyTensorList"(%elem_shape, %max_size) : (tensor<0xi32>, tensor) -> tensor>> + // CHECK: %[[ZERO:.*]] = "tf.Const"() {value = dense<0> : tensor} + // CHECK: %[[ZERO_F32:.*]] = "tf.Cast"(%[[ZERO]]) + // CHECK: %[[MAX_SIZE:.*]] = "tf.Const"() {value = dense<10> : tensor<1xi32>} + // CHECK: %[[BUFFER:.*]] = "tf.BroadcastTo"(%[[ZERO_F32]], %[[MAX_SIZE]]) + // CHECK: %[[BUFFER_SIZE:.*]] = "tf.Const"() {value = dense<0> : tensor<1xi32>} + // CHECK-NOT: tf.EmptyTensorList + %if_op = "tf.IfRegion"(%arg0) ({ + %elem = "tf._SomeOp"() : () -> tensor + // CHECK: %[[UPDATE:.*]] = "tf.XlaDynamicUpdateSlice" + // CHECK: %[[ONE:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK: %[[ADD:.*]] = "tf.AddV2"(%[[BUFFER_SIZE]], %[[ONE]]) + // CHECK-NOT: "tf.TensorListPushBack" + %push = "tf.TensorListPushBack"(%tl, %elem) : (tensor>>, tensor) -> tensor>> + "tf.Yield" (%push) : (tensor>>) -> () + }, { + // CHECK: %[[COPY:.*]] = "tf.Identity"(%[[BUFFER]]) + // CHECK: %[[ONE:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} + // CHECK: %[[SUB:.*]] = "tf.Sub"(%[[BUFFER_SIZE]], %[[ONE]]) + // CHECK: %[[SLICE_SIZE:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} + // CHECK: %[[SLICE:.*]] = "tf.Slice"(%[[COPY]], %[[SUB]], %[[SLICE_SIZE]]) + // CHECK: %[[ELEM_SHAPE:.*]] = "tf.Const"() {value = dense<> : tensor<0xi32>} + // CHECK: %[[ELEM:.*]] = "tf.Reshape"(%[[SLICE]], %[[ELEM_SHAPE]]) + // CHECK-NOT: "tf.TensorListPopBack" + %pop:2 = "tf.TensorListPopBack"(%tl, %elem_shape) : (tensor>>, tensor<0xi32>) -> (tensor>>, tensor) + // CHECK: "tf.Yield"(%[[COPY]], %[[SUB]]) + "tf.Yield" (%pop#0) : (tensor>>) -> () + }) + {is_stateless = false} + : (tensor) -> tensor>> + // CHECK: "tf.Slice" + // CHECK-NOT: tf.TensorListPopBack + %pop:2 = "tf.TensorListPopBack"(%if_op, %elem_shape) : (tensor>>, tensor<0xi32>) -> (tensor>>, tensor) + return +} + +// ----- + +// CHECK-LABEL: func @main +func @main(%arg0: tensor) -> () { + %elem_shape = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32> + %max_size = "tf.Const"() {value = dense<10> : tensor} : () -> tensor + // CHECK: %[[ZERO:.*]] = "tf.Const"() {value = dense<0> : tensor} + // CHECK: %[[ZERO_F32:.*]] = "tf.Cast"(%[[ZERO]]) + // CHECK: %[[MAX_SIZE:.*]] = "tf.Const"() {value = dense<10> : tensor<1xi32>} + // CHECK: %[[BUFFER:.*]] = "tf.BroadcastTo"(%[[ZERO_F32]], %[[MAX_SIZE]]) + // CHECK: %[[BUFFER_SIZE:.*]] = "tf.Const"() {value = dense<0> : tensor<1xi32>} + // CHECK-NOT: tf.EmptyTensorList + %tl = "tf.EmptyTensorList"(%elem_shape, %max_size) : (tensor<0xi32>, tensor) -> tensor>> + %case_op = "tf.CaseRegion"(%arg0) ({ + %elem = "tf._SomeOp"() : () -> tensor + // CHECK: %[[UPDATE:.*]] = "tf.XlaDynamicUpdateSlice" + // CHECK: %[[ONE:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK: %[[ADD:.*]] = "tf.AddV2"(%[[BUFFER_SIZE]], %[[ONE]]) + // CHECK-NOT: "tf.TensorListPushBack" + %push = "tf.TensorListPushBack"(%tl, %elem) : (tensor>>, tensor) -> tensor>> + "tf.Yield" (%push) : (tensor>>) -> () + }, { + // CHECK: %[[COPY:.*]] = "tf.Identity"(%[[BUFFER]]) + // CHECK: %[[ONE:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} + // CHECK: %[[SUB:.*]] = "tf.Sub"(%[[BUFFER_SIZE]], %[[ONE]]) + // CHECK: %[[SLICE_SIZE:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} + // CHECK: %[[SLICE:.*]] = "tf.Slice"(%[[COPY]], %[[SUB]], %[[SLICE_SIZE]]) + // CHECK: %[[ELEM_SHAPE:.*]] = "tf.Const"() {value = dense<> : tensor<0xi32>} + // CHECK: %[[ELEM:.*]] = "tf.Reshape"(%[[SLICE]], %[[ELEM_SHAPE]]) + // CHECK-NOT: "tf.TensorListPopBack" + %pop:2 = "tf.TensorListPopBack"(%tl, %elem_shape) : (tensor>>, tensor<0xi32>) -> (tensor>>, tensor) + // CHECK: "tf.Yield"(%[[COPY]], %[[SUB]]) + "tf.Yield" (%pop#0) : (tensor>>) -> () + }, { + // CHECK: %[[COPY:.*]] = "tf.Identity"(%[[BUFFER]]) + // CHECK: %[[ONE:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} + // CHECK: %[[SUB:.*]] = "tf.Sub"(%[[BUFFER_SIZE]], %[[ONE]]) + // CHECK: %[[SLICE_SIZE:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} + // CHECK: %[[SLICE:.*]] = "tf.Slice"(%[[COPY]], %[[SUB]], %[[SLICE_SIZE]]) + // CHECK: %[[ELEM_SHAPE:.*]] = "tf.Const"() {value = dense<> : tensor<0xi32>} + // CHECK: %[[ELEM:.*]] = "tf.Reshape"(%[[SLICE]], %[[ELEM_SHAPE]]) + // CHECK-NOT: "tf.TensorListPopBack" + %pop:2 = "tf.TensorListPopBack"(%tl, %elem_shape) : (tensor>>, tensor<0xi32>) -> (tensor>>, tensor) + // CHECK: "tf.Yield"(%[[COPY]], %[[SUB]]) + "tf.Yield" (%pop#0) : (tensor>>) -> () + }) {is_stateless = false} + : (tensor) -> tensor>> + // CHECK: "tf.Slice" + // CHECK-NOT: tf.TensorListPopBack + %pop:2 = "tf.TensorListPopBack"(%case_op, %elem_shape) : (tensor>>, tensor<0xi32>) -> (tensor>>, tensor) + return +} + // ----- // Tests PartitionedCall/StatefulPartitionedCall. diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir index 9a8d97eddf1..8b97bfdad6d 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir @@ -183,6 +183,20 @@ func @testLeakyWrongAlphaType(tensor<16xf32>) -> tensor<16xf32> { // ----- +// Test tf.Min with complex numbers. +// Previous versions of tensorflow said complex numbers were allowed with +// tf.Min even though it doesn't make sense. The legalization of tf to xla +// requires that complex types are not allowed in tf.Min, so we have an +// explicit unit here to make sure that invariant is enforced. +func @testMinComplex(%arg0: tensor<4x8xcomplex>) -> tensor<4x1xcomplex> { + %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> + // expected-error@below {{'tf.Min' op operand #0 must be tensor of}} + %0 = "tf.Min"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8xcomplex>, tensor<1xi64>) -> tensor<4x1xcomplex> + return %0 : tensor<4x1xcomplex> +} + +// ----- + // CHECK-LABEL: func @testMul func @testMul(%arg0: tensor<2xui16>) -> (tensor<2xui16>) { %0 = "tf.Mul"(%arg0, %arg0) {T = "tfdtype$DT_UINT16", device = "/device:CPU:0", name = "Mul"} : (tensor<2xui16>, tensor<2xui16>) -> tensor<2xui16> @@ -210,17 +224,17 @@ func @testIncompatibleElementTypes(%arg0: tensor<3x2xf32>, %arg1: tensor<3x2xf32 // ----- // CHECK-LABEL: func @testReshape(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<10000xf32>, %arg3: tensor<*xi32>) -func @testReshape(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<10000xf32>, %arg3: tensor<*xi32>) -> (tensor<100x100xf32>, tensor<*xf32>, tensor<10000xf32>, tensor<100x100xf32>, tensor<*xf32>, tensor<*xf32>) { +func @testReshape(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<10000xf32>, %arg3: tensor<*xi32>) -> (tensor<100x100xf32>, tensor<*xf32>, tensor<100x100xf32>, tensor<100x100xf32>, tensor<*xf32>, tensor<*xf32>) { %shape1 = constant dense<100> : tensor<2xi32> - %r1 = "tf.Reshape" (%arg0, %shape1) : (tensor<*xf32>, tensor<2xi32>) -> (tensor<100x100xf32>) - %shape2 = "tf.Shape"(%arg0) {device = "", name = "Shape", T = "tfdtype$DT_FLOAT", out_type = "tfdtype$DT_INT32"} : (tensor<*xf32>) -> (tensor) - %r2 = "tf.Reshape"(%arg1, %shape2) {device = "", name = "Reshape_1", T = "tfdtype$DT_FLOAT", Tshape = "tfdtype$DT_INT32"} : (tensor<*xf32>, tensor) -> (tensor<*xf32>) - %r3 = "tf.Reshape"(%arg2, %shape1) {device = "", name = "Reshape_1", T = "tfdtype$DT_FLOAT", Tshape = "tfdtype$DT_INT32"} : (tensor<10000xf32>, tensor<2xi32>) -> (tensor<10000xf32>) + %r1 = "tf.Reshape" (%arg0, %shape1) : (tensor<*xf32>, tensor<2xi32>) -> tensor<100x100xf32> + %shape2 = "tf.Shape"(%arg0) : (tensor<*xf32>) -> tensor + %r2 = "tf.Reshape"(%arg1, %shape2) : (tensor<*xf32>, tensor) -> tensor<*xf32> + %r3 = "tf.Reshape"(%arg2, %shape1) : (tensor<10000xf32>, tensor<2xi32>) -> tensor<100x100xf32> %shape3 = constant dense<[-1, 100]> : tensor<2xi32> - %r4 = "tf.Reshape"(%arg2, %shape3) {device = "", name = "Reshape_1", T = "tfdtype$DT_FLOAT", Tshape = "tfdtype$DT_INT32"} : (tensor<10000xf32>, tensor<2xi32>) -> (tensor<100x100xf32>) - %r5 = "tf.Reshape"(%arg0, %arg3) {T = "tfdtype$DT_FLOAT", Tshape = "tfdtype$DT_INT32"} : (tensor<*xf32>, tensor<*xi32>) -> (tensor<*xf32>) - %r6 = "tf.Reshape"(%arg2, %arg3) {T = "tfdtype$DT_FLOAT", Tshape = "tfdtype$DT_INT32"} : (tensor<10000xf32>, tensor<*xi32>) -> (tensor<*xf32>) - return %r1, %r2, %r3, %r4, %r5, %r6: tensor<100x100xf32>, tensor<*xf32>, tensor<10000xf32>, tensor<100x100xf32>, tensor<*xf32>, tensor<*xf32> + %r4 = "tf.Reshape"(%arg2, %shape3) : (tensor<10000xf32>, tensor<2xi32>) -> tensor<100x100xf32> + %r5 = "tf.Reshape"(%arg0, %arg3) : (tensor<*xf32>, tensor<*xi32>) -> tensor<*xf32> + %r6 = "tf.Reshape"(%arg2, %arg3) : (tensor<10000xf32>, tensor<*xi32>) -> tensor<*xf32> + return %r1, %r2, %r3, %r4, %r5, %r6: tensor<100x100xf32>, tensor<*xf32>, tensor<100x100xf32>, tensor<100x100xf32>, tensor<*xf32>, tensor<*xf32> } // ----- @@ -228,26 +242,42 @@ func @testReshape(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<1000 func @testReshape(tensor<*xf32>, tensor<*xf32>) -> (tensor<100x100xf32>) { ^bb0(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>): %shape1 = constant dense<100.> : tensor<2xf32> - // expected-error @+1 {{must be tensor of 32/64-bit signless integer values}} - %r1 = "tf.Reshape" (%arg0, %shape1) : (tensor<*xf32>, tensor<2xf32>) -> (tensor<100x100xf32>) + // expected-error @+1 {{must be tensor of 32/64-bit signed integer values}} + %r1 = "tf.Reshape" (%arg0, %shape1) : (tensor<*xf32>, tensor<2xf32>) -> tensor<100x100xf32> return %r1 : tensor<100x100xf32> } // ----- // tf.Reshape with incorrect element number. -func @testReshape(%arg0: tensor<10x10x10xf32>) -> tensor<100x100xf32> { - %shape1 = constant dense<100> : tensor<2xi32> - // expected-error @+1 {{number of output elements (10000) does not match expected number of elements (1000)}} - %r1 = "tf.Reshape" (%arg0, %shape1) : (tensor<10x10x10xf32>, tensor<2xi32>) -> (tensor<100x100xf32>) +func @testReshape(%arg0: tensor<10x10x10xf32>, %shape1: tensor<2xi32>) -> tensor<100x100xf32> { + // expected-error @+1 {{requires 'output' number of elements to match 'tensor' number of elements, but got 10000 and 1000}} + %r1 = "tf.Reshape" (%arg0, %shape1) : (tensor<10x10x10xf32>, tensor<2xi32>) -> tensor<100x100xf32> return %r1 : tensor<100x100xf32> } +// ----- +// tf.Reshape with incorrect shape operand rank. +func @testReshape(%arg0: tensor<10x10x10xf32>, %shape1: tensor<2x2xi32>) -> tensor<*xf32> { + // expected-error @+1 {{requires 'shape' to be rank 1, but got 2}} + %r1 = "tf.Reshape" (%arg0, %shape1) : (tensor<10x10x10xf32>, tensor<2x2xi32>) -> tensor<*xf32> + return %r1 : tensor<*xf32> +} + // ----- // tf.Reshape with more than one -1 in the shape. func @testReshape(%arg0: tensor<10x10x10x10xf32>) -> tensor<100x100xf32> { %shape1 = constant dense<-1> : tensor<2xi32> - // expected-error @+1 {{more than one component of shape are -1}} - %r1 = "tf.Reshape" (%arg0, %shape1) : (tensor<10x10x10x10xf32>, tensor<2xi32>) -> (tensor<100x100xf32>) + // expected-error @+1 {{requires 'shape' to have at most one dynamic dimension, but got multiple dynamic dimensions at indices 0 and 1}} + %r1 = "tf.Reshape" (%arg0, %shape1) : (tensor<10x10x10x10xf32>, tensor<2xi32>) -> tensor<100x100xf32> + return %r1 : tensor<100x100xf32> +} + +// ----- +// tf.Reshape with shape operand element < -1. +func @testReshape(%arg0: tensor<10x10x10x10xf32>) -> tensor<100x100xf32> { + %shape1 = constant dense<[100, -2]> : tensor<2xi32> + // expected-error @+1 {{requires 'shape' to have dimensions greater than -1, but got -2 at index 1}} + %r1 = "tf.Reshape" (%arg0, %shape1) : (tensor<10x10x10x10xf32>, tensor<2xi32>) -> tensor<100x100xf32> return %r1 : tensor<100x100xf32> } @@ -255,19 +285,68 @@ func @testReshape(%arg0: tensor<10x10x10x10xf32>) -> tensor<100x100xf32> { // tf.Reshape with -1 in the shape can't infer the dimension. func @testReshape(%arg0: tensor<10x10x10x10xf32>) -> tensor<100x100xf32> { %shape1 = constant dense<[101, -1]> : tensor<2xi32> - // expected-error @+1 {{one component of shape is -1 but couldn't infer the dimension}} - %r1 = "tf.Reshape" (%arg0, %shape1) : (tensor<10x10x10x10xf32>, tensor<2xi32>) -> (tensor<100x100xf32>) + // expected-error @+1 {{requires 'tensor' number of elements be a multiple of 101, but got 10000}} + %r1 = "tf.Reshape" (%arg0, %shape1) : (tensor<10x10x10x10xf32>, tensor<2xi32>) -> tensor<100x100xf32> return %r1 : tensor<100x100xf32> } // ----- -// tf.Reshape with a first operand that has non-static shape. +// tf.Reshape with incorrect output rank. +func @testReshape(%arg0: tensor<10x10xf32>) -> tensor { + %shape1 = constant dense<[100]> : tensor<1xi32> + // expected-error @+1 {{requires 'output' type 'tensor' to be cast compatible with expected type 'tensor<100xf32>'}} + %r1 = "tf.Reshape" (%arg0, %shape1) : (tensor<10x10xf32>, tensor<1xi32>) -> tensor + return %r1 : tensor +} + +// ----- +// tf.Reshape with incorrect output dimension. +func @testReshape(%arg0: tensor<1000xf32>) -> tensor { + %shape1 = constant dense<[10, 10, 10]> : tensor<3xi32> + // expected-error @+1 {{requires 'output' type 'tensor' to be cast compatible with expected type 'tensor<10x10x10xf32>'}} + %r1 = "tf.Reshape" (%arg0, %shape1) : (tensor<1000xf32>, tensor<3xi32>) -> tensor + return %r1 : tensor +} + +// ----- +// tf.Reshape with a shape operand that has 0 for one of its elements. +func @testReshape(%arg0: tensor<10x10x10xf32>) -> tensor { + %shape1 = constant dense<[-1, 0]> : tensor<2xi32> + %r1 = "tf.Reshape" (%arg0, %shape1) : (tensor<10x10x10xf32>, tensor<2xi32>) -> tensor + return %r1 : tensor +} + +// ----- +// tf.Reshape with a tensor operand that has 0 for one of its elements. +func @testReshape(%arg0: tensor<10x10x0xf32>) -> tensor { + %shape1 = constant dense<[-1, 0]> : tensor<2xi32> + %r1 = "tf.Reshape" (%arg0, %shape1) : (tensor<10x10x0xf32>, tensor<2xi32>) -> tensor + return %r1 : tensor +} + +// ----- +// tf.Reshape with a tensor operand that has non-static shape. func @testReshape(%arg0: tensor<10x10x?xf32>) -> tensor<10x10xf32> { %shape1 = constant dense<[10, 10]> : tensor<2xi32> - %r1 = "tf.Reshape" (%arg0, %shape1) : (tensor<10x10x?xf32>, tensor<2xi32>) -> (tensor<10x10xf32>) + %r1 = "tf.Reshape" (%arg0, %shape1) : (tensor<10x10x?xf32>, tensor<2xi32>) -> tensor<10x10xf32> return %r1 : tensor<10x10xf32> } +// ----- +// tf.Reshape with tensor operand that has non-static shape and shape operand +// with static shape. +func @testReshape(%arg0: tensor<10x10x?xf32>, %shape1: tensor<2xi32>) -> tensor<100x100xf32> { + %r1 = "tf.Reshape" (%arg0, %shape1) : (tensor<10x10x?xf32>, tensor<2xi32>) -> tensor<100x100xf32> + return %r1 : tensor<100x100xf32> +} + +// ----- +// tf.Reshape with tensor and shape operands with static shape. +func @testReshape(%arg0: tensor<10x10x10x10xf32>, %shape1: tensor<2xi32>) -> tensor<100x100xf32> { + %r1 = "tf.Reshape" (%arg0, %shape1) : (tensor<10x10x10x10xf32>, tensor<2xi32>) -> tensor<100x100xf32> + return %r1 : tensor<100x100xf32> +} + // ----- // CHECK-LABEL: func @testValidAvgPool @@ -780,7 +859,7 @@ func @testIfElse(tensor<2xf32>) -> tensor<2xf32> // Test invalid tf.If operation func @testInvalidIfOp(tensor, tensor<2xf32>) -> tensor<2xf32> { ^bb0(%arg0: tensor, %arg1: tensor<2xf32>): - // expected-error @+1 {{expects all branches to have 1 input(s), but 'then_branch' has 2 input(s)}} + // expected-error @+1 {{'tf.If' op 'then_branch' inputs (size = 2) should have the same number of values as inputs (size = 1)}} %1 = "tf.If"(%arg0, %arg1) { then_branch = @testIfThen, else_branch = @testIfElse, @@ -798,7 +877,7 @@ func @testIfElse(tensor<2xf32>) -> tensor<2xf32> // Test invalid tf.If operation func @testInvalidIfOp(tensor, tensor<2xf32>) -> tensor<2xf32> { ^bb0(%arg0: tensor, %arg1: tensor<2xf32>): - // expected-error @+1 {{expects all branches to have 1 result(s), but 'then_branch' has 2 result(s)}} + // expected-error @+1 {{'tf.If' op 'then_branch' results (size = 2) should have the same number of values as results (size = 1)}} %1 = "tf.If"(%arg0, %arg1) { then_branch = @testIfThen, else_branch = @testIfElse, @@ -816,7 +895,7 @@ func @testIfElse(tensor<*xf32>) -> tensor<*xf32> // Test invalid tf.If operation func @testInvalidIfOp(tensor, tensor<2xf32>) -> tensor<2xf32> { ^bb0(%arg0: tensor, %arg1: tensor<2xf32>): - // expected-error @+1 {{expects operand type 'tensor<2xf32>' to be cast compatible with 'then_branch' input type 'tensor<*xf16>' at index 0}} + // expected-error @+1 {{'tf.If' op 'then_branch' input type tensor<*xf16> is incompatible with input type tensor<2xf32> at index 0}} %1 = "tf.If"(%arg0, %arg1) { then_branch = @testIfThen, else_branch = @testIfElse, @@ -852,7 +931,7 @@ func @testIfElse(tensor<*xf32>) -> tensor<3xf32> // Test invalid tf.If operation func @testInvalidIfOp(tensor, tensor<*xf32>) -> tensor<2xf32> { ^bb0(%arg0: tensor, %arg1: tensor<*xf32>): - // expected-error @+1 {{expects result type 'tensor<2xf32>' to be cast compatible with 'else_branch' result type 'tensor<3xf32>' at index 0}} + // expected-error @+1 {{'tf.If' op 'else_branch' result type tensor<3xf32> is incompatible with result type tensor<2xf32> at index 0}} %1 = "tf.If"(%arg0, %arg1) { then_branch = @testIfThen, else_branch = @testIfElse, @@ -1000,7 +1079,7 @@ func @testIfRegionElseTerminator(%arg0: tensor, %arg1: tensor<2xf32>) -> ten // tf.Region yield number of results should match op number of results func @testIfRegionThenResultCount(%arg0: tensor, %arg1: tensor<2xf32>) -> tensor<2xf32> { - // expected-error @+1 {{'tf.IfRegion' op then should have same number (1) of results as tf.IfRegion but has 2 results}} + // expected-error @+1 {{'tf.IfRegion' op then results (size = 2) should have the same number of values as results (size = 1)}} %0 = "tf.IfRegion"(%arg0) ({ %t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> "tf.Yield"(%t, %t) : (tensor<2xf32>, tensor<2xf32>) -> () @@ -1015,7 +1094,7 @@ func @testIfRegionThenResultCount(%arg0: tensor, %arg1: tensor<2xf32>) -> te // ----- func @testIfRegionElseResultCount(%arg0: tensor, %arg1: tensor<2xf32>) -> tensor<2xf32> { - // expected-error @+1 {{tf.IfRegion' op else should have same number (1) of results as tf.IfRegion but has 2 results}} + // expected-error @+1 {{'tf.IfRegion' op else results (size = 2) should have the same number of values as results (size = 1)}} %0 = "tf.IfRegion"(%arg0) ({ %t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> "tf.Yield"(%t) : (tensor<2xf32>) -> () @@ -1031,7 +1110,7 @@ func @testIfRegionElseResultCount(%arg0: tensor, %arg1: tensor<2xf32>) -> te // tf.IfRegion yield types should match op result types func @testIfRegionOpYieldMismatchThen(%arg0: tensor, %arg1: tensor<2xf32>) -> tensor<2xf32> { - // expected-error @+1 {{then result type tensor is incompatible with tf.IfRegion result type tensor<2xf32> at index 0}} + // expected-error @+1 {{'tf.IfRegion' op then result type tensor is incompatible with result type tensor<2xf32> at index 0}} %0 = "tf.IfRegion"(%arg0) ({ "tf.Yield"(%arg0) : (tensor) -> () }, { @@ -1045,7 +1124,7 @@ func @testIfRegionOpYieldMismatchThen(%arg0: tensor, %arg1: tensor<2xf32>) - // ----- func @testIfRegionOpYieldMismatchElse(%arg0: tensor, %arg1: tensor<2xf32>) -> tensor<2xf32> { - // expected-error @+1 {{else result type tensor is incompatible with tf.IfRegion result type tensor<2xf32> at index 0}} + // expected-error @+1 {{'tf.IfRegion' op else result type tensor is incompatible with result type tensor<2xf32> at index 0}} %0 = "tf.IfRegion"(%arg0) ({ %t = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> "tf.Yield"(%t) : (tensor<2xf32>) -> () @@ -1434,6 +1513,110 @@ func @testSoftmaxCrossEntropyWithLogits(%arg0: tensor<3xf32>, %arg1: tensor<3xf3 // ----- +//===--------------------------------------------------------------------===// +// tf.SpaceToBatchND +//===--------------------------------------------------------------------===// + +// Test valid tf.SpaceToBatchND +// CHECK-LABEL: func @testSpaceToBatchND +func @testSpaceToBatchND(%input: tensor<3x5x7x10xf32>, %block_shape: tensor<2xi64>, %paddings: tensor<2x2xi64>) -> tensor { + %0 = "tf.SpaceToBatchND"(%input, %block_shape, %paddings) : (tensor<3x5x7x10xf32>, tensor<2xi64>, tensor<2x2xi64>) -> tensor + return %0 : tensor +} + +// ----- + +// Test valid tf.SpaceToBatchND +// CHECK-LABEL: func @testSpaceToBatchND +func @testSpaceToBatchND(%input: tensor<3x5x7x10xf32>) -> tensor<36x2x3x10xf32> { + %block_shape = "tf.Const"() {value = dense<[4, 3]> : tensor<2xi64>} : () -> tensor<2xi64> + %paddings = "tf.Const"() {value = dense<[[1, 2], [1, 1]]> : tensor<2x2xi64>} : () -> tensor<2x2xi64> + %0 = "tf.SpaceToBatchND"(%input, %block_shape, %paddings) : (tensor<3x5x7x10xf32>, tensor<2xi64>, tensor<2x2xi64>) -> tensor<36x2x3x10xf32> + return %0 : tensor<36x2x3x10xf32> +} + +// ----- + +// Test invalid tf.SpaceToBatchND +func @testSpaceToBatchND(%input: tensor<3x5x7x10xf32>, %block_shape: tensor<2x2xi64>, %paddings: tensor<2x2xi64>) -> tensor { + // expected-error @+1 {{requires rank of block_shape = 1; got 2}} + %0 = "tf.SpaceToBatchND"(%input, %block_shape, %paddings) : (tensor<3x5x7x10xf32>, tensor<2x2xi64>, tensor<2x2xi64>) -> tensor + return %0 : tensor +} + +// ----- + +// Test invalid tf.SpaceToBatchND +func @testSpaceToBatchND(%input: tensor<3x5x7x10xf32>, %block_shape: tensor<2xi64>, %paddings: tensor<2xi64>) -> tensor { + // expected-error @+1 {{requires rank of paddings = 2; got 1}} + %0 = "tf.SpaceToBatchND"(%input, %block_shape, %paddings) : (tensor<3x5x7x10xf32>, tensor<2xi64>, tensor<2xi64>) -> tensor + return %0 : tensor +} + +// ----- + +// Test invalid tf.SpaceToBatchND +func @testSpaceToBatchND(%input: tensor<3x5x7x10xf32>, %block_shape: tensor<2xi64>, %paddings: tensor<2x10xi64>) -> tensor { + // expected-error @+1 {{requires paddings.shape[1] to be 2; got 10}} + %0 = "tf.SpaceToBatchND"(%input, %block_shape, %paddings) : (tensor<3x5x7x10xf32>, tensor<2xi64>, tensor<2x10xi64>) -> tensor + return %0 : tensor +} + +// ----- + +// Test invalid tf.SpaceToBatchND +func @testSpaceToBatchND(%input: tensor<3x5x7x10xf32>, %block_shape: tensor<4xi64>, %paddings: tensor<2x2xi64>) -> tensor { + // expected-error @+1 {{requires block_shape.shape[0] must equal paddings.shape[0]}} + %0 = "tf.SpaceToBatchND"(%input, %block_shape, %paddings) : (tensor<3x5x7x10xf32>, tensor<4xi64>, tensor<2x2xi64>) -> tensor + return %0 : tensor +} + +// ----- + +// Test invalid tf.SpaceToBatchND +func @testSpaceToBatchND(%input: tensor<3x5xf32>, %block_shape: tensor<2xi64>, %paddings: tensor<2x2xi64>) -> tensor { + // expected-error @+1 {{requires rank of input >= 1 + rank of block}} + %0 = "tf.SpaceToBatchND"(%input, %block_shape, %paddings) : (tensor<3x5xf32>, tensor<2xi64>, tensor<2x2xi64>) -> tensor + return %0 : tensor +} + +// ----- + +// Test invalid tf.SpaceToBatchND +func @testSpaceToBatchND(%input: tensor<3x5x7x10xf32>, %paddings: tensor<2x2xi64>) -> tensor { + %block_shape = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> + // expected-error @+1 {{requires all values of block_shape to be >= 1; failed for dimension 1}} + %1 = "tf.SpaceToBatchND"(%input, %block_shape, %paddings) : (tensor<3x5x7x10xf32>, tensor<2xi64>, tensor<2x2xi64>) -> tensor + return %1 : tensor +} + +// ----- + +// Test invalid tf.SpaceToBatchND +func @testSpaceToBatchND(%input: tensor<3x5x7x10xf32>, %block_shape: tensor<2xi64>) -> tensor { + %paddings = "tf.Const"() {value = dense<[[1, 0], [-1, 0]]> : tensor<2x2xi64>} : () -> tensor<2x2xi64> + // expected-error @+1 {{requires all values of paddings to be >= 0; failed for dimension 1}} + %1 = "tf.SpaceToBatchND"(%input, %block_shape, %paddings) : (tensor<3x5x7x10xf32>, tensor<2xi64>, tensor<2x2xi64>) -> tensor + return %1 : tensor +} + +// ----- + +// Test invalid tf.SpaceToBatchND +func @testSpaceToBatchND(%input: tensor<3x5x7x10xf32>) -> tensor<36x2x3x10xf32> { + %block_shape = "tf.Const"() {value = dense<[4, 3]> : tensor<2xi64>} : () -> tensor<2xi64> + %paddings = "tf.Const"() {value = dense<[[1, 2], [1, 2]]> : tensor<2x2xi64>} : () -> tensor<2x2xi64> + // expected-error @+1 {{requires block_shape[i] divides input_shape[i + 1] + paddings[i, 0] + paddings[i, 1]; failed for i=1}} + %1 = "tf.SpaceToBatchND"(%input, %block_shape, %paddings) : (tensor<3x5x7x10xf32>, tensor<2xi64>, tensor<2x2xi64>) -> tensor<36x2x3x10xf32> + return %1 : tensor<36x2x3x10xf32> +} + +// ----- + +//===--------------------------------------------------------------------===// +// tf.SparseSoftmaxCrossEntropyWithLogits +//===--------------------------------------------------------------------===// + // Test valid tf.SparseSoftmaxCrossEntropyWithLogits // CHECK-LABEL: func @testSparseSoftmaxCrossEntropyWithLogits func @testSparseSoftmaxCrossEntropyWithLogits(%arg0: tensor<2x3xf32>, %arg1: tensor<2xi32>) -> (tensor<3xf32>, tensor<2x3xf32>) { @@ -1527,7 +1710,7 @@ func @testWhileBody(tensor<*xf32>) -> (tensor<*xf32>) // Test invalid 'While' operation func @testWhileResult(tensor<*xf32>) -> (tensor<*xi32>) { ^bb0(%arg0: tensor<*xf32>): - // expected-error @+1 {{operand type tensor<*xf32> is incompatible with result type}} + // expected-error @+1 {{'tf.While' op input type tensor<*xf32> is incompatible with result type tensor<*xi32> at index 0}} %1 = "tf.While"(%arg0) { cond = @testWhileCond, body = @testWhileBody, @@ -1545,7 +1728,7 @@ func @testWhileBody(tensor<*xf32>) -> (tensor<*xf32>) // Test invalid 'While' operation func @testWhileResult(tensor<*xf32>) -> (tensor<*xf32>) { ^bb0(%arg0: tensor<*xf32>): - // expected-error @+1 {{operand type tensor<*xf32> is incompatible with cond function input type}} + // expected-error @+1 {{'tf.While' op input type tensor<*xf32> is incompatible with condition input type tensor<*xi32> at index 0}} %1 = "tf.While"(%arg0) { cond = @testWhileCond, body = @testWhileBody, @@ -1563,7 +1746,7 @@ func @testWhileBody(tensor<*xf32>, tensor<*xf32>) -> (tensor<*xf32>) // Test invalid 'While' operation func @testWhileResult(tensor<*xf32>) -> (tensor<*xf32>) { ^bb0(%arg0: tensor<*xf32>): - // expected-error @+1 {{requires the number of operands to be equal to the number of body function inputs. Found 1 and 2, respectively}} + // expected-error @+1 {{'tf.While' op inputs (size = 1) should have the same number of values as body inputs (size = 2)}} %1 = "tf.While"(%arg0) { cond = @testWhileCond, body = @testWhileBody, @@ -1581,7 +1764,7 @@ func @testWhileBody(tensor<*xf32>) -> (tensor<*xi32>) // Test invalid 'While' operation func @testWhileResult(tensor<*xf32>) -> (tensor<*xf32>) { ^bb0(%arg0: tensor<*xf32>): - // expected-error @+1 {{body function result type tensor<*xi32> is incompatible with result type}} + // expected-error @+1 {{'tf.While' op body result type tensor<*xi32> is incompatible with result type tensor<*xf32> at index 0}} %1 = "tf.While"(%arg0) { cond = @testWhileCond, body = @testWhileBody, @@ -1599,7 +1782,7 @@ func @testWhileBody(tensor<4xf32>) -> (tensor<*xf32>) // Test invalid 'While' operation func @testWhileResult(tensor<*xf32>) -> (tensor<*xf32>) { ^bb0(%arg0: tensor<*xf32>): - // expected-error @+1 {{cond function input type tensor<3xf32> is incompatible with body function input type}} + // expected-error @+1 {{'tf.While' op condition input type tensor<3xf32> is incompatible with body input type tensor<4xf32> at index 0}} %1 = "tf.While"(%arg0) { cond = @testWhileCond, body = @testWhileBody, @@ -1618,7 +1801,7 @@ func @testWhileBody(tensor<*x!tf.resource>>) -> (tensor>>) -> (tensor>>) { ^bb0(%arg0: tensor<*x!tf.resource>>): - // expected-error @+1 {{operand type tensor<*x!tf.resource>> is incompatible with result type}} + // expected-error @+1 {{'tf.While' op input type tensor<*x!tf.resource>> is incompatible with result type tensor>> at index 0}} %1 = "tf.While"(%arg0) { cond = @testWhileCond, body = @testWhileBody, @@ -1714,48 +1897,71 @@ func @testValidWhileRegionNoInputs() -> () { } // ----- +// Invalid while tests. There are 5 sets of type matching that is required +// I = input, O = output, BI, BO = body input/output, CI = cond input. +// [I, O], [I, CI], [I, BI], [BO, BI], [BO, O]. +// Each check can fail due to number or type mismatch. However, these +// conditions are not all independent. So we just check I->{CI, BI}, O->BO, and +// in addition I->O. BO->BI mismatch cannot be independently created without +// breaking one of these mismatches. That gives us 4x2 tests. In addition +// condition result needs to be tensor, for which we have 3 +// additional validation tests. All these tests are based on the following +// valid while -func @testInvalidWhileRegionMismatchCondInputCount(%arg : tensor) -> (tensor) { - // expected-error @+1 {{'tf.WhileRegion' op condition should have same number of inputs (1) as tf.WhileRegion but has 0 inputs}} - %0 = "tf.WhileRegion"(%arg) ( - { - // ^bb0(%carg: tensor): - %true = constant dense<1> : tensor - "tf.Yield"(%true) : (tensor) -> () - }, - { - ^bb0(%barg: tensor): - "tf.Yield"(%arg) : (tensor) -> () - } - ) : (tensor) -> (tensor) +func @testInvalidTestValidBase(%arg0 : tensor) -> (tensor) { + %0 = "tf.WhileRegion"(%arg0) ( + { + ^bb0(%carg: tensor): + %false = constant dense : tensor + "tf.Yield"(%false) : (tensor) -> () + }, + { + ^bb0(%barg: tensor): + "tf.Yield"(%barg) : (tensor) -> () + } + ) { is_stateless = true } : (tensor) -> (tensor) + return %0 : tensor +} +func @testInvalidWhileRegion_I_CI_CountMismatch(%arg0 : tensor) -> (tensor) { + // expected-error @+1 {{'tf.WhileRegion' op inputs (size = 1) should have the same number of values as condition inputs (size = 0)}} + %0 = "tf.WhileRegion"(%arg0) ( + { + //^bb0(%carg: tensor): + %false = constant dense : tensor + "tf.Yield"(%false) : (tensor) -> () + }, + { + ^bb0(%barg: tensor): + "tf.Yield"(%barg) : (tensor) -> () + } + ) { is_stateless = true } : (tensor) -> (tensor) return %0 : tensor } // ----- -func @testInvalidWhileRegionMismatchCondInputType(%arg : tensor) -> (tensor) { - // expected-error @+1 {{'tf.WhileRegion' op condition input type tensor is incompatible with tf.WhileRegion input type tensor at index 0}} - %0 = "tf.WhileRegion"(%arg) ( - { - ^bb0(%carg: tensor): - %true = constant dense<1> : tensor - "tf.Yield"(%true) : (tensor) -> () - }, - { - ^bb0(%barg: tensor): - "tf.Yield"(%barg) : (tensor) -> () - } - ) : (tensor) -> (tensor) - +func @testInvalidWhileRegion_I_CI_TypeMismatch(%arg0 : tensor) -> (tensor) { + // expected-error @+1 {{'tf.WhileRegion' op input type tensor is incompatible with condition input type tensor at index 0}} + %0 = "tf.WhileRegion"(%arg0) ( + { + ^bb0(%carg: tensor): + %false = constant dense : tensor + "tf.Yield"(%false) : (tensor) -> () + }, + { + ^bb0(%barg: tensor): + "tf.Yield"(%barg) : (tensor) -> () + } + ) { is_stateless = true } : (tensor) -> (tensor) return %0 : tensor } // ----- -func @testInvalidWhileRegionMismatchBodyInputCount(%arg : tensor) -> (tensor) { - // expected-error @+1 {{'tf.WhileRegion' op body should have same number of inputs (1) as tf.WhileRegion but has 2 inputs}} - %0 = "tf.WhileRegion"(%arg) ( +func @testInvalidWhileRegion_I_BI_CountMismatch(%arg0 : tensor) -> (tensor) { + // expected-error @+1 {{'tf.WhileRegion' op inputs (size = 1) should have the same number of values as body inputs (size = 2)}} + %0 = "tf.WhileRegion"(%arg0) ( { ^bb0(%carg: tensor): %true = constant dense<1> : tensor @@ -1772,9 +1978,9 @@ func @testInvalidWhileRegionMismatchBodyInputCount(%arg : tensor) -> (tenso // ----- -func @testInvalidWhileRegionMismatchBodyInputType(%arg : tensor) -> (tensor) { - // expected-error @+1 {{body input type tensor is incompatible with tf.WhileRegion input type tensor at index 0}} - %0 = "tf.WhileRegion"(%arg) ( +func @testInvalidWhileRegion_I_BI_TypeMismatch(%arg0 : tensor) -> (tensor) { + // expected-error @+1 {{'tf.WhileRegion' op input type tensor is incompatible with body input type tensor at index 0}} + %0 = "tf.WhileRegion"(%arg0) ( { ^bb0(%carg: tensor): %true = constant dense<1> : tensor @@ -1792,6 +1998,77 @@ func @testInvalidWhileRegionMismatchBodyInputType(%arg : tensor) -> (tensor // ----- +func @testInvalidWhileRegion_O_BO_CountMismatch(%arg0 : tensor) -> (tensor) { + // expected-error @+1 {{'tf.WhileRegion' op body results (size = 2) should have the same number of values as results (size = 1)}} + %0 = "tf.WhileRegion"(%arg0) ( + { + ^bb0(%carg: tensor): + %false = constant dense : tensor + "tf.Yield"(%false) : (tensor) -> () + }, + { + ^bb0(%barg: tensor): + "tf.Yield"(%barg, %barg) : (tensor, tensor) -> () + } + ) { is_stateless = true } : (tensor) -> (tensor) + return %0#0 : tensor +} + +// ----- + +func @testInvalidWhileRegionMismatch_O_BO_TypeMismatch(%arg0 : tensor, %arg1: tensor) -> (tensor) { + // expected-error @+1 {{'tf.WhileRegion' op body result type tensor is incompatible with result type tensor at index 0}} + %0 = "tf.WhileRegion"(%arg0) ( + { + ^bb0(%carg: tensor): + %false = constant dense : tensor + "tf.Yield"(%false) : (tensor) -> () + }, + { + ^bb0(%barg: tensor): + "tf.Yield"(%arg1) : (tensor) -> () + } + ) { is_stateless = true } : (tensor) -> (tensor) + return %0 : tensor +} + +// ----- + +func @testInvalidWhileRegion_I_O_CountMismatch(%arg0 : tensor) -> (tensor) { + // expected-error@+1 {{'tf.WhileRegion' op inputs (size = 1) should have the same number of values as results (size = 2)}} + %0:2 = "tf.WhileRegion"(%arg0) ( + { + ^bb0(%carg: tensor): + %false = constant dense : tensor + "tf.Yield"(%false) : (tensor) -> () + }, + { + ^bb0(%barg: tensor): + "tf.Yield"(%barg, %barg) : (tensor, tensor) -> () + } + ) { is_stateless = true } : (tensor) -> (tensor, tensor) + return %0#0 : tensor +} + +// ----- + +func @testInvalidWhileRegion_I_O_TypeMismatch(%arg0: tensor, %arg1 : tensor) -> (tensor) { + // expected-error@+1 {{'tf.WhileRegion' op input type tensor is incompatible with result type tensor at index 0}} + %0 = "tf.WhileRegion"(%arg0) ( + { + ^bb0(%carg: tensor): + %false = constant dense : tensor + "tf.Yield"(%false) : (tensor) -> () + }, + { + ^bb0(%barg: tensor): + "tf.Yield"(%arg1) : (tensor) -> () + } + ) { is_stateless = true } : (tensor) -> (tensor) + return %0 : tensor +} +// ----- + func @testInvalidWhileRegionConditionOutputCount2(%arg : tensor) -> (tensor) { // expected-error @+1 {{'tf.WhileRegion' op condition should have a single tensor result}} %0 = "tf.WhileRegion"(%arg) ( @@ -1845,45 +2122,6 @@ func @testInvalidWhileRegionConditionOutputType(%arg : tensor) -> (tensor } -// ----- - -func @testInvalidWhileRegionMismatchBodyOutputCount(%arg : tensor) -> (tensor) { - // expected-error @+1 {{'tf.WhileRegion' op body should have same number (1) of results as tf.WhileRegion but has 2 results}} - %0 = "tf.WhileRegion"(%arg) ( - { - ^bb0(%carg: tensor): - %true = constant dense<1> : tensor - "tf.Yield"(%true) : (tensor) -> () - }, - { - ^bb0(%barg: tensor): - %false = constant dense<1> : tensor - "tf.Yield"(%barg, %false) : (tensor, tensor) -> () - } - ) : (tensor) -> (tensor) - - return %0 : tensor -} - -// ----- - -func @testInvalidWhileRegionMismatchBodyOutputType(%arg : tensor) -> (tensor) { - // expected-error @+1 {{body result type tensor is incompatible with tf.WhileRegion result type tensor at index 0}} - %0 = "tf.WhileRegion"(%arg) ( - { - ^bb0(%carg: tensor): - %true = constant dense<1> : tensor - "tf.Yield"(%true) : (tensor) -> () - }, - { - ^bb0(%barg: tensor): - %c = "tf.Cast"(%barg) : (tensor) -> tensor - "tf.Yield"(%c) : (tensor) -> () - } - ) : (tensor) -> (tensor) - - return %0 : tensor -} // ----- @@ -1898,7 +2136,7 @@ func @testValidShape(tensor<1x32x32x16xf32>, tensor<*xf32>) -> (tensor<4xi32>, t // ----- func @testShapeWrongResultElemType(%arg0: tensor<1x32x32x16xf32>) -> tensor<4xf32> { - // expected-error @+1 {{result #0 must be tensor of 32/64-bit signless integer values}} + // expected-error @+1 {{result #0 must be tensor of 32/64-bit signed integer values}} %0 = "tf.Shape"(%arg0) : (tensor<1x32x32x16xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } @@ -1942,7 +2180,7 @@ func @testValidShapeN(%arg0 : tensor<1x32x32x16xf32>, %arg1 : tensor<*xf32>) -> // ----- func @testShapeNWrongResultElemType(%arg0: tensor<1x32x32x16xf32>) -> tensor<4xf32> { - // expected-error @+1 {{result #1 must be tensor of 32/64-bit signless integer values}} + // expected-error @+1 {{result #1 must be tensor of 32/64-bit signed integer values}} %0:2 = "tf.ShapeN"(%arg0, %arg0) : (tensor<1x32x32x16xf32>, tensor<1x32x32x16xf32>) -> (tensor<4xi32>, tensor<4xf32>) return %0#1 : tensor<4xf32> } @@ -2003,7 +2241,7 @@ func @testVariableShapeMultipleSubtypes(%arg0: tensor<*x!tf.resource>>) -> tensor { - // expected-error @+1 {{result #0 must be tensor of 32/64-bit signless integer values}} + // expected-error @+1 {{result #0 must be tensor of 32/64-bit signed integer values}} %0 = "tf.VariableShape"(%arg0) : (tensor<*x!tf.resource>>) -> tensor<4xf32> return %0 : tensor<4xf32> } @@ -2139,7 +2377,7 @@ func @testTranspose(tensor<2x3x4xf32>) -> tensor<3x2x4xf32> { // Test invalid tf.Less func @testLess(tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> { ^bb0(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>): - // expected-error @+1 {{op result #0 must be tensor of 1-bit signless integer values}} + // expected-error @+1 {{op result #0 must be tensor of bool values}} %0 = "tf.Less"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> return %0 : tensor<4xi32> } @@ -2156,7 +2394,7 @@ func @testConcatV2(%arg: tensor<8x16xf32>, %axis: tensor) -> tensor // tf.ConcatV2 with wrong 'axis' element type func @testConcatV2(%arg: tensor<8x16xf32>, %axis: tensor) -> tensor { - // expected-error @+1 {{operand #2 must be tensor of 32/64-bit signless integer values}} + // expected-error @+1 {{operand #2 must be tensor of 32/64-bit signed integer values}} %0 = "tf.ConcatV2"(%arg, %arg, %axis) : (tensor<8x16xf32>, tensor<8x16xf32>, tensor) -> tensor return %0 : tensor } @@ -2189,7 +2427,7 @@ func @testAll64(%arg0: tensor<2x2xi1>, %arg1: tensor) -> tensor { // ----- func @testAllFloat(%arg0: tensor<2x2xi1>, %arg1: tensor) -> tensor { - // expected-error @+1 {{'tf.All' op operand #1 must be tensor of 32/64-bit signless integer values}} + // expected-error @+1 {{'tf.All' op operand #1 must be tensor of 32/64-bit signed integer values}} %0 = "tf.All"(%arg0, %arg1) {keep_dims = false} : (tensor<2x2xi1>, tensor) -> tensor return %0 : tensor } @@ -2197,7 +2435,7 @@ func @testAllFloat(%arg0: tensor<2x2xi1>, %arg1: tensor) -> tensor { // ----- func @testAllI32(%arg0: tensor<2x2xi32>, %arg1: tensor) -> tensor { - // expected-error @+1 {{'tf.All' op operand #0 must be tensor of 1-bit signless integer values}} + // expected-error @+1 {{'tf.All' op operand #0 must be tensor of bool values}} %0 = "tf.All"(%arg0, %arg1) {keep_dims = false} : (tensor<2x2xi32>, tensor) -> tensor return %0 : tensor } @@ -2381,6 +2619,25 @@ func @testSlice_unknown_begin_in_bounds(%arg0: tensor<4xi32>, %begins: tensor<1x // ----- +func @testSlice_unequal_output_input_rank(%arg0: tensor<4xi32>, %begins: tensor<1xi64>) -> tensor { + %sizes = "tf.Const"() {value = dense<[1]> : tensor<1xi64>} : () -> (tensor<1xi64>) + // expected-error @+1 {{requires output to have the same rank as input, but got input rank 1 and output rank 0}} + %0 = "tf.Slice"(%arg0, %begins, %sizes) : (tensor<4xi32>, tensor<1xi64>, tensor<1xi64>) -> tensor + return %0 : tensor +} + +// ----- + +func @testSlice_wrong_output_size(%arg0: tensor<4xi32>) -> tensor<1xi32> { + %begins = "tf.Const"() {value = dense<[1]> : tensor<1xi64>} : () -> (tensor<1xi64>) + %sizes = "tf.Const"() {value = dense<[2]> : tensor<1xi64>} : () -> (tensor<1xi64>) + // expected-error @+1 {{requires output size to have the same size of slice, got slice size 2 and output size 1}} + %0 = "tf.Slice"(%arg0, %begins, %sizes) : (tensor<4xi32>, tensor<1xi64>, tensor<1xi64>) -> tensor<1xi32> + return %0 : tensor<1xi32> +} + +// ----- + // Valid StridedSlice operation. func @testStridedSlice(%input: tensor<4x8xf32>, %begin: tensor<2xi64>, %end: tensor<2xi64>, %strides: tensor<2xi64>) -> tensor { %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) : (tensor<4x8xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor @@ -2660,6 +2917,13 @@ func @testSplitV2(%input: tensor<4x4xf32>) { // ----- +func @testSplitVDynamic(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> (tensor, tensor) { + %0:2 = "tf.SplitV"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> (tensor, tensor) + return %0#0, %0#1 : tensor, tensor +} + +// ----- + //===--------------------------------------------------------------------===// // tf.All //===--------------------------------------------------------------------===// @@ -3165,6 +3429,125 @@ func @testBatchMatMulV2(%lhs: tensor<10x10xf32>, %rhs: tensor) { // ----- +// CHECK-LABEL: func @testBatchMatMulV2NoBatchDimension +func @testBatchMatMulV2NoBatchDimension(%lhs: tensor<5x10xf32>, %rhs: tensor<10x10xf32>) -> (tensor<5x10xf32>) { + %0 = "tf.BatchMatMulV2"(%lhs, %rhs) : (tensor<5x10xf32>, tensor<10x10xf32>) -> tensor<5x10xf32> + return %0 : tensor<5x10xf32> +} + +// ----- + +// CHECK-LABEL: func @testBatchMatMulV2ValidBroadcastingBatchDimension +func @testBatchMatMulV2ValidBroadcastingBatchDimension(%lhs: tensor<10x2x5x10xf32>, %rhs: tensor<10x10xf32>) -> (tensor<10x2x5x10xf32>) { + %0 = "tf.BatchMatMulV2"(%lhs, %rhs) : (tensor<10x2x5x10xf32>, tensor<10x10xf32>) -> tensor<10x2x5x10xf32> + return %0 : tensor<10x2x5x10xf32> +} + +// ----- + +// CHECK-LABEL: func @testBatchMatMulV2ValidMultiBatchDimension +func @testBatchMatMulV2ValidMultiBatchDimension(%lhs: tensor<4x5x1x3x2xf32>, %rhs: tensor<1x1x3x5xf32>) -> (tensor<4x5x1x2x5xf32>) { + %0 = "tf.BatchMatMulV2"(%lhs, %rhs) { adj_x = true } : (tensor<4x5x1x3x2xf32>, tensor<1x1x3x5xf32>) -> tensor<4x5x1x2x5xf32> + return %0 : tensor<4x5x1x2x5xf32> +} + +// ----- + +func @testBatchMatMulV2InvalidBroadcastingBatchDimensionWithHigherXRank(%lhs: tensor<10x2x5x10xf32>, %rhs: tensor<10x10x10xf32>) { + // expected-error @+1 {{found incompatible broadcast batch dimensions for lhs shape 'tensor<10x2x5x10xf32>' and rhs shape 'tensor<10x10x10xf32>'}} + %0 = "tf.BatchMatMulV2"(%lhs, %rhs) : (tensor<10x2x5x10xf32>, tensor<10x10x10xf32>) -> tensor<10x10xf32> +} + +// ----- + +func @testBatchMatMulV2InvalidBroadcastingBatchDimensionWithSameRank(%lhs: tensor<10x2x5x10xf32>, %rhs: tensor<10x10x10x10xf32>) { + // expected-error @+1 {{found incompatible broadcast batch dimensions for lhs shape 'tensor<10x2x5x10xf32>' and rhs shape 'tensor<10x10x10x10xf32>'}} + %0 = "tf.BatchMatMulV2"(%lhs, %rhs) : (tensor<10x2x5x10xf32>, tensor<10x10x10x10xf32>) -> tensor<10x10xf32> +} + +// ----- + +func @testBatchMatMulV2InvalidBroadcastingBatchDimensionWithHigherYRank(%lhs: tensor<2x5x10xf32>, %rhs: tensor<10x10x10x10xf32>) { + // expected-error @+1 {{found incompatible broadcast batch dimensions for lhs shape 'tensor<2x5x10xf32>' and rhs shape 'tensor<10x10x10x10xf32>'}} + %0 = "tf.BatchMatMulV2"(%lhs, %rhs) : (tensor<2x5x10xf32>, tensor<10x10x10x10xf32>) -> tensor<10x10xf32> +} + +// ----- + +func @testBatchMatMulV2InvalidOutputBatchDimension(%lhs: tensor<10x2x5x10xf32>, %rhs: tensor<2x10x10xf32>) { + // expected-error @+1 {{has mismatching input batch dimension 2 and output batch dimension 3}} + %0 = "tf.BatchMatMulV2"(%lhs, %rhs) : (tensor<10x2x5x10xf32>, tensor<2x10x10xf32>) -> tensor<10x3x10x10xf32> +} + +// ----- + +func @testBatchMatMulV2InvalidOutputRank(%lhs: tensor<10x2x5x10xf32>, %rhs: tensor<10x1x10x10xf32>) { + // expected-error @+1 {{found invalid output rank, expected 4 but got 3}} + %0 = "tf.BatchMatMulV2"(%lhs, %rhs) : (tensor<10x2x5x10xf32>, tensor<10x1x10x10xf32>) -> tensor<10x5x10xf32> +} + +// ----- + +func @testBatchMatMulV2InvalidOutputRowDim(%lhs: tensor<10x2x5x10xf32>, %rhs: tensor<10x10xf32>) { + // expected-error @+1 {{found invalid output dimension on row, expected 5 but got 10}} + %0 = "tf.BatchMatMulV2"(%lhs, %rhs) : (tensor<10x2x5x10xf32>, tensor<10x10xf32>) -> tensor<10x2x10x10xf32> +} + +// ----- + +func @testBatchMatMulV2AdjXInvalidOutputRowDim(%lhs: tensor<10x2x10x5xf32>, %rhs: tensor<10x10xf32>) { + // expected-error @+1 {{found invalid output dimension on row, expected 5 but got 10}} + %0 = "tf.BatchMatMulV2"(%lhs, %rhs) { adj_x = true } : (tensor<10x2x10x5xf32>, tensor<10x10xf32>) -> tensor<10x2x10x10xf32> +} + +// ----- + +func @testBatchMatMulV2InvalidOutputColDim(%lhs: tensor<10x2x5x10xf32>, %rhs: tensor<10x10xf32>) { + // expected-error @+1 {{found invalid output dimension on col, expected 10 but got 5}} + %0 = "tf.BatchMatMulV2"(%lhs, %rhs) : (tensor<10x2x5x10xf32>, tensor<10x10xf32>) -> tensor<10x2x5x5xf32> +} + +// ----- + +func @testBatchMatMulV2AdjYInvalidOutputColDim(%lhs: tensor<10x2x5x10xf32>, %rhs: tensor<4x10xf32>) { + // expected-error @+1 {{found invalid output dimension on col, expected 4 but got 10}} + %0 = "tf.BatchMatMulV2"(%lhs, %rhs) { adj_y = true } : (tensor<10x2x5x10xf32>, tensor<4x10xf32>) -> tensor<10x2x5x10xf32> +} + +// ----- + +// CHECK-LABEL: func @testBatchMatMulV2PartiallyKnownInputBatchDim +func @testBatchMatMulV2PartiallyKnownInputBatchDim(%lhs: tensor<4x5x?x3x2xf32>, %rhs: tensor<1x1x3x5xf32>) -> (tensor<4x5x?x2x5xf32>) { + %0 = "tf.BatchMatMulV2"(%lhs, %rhs) { adj_x = true } : (tensor<4x5x?x3x2xf32>, tensor<1x1x3x5xf32>) -> tensor<4x5x?x2x5xf32> + return %0 : tensor<4x5x?x2x5xf32> +} + +// ----- + +// CHECK-LABEL: func @testBatchMatMulV2PartiallyKnownMatmulDim +func @testBatchMatMulV2PartiallyKnownMatmulDim(%lhs: tensor<4x5x1x?x3xf32>, %rhs: tensor<1x1x3x5xf32>) -> (tensor<4x5x1x?x5xf32>) { + %0 = "tf.BatchMatMulV2"(%lhs, %rhs) : (tensor<4x5x1x?x3xf32>, tensor<1x1x3x5xf32>) -> tensor<4x5x1x?x5xf32> + return %0 : tensor<4x5x1x?x5xf32> +} + +// ----- + +func @testBatchMatMulV2InvalidPartiallyKnownMatmulDim(%lhs: tensor<4x5x1x?x3xf32>, %rhs: tensor<1x1x3x5xf32>) -> (tensor<4x5x1x?x3xf32>) { + // expected-error @+1 {{found invalid output dimension on col, expected 5 but got 3}} + %0 = "tf.BatchMatMulV2"(%lhs, %rhs) : (tensor<4x5x1x?x3xf32>, tensor<1x1x3x5xf32>) -> tensor<4x5x1x?x3xf32> + return %0 : tensor<4x5x1x?x3xf32> +} + +// ----- + +func @testBatchMatMulV2AdjXInvalidPartiallyKnownMatmulDim(%lhs: tensor<4x5x1x3x?xf32>, %rhs: tensor<1x1x3x5xf32>) -> (tensor<4x5x1x?x3xf32>) { + // expected-error @+1 {{found invalid output dimension on col, expected 5 but got 3}} + %0 = "tf.BatchMatMulV2"(%lhs, %rhs) { adj_x = true } : (tensor<4x5x1x3x?xf32>, tensor<1x1x3x5xf32>) -> tensor<4x5x1x?x3xf32> + return %0 : tensor<4x5x1x?x3xf32> +} + +// ----- + func @testDataFormatVecPermuteInvalid1dInput(%x: tensor<5xi32>) { // expected-error @+1 {{requires 1D input of size 4}} %0 = "tf.DataFormatVecPermute"(%x): (tensor<5xi32>) -> tensor<5xi32> @@ -3357,7 +3740,7 @@ func @branch0(tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> func @branch1(tensor<2xf32>) -> tensor<2xf32> func @testCaseMismatchedNumOperands(%arg0: tensor, %arg1: tensor<2xf32>) -> tensor<2xf32> { - // expected-error @+1 {{expects all branches to have 1 input(s), but branch #0 has 2 input(s)}} + // expected-error @+1 {{'tf.Case' op branch #0 inputs (size = 2) should have the same number of values as inputs (size = 1)}} %0 = "tf.Case"(%arg0, %arg1) {branches = [@branch0, @branch1], is_stateless = false} : (tensor, tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } @@ -3368,7 +3751,7 @@ func @branch0(tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) func @branch1(tensor<2xf32>) -> tensor<2xf32> func @testCaseMismatchedNumResults(%arg0: tensor, %arg1: tensor<2xf32>) -> tensor<2xf32> { - // expected-error @+1 {{expects all branches to have 1 result(s), but branch #0 has 2 result(s)}} + // expected-error @+1 {{'tf.Case' op branch #0 results (size = 2) should have the same number of values as results (size = 1)}} %0 = "tf.Case"(%arg0, %arg1) {branches = [@branch0, @branch1], is_stateless = false} : (tensor, tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } @@ -3379,7 +3762,7 @@ func @branch0(tensor<*xf16>) -> tensor<*xf32> func @branch1(tensor<*xf32>) -> tensor<*xf32> func @testCaseOperandNotCastCompatible(%arg0: tensor, %arg1: tensor<2xf32>) -> tensor<2xf32> { - // expected-error @+1 {{expects operand type 'tensor<2xf32>' to be cast compatible with branch #0 input type 'tensor<*xf16>' at index 0}} + // expected-error @+1 {{'tf.Case' op branch #0 input type tensor<*xf16> is incompatible with input type tensor<2xf32> at index 0}} %0 = "tf.Case"(%arg0, %arg1) {branches = [@branch0, @branch1], is_stateless = false} : (tensor, tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } @@ -3401,7 +3784,7 @@ func @branch0(tensor<*xf32>) -> tensor<*xf32> func @branch1(tensor<*xf32>) -> tensor<3xf32> func @testCaseResultNotCastCompatible(%arg0: tensor, %arg1: tensor<*xf32>) -> tensor<2xf32> { - // expected-error @+1 {{expects result type 'tensor<2xf32>' to be cast compatible with branch #1 result type 'tensor<3xf32>' at index 0}} + // expected-error @+1 {{'tf.Case' op branch #1 result type tensor<3xf32> is incompatible with result type tensor<2xf32> at index 0}} %0 = "tf.Case"(%arg0, %arg1) {branches = [@branch0, @branch1], is_stateless = false} : (tensor, tensor<*xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } @@ -3427,7 +3810,7 @@ func @testCaseRegionBadBranchIndicesShape(%arg0: tensor<8xi32>) { // ----- func @testCaseRegionMismatchedNumResults(%arg0: tensor) { - // expected-error @+1 {{region #0 should have same number (1) of results as tf.CaseRegion but has 0 results}} + // expected-error @+1 {{'tf.CaseRegion' op branch #0 results (size = 0) should have the same number of values as results (size = 1)}} %1 = "tf.CaseRegion"(%arg0) ( { "tf.Yield"() : () -> () }) {is_stateless = false} : (tensor) -> tensor @@ -3437,7 +3820,7 @@ func @testCaseRegionMismatchedNumResults(%arg0: tensor) { // ----- func @testCaseRegionMismatchedResultTypes(%arg0: tensor, %arg1: tensor) { - // expected-error @+1 {{region #0 result type tensor is incompatible with tf.CaseRegion result type tensor at index 0}} + // expected-error @+1 {{'tf.CaseRegion' op branch #0 result type tensor is incompatible with result type tensor at index 0}} %1 = "tf.CaseRegion"(%arg0) ( { "tf.Yield"(%arg1) : (tensor) -> () }) {is_stateless = false} : (tensor) -> tensor @@ -3468,3 +3851,92 @@ func @testCumprod(%arg: tensor<8x16xf32>) -> tensor<8x16xf32> { %0 = "tf.Cumprod"(%arg, %axis) : (tensor<8x16xf32>, tensor) -> tensor<8x16xf32> return %0 : tensor<8x16xf32> } + +// ----- + +func @testTile(%arg0: tensor<2x3x?xf32>) { + %cst = constant dense <[2, 3, 4]> : tensor<3xi32> + %0 = "tf.Tile"(%arg0, %cst) : (tensor<2x3x?xf32>, tensor<3xi32>) -> tensor<4x9x?xf32> + return +} + +// ----- + +func @testTileMultipleNotRank1(%arg0: tensor<2x3xf32>, %arg1: tensor<1x1xi32>) { + // expected-error @+1 {{expected multiples to be rank 1, got rank = 2}} + %0 = "tf.Tile"(%arg0, %arg1) : (tensor<2x3xf32>, tensor<1x1xi32>) -> tensor<2x3xf32> + return +} + +// ----- + +func @testTileInputRankNotEqualToMultiplesSize(%arg0: tensor<2x3xf32>, %arg1: tensor<3xi32>) { + // expected-error @+1 {{expected size of multiples equal to rank of input, got multiples of size 3, and input of rank 2}} + %0 = "tf.Tile"(%arg0, %arg1) : (tensor<2x3xf32>, tensor<3xi32>) -> tensor<2x3xf32> + return +} + +// ----- + +func @testTileInputRankNotEqualToOutputRank(%arg0: tensor<2x3xf32>, %arg1: tensor<2xi32>) { + // expected-error @+1 {{expected rank of input to equal to rank of output, got input of rank 2, and output of rank 3}} + %0 = "tf.Tile"(%arg0, %arg1) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<2x3x1xf32> + return +} + +// ----- + +func @testTileNegativeMultiples(%arg0: tensor<2x3xf32>) { + %cst = constant dense <[-1, 1]> : tensor<2xi32> + // expected-error @+1 {{expected multiples to be non-negative, got multiples[0] = -1}} + %0 = "tf.Tile"(%arg0, %cst) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<2x3xf32> + return +} + +// ----- + +func @testTileInvalidOutputShape(%arg0: tensor<2x3xf32>) { + %cst = constant dense <[2, 3]> : tensor<2xi32> + // expected-error @+1 {{requires input.shape[1] (3) * 3 to be equal to output.shape[1] (6)}} + %0 = "tf.Tile"(%arg0, %cst) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<4x6xf32> + return +} + +// ----- + +// Test reference variable support for some ops (no errors expected) + +// CHECK-LABEL: @testMaximumWithRef +func @testMaximumWithRef(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: tf.Maximum + %0 = "tf.Maximum"(%arg0, %arg1) : (tensor, tensor) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: @testAddV2WithRef +func @testAddV2WithRef(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: tf.AddV2 + %0 = "tf.AddV2"(%arg0, %arg1) : (tensor, tensor) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: @testRealDivWithRef +func @testRealDivWithRef(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: tf.RealDivOp + %0 = "tf.RealDivOp"(%arg0, %arg1) : (tensor, tensor) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: @testDivNoNanWithRef +func @testDivNoNanWithRef(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: tf.DivNoNanOp + %0 = "tf.DivNoNanOp"(%arg0, %arg1) : (tensor, tensor) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: @testAddWithRef +func @testAddWithRef(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: tf.Add + %0 = "tf.Add"(%arg0, %arg1) : (tensor, tensor) -> tensor + return %0 : tensor +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_device_ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_device_ops.mlir index 745cf72f959..f6f14c5be61 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_device_ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_device_ops.mlir @@ -40,6 +40,19 @@ func @empty_replicate() { // CHECK-NEXT: tf_device.return } +// CHECK-LABEL: func @no_operand_replicate +func @no_operand_replicate() { + tf_device.replicate {n = 2 : i32} { + %0 = "tf.Const"() { value = dense<0> : tensor } : () -> tensor + %1 = "tf.Const"() { value = dense<1> : tensor } : () -> tensor + tf_device.return %0, %1 : tensor, tensor + } + return + // CHECK: tf_device.replicate + // CHECK-SAME: n = 2 + // CHECK: tf_device.return +} + // CHECK-LABEL: func @replicate_with_multiple_operands func @replicate_with_multiple_operands() { %0 = "tf.opA"() : () -> tensor<*xi1> diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops.mlir index 1e537880620..23a8e904ad9 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops.mlir @@ -433,7 +433,7 @@ func @nextiteration(%arg0: tensor<*xf32>, %arg1: i1) -> tensor<*xf32> { %1:3 = tf_executor.NextIteration.Source : tensor<*xf32> tf_executor.NextIteration.Sink[%1#1] %1#0 : tensor<*xf32> // CHECK: tf_executor.NextIteration.Source : tensor<*xf32> -// CHECK: tf_executor.NextIteration.Sink [%{{.*}}] %{{.*}} : tensor<*xf32> +// CHECK: tf_executor.NextIteration.Sink[%{{.*}}] %{{.*}} : tensor<*xf32> tf_executor.fetch %1#0 : tensor<*xf32> } return %0 : tensor<*xf32> @@ -445,7 +445,7 @@ func @nextiteration_with_attributes(%arg0: tensor<*xf32>, %arg1: i1) -> tensor<* %1:3 = tf_executor.NextIteration.Source : tensor<*xf32> {attr3 = 32 : i64, tf_executor.attr_fetch = "some_value"} tf_executor.NextIteration.Sink[%1#1] %1#0 : tensor<*xf32> {attr4 = 42 : i64, tf_executor.attr_push = "other_value"} // CHECK: tf_executor.NextIteration.Source : tensor<*xf32> {attr3 = 32 : i64, tf_executor.attr_fetch = "some_value"} -// CHECK: tf_executor.NextIteration.Sink [%{{.*}}] %{{.*}} : tensor<*xf32> {attr4 = 42 : i64, tf_executor.attr_push = "other_value"} +// CHECK: tf_executor.NextIteration.Sink[%{{.*}}] %{{.*}} : tensor<*xf32> {attr4 = 42 : i64, tf_executor.attr_push = "other_value"} tf_executor.fetch %1#0 : tensor<*xf32> } return %0 : tensor<*xf32> @@ -457,9 +457,9 @@ func @nextiteration_control(%arg0: tensor<*xf32>, %arg1: tensor) -> tensor<* %1:3 = tf_executor.Switch %arg0, %arg1 : tensor<*xf32> %2:2 = tf_executor.Enter %arg0, %1#2, %1#2 frame "some/frame" : tensor<*xf32> %3:3 = tf_executor.NextIteration.Source : tensor<*xf32> - tf_executor.NextIteration.Sink [%3#1] %3#0, %1#2 : tensor<*xf32> + tf_executor.NextIteration.Sink[%3#1] %3#0, %1#2 : tensor<*xf32> // CHECK: tf_executor.NextIteration.Source : tensor<*xf32> -// CHECK: tf_executor.NextIteration.Sink [%{{.*}}] %{{.*}}, %{{.*}} : tensor<*xf32> +// CHECK: tf_executor.NextIteration.Sink[%{{.*}}] %{{.*}}, %{{.*}} : tensor<*xf32> tf_executor.fetch %3#0 : tensor<*xf32> } return %0 : tensor<*xf32> diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/BUILD b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/BUILD index 318f0422231..8ba18215ab5 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/BUILD @@ -1,3 +1,4 @@ +load("//tensorflow:tensorflow.bzl", "filegroup") load("//tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model:build_defs.bzl", "tf_saved_model_test") package( diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/structured_output.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/structured_output.py new file mode 100644 index 00000000000..a6d78d4693b --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/structured_output.py @@ -0,0 +1,125 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# RUN: %p/structured_output | FileCheck %s + +# pylint: disable=missing-docstring,line-too-long +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow.compat.v2 as tf +from tensorflow.compiler.mlir.tensorflow.tests.tf_saved_model import common + + +class TestModule(tf.Module): + # The fNNNN name prefixes in this file are such that the sorted order of the + # functions in the resulting MLIR output match the order in the source file, + # allowing us to conveniently co-locate the CHECK's with the code they are + # checking. + # + # Note: CHECK-DAG doesn't work with CHECK-SAME/CHECK-NEXT. + + # Check index paths for results. + # + # CHECK: func {{@[a-zA-Z_0-9]+}}() -> ( + # CHECK-SAME: tensor<1xf32> {tf_saved_model.index_path = []}) + # CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["f0000_single_return"] + @tf.function(input_signature=[]) + def f0000_single_return(self): + return tf.constant(1.0, shape=[1]) + + # Check index paths for results with multiple return values. + # Note that semantically in Python, multiple return values are equivalent + # to returning a tuple/list. + # + # CHECK: func {{@[a-zA-Z_0-9]+}}() -> ( + # CHECK-SAME: tensor<1xf32> {tf_saved_model.index_path = [0]}, + # CHECK-SAME: tensor<2xf32> {tf_saved_model.index_path = [1]}) + # CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["f0001_multiple_results_no_punctuation"] + @tf.function(input_signature=[]) + def f0001_multiple_results_no_punctuation(self): + return tf.constant(1.0, shape=[1]), tf.constant(1.0, shape=[2]) + + # Check index paths for results written explicitly with parentheses. + # This is semantically equivalent to the earlier test without parentheses, + # but this test serves as documentation of this behavior for the purposes + # of tf_saved_model users. + # + # CHECK: func {{@[a-zA-Z_0-9]+}}() -> ( + # CHECK-SAME: tensor<1xf32> {tf_saved_model.index_path = [0]}, + # CHECK-SAME: tensor<2xf32> {tf_saved_model.index_path = [1]}) + # CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["f0002_multiple_results_parentheses"] + @tf.function(input_signature=[]) + def f0002_multiple_results_parentheses(self): + return (tf.constant(1.0, shape=[1]), tf.constant(1.0, shape=[2])) + + # Check index paths for results written explicitly with brackets. + # This is semantically equivalent to the earlier test without parentheses, + # but this test serves as documentation of this behavior for the purposes + # of tf_saved_model users. + # + # CHECK: func {{@[a-zA-Z_0-9]+}}() -> ( + # CHECK-SAME: tensor<1xf32> {tf_saved_model.index_path = [0]}, + # CHECK-SAME: tensor<2xf32> {tf_saved_model.index_path = [1]}) + # CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["f0003_multiple_results_brackets"] + @tf.function(input_signature=[]) + def f0003_multiple_results_brackets(self): + return [tf.constant(1.0, shape=[1]), tf.constant(1.0, shape=[2])] + + # Check index paths for lists. + # + # CHECK: func {{@[a-zA-Z_0-9]+}}() -> ( + # CHECK-SAME: tensor<1xf32> {tf_saved_model.index_path = [0, 0]}, + # CHECK-SAME: tensor<2xf32> {tf_saved_model.index_path = [0, 1]}) + # CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["f0004_list_2_elements"] + @tf.function(input_signature=[]) + def f0004_list_2_elements(self): + return [[tf.constant(1.0, shape=[1]), tf.constant(1.0, shape=[2])]] + + # Check index paths for dicts. + # Keys are linearized in sorted order, matching `tf.nest.flatten`. + # More thorough testing of this is in structured_input.py. The underlying code + # path for linearization is shared, so no need to replicate that testing here. + # + # CHECK: func {{@[a-zA-Z_0-9]+}}() -> ( + # CHECK-SAME: tensor<1xf32> {tf_saved_model.index_path = ["x"]}, + # CHECK-SAME: tensor<2xf32> {tf_saved_model.index_path = ["y"]}) + # CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["f0005_dict_2_keys"] + @tf.function(input_signature=[]) + def f0005_dict_2_keys(self): + return { + 'x': tf.constant(1.0, shape=[1]), + 'y': tf.constant(1.0, shape=[2]), + } + + # Check index paths for outputs are correctly handled in the presence of + # multiple return statements. + # + # CHECK: func {{@[a-zA-Z_0-9]+}}( + # CHECK-SAME: %arg0: tensor {tf._user_specified_name = "x", tf_saved_model.index_path = [0]} + # CHECK-SAME: ) -> ( + # CHECK-SAME: tensor<1xf32> {tf_saved_model.index_path = ["x"]}) + # CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["f0006_multiple_return_statements"] + @tf.function(input_signature=[tf.TensorSpec([], tf.float32)]) + def f0006_multiple_return_statements(self, x): + if x > 3.: + return {'x': tf.constant(1.0, shape=[1])} + else: + return {'x': tf.constant(1.0, shape=[1])} + + +if __name__ == '__main__': + common.do_test(TestModule) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_lift_variables.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_lift_variables.mlir index 84b4f97d4eb..ea2ebc64a29 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_lift_variables.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_lift_variables.mlir @@ -59,3 +59,25 @@ module attributes {tf_saved_model.semantics, tf_saved_model.under_construction} // CHECK: %arg0: tensor>> {tf_saved_model.bound_input = @"dense/kernel"}, // CHECK: %arg1: tensor>> {tf_saved_model.bound_input = @"dense/bias"}) } + +// ----- + +module attributes {tf_saved_model.semantics, tf_saved_model.under_construction} { + + // Test case: Fix bound_inputs' types. + + func @serving_default(%arg0: tensor>> {tf.resource_name = "dense/kernel"}, %arg1: tensor>> {tf.resource_name = "dense/bias"}) -> (tensor<*xf32> {tf_saved_model.index_path = ["dense_2"]}) + attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = "dense_2/Add:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %0 = "tf.ReadVariableOp"(%arg0) {device = ""} : (tensor>>) -> tensor<*xf32> + %1 = "tf.ReadVariableOp"(%arg1) {device = ""} : (tensor>>) -> tensor<*xf32> + %2 = "tf.Add"(%0, %1) {device = ""} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + return %2 : tensor<*xf32> + } + // CHECK: "tf_saved_model.global_tensor"() + // CHECK: sym_name = "dense/kernel" + // CHECK: "tf_saved_model.global_tensor"() + // CHECK: sym_name = "dense/bias" + // CHECK: func @serving_default( + // CHECK: %arg0: tensor>> {tf_saved_model.bound_input = @"dense/kernel"}, + // CHECK: %arg1: tensor>> {tf_saved_model.bound_input = @"dense/bias"}) +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu-cluster-cleanup-attributes.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu-cluster-cleanup-attributes.mlir new file mode 100644 index 00000000000..6399d7d6fb0 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu-cluster-cleanup-attributes.mlir @@ -0,0 +1,24 @@ +// RUN: tf-opt %s -tf-tpu-cleanup-cluster-attributes | FileCheck %s + +func @test(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK: "tf_device.cluster" + // CHECK-NOT: _tpu_replicate = + // CHECK-NOT: device = + %1 = "tf_device.cluster"() ( { + %2 = "tf.Add"(%arg1, %arg1) : (tensor, tensor) -> tensor + %3 = "tf.IfRegion"(%arg0) ({ + %4 = "tf.Mul" (%arg1, %2) {device = "y"}: (tensor, tensor) -> tensor + "tf.Yield"(%4) : (tensor) -> () + }, { + %5 = "tf.Div" (%arg1, %2) : (tensor, tensor) -> tensor + "tf.Yield"(%5) : (tensor) -> () + }) {is_stateless = true, _tpu_replicate = "x" } : (tensor) -> (tensor) + tf_device.return %3 : tensor + // CHECK: {_tpu_replicate = "x", cluster_attr = "cluster_attr", device = "y"} + }) {cluster_attr = "cluster_attr", _tpu_replicate = "x", device = "y"} : () -> tensor + // CHECK: "tf.Add" + // CHECK-SAME: {_tpu_replicate = "x", device = "y"} + %2 = "tf.Add"(%arg2, %1) {_tpu_replicate = "x", device = "y"} : (tensor, tensor) -> tensor + // CHECK: return + return %2 : tensor +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu-resource-read-for-write.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu-resource-read-for-write.mlir new file mode 100644 index 00000000000..a505a4e3269 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu-resource-read-for-write.mlir @@ -0,0 +1,64 @@ +// RUN: tf-opt -tf-tpu-resource-read-for-write %s | FileCheck %s --dump-input=always + +// CHECK-LABEL: func @write_only_resource +// CHECK-SAME: ([[ARG0:%.*]]: tensor, [[ARG1:%.*]]: tensor, [[ARG2:%.*]]: tensor<*x!tf.resource>>) +func @write_only_resource(%arg0: tensor, %arg1: tensor, %arg2: tensor<*x!tf.resource>>) { + // CHECK-NEXT: [[READ:%.*]] = "tf.ReadVariableOp"([[ARG2]]) + // CHECK-NEXT: [[CLUSTER:%.*]]:2 = "tf_device.cluster_func"([[ARG0]], [[ARG1]], [[READ]]) + // CHECK-SAME: _tpu_replicate = "write" + %0:2 = "tf_device.cluster_func"(%arg0, %arg1) {_tpu_replicate = "write", func = @write_func} : (tensor, tensor) -> (tensor, tensor) + // CHECK-NEXT: "tf.AssignVariableOp"([[ARG2]], [[CLUSTER]]#1) + "tf.AssignVariableOp"(%arg2, %0#1) : (tensor<*x!tf.resource>>, tensor) -> () + // CHECK-NEXT: return + return +} + +// CHECK-LABEL: func @write_func +// CHECK-SAME: ({{%.*}}: tensor, {{%.*}}: tensor, {{%.*}}: tensor) -> (tensor, tensor) +func @write_func(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { + return %arg1, %arg0 : tensor, tensor +} + +// CHECK-LABEL: func @read_write_resource +func @read_write_resource(%arg0: tensor, %arg1: tensor, %arg2: tensor<*x!tf.resource>>) { + // CHECK-COUNT-1: tf.ReadVariableOp + %0 = "tf.ReadVariableOp"(%arg2) : (tensor<*x!tf.resource>>) -> tensor + %1:2 = "tf_device.cluster_func"(%arg0, %arg1, %0) {_tpu_replicate = "read_write", func = @read_write_func} : (tensor, tensor, tensor) -> (tensor, tensor) + "tf.AssignVariableOp"(%arg2, %1#1) : (tensor<*x!tf.resource>>, tensor) -> () + return +} + +// CHECK-LABEL: func @read_write_func +// CHECK-SAME: ({{%.*}}: tensor, {{%.*}}: tensor) -> (tensor, tensor) +func @read_write_func(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { + return %arg1, %arg0 : tensor, tensor +} + +// CHECK-LABEL: func @multiple_write_resource +func @multiple_write_resource(%arg0: tensor, %arg1: tensor<*x!tf.resource>>) { + // CHECK-NOT: tf.ReadVariableOp + %0:2 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "multiple_write", func = @multiple_write_func} : (tensor) -> (tensor, tensor) + "tf.AssignVariableOp"(%arg1, %0#0) : (tensor<*x!tf.resource>>, tensor) -> () + "tf.AssignVariableOp"(%arg1, %0#1) : (tensor<*x!tf.resource>>, tensor) -> () + return +} + +// CHECK-LABEL: func @multiple_write_func +// CHECK-SAME: ({{%.*}}: tensor) -> (tensor, tensor) +func @multiple_write_func(%arg0: tensor) -> (tensor, tensor) { + return %arg0, %arg0 : tensor, tensor +} + +// CHECK-LABEL: func @multiple_result_user +func @multiple_result_user(%arg0: tensor, %arg1: tensor<*x!tf.resource>>) -> tensor { + // CHECK-NOT: tf.ReadVariableOp + %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "multiple_uses", func = @multiple_result_user_func} : (tensor) -> tensor + "tf.AssignVariableOp"(%arg1, %0) : (tensor<*x!tf.resource>>, tensor) -> () + return %0 : tensor +} + +// CHECK-LABEL: func @multiple_result_user_func +// CHECK-SAME: ({{%.*}}: tensor) -> tensor +func @multiple_result_user_func(%arg0: tensor) -> tensor { + return %arg0 : tensor +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu-variable-runtime-reformatting.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu-variable-runtime-reformatting.mlir index 277e4a8415e..e87b83b0cdf 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu-variable-runtime-reformatting.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu-variable-runtime-reformatting.mlir @@ -2,29 +2,81 @@ // Tests that the pass can correctly transform a training loop with 2 replicas. +!tf_res_f32 = type tensor<*x!tf.resource>> +!tf_res_md_f32 = type tensor<*x!tf.resource>> // Multi-dim f32 + module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} { // CHECK-LABEL: func @main - func @main(%arg0: tensor<*x!tf.resource>> {tf.device = "/device:TPU:0"}, - %arg1: tensor<*x!tf.resource>> {tf.device = "/device:TPU:1"}, - %arg2: tensor<*x!tf.resource>> {tf.device = "/device:TPU:0"}, - %arg3: tensor<*x!tf.resource>> {tf.device = "/device:TPU:1"}) { + // CHECK-SAME: %[[ARG0:.*]]: tensor<*x!tf.resource>> {tf.device = "/device:TPU:0"}, + // CHECK-SAME: %[[ARG1:.*]]: tensor<*x!tf.resource>> {tf.device = "/device:TPU:1"}, + // CHECK-SAME: %[[ARG2:.*]]: tensor<*x!tf.resource>> {tf.device = "/device:TPU:0"}, + // CHECK-SAME: %[[ARG3:.*]]: tensor<*x!tf.resource>> {tf.device = "/device:TPU:1"}) + func @main(%arg0: !tf_res_f32 {tf.device = "/device:TPU:0"}, + %arg1: !tf_res_f32 {tf.device = "/device:TPU:1"}, + %arg2: !tf_res_md_f32 {tf.device = "/device:TPU:0"}, + %arg3: !tf_res_md_f32 {tf.device = "/device:TPU:1"}) { %0 = "tf.Const"() {value = dense<100> : tensor} : () -> tensor // CHECK: %[[STATE0:.*]] = "tf.VarHandleOp"() // CHECK-SAME: device = "/device:TPU:0" // CHECK: %[[STATE1:.*]] = "tf.VarHandleOp"() // CHECK-SAME: device = "/device:TPU:1" - // CHECK: %[[WHILE:.*]]:7 = "tf.While"( - // CHECK-SAME: %[[STATE0]], %[[STATE1]]) - %1:5 = "tf.While"(%0, %arg0, %arg1, %arg2, %arg3) - {T = ["tfdtype$DT_INT32", "tfdtype$DT_RESOURCE", - "tfdtype$DT_RESOURCE", "tfdtype$DT_RESOURCE", - "tfdtype$DT_RESOURCE"], body = @while_body_7560, - cond = @while_cond_7550, device = "", is_stateless = false} - : (tensor, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, - tensor<*x!tf.resource>>, tensor<*x!tf.resource>>) - -> (tensor, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, - tensor<*x!tf.resource>>, tensor<*x!tf.resource>>) + // CHECK: %[[WHILE:.*]] = "tf.WhileRegion"( + %1 = "tf.WhileRegion"(%0) ( { + // Condition region + // CHECK: ^bb + // CHECK: "tf.Yield" + ^bb0(%carg0: tensor): + %c0 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %c1 = "tf.GreaterEqual"(%carg0, %0) {T = i32, device = ""} : (tensor, tensor) -> tensor + "tf.Yield"(%c1) : (tensor) -> () + }, { + // Body region + // CHECK: ^bb0 + ^bb0(%barg0: tensor): + %b0 = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor + %b1 = "tf.AddV2"(%barg0, %0) {T = i32, device = ""} : (tensor, tensor) -> tensor + // CHECK: %[[COMPILE:.*]]:2 = "tf_device.launch" + // CHECK-NEXT: "tf._TPUCompileMlir"() + %compile:2 = "tf_device.launch"() ( { + %b2:2 = "tf._TPUCompileMlir"() { + NumDynamicShapes = 0 : i64, + // The metadata encodes 2 parameter and two return values. + metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01", + mlir_module = "..."} : () -> (tensor, tensor<2x!tf.string>) + tf_device.return %b2#0, %b2#1 : tensor, tensor<2x!tf.string> + }) {device = "/device:CPU:0"} : () -> (tensor, tensor<2x!tf.string>) + "tf_device.launch"() ( { + "tf.TPUCompileSucceededAssert"(%compile#0) : (tensor) -> () + tf_device.return + }) {device = "/device:CPU:0"} : () -> () + // CHECK: tf_device.replicate + // CHECK-SAME: [%[[ARG0]], %[[ARG1]]] as %[[R0:.*]]: tensor<*x!tf.resource>>, + // CHECK-SAME: [%[[ARG2]], %[[ARG3]]] as %[[R1:.*]]: tensor<*x!tf.resource>>, + // CHECK-SAME: [%[[STATE0]], %[[STATE1]]] as %[[R_STATE:.*]]: tensor>> + // CHECK-SAME: devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"] + %rep:2 = tf_device.replicate([%arg0, %arg1] as %arg30: tensor<*x!tf.resource>>, + [%arg2, %arg3] as %arg31: tensor<*x!tf.resource>>) + {_mirrored_variable_indices = [0, 1], devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]}, n = 2 : i32} { + // CHECK: %[[ID:.*]] = "tf.Identity"(%[[R0]]) + %id = "tf.Identity"(%arg30) : (tensor<*x!tf.resource>>) -> tensor<*x!tf.resource>> + // CHECK: "tf_device.launch" + // CHECK-NEXT: "tf.TPUReshardVariables"(%[[ID]], %[[R1]], %[[COMPILE]]#1, %[[R_STATE]]) + // CHECK-NEXT: tf_device.return + // CHECK-NEXT: device = "TPU_REPLICATED_CORE_0" + // CHECK: "tf.TPUExecuteAndUpdateVariables"(%[[ID]], %[[R1]], %[[COMPILE]]#1) + "tf_device.launch"() ( { + "tf.TPUExecuteAndUpdateVariables"(%id, %arg31, %compile#1) + {device_var_reads_indices = [0, 1], device_var_updates_indices = [0, 1]} + : (tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, tensor<2x!tf.string>) -> () + tf_device.return + }) {device = "TPU_REPLICATED_CORE_0"} : () -> () + %ret = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + tf_device.return %ret : tensor + } + // CHECK: "tf.Yield" + "tf.Yield"(%b1) : (tensor) -> () + }) {device = "", is_stateless = false} : (tensor) -> (tensor) // CHECK: %[[DEFAULT:.*]] = "tf.Const"() // CHECK: tf_device.replicate // CHECK-SAME: as %[[V0:.*]]: tensor<*x!tf.resource>>, @@ -37,165 +89,72 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr // CHECK-NEXT: device = "TPU_REPLICATED_CORE_0" return } - // CHECK-LABEL: func @while_body_7560 - func @while_body_7560(%arg0: tensor, - %arg1: tensor<*x!tf.resource>> {tf.device = "/device:TPU:0"}, - %arg2: tensor<*x!tf.resource>> {tf.device = "/device:TPU:1"}, - %arg3: tensor<*x!tf.resource>> {tf.device = "/device:TPU:0"}, - %arg4: tensor<*x!tf.resource>> {tf.device = "/device:TPU:1"}) - -> (tensor, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, - tensor<*x!tf.resource>>, tensor<*x!tf.resource>>) { - // CHECK-SAME: (%[[ITER:.*]]: tensor, - // CHECK-SAME: %[[BODY_ARG1:.*]]: tensor<*x!tf.resource>> {tf.device = "/device:TPU:0"}, - // CHECK-SAME: %[[BODY_ARG2:.*]]: tensor<*x!tf.resource>> {tf.device = "/device:TPU:1"}, - // CHECK-SAME: %[[BODY_ARG3:.*]]: tensor<*x!tf.resource>> {tf.device = "/device:TPU:0"}, - // CHECK-SAME: %[[BODY_ARG4:.*]]: tensor<*x!tf.resource>> {tf.device = "/device:TPU:1"}, - // CHECK-SAME: %[[STATE_ARG0:.*]]: tensor>> {tf.device = "/device:TPU:0"}, - // CHECK-SAME: %[[STATE_ARG1:.*]]: tensor>> {tf.device = "/device:TPU:1"}) - %0 = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor - %1 = "tf.AddV2"(%arg0, %0) {T = i32, device = ""} : (tensor, tensor) -> tensor - // CHECK: %[[COMPILE:.*]]:2 = "tf_device.launch" - // CHECK-NEXT: "tf._TPUCompileMlir"() - %compile:2 = "tf_device.launch"() ( { - %2:2 = "tf._TPUCompileMlir"() { - NumDynamicShapes = 0 : i64, - // The metadata encodes 2 parameter and two return values. - metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01", - mlir_module = "..."} : () -> (tensor, tensor<2x!tf.string>) - tf_device.return %2#0, %2#1 : tensor, tensor<2x!tf.string> - }) {device = "/device:CPU:0"} : () -> (tensor, tensor<2x!tf.string>) - "tf_device.launch"() ( { - "tf.TPUCompileSucceededAssert"(%compile#0) : (tensor) -> () - tf_device.return - }) {device = "/device:CPU:0"} : () -> () - // CHECK: tf_device.replicate - // CHECK-SAME: [%[[BODY_ARG1]], %[[BODY_ARG2]]] as %[[R0:.*]]: tensor<*x!tf.resource>>, - // CHECK-SAME: [%[[BODY_ARG3]], %[[BODY_ARG4]]] as %[[R1:.*]]: tensor<*x!tf.resource>>, - // CHECK-SAME: [%[[STATE_ARG0]], %[[STATE_ARG1]]] as %[[R_STATE:.*]]: tensor>> - // CHECK-SAME: devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"] - %rep:2 = tf_device.replicate([%arg1, %arg2] as %arg30: tensor<*x!tf.resource>>, - [%arg3, %arg4] as %arg31: tensor<*x!tf.resource>>) - {_mirrored_variable_indices = [0, 1], devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]}, n = 2 : i32} { - // CHECK: %[[ID:.*]] = "tf.Identity"(%[[R0]]) - %id = "tf.Identity"(%arg30) : (tensor<*x!tf.resource>>) -> tensor<*x!tf.resource>> - // CHECK: "tf_device.launch" - // CHECK-NEXT: "tf.TPUReshardVariables"(%[[ID]], %[[R1]], %[[COMPILE]]#1, %[[R_STATE]]) - // CHECK-NEXT: tf_device.return - // CHECK-NEXT: device = "TPU_REPLICATED_CORE_0" - // CHECK: "tf.TPUExecuteAndUpdateVariables"(%[[ID]], %[[R1]], %[[COMPILE]]#1) - "tf_device.launch"() ( { - "tf.TPUExecuteAndUpdateVariables"(%id, %arg31, %compile#1) - {device_var_reads_indices = [0, 1], device_var_updates_indices = [0, 1]} - : (tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, tensor<2x!tf.string>) -> () - tf_device.return - }) {device = "TPU_REPLICATED_CORE_0"} : () -> () - %ret = "tf.Const"() {value = dense<0> : tensor} : () -> tensor - tf_device.return %ret : tensor - } - return %1, %arg1, %arg2, %arg3, %arg4 : tensor, tensor<*x!tf.resource>>, - tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, - tensor<*x!tf.resource>> - } - // CHECK-LABEL: func @while_cond_7550 - func @while_cond_7550(%arg0: tensor, - %arg1: tensor<*x!tf.resource>> {tf.device = "/device:TPU:0"}, - %arg2: tensor<*x!tf.resource>> {tf.device = "/device:TPU:1"}, - %arg3: tensor<*x!tf.resource>> {tf.device = "/device:TPU:0"}, - %arg4: tensor<*x!tf.resource>> {tf.device = "/device:TPU:1"}) - -> tensor { - %0 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor - %1 = "tf.GreaterEqual"(%arg0, %0) {T = i32, device = ""} : (tensor, tensor) -> tensor - return %1 : tensor - } } + // ----- // Tests that the pass does not format variables with other uses. +!tf_res_f32 = type tensor<*x!tf.resource>> +!tf_res_md_f32 = type tensor<*x!tf.resource>> // Multi-dim f32 + module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} { // CHECK-LABEL: func @main // CHECK-NOT: TPUReshardVariables - func @main(%arg0: tensor<*x!tf.resource>> {tf.device = "/device:TPU:0"}, - %arg1: tensor<*x!tf.resource>> {tf.device = "/device:TPU:1"}, - %arg2: tensor<*x!tf.resource>> {tf.device = "/device:TPU:0"}, - %arg3: tensor<*x!tf.resource>> {tf.device = "/device:TPU:1"}, - %arg4: tensor<*x!tf.resource>> {tf.device = "/device:TPU:0"}, - %arg5: tensor<*x!tf.resource>> {tf.device = "/device:TPU:1"}) { + func @main(%arg0: !tf_res_f32 {tf.device = "/device:TPU:0"}, + %arg1: !tf_res_f32 {tf.device = "/device:TPU:1"}, + %arg2: !tf_res_md_f32 {tf.device = "/device:TPU:0"}, + %arg3: !tf_res_md_f32 {tf.device = "/device:TPU:1"}, + %arg4: !tf_res_f32 {tf.device = "/device:TPU:1"}) { + %0 = "tf.Const"() {value = dense<100> : tensor} : () -> tensor - %1:7 = "tf.While"(%0, %arg0, %arg1, %arg2, %arg3, %arg4, %arg5) - {body = @while_body_7560, - cond = @while_cond_7550, device = "", is_stateless = false} - : (tensor, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, - tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, - tensor<*x!tf.resource>>, tensor<*x!tf.resource>>) - -> (tensor, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, - tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, - tensor<*x!tf.resource>>, tensor<*x!tf.resource>>) + %1 = "tf.WhileRegion"(%0) ( { + // Condition region + ^bb0(%carg0: tensor): + %c0 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %c1 = "tf.GreaterEqual"(%carg0, %0) {T = i32, device = ""} : (tensor, tensor) -> tensor + "tf._UnknownOp1_"(%arg1) : (!tf_res_f32) -> () + "tf.Yield"(%c1) : (tensor) -> () + }, { + // Body region + ^bb0(%barg0: tensor): + %b0 = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor + %b1 = "tf.AddV2"(%barg0, %0) {T = i32, device = ""} : (tensor, tensor) -> tensor + %compile:2 = "tf_device.launch"() ( { + %b2:2 = "tf._TPUCompileMlir"() { + NumDynamicShapes = 0 : i64, + // The metadata encodes 2 parameter and two return values. + metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01", + mlir_module = "..."} : () -> (tensor, tensor<2x!tf.string>) + tf_device.return %b2#0, %b2#1 : tensor, tensor<2x!tf.string> + }) {device = "/device:CPU:0"} : () -> (tensor, tensor<2x!tf.string>) + "tf_device.launch"() ( { + "tf.TPUCompileSucceededAssert"(%compile#0) : (tensor) -> () + tf_device.return + }) {device = "/device:CPU:0"} : () -> () + %id0 = "tf.Identity"(%arg3) : (!tf_res_md_f32) -> !tf_res_md_f32 + "tf._Unknown_"(%id0) : (!tf_res_md_f32) -> () + %newvar = "tf._SomeOp"() : () -> !tf_res_f32 + %rep:2 = tf_device.replicate([%arg0, %arg1] as %arg30: !tf_res_f32, + [%arg2, %arg3] as %arg31: !tf_res_md_f32, + [%newvar, %arg4] as %arg32 : !tf_res_f32) + {_mirrored_variable_indices = [0, 1, 2], devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]}, n = 2 : i32} { + // %arg30 is used in the cond function, %arg31 has other uses (%id0), and + // %arg32 is not a pass-through. + "tf_device.launch"() ( { + "tf.TPUExecuteAndUpdateVariables"(%arg30, %arg31, %arg32, %compile#1) + {device_var_reads_indices = [0, 1], device_var_updates_indices = [0, 1]} + : (!tf_res_f32, !tf_res_md_f32, !tf_res_f32, tensor<2x!tf.string>) -> () + tf_device.return + }) {device = "TPU_REPLICATED_CORE_0"} : () -> () + %ret = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + tf_device.return %ret : tensor + } + "tf.Yield"(%b1) : (tensor) -> () + }) {device = "", is_stateless = false} : (tensor) -> (tensor) return } - // CHECK-LABEL: func @while_body_7560 - // CHECK-NOT: TPUReshardVariables - func @while_body_7560(%arg0: tensor, - %arg1: tensor<*x!tf.resource>> {tf.device = "/device:TPU:0"}, - %arg2: tensor<*x!tf.resource>> {tf.device = "/device:TPU:1"}, - %arg3: tensor<*x!tf.resource>> {tf.device = "/device:TPU:0"}, - %arg4: tensor<*x!tf.resource>> {tf.device = "/device:TPU:1"}, - %arg5: tensor<*x!tf.resource>> {tf.device = "/device:TPU:0"}, - %arg6: tensor<*x!tf.resource>> {tf.device = "/device:TPU:1"}) - -> (tensor, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, - tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, - tensor<*x!tf.resource>>, tensor<*x!tf.resource>>) { - %0 = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor - %1 = "tf.AddV2"(%arg0, %0) {T = i32, device = ""} : (tensor, tensor) -> tensor - %compile:2 = "tf_device.launch"() ( { - %2:2 = "tf._TPUCompileMlir"() { - NumDynamicShapes = 0 : i64, - // The metadata encodes 2 parameter and two return values. - metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01", - mlir_module = "..."} : () -> (tensor, tensor<2x!tf.string>) - tf_device.return %2#0, %2#1 : tensor, tensor<2x!tf.string> - }) {device = "/device:CPU:0"} : () -> (tensor, tensor<2x!tf.string>) - "tf_device.launch"() ( { - "tf.TPUCompileSucceededAssert"(%compile#0) : (tensor) -> () - tf_device.return - }) {device = "/device:CPU:0"} : () -> () - %id0 = "tf.Identity"(%arg3) : (tensor<*x!tf.resource>>) -> tensor<*x!tf.resource>> - "tf._Unknown_"(%id0) : (tensor<*x!tf.resource>>) -> () - %newvar = "tf._SomeOp"() : () -> tensor<*x!tf.resource>> - tf_device.replicate([%arg1, %arg2] as %arg30: tensor<*x!tf.resource>>, - [%arg3, %arg4] as %arg31: tensor<*x!tf.resource>>, - [%newvar, %arg6] as %arg32: tensor<*x!tf.resource>>) - {_mirrored_variable_indices = [0, 1], devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]}, n = 2 : i32} { - // %arg30 is used in the cond function, %arg31 has other uses (%id0), and - // %arg32 is not a pass-through. - "tf_device.launch"() ( { - "tf.TPUExecuteAndUpdateVariables"(%arg30, %arg31, %arg32, %compile#1) - {device_var_reads_indices = [0, 1], device_var_updates_indices = [0, 1]} - : (tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, - tensor<*x!tf.resource>>, tensor<2x!tf.string>) -> () - tf_device.return - }) {device = "TPU_REPLICATED_CORE_0"} : () -> () - tf_device.return - } - return %1, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6 : tensor, tensor<*x!tf.resource>>, - tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, - tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, tensor<*x!tf.resource>> - } - // CHECK-LABEL: func @while_cond_7550 - func @while_cond_7550(%arg0: tensor, - %arg1: tensor<*x!tf.resource>> {tf.device = "/device:TPU:0"}, - %arg2: tensor<*x!tf.resource>> {tf.device = "/device:TPU:1"}, - %arg3: tensor<*x!tf.resource>> {tf.device = "/device:TPU:0"}, - %arg4: tensor<*x!tf.resource>> {tf.device = "/device:TPU:1"}, - %arg5: tensor<*x!tf.resource>> {tf.device = "/device:TPU:0"}, - %arg6: tensor<*x!tf.resource>> {tf.device = "/device:TPU:1"}) - -> tensor { - %0 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor - %1 = "tf.GreaterEqual"(%arg0, %0) {T = i32, device = ""} : (tensor, tensor) -> tensor - "tf._UnknownOp1_"(%arg1) : (tensor<*x!tf.resource>>) -> () - return %1 : tensor - } } // ----- @@ -203,81 +162,62 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr // Tests that the pass does not format variables when model parallelism is // present. +!tf_res_f32 = type tensor<*x!tf.resource>> +!tf_res_md_f32 = type tensor<*x!tf.resource>> // Multi-dim f32 + module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} { // CHECK-LABEL: func @main // CHECK-NOT: TPUReshardVariables - func @main(%arg0: tensor<*x!tf.resource>> {tf.device = "/device:TPU:0"}, - %arg1: tensor<*x!tf.resource>> {tf.device = "/device:TPU:1"}, - %arg2: tensor<*x!tf.resource>> {tf.device = "/device:TPU:0"}, - %arg3: tensor<*x!tf.resource>> {tf.device = "/device:TPU:1"}) { + func @main(%arg0: !tf_res_f32 {tf.device = "/device:TPU:0"}, + %arg1: !tf_res_f32 {tf.device = "/device:TPU:1"}, + %arg2: !tf_res_md_f32 {tf.device = "/device:TPU:0"}, + %arg3: !tf_res_md_f32 {tf.device = "/device:TPU:1"}) { %0 = "tf.Const"() {value = dense<100> : tensor} : () -> tensor - %1:5 = "tf.While"(%0, %arg0, %arg1, %arg2, %arg3) - {T = ["tfdtype$DT_INT32", "tfdtype$DT_RESOURCE", - "tfdtype$DT_RESOURCE", "tfdtype$DT_RESOURCE", - "tfdtype$DT_RESOURCE"], body = @while_body_7560, - cond = @while_cond_7550, device = "", is_stateless = false} - : (tensor, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, - tensor<*x!tf.resource>>, tensor<*x!tf.resource>>) - -> (tensor, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, - tensor<*x!tf.resource>>, tensor<*x!tf.resource>>) - return - } - // CHECK-LABEL: func @while_body_7560 - // CHECK-NOT: TPUReshardVariables - func @while_body_7560(%arg0: tensor, - %arg1: tensor<*x!tf.resource>> {tf.device = "/device:TPU:0"}, - %arg2: tensor<*x!tf.resource>> {tf.device = "/device:TPU:1"}, - %arg3: tensor<*x!tf.resource>> {tf.device = "/device:TPU:0"}, - %arg4: tensor<*x!tf.resource>> {tf.device = "/device:TPU:1"}) - -> (tensor, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, - tensor<*x!tf.resource>>, tensor<*x!tf.resource>>) { - %0 = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor - %1 = "tf.AddV2"(%arg0, %0) {T = i32, device = ""} : (tensor, tensor) -> tensor - %compile:2 = "tf_device.launch"() ( { - %2:2 = "tf._TPUCompileMlir"() { - NumDynamicShapes = 0 : i64, - // The metadata encodes 2 parameter and two return values. - metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01", - mlir_module = "..."} : () -> (tensor, tensor<2x!tf.string>) - tf_device.return %2#0, %2#1 : tensor, tensor<2x!tf.string> - }) {device = "/device:CPU:0"} : () -> (tensor, tensor<2x!tf.string>) - "tf_device.launch"() ( { - "tf.TPUCompileSucceededAssert"(%compile#0) : (tensor) -> () - tf_device.return - }) {device = "/device:CPU:0"} : () -> () - %rep:2 = tf_device.replicate([%arg1, %arg2] as %arg30: tensor<*x!tf.resource>>, - [%arg3, %arg4] as %arg31: tensor<*x!tf.resource>>) - {_mirrored_variable_indices = [0, 1], devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]}, n = 2 : i32} { - %id = "tf.Identity"(%arg30) : (tensor<*x!tf.resource>>) -> tensor<*x!tf.resource>> - "tf_device.parallel_execute"() ({ - "tf_device.launch"() ( { - "tf.TPUExecuteAndUpdateVariables"(%id, %arg31, %compile#1) - {device_var_reads_indices = [0, 1], device_var_updates_indices = [0, 1]} - : (tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, tensor<2x!tf.string>) -> () - tf_device.return - }) {device = "TPU_REPLICATED_CORE_0"} : () -> () - tf_device.return + %1 = "tf.WhileRegion"(%0) ( { + // Condition region + ^bb0(%carg0: tensor): + %c0 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %c1 = "tf.GreaterEqual"(%carg0, %0) {T = i32, device = ""} : (tensor, tensor) -> tensor + "tf.Yield"(%c1) : (tensor) -> () }, { - tf_device.return - }) {} : () -> () - %ret = "tf.Const"() {value = dense<0> : tensor} : () -> tensor - tf_device.return %ret : tensor - } - return %1, %arg1, %arg2, %arg3, %arg4 : tensor, tensor<*x!tf.resource>>, - tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, - tensor<*x!tf.resource>> - } - // CHECK-LABEL: func @while_cond_7550 - func @while_cond_7550(%arg0: tensor, - %arg1: tensor<*x!tf.resource>> {tf.device = "/device:TPU:0"}, - %arg2: tensor<*x!tf.resource>> {tf.device = "/device:TPU:1"}, - %arg3: tensor<*x!tf.resource>> {tf.device = "/device:TPU:0"}, - %arg4: tensor<*x!tf.resource>> {tf.device = "/device:TPU:1"}) - -> tensor { - %0 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor - %1 = "tf.GreaterEqual"(%arg0, %0) {T = i32, device = ""} : (tensor, tensor) -> tensor - return %1 : tensor + // Body region + ^bb0(%barg0: tensor): + %b0 = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor + %b1 = "tf.AddV2"(%barg0, %0) {T = i32, device = ""} : (tensor, tensor) -> tensor + %compile:2 = "tf_device.launch"() ( { + %b2:2 = "tf._TPUCompileMlir"() { + NumDynamicShapes = 0 : i64, + // The metadata encodes 2 parameter and two return values. + metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01", + mlir_module = "..."} : () -> (tensor, tensor<2x!tf.string>) + tf_device.return %b2#0, %b2#1 : tensor, tensor<2x!tf.string> + }) {device = "/device:CPU:0"} : () -> (tensor, tensor<2x!tf.string>) + "tf_device.launch"() ( { + "tf.TPUCompileSucceededAssert"(%compile#0) : (tensor) -> () + tf_device.return + }) {device = "/device:CPU:0"} : () -> () + %rep:2 = tf_device.replicate([%arg0, %arg1] as %arg30: tensor<*x!tf.resource>>, + [%arg2, %arg3] as %arg31: tensor<*x!tf.resource>>) + {_mirrored_variable_indices = [0, 1], devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]}, n = 2 : i32} { + %id = "tf.Identity"(%arg30) : (tensor<*x!tf.resource>>) -> tensor<*x!tf.resource>> + "tf_device.parallel_execute"() ({ + "tf_device.launch"() ( { + "tf.TPUExecuteAndUpdateVariables"(%id, %arg31, %compile#1) + {device_var_reads_indices = [0, 1], device_var_updates_indices = [0, 1]} + : (tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, tensor<2x!tf.string>) -> () + tf_device.return + }) {device = "TPU_REPLICATED_CORE_0"} : () -> () + tf_device.return + }, { + tf_device.return + }) {} : () -> () + %ret = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + tf_device.return %ret : tensor + } + "tf.Yield"(%b1) : (tensor) -> () + }) {device = "", is_stateless = false} : (tensor) -> (tensor) + return } } @@ -285,34 +225,83 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr // Tests that the pass can correctly transform a training loop with a packed // variable. +!tf_res_f32 = type tensor<*x!tf.resource>> +!tf_res_md_f32 = type tensor<*x!tf.resource>> // Multi-dim f32 module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} { // CHECK-LABEL: func @main - func @main(%arg0: tensor<*x!tf.resource>> {tf.device = "/device:TPU:0"}, - %arg1: tensor<*x!tf.resource>> {tf.device = "/device:TPU:1"}, - %arg2: tensor<*x!tf.resource>> {tf.device = "/device:COMPOSITE:0"}) { - + // CHECK-SAME: %[[ARG0:.*]]: tensor<*x!tf.resource>> {tf.device = "/device:TPU:0"}, + // CHECK-SAME: %[[ARG1:.*]]: tensor<*x!tf.resource>> {tf.device = "/device:TPU:1"}, + // CHECK-SAME: %[[ARG2:.*]]: tensor<*x!tf.resource>> {tf.device = "/device:COMPOSITE:0"}) + func @main(%arg0: !tf_res_f32 {tf.device = "/device:TPU:0"}, + %arg1: !tf_res_f32 {tf.device = "/device:TPU:1"}, + %arg2: !tf_res_md_f32 {tf.device = "/device:COMPOSITE:0"}) { %0 = "tf.Const"() {value = dense<100> : tensor} : () -> tensor // CHECK: %[[STATE0:.*]] = "tf.VarHandleOp"() // CHECK-SAME: device = "/device:TPU:0" // CHECK: %[[STATE1:.*]] = "tf.VarHandleOp"() // CHECK-SAME: device = "/device:TPU:1" - // CHECK: %[[WHILE:.*]]:6 = "tf.While"( - // CHECK-SAME: %[[STATE0]], %[[STATE1]]) - %1:4 = "tf.While"(%0, %arg0, %arg1, %arg2) - {T = ["tfdtype$DT_INT32", "tfdtype$DT_RESOURCE", - "tfdtype$DT_RESOURCE", "tfdtype$DT_RESOURCE"], - body = @while_body_7560, - cond = @while_cond_7550, device = "", is_stateless = false} - : (tensor, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, - tensor<*x!tf.resource>>) - -> (tensor, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, - tensor<*x!tf.resource>>) + // CHECK: %[[WHILE:.*]] = "tf.WhileRegion"( + %1 = "tf.WhileRegion"(%0) ( { + // Condition region + // CHECK: ^bb + // CHECK: "tf.Yield" + ^bb0(%carg0: tensor): + %c0 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %c1 = "tf.GreaterEqual"(%carg0, %0) {T = i32, device = ""} : (tensor, tensor) -> tensor + "tf.Yield"(%c1) : (tensor) -> () + }, { + // Body region + // CHECK: ^bb0 + ^bb0(%barg0: tensor): + %b0 = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor + %b1 = "tf.AddV2"(%barg0, %0) {T = i32, device = ""} : (tensor, tensor) -> tensor + // CHECK: %[[COMPILE:.*]]:2 = "tf_device.launch" + // CHECK-NEXT: "tf._TPUCompileMlir"() + %compile:2 = "tf_device.launch"() ( { + %b2:2 = "tf._TPUCompileMlir"() { + NumDynamicShapes = 0 : i64, + // The metadata encodes 2 parameter and two return values. + metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01", + mlir_module = "..."} : () -> (tensor, tensor<2x!tf.string>) + tf_device.return %b2#0, %b2#1 : tensor, tensor<2x!tf.string> + }) {device = "/device:CPU:0"} : () -> (tensor, tensor<2x!tf.string>) + "tf_device.launch"() ( { + "tf.TPUCompileSucceededAssert"(%compile#0) : (tensor) -> () + tf_device.return + }) {device = "/device:CPU:0"} : () -> () + // CHECK: tf_device.replicate + // CHECK-SAME: [%[[ARG0]], %[[ARG1]]] as %[[R0:.*]]: tensor<*x!tf.resource>>, + // CHECK-SAME: [%[[STATE0]], %[[STATE1]]] as %[[R_STATE:.*]]: tensor>> + // CHECK-SAME: %[[ARG2]] as %[[R1:.*]]: tensor<*x!tf.resource>> + // CHECK-SAME: devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"] + %rep:2 = tf_device.replicate([%arg0, %arg1] as %arg30: tensor<*x!tf.resource>>, + %arg2 as %arg31: tensor<*x!tf.resource>>) + {_mirrored_variable_indices = [0, 1], _packed_input_indices = [1], devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]}, n = 2 : i32} { + // CHECK: %[[ID:.*]] = "tf.Identity"(%[[R0]]) + %id = "tf.Identity"(%arg30) : (tensor<*x!tf.resource>>) -> tensor<*x!tf.resource>> + // CHECK: "tf_device.launch" + // CHECK-NEXT: "tf.TPUReshardVariables"(%[[ID]], %[[R1]], %[[COMPILE]]#1, %[[R_STATE]]) + // CHECK-NEXT: tf_device.return + // CHECK-NEXT: device = "TPU_REPLICATED_CORE_0" + // CHECK: "tf.TPUExecuteAndUpdateVariables"(%[[ID]], %[[R1]], %[[COMPILE]]#1) + "tf_device.launch"() ( { + "tf.TPUExecuteAndUpdateVariables"(%id, %arg31, %compile#1) + {device_var_reads_indices = [0, 1], device_var_updates_indices = [0, 1]} + : (tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, tensor<2x!tf.string>) -> () + tf_device.return + }) {device = "TPU_REPLICATED_CORE_0"} : () -> () + %ret = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + tf_device.return %ret : tensor + } + // CHECK: "tf.Yield" + "tf.Yield"(%b1) : (tensor) -> () + }) {device = "", is_stateless = false} : (tensor) -> (tensor) // CHECK: %[[DEFAULT:.*]] = "tf.Const"() // CHECK: tf_device.replicate - // CHECK-SAME: as %[[V0:.*]]: tensor<*x!tf.resource>>, - // CHECK-SAME: [%[[STATE0]], %[[STATE1]]] as %[[STATE:.*]]: tensor>>, - // CHECK-SAME: as %[[V1:.*]]: tensor<*x!tf.resource>> + // CHECK-SAME: [%[[ARG0]], %[[ARG1]]] as %[[V0:.*]]: tensor<*x!tf.resource>>, + // CHECK-SAME: [%[[STATE0]], %[[STATE1]]] as %[[STATE:.*]]: tensor>> + // CHECK-SAME: %[[ARG2]] as %[[V1:.*]]: tensor<*x!tf.resource>> // CHECK-SAME: devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"] // CHECK: "tf_device.launch" // CHECK-NEXT: "tf.TPUReshardVariables"(%[[V0]], %[[V1]], %[[DEFAULT]], %[[STATE]]) @@ -320,70 +309,4 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr // CHECK-NEXT: device = "TPU_REPLICATED_CORE_0" return } - // CHECK-LABEL: func @while_body_7560 - func @while_body_7560(%arg0: tensor, - %arg1: tensor<*x!tf.resource>> {tf.device = "/device:TPU:0"}, - %arg2: tensor<*x!tf.resource>> {tf.device = "/device:TPU:1"}, - %arg3: tensor<*x!tf.resource>> {tf.device = "/device:COMPOSITE:0"}) - -> (tensor, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, - tensor<*x!tf.resource>>) { - // CHECK-SAME: (%[[ITER:.*]]: tensor, - // CHECK-SAME: %[[BODY_ARG1:.*]]: tensor<*x!tf.resource>> {tf.device = "/device:TPU:0"}, - // CHECK-SAME: %[[BODY_ARG2:.*]]: tensor<*x!tf.resource>> {tf.device = "/device:TPU:1"}, - // CHECK-SAME: %[[BODY_ARG3:.*]]: tensor<*x!tf.resource>> {tf.device = "/device:COMPOSITE:0"}, - // CHECK-SAME: %[[STATE_ARG0:.*]]: tensor>> {tf.device = "/device:TPU:0"}, - // CHECK-SAME: %[[STATE_ARG1:.*]]: tensor>> {tf.device = "/device:TPU:1"}) - %0 = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor - %1 = "tf.AddV2"(%arg0, %0) {T = i32, device = ""} : (tensor, tensor) -> tensor - // CHECK: %[[COMPILE:.*]]:2 = "tf_device.launch" - // CHECK-NEXT: "tf._TPUCompileMlir"() - %compile:2 = "tf_device.launch"() ( { - %2:2 = "tf._TPUCompileMlir"() { - NumDynamicShapes = 0 : i64, - // The metadata encodes 2 parameter and two return values. - metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01", - mlir_module = "..."} : () -> (tensor, tensor<2x!tf.string>) - tf_device.return %2#0, %2#1 : tensor, tensor<2x!tf.string> - }) {device = "/device:CPU:0"} : () -> (tensor, tensor<2x!tf.string>) - "tf_device.launch"() ( { - "tf.TPUCompileSucceededAssert"(%compile#0) : (tensor) -> () - tf_device.return - }) {device = "/device:CPU:0"} : () -> () - // CHECK: tf_device.replicate - // CHECK-SAME: [%[[BODY_ARG1]], %[[BODY_ARG2]]] as %[[R0:.*]]: tensor<*x!tf.resource>>, - // CHECK-SAME: [%[[STATE_ARG0]], %[[STATE_ARG1]]] as %[[R_STATE:.*]]: tensor>>, - // CHECK-SAME: %[[BODY_ARG3]] as %[[R1:.*]]: tensor<*x!tf.resource>> - // CHECK-SAME: devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"] - %rep:2 = tf_device.replicate([%arg1, %arg2] as %arg30: tensor<*x!tf.resource>>, - %arg3 as %arg31: tensor<*x!tf.resource>>) - {_mirrored_variable_indices = [0, 1], _packed_input_indices = [1], devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]}, n = 2 : i32} { - // CHECK: %[[ID:.*]] = "tf.Identity"(%[[R0]]) - %id = "tf.Identity"(%arg30) : (tensor<*x!tf.resource>>) -> tensor<*x!tf.resource>> - // CHECK: "tf_device.launch" - // CHECK-NEXT: "tf.TPUReshardVariables"(%[[ID]], %[[R1]], %[[COMPILE]]#1, %[[R_STATE]]) - // CHECK-NEXT: tf_device.return - // CHECK-NEXT: device = "TPU_REPLICATED_CORE_0" - // CHECK: "tf.TPUExecuteAndUpdateVariables"(%[[ID]], %[[R1]], %[[COMPILE]]#1) - "tf_device.launch"() ( { - "tf.TPUExecuteAndUpdateVariables"(%id, %arg31, %compile#1) - {device_var_reads_indices = [0, 1], device_var_updates_indices = [0, 1]} - : (tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, tensor<2x!tf.string>) -> () - tf_device.return - }) {device = "TPU_REPLICATED_CORE_0"} : () -> () - %ret = "tf.Const"() {value = dense<0> : tensor} : () -> tensor - tf_device.return %ret : tensor - } - return %1, %arg1, %arg2, %arg3 : tensor, tensor<*x!tf.resource>>, - tensor<*x!tf.resource>>, tensor<*x!tf.resource>> - } - // CHECK-LABEL: func @while_cond_7550 - func @while_cond_7550(%arg0: tensor, - %arg1: tensor<*x!tf.resource>> {tf.device = "/device:TPU:0"}, - %arg2: tensor<*x!tf.resource>> {tf.device = "/device:TPU:1"}, - %arg3: tensor<*x!tf.resource>> {tf.device = "/device:COMPOSITE:0"}) - -> tensor { - %0 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor - %1 = "tf.GreaterEqual"(%arg0, %0) {T = i32, device = ""} : (tensor, tensor) -> tensor - return %1 : tensor - } } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_cluster_formation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_cluster_formation.mlir index 978f6e74aa8..3c2344be1e4 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_cluster_formation.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_cluster_formation.mlir @@ -380,6 +380,153 @@ func @resource_before_cluster() { } +// Test cluster formation with ops with attached regions within a cluster. +// Nested op's that are moved should get their _tpu_replicate and device +// attributes cleared. +// CHECK-LABEL: func @cluster_ops_with_regions +func @cluster_ops_with_regions() { + %0 = "tf.opA"() ({ + %1 = "tf.opB"() {_tpu_replicate = "replicate", device = "device", name = "nameB"} : () -> (tensor) + }) {_tpu_replicate = "replicate", device = "device", name = "nameA"} : () -> tensor + "tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 1, topology = "topology"} : () -> () + return +} + +// CHECK: "tf.opA"() ( { +// CHECK-NEXT: "tf.opB" +// CHECK-NOT: _tpu_replicate = "replicate" +// CHECK-NOT: device = "device" +// CHECK-SAME: name = "nameB" +// CHECK: }) +// CHECK-NOT: _tpu_replicate = "replicate" +// CHECK-NOT: device = "device" +// CHECK: name = "nameA" +// CHECK: tf_device.return + +// A nested cluster op using result of another cluster op. In the below, opA and +// opB go in a cluster, and opD stays outside. +// CHECK-LABEL: func @cluster_nested_op_using_other_op +func @cluster_nested_op_using_other_op() { + %0 = "tf.opA"() { _tpu_replicate = "foo" } : () -> tensor + "tf.opB"() ({ + "tf.opC"(%0) : (tensor) -> () + }) { _tpu_replicate = "foo" } : () -> () + "tf.opD"(%0) : (tensor) -> () + "tf.TPUReplicateMetadata"() {_tpu_replicate = "foo", device = "CPU", num_replicas = 1, topology = "topology"} : () -> () + return +} + +// CHECK: [[CLUSTER:%.*]] = "tf_device.cluster"() ( { +// CHECK: [[OPA:%.*]] = "tf.opA"() : () -> tensor +// CHECK: "tf.opB"() ( { +// CHECK: "tf.opC"([[OPA]]) +// CHECK: tf_device.return [[OPA]] +// CHECK: "tf.opD"([[CLUSTER]]) + +// Preceding user is using resource updated by a nested op. +!tf_res = type tensor<*x!tf.resource>> +// CHECK-LABEL: func @cluster_nested_op_updating_resource +func @cluster_nested_op_updating_resource() { + %0 = "tf.Const"() {value = dense<1.000000e+00> : tensor} : () -> tensor + %1 = "tf.VarHandleOp"() {container = "", shape = #tf.shape<>, shared_name = "x"} : () -> !tf_res + + "tf.opA"() ({ + "tf.AssignAddVariableOp"(%1, %0) : (!tf_res, tensor) -> () + "tf.terminator"() : () -> () + }) { _tpu_replicate = "foo" } : () -> () + "tf.AssignAddVariableOp"(%1, %0) : (!tf_res, tensor) -> () + "tf.opB"() { _tpu_replicate = "foo" } : () -> () + "tf.TPUReplicateMetadata"() {_tpu_replicate = "foo", device = "CPU", num_replicas = 1, topology = "topology"} : () -> () + return +} + +// CHECK: [[CONST:%.*]] = "tf.Const" +// CHECK: [[VAR:%.*]] = "tf.VarHandleOp" +// CHECK: "tf_device.cluster"() ( { +// CHECK: "tf.opA"() ( { +// CHECK: "tf.AssignAddVariableOp"([[VAR]], [[CONST]]) +// CHECK: }) +// CHECK: "tf.opB"() +// CHECK: tf_device.return +// CHECK: }) +// CHECK-SAME: _tpu_replicate = "foo" +// CHECK: "tf.AssignAddVariableOp"([[VAR]], [[CONST]]) + +// Preceding user is using resource updated by the cluster within a nested op. +// Resource is updated by a cluster op, and opA (not in cluster) is using the +// resource in a nested op. We expect opA to be after the cluster. +// CHECK-LABEL: func @cluster_nested_op_using_resource +func @cluster_nested_op_using_resource() { + %0 = "tf.Const"() {value = dense<1.000000e+00> : tensor} : () -> tensor + %1 = "tf.VarHandleOp"() {container = "", shape = #tf.shape<>, shared_name = "x"} : () -> !tf_res + "tf.AssignAddVariableOp"(%1, %0) { _tpu_replicate = "foo" } : (!tf_res, tensor) -> () + "tf.opA"() ({ + "tf.AssignAddVariableOp"(%1, %0) : (!tf_res, tensor) -> () + "tf.terminator"() : () -> () + }) : () -> () + "tf.opB"() { _tpu_replicate = "foo" } : () -> () + "tf.TPUReplicateMetadata"() {_tpu_replicate = "foo", device = "CPU", num_replicas = 1, topology = "topology"} : () -> () + return +} + +// CHECK: [[CONST:%.*]] = "tf.Const" +// CHECK: [[VAR:%.*]] = "tf.VarHandleOp" +// CHECK: "tf_device.cluster"() ( { +// CHECK: "tf.AssignAddVariableOp"([[VAR]], [[CONST]]) +// CHECK: "tf.opB"() +// CHECK: tf_device.return +// CHECK: }) +// CHECK-SAME: _tpu_replicate = "foo" +// CHECK: "tf.opA"() ( { +// CHECK: "tf.AssignAddVariableOp"([[VAR]], [[CONST]]) + + +// ----- + + +!tf_res = type tensor<*x!tf.resource>> + +// Test multiple replicated clusters interleaved and uses resource variables. +// CHECK-LABEL: func @multiple_replicated_interleaved +func @multiple_replicated_interleaved(%arg0: !tf_res) { + "tf.TPUReplicateMetadata"() {_tpu_replicate = "a", num_replicas = 2, topology = "topology"} : () -> () + "tf.TPUReplicateMetadata"() {_tpu_replicate = "b", num_replicas = 2, topology = "topology"} : () -> () + "tf.TPUReplicateMetadata"() {_tpu_replicate = "c", num_replicas = 2, topology = "topology"} : () -> () + %0 = "tf.TPUReplicatedInput"(%arg0, %arg0) : (!tf_res, !tf_res) -> !tf_res + %1 = "tf.TPUReplicatedInput"(%arg0, %arg0) : (!tf_res, !tf_res) -> !tf_res + %2 = "tf.TPUReplicatedInput"(%arg0, %arg0) : (!tf_res, !tf_res) -> !tf_res + %3 = "tf.ReadVariableOp"(%0) {_tpu_replicate = "a"} : (!tf_res) -> tensor + %4 = "tf.ReadVariableOp"(%1) {_tpu_replicate = "b"} : (!tf_res) -> tensor + %5 = "tf.ReadVariableOp"(%2) {_tpu_replicate = "c"} : (!tf_res) -> tensor + %6 = "tf.Identity"(%3) {_tpu_replicate = "a"} : (tensor) -> tensor + %7 = "tf.Identity"(%4) {_tpu_replicate = "b"} : (tensor) -> tensor + %8 = "tf.Identity"(%5) {_tpu_replicate = "c"} : (tensor) -> tensor + %9:2 = "tf.TPUReplicatedOutput"(%6) : (tensor) -> (tensor, tensor) + %10:2 = "tf.TPUReplicatedOutput"(%7) : (tensor) -> (tensor, tensor) + %11:2 = "tf.TPUReplicatedOutput"(%8) : (tensor) -> (tensor, tensor) + return +} + +// CHECK: tf_device.replicate +// CHECK: tf_device.replicate +// CHECK: tf_device.replicate + + +// ----- + + +// Test cluster that is replicated but has a non TPUReplicatedOutput consumer. +// CHECK-LABEL: func @replicated_non_replicated_output +func @replicated_non_replicated_output() { + %0 = "tf.opA"() {_tpu_replicate = "replicate", device = "device", name = "name"} : () -> tensor + %1 = "tf.opB"(%0) : (tensor) -> tensor + "tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> () + return +} + +// CHECK: [[REPLICATE:%.+]]:2 = tf_device.replicate +// CHECK: "tf.opB"([[REPLICATE]]#0) + // ----- @@ -407,18 +554,6 @@ func @bad_num_replicas() { // ----- -// Test that functions without TPUReplicateMetadata op are skipped without -// error -// CHECK-LABEL: func @missing_metadata_op -func @missing_metadata_op() { - // expected-warning@+1 {{TPUReplicateMetadata for associated '_tpu_replicate' attribute 'replicate' is missing}} - %0 = "tf.opA"() {_tpu_replicate = "replicate"} : () -> tensor - return -} - -// ----- - - // Test cluster with TPUReplicatedInput where the number of operands does not // match associated `num_replicas` attribute. func @mismatched_replicated_input(%arg0: tensor) { @@ -447,20 +582,6 @@ func @mismatched_replicated_output() { // ----- -// Test cluster that should be replicated where its outputs do not lead to a -// TPUReplicatedOutput. -func @missing_replicated_output() { - // expected-error@+1 {{requires output of tf_device.cluster to lead to a 'tf.TPUReplicatedOutput' op}} - %0 = "tf.opA"() {_tpu_replicate = "replicate", device = "device", name = "name"} : () -> tensor - %1 = "tf.opB"(%0) : (tensor) -> tensor - "tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> () - return -} - - -// ----- - - // Test unused TPUReplicatedInput that has more than one operand. func @leftover_replicated_input(%arg0: tensor) { %0 = "tf.TPUReplicatedInput"(%arg0, %arg0) : (tensor, tensor) -> tensor diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_colocate_composite_resource_ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_colocate_composite_resource_ops.mlir new file mode 100644 index 00000000000..88af4535d81 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_colocate_composite_resource_ops.mlir @@ -0,0 +1,118 @@ +// RUN: tf-opt %s -tf-tpu-colocate-composite-resource-ops | FileCheck %s + +// Tests ReadVariable op using composite device resource is wrapped inside +// tf_device.Cluster. + +// CHECK-LABEL: func @testReadVariableOpColocated +// CHECK-SAME: (%[[ARG0:.*]]: tensor<*x!tf.resource>>) +func @testReadVariableOpColocated(%arg0: tensor<*x!tf.resource>>) { + // CHECK: tf_device.replicate + // CHECK-SAME: (%[[ARG0]] as %[[RI_0:[a-z0-9]*]]: tensor<*x!tf.resource>>) + tf_device.replicate(%arg0 as %arg1: tensor<*x!tf.resource>>) { + _mirrored_variable_indices = [0], _replicated_input_indices = [-1], + devices = {TPU_REPLICATED_CORE_0 = ["/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"]}, + n = 2 : i32} { + // CHECK: %[[RESOURCE_OUT:.*]] = "tf_device.launch"() + // CHECK-NEXT: %[[READ_OUT:.*]] = "tf.ReadVariableOp"(%[[RI_0]]) + // CHECK-NEXT: tf_device.return %[[READ_OUT]] + // CHECK-NEXT: TPU_REPLICATED_CORE_0 + %0 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf.resource>>) -> tensor<4xf32> + %1 = "tf.A"() : () -> (tensor<2x!tf.string>) + "tf_device.launch"() ( { + "tf.TPUExecuteAndUpdateVariables"(%arg1, %1) {device_var_reads_indices = [0], device_var_updates_indices = [-1]} : (tensor<*x!tf.resource>>, tensor<2x!tf.string>) -> () + tf_device.return + }) {device = "TPU_REPLICATED_CORE_0"} : () -> () + "tf_device.launch"() ( { + // CHECK: "tf.B"(%[[RESOURCE_OUT]]) + "tf.B"(%0) : (tensor<4xf32>) -> () + tf_device.return + }) {device = "TPU_REPLICATED_CORE_0"} : () -> () + tf_device.return + } + return +} + +// CHECK-LABEL: func @testReadVariableOpAfterIdentityColocated +// CHECK-SAME: (%[[ARG0:.*]]: tensor<*x!tf.resource>>) +func @testReadVariableOpAfterIdentityColocated(%arg0: tensor<*x!tf.resource>>) { + // CHECK: tf_device.replicate + // CHECK-SAME: (%[[ARG0]] as %[[RI_0:[a-z0-9]*]]: tensor<*x!tf.resource>>) + tf_device.replicate(%arg0 as %arg1: tensor<*x!tf.resource>>) { + _mirrored_variable_indices = [0], _replicated_input_indices = [-1], + devices = {TPU_REPLICATED_CORE_0 = ["/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"]}, + n = 2 : i32} { + // CHECK: %[[IDENTITY_OUT:.*]] = "tf.Identity"(%[[RI_0]]) + // CHECK: %[[RESOURCE_OUT:.*]] = "tf_device.launch"() + // CHECK-NEXT: %[[READ_OUT:.*]] = "tf.ReadVariableOp"(%[[IDENTITY_OUT]]) + // CHECK-NEXT: tf_device.return %[[READ_OUT]] + // CHECK-NEXT: TPU_REPLICATED_CORE_0 + %0 = "tf.Identity"(%arg1) : (tensor<*x!tf.resource>>) -> tensor<*x!tf.resource>> + %1 = "tf.ReadVariableOp"(%0) : (tensor<*x!tf.resource>>) -> tensor<4xf32> + %2 = "tf.A"() : () -> (tensor<2x!tf.string>) + "tf_device.launch"() ( { + "tf.TPUExecuteAndUpdateVariables"(%arg1, %2) {device_var_reads_indices = [0], device_var_updates_indices = [-1]} : (tensor<*x!tf.resource>>, tensor<2x!tf.string>) -> () + tf_device.return + }) {device = "TPU_REPLICATED_CORE_0"} : () -> () + "tf_device.launch"() ( { + // CHECK: "tf.B"(%[[RESOURCE_OUT]]) + "tf.B"(%1) : (tensor<4xf32>) -> () + tf_device.return + }) {device = "TPU_REPLICATED_CORE_0"} : () -> () + tf_device.return + } + return +} + +// Tests AssignVariable op using composite device resource is wrapped inside +// tf_device.Cluster. + +// CHECK-LABEL: func @testAssignVariableOpColocated +// CHECK-SAME: (%[[ARG0:.*]]: tensor<*x!tf.resource>>) +func @testAssignVariableOpColocated(%arg0: tensor<*x!tf.resource>>) { + // CHECK: tf_device.replicate + // CHECK-SAME: (%[[ARG0]] as %[[RI_0:[a-z0-9]*]]: tensor<*x!tf.resource>>) + tf_device.replicate(%arg0 as %arg1: tensor<*x!tf.resource>>) { + _mirrored_variable_indices = [0], _replicated_input_indices = [-1], + devices = {TPU_REPLICATED_CORE_0 = ["/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"]}, + n = 2 : i32} { + // CHECK: %[[VAL_OUT:.*]] = "tf.A"() : () -> tensor<4xf32> + // CHECK: "tf_device.launch"() + // CHECK-NEXT: "tf.AssignVariableOp"(%[[RI_0]], %[[VAL_OUT]]) + // CHECK-NEXT: tf_device.return + // CHECK-NEXT: TPU_REPLICATED_CORE_0 + %1 = "tf.A"() : () -> (tensor<4xf32>) + "tf.AssignVariableOp"(%arg1, %1) : (tensor<*x!tf.resource>>, tensor<4xf32>) -> () + %2 = "tf.B"() : () -> (tensor<2x!tf.string>) + "tf_device.launch"() ( { + "tf.TPUExecuteAndUpdateVariables"(%arg1, %2) {device_var_reads_indices = [0], device_var_updates_indices = [-1]} : (tensor<*x!tf.resource>>, tensor<2x!tf.string>) -> () + tf_device.return + }) {device = "TPU_REPLICATED_CORE_0"} : () -> () + tf_device.return + } + return +} + +// Tests tf_device.replicate op not running on TPU devices ignored. + +// CHECK-LABEL: func @testNonTPUDeviceReplicationIgnored +// CHECK-SAME: (%[[ARG0:.*]]: tensor<*x!tf.resource>>) +func @testNonTPUDeviceReplicationIgnored(%arg0: tensor<*x!tf.resource>>) { + // CHECK: tf_device.replicate + // CHECK-SAME: (%[[ARG0]] as %[[RI_0:[a-z0-9]*]]: tensor<*x!tf.resource>>) + tf_device.replicate(%arg0 as %arg1: tensor<*x!tf.resource>>) { + _mirrored_variable_indices = [0], _replicated_input_indices = [-1], + devices = {TPU_REPLICATED_HOST = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:CPU:1"]}, + n = 2 : i32} { + // CHECK: %[[VAL_OUT:.*]] = "tf.A"() : () -> tensor<4xf32> + // CHECK-NEXT: "tf.AssignVariableOp"(%[[RI_0]], %[[VAL_OUT]]) + %1 = "tf.A"() : () -> (tensor<4xf32>) + "tf.AssignVariableOp"(%arg1, %1) : (tensor<*x!tf.resource>>, tensor<4xf32>) -> () + %2 = "tf.B"() : () -> (tensor<2x!tf.string>) + "tf_device.launch"() ( { + "tf.TPUExecuteAndUpdateVariables"(%arg1, %2) {device_var_reads_indices = [0], device_var_updates_indices = [-1]} : (tensor<*x!tf.resource>>, tensor<2x!tf.string>) -> () + tf_device.return + }) {device = "TPU_REPLICATED_HOST"} : () -> () + tf_device.return + } + return +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_head_tail_outside_compilation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_head_tail_outside_compilation.mlir index 32a8000ea82..8ae6fa958a6 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_head_tail_outside_compilation.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_head_tail_outside_compilation.mlir @@ -173,7 +173,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor func @tail_single_outside_compiled_op() { // CHECK: %[[CLUSTER_OUT:.*]] = "tf_device.cluster" // CHECK-NEXT: %[[A_OUT:.*]] = "tf.A" - // CHECK-NEXT: "tf.C" + // CHECK-NEXT: "tf.NoOp" // CHECK-NEXT: tf_device.return %[[A_OUT]] // CHECK-NEXT: { // CHECK-DAG: num_cores_per_replica = 1 @@ -190,7 +190,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor "tf_device.cluster"() ( { %a = "tf.A"() : () -> tensor "tf.B"(%a) {_xla_outside_compilation = "cluster1"} : (tensor) -> () - "tf.C"() : () -> () + "tf.NoOp"() : () -> () tf_device.return }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> () return @@ -200,7 +200,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor func @tail_single_outside_compiled_op_user() -> tensor { // CHECK: %[[CLUSTER_OUT:.*]] = "tf_device.cluster" // CHECK-NEXT: %[[A_OUT:.*]] = "tf.A" - // CHECK-NEXT: "tf.C" + // CHECK-NEXT: "tf.NoOp" // CHECK-NEXT: tf_device.return %[[A_OUT]] // CHECK-NEXT: { // CHECK-DAG: num_cores_per_replica = 1 @@ -217,7 +217,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor %cluster = "tf_device.cluster"() ( { %a = "tf.A"() : () -> tensor %b = "tf.B"(%a) {_xla_outside_compilation = "cluster1"} : (tensor) -> tensor - "tf.C"() : () -> () + "tf.NoOp"() : () -> () tf_device.return %b : tensor }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> tensor // CHECK: return %[[LAUNCH_OUT]] @@ -262,7 +262,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor %b = "tf.B"() : () -> tensor // CHECK: %[[CLUSTER_OUT:.*]]:2 = "tf_device.cluster" // CHECK-NEXT: %[[C_OUT:.*]] = "tf.C" - // CHECK-NEXT: %[[E_OUT:.*]] = "tf.E" + // CHECK-NEXT: %[[E_OUT:.*]] = "tf.Const" // CHECK-NEXT: tf_device.return %[[C_OUT]], %[[E_OUT]] // CHECK-NEXT: { // CHECK-DAG: num_cores_per_replica = 1 @@ -279,7 +279,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor %cluster:5 = "tf_device.cluster"() ( { %c = "tf.C"() : () -> tensor %d = "tf.D"(%c, %a) {_xla_outside_compilation = "cluster1"} : (tensor, tensor) -> tensor - %e = "tf.E"() : () -> tensor + %e = "tf.Const"() {value = dense<0> : tensor} : () -> tensor tf_device.return %a, %b, %c, %d, %e : tensor, tensor, tensor, tensor, tensor }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> (tensor, tensor, tensor, tensor, tensor) // CHECK: return %[[A_OUT]], %[[B_OUT]], %[[CLUSTER_OUT]]#0, %[[LAUNCH_OUT]], %[[CLUSTER_OUT]]#1 @@ -320,14 +320,14 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor func @head_tail_no_extraction_middle_outside_compiled_ops(%arg0: tensor) { // CHECK-NOT: "tf_device.launch" // CHECK: "tf_device.cluster" - // CHECK-NEXT: "tf.A" + // CHECK-NEXT: "tf.Identity" // CHECK-NEXT: "tf.B" - // CHECK-NEXT: "tf.C" + // CHECK-NEXT: "tf.Identity" // CHECK-NEXT: tf_device.return "tf_device.cluster"() ( { - %a = "tf.A"(%arg0) : (tensor) -> tensor + %a = "tf.Identity"(%arg0) : (tensor) -> tensor %b = "tf.B"(%a) {_xla_outside_compilation = "cluster1"} : (tensor) -> tensor - "tf.C"(%b) : (tensor) -> () + %c = "tf.Identity"(%b) : (tensor) -> tensor tf_device.return }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> () return @@ -379,7 +379,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK: %[[CLUSTER_OUT:.*]] = "tf_device.cluster" // CHECK-NEXT: %[[B_OUT:.*]] = "tf.B" // CHECK-NEXT: %[[C_OUT:.*]] = "tf.C"(%[[RI]], %[[B_OUT]]) - // CHECK-NEXT: "tf.E"(%[[C_OUT]], %[[HEAD_LAUNCH_OUT]]) + // CHECK-NEXT: "tf.IdentityN"(%[[C_OUT]], %[[HEAD_LAUNCH_OUT]]) // CHECK-NEXT: tf_device.return %[[C_OUT]] // CHECK-NEXT: { // CHECK-DAG: num_cores_per_replica = 1 @@ -399,11 +399,139 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor %b = "tf.B"() : () -> tensor %c = "tf.C"(%ri, %b) {_xla_outside_compilation = "cluster1"} : (tensor, tensor) -> tensor %d = "tf.D"(%a, %c, %ri) {_xla_outside_compilation = "cluster1"} : (tensor, tensor, tensor) -> tensor - %e = "tf.E"(%c, %a) : (tensor, tensor) -> tensor + %e:2 = "tf.IdentityN"(%c, %a) : (tensor, tensor) -> (tensor, tensor) tf_device.return }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> () tf_device.return } return } + + // CHECK-LABEL: func @side_effect_middle + func @side_effect_middle() { + // CHECK: "tf_device.cluster" + // CHECK-NEXT: "tf.A" + // CHECK-NEXT: "tf.B" + // CHECK-NEXT: "tf.C" + // CHECK-NEXT: tf_device.return + "tf_device.cluster"() ( { + "tf.A"() : () -> () + "tf.B"() {_xla_outside_compilation = "cluster1"} : () -> () + "tf.C"() : () -> () + tf_device.return + }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> () + return + } + + // CHECK-LABEL: func @side_effect_head_no_operand + func @side_effect_head_no_operand() { + // CHECK: %[[HEAD_LAUNCH_OUT:.*]] = "tf_device.launch"() + // CHECK-NEXT: "tf.B" + // CHECK-NEXT: %[[C_OUT:.*]] = "tf.C" + // CHECK-NEXT: tf_device.return %[[C_OUT]] + // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0" + + // CHECK: "tf_device.cluster" + // CHECK-NEXT: "tf.Const" + // CHECK-NEXT: "tf.D"(%[[HEAD_LAUNCH_OUT]]) + // CHECK-NEXT: tf_device.return + + "tf_device.cluster"() ( { + %cst = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + "tf.B"() {_xla_outside_compilation = "cluster1"} : () -> () + %c = "tf.C"() {_xla_outside_compilation = "cluster1"} : () -> tensor + "tf.D"(%c) : (tensor) -> () + tf_device.return + }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> () + return + } + + // CHECK-LABEL: func @side_effect_tail_no_operand + func @side_effect_tail_no_operand() { + // CHECK: %[[CLUSTER_OUT:.*]] = "tf_device.cluster" + // CHECK-NEXT: %[[A_OUT:.*]] = "tf.A" + // CHECK-NEXT: "tf.Const" + // CHECK-NEXT: tf_device.return %[[A_OUT]] + + // CHECK: "tf_device.launch"() + // CHECK-NEXT: "tf.B"(%[[CLUSTER_OUT]]) + // CHECK-NEXT: "tf.C" + // CHECK-NEXT: tf_device.return + // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0" + "tf_device.cluster"() ( { + %a = "tf.A"() : () -> tensor + "tf.B"(%a) {_xla_outside_compilation = "cluster1"} : (tensor) -> () + "tf.C"() {_xla_outside_compilation = "cluster1"} : () -> () + %cst = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + tf_device.return + }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> () + return + } + + // Test embedding ops can be head extracted and side effect analysis + // predecessors are ignored. + + // CHECK-LABEL: func @embedding_head_extraction + func @embedding_head_extraction(%arg0: tensor) { + // CHECK: "tf_device.launch"() + // CHECK-NEXT: "tf.EnqueueTPUEmbeddingRaggedTensorBatch" + // CHECK-NEXT: tf_device.return + // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0" + + // CHECK: "tf_device.cluster" + // CHECK-NEXT: "tf.UnknownOp" + // CHECK-NEXT: tf_device.return + "tf_device.cluster"() ( { + "tf.UnknownOp"() : () -> () + "tf.EnqueueTPUEmbeddingRaggedTensorBatch"(%arg0) {_xla_outside_compilation = "cluster1", table_ids = [1, 2]} : (tensor) -> () + tf_device.return + }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> () + return + } + + // Test side effecting op after embedding op can be head extracted. + + // CHECK-LABEL: func @op_after_embedding_head_extraction + func @op_after_embedding_head_extraction() { + // CHECK: "tf_device.launch"() + // CHECK-NEXT: "tf.A" + // CHECK-NEXT: tf_device.return + // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0" + + // CHECK: "tf_device.cluster" + // CHECK-NEXT: "tf.RecvTPUEmbeddingActivations" + // CHECK-NEXT: "tf.SendTPUEmbeddingGradients" + // CHECK-NEXT: tf_device.return + "tf_device.cluster"() ( { + %0 = "tf.RecvTPUEmbeddingActivations"() {config = "test_config_recv_embedding"} : () -> tensor<512x256xf32> + "tf.SendTPUEmbeddingGradients"(%0) {N = 1 : i64, NN = 0 : i64, config = "test_config_send_embedding", operand_segment_sizes = dense<[1, 0]> : vector<2xi32>} : (tensor<512x256xf32>) -> () + "tf.A"() {_xla_outside_compilation = "cluster1"} : () -> () + tf_device.return + }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> () + return + } + + // Test side effecting op before embedding op can be tail extracted. + + // CHECK-LABEL: func @op_before_embedding_tail_extraction + func @op_before_embedding_tail_extraction() { + // CHECK: "tf_device.cluster" + // CHECK-NEXT: "tf.UnknownOp" + // CHECK-NEXT: "tf.RecvTPUEmbeddingActivations" + // CHECK-NEXT: "tf.SendTPUEmbeddingGradients" + // CHECK-NEXT: tf_device.return + + // CHECK: "tf_device.launch"() + // CHECK-NEXT: "tf.A" + // CHECK-NEXT: tf_device.return + // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0" + "tf_device.cluster"() ( { + "tf.UnknownOp"() : () -> () + "tf.A"() {_xla_outside_compilation = "cluster1"} : () -> () + %0 = "tf.RecvTPUEmbeddingActivations"() {config = "test_config_recv_embedding"} : () -> tensor<512x256xf32> + "tf.SendTPUEmbeddingGradients"(%0) {N = 1 : i64, NN = 0 : i64, config = "test_config_send_embedding", operand_segment_sizes = dense<[1, 0]> : vector<2xi32>} : (tensor<512x256xf32>) -> () + tf_device.return + }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> () + return + } } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir index 2271bca7382..e2cfd6c82b2 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir @@ -145,12 +145,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK-NEXT: "tf_device.launch" // CHECK: %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey" // CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]]) - // CHECK-SAME: key = "host_compute_channel_cluster1_args" + // CHECK-SAME: key = "host_compute_channel_cluster1_0_args" // CHECK: "tf.B"(%[[RECV_OUTPUT]]) // CHECK: "tf_device.cluster" // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" // CHECK: "tf._XlaHostComputeMlir"(%[[A_OUTPUT]]) - // CHECK-SAME: send_key = "host_compute_channel_cluster1_args" + // CHECK-SAME: send_key = "host_compute_channel_cluster1_0_args" %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { %2 = "tf_device.cluster"() ( { %3 = "tf.A"() : () -> (tensor) @@ -164,6 +164,32 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor return %1 : tensor } + // Tests value is added as operand to XlaHostCompute op only if defining op is + // in TPU cluster. + + // CHECK-LABEL: func @single_outside_compiled_input_from_outside_device_cluster + func @single_outside_compiled_input_from_outside_device_cluster(%arg0: tensor) -> tensor { + %0 = "tf.A"(%arg0) : (tensor) -> tensor + // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate + // CHECK-NEXT: %[[A_OUTPUT:[0-9]*]] = "tf.A" + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute" + // CHECK-NEXT: "tf_device.launch" + // CHECK-NEXT: "tf.B"(%[[A_OUTPUT]]) + // CHECK: "tf_device.cluster" + // CHECK-NEXT: "tf.C"() + %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { + %3 = "tf.A"() : () -> (tensor) + %2 = "tf_device.cluster"() ( { + "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor) -> () + %4 = "tf.C"() : () -> tensor + tf_device.return %4 : tensor + }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor + tf_device.return %2 : tensor + } + + return %1 : tensor + } + // Tests extraction of a single outside compiled cluster with single host->device output. // CHECK-LABEL: func @single_outside_compiled_output_single_outside_compilation @@ -174,15 +200,15 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK-NEXT: "tf_device.launch" // CHECK: %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey" // CHECK: "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]]) - // CHECK-SAME: key = "host_compute_channel_cluster1_args" + // CHECK-SAME: key = "host_compute_channel_cluster1_0_args" // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"() // CHECK: "tf._XlaSendFromHost"(%[[B_OUTPUT]], %[[PROGRAM_OUTPUT]]) - // CHECK-SAME: key = "host_compute_channel_cluster1_retvals" + // CHECK-SAME: key = "host_compute_channel_cluster1_0_retvals" // CHECK: "tf_device.cluster" // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" // CHECK: %[[HOST_OUTPUT:[0-9]*]] = "tf._XlaHostComputeMlir"() - // CHECK-SAME: recv_key = "host_compute_channel_cluster1_retvals" - // CHECK-SAME: send_key = "host_compute_channel_cluster1_args" + // CHECK-SAME: recv_key = "host_compute_channel_cluster1_0_retvals" + // CHECK-SAME: send_key = "host_compute_channel_cluster1_0_args" // CHECK: "tf.C"(%[[HOST_OUTPUT]]) %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { %2 = "tf_device.cluster"() ( { @@ -209,11 +235,11 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]]) // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[RECV_OUTPUT]]) // CHECK: "tf._XlaSendFromHost"(%[[B_OUTPUT]], %[[PROGRAM_OUTPUT]]) - // CHECK-SAME: key = "host_compute_channel_cluster1_retvals" + // CHECK-SAME: key = "host_compute_channel_cluster1_0_retvals" // CHECK: "tf_device.cluster" // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" // CHECK: %[[HOST_OUTPUT:[0-9]*]] = "tf._XlaHostComputeMlir"(%[[A_OUTPUT]]) - // CHECK-SAME: recv_key = "host_compute_channel_cluster1_retvals" + // CHECK-SAME: recv_key = "host_compute_channel_cluster1_0_retvals" // CHECK: tf_device.return %[[HOST_OUTPUT]] %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { %2 = "tf_device.cluster"() ( { @@ -240,11 +266,11 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]]) // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[RECV_OUTPUT]]) // CHECK: "tf._XlaSendFromHost"(%[[B_OUTPUT]], %[[PROGRAM_OUTPUT]]) - // CHECK-SAME: key = "host_compute_channel_cluster1_retvals" + // CHECK-SAME: key = "host_compute_channel_cluster1_0_retvals" // CHECK: "tf_device.cluster" // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" // CHECK: %[[HOST_OUTPUT:[0-9]*]] = "tf._XlaHostComputeMlir"(%[[A_OUTPUT]]) - // CHECK-SAME: recv_key = "host_compute_channel_cluster1_retvals" + // CHECK-SAME: recv_key = "host_compute_channel_cluster1_0_retvals" // CHECK: "tf.C"(%[[HOST_OUTPUT]]) %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { %2 = "tf_device.cluster"() ( { @@ -259,6 +285,42 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor return %1 : tensor } + // Tests host to device communcation is added only if value is used for ops + // that are not outside compiled. + + // CHECK-LABEL: func @single_outside_compiled_output_used_for_another_host_op + func @single_outside_compiled_output_used_for_another_host_op(%arg0: tensor) -> tensor { + %0 = "tf.A"(%arg0) : (tensor) -> tensor + // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute" + // CHECK-NEXT: "tf_device.launch" + // CHECK-NEXT: %[[B_OUTPUT:[0-9]*]] = "tf.B"() + // CHECK-NEXT: "tf.IfRegion"(%[[A_OUTPUT]]) + // CHECK-NEXT: "tf.D"(%[[B_OUTPUT]]) + // CHECK: "tf_device.cluster" + // CHECK-NEXT: "tf.C"() + %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { + %3 = "tf.A"() : () -> (tensor) + %2 = "tf_device.cluster"() ( { + %4 = "tf.B"() {_xla_outside_compilation = "cluster1"} : () -> (tensor) + "tf.IfRegion"(%3) ({ + "tf.D"(%4) : (tensor) -> () + "tf.Yield"() : () -> () + }, { + "tf.Yield"() : () -> () + }) { _xla_outside_compilation = "cluster1", is_stateless = false} : (tensor) -> () + + %5 = "tf.C"() : () -> (tensor) + tf_device.return %5 : tensor + }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor + tf_device.return %2 : tensor + } + + return %1 : tensor + } + + // Tests extraction of a single outside compiled cluster with multiple input/output. // CHECK-LABEL: func @multiple_outside_compiled_input_output_single_outside_compilation @@ -271,12 +333,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK: %[[RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]]) // CHECK: %[[B_OUTPUT:[0-9]*]]:2 = "tf.C"(%[[RECV_OUTPUT]]#0, %[[RECV_OUTPUT]]#1) // CHECK: "tf._XlaSendFromHost"(%[[B_OUTPUT]]#0, %[[B_OUTPUT]]#1, %[[PROGRAM_OUTPUT]]) - // CHECK-SAME: key = "host_compute_channel_cluster1_retvals" + // CHECK-SAME: key = "host_compute_channel_cluster1_0_retvals" // CHECK: "tf_device.cluster" // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B" // CHECK: %[[HOST_OUTPUT:[0-9]*]]:2 = "tf._XlaHostComputeMlir"(%[[A_OUTPUT]], %[[B_OUTPUT]]) - // CHECK-SAME: recv_key = "host_compute_channel_cluster1_retvals" + // CHECK-SAME: recv_key = "host_compute_channel_cluster1_0_retvals" // CHECK: "tf.D"(%[[HOST_OUTPUT]]#0) // CHECK: "tf.E"(%[[HOST_OUTPUT]]#1) %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { @@ -306,20 +368,20 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK: %[[RECV_OUTPUT2:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT2]]) // CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[RECV_OUTPUT2]]) // CHECK: "tf._XlaSendFromHost"(%[[D_OUTPUT]], %[[PROGRAM_OUTPUT]]) - // CHECK-SAME: key = "host_compute_channel_cluster2_retvals" + // CHECK-SAME: key = "host_compute_channel_cluster2_0_retvals" // CHECK: "tf_device.launch" // CHECK: %[[PROGRAM_OUTPUT1:[a-z_0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey" // CHECK: %[[RECV_OUTPUT1:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT1]]) // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[RECV_OUTPUT1]]) // CHECK: "tf._XlaSendFromHost"(%[[B_OUTPUT]], %[[PROGRAM_OUTPUT]]) - // CHECK-SAME: key = "host_compute_channel_cluster1_retvals" + // CHECK-SAME: key = "host_compute_channel_cluster1_0_retvals" // CHECK: "tf_device.cluster" // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" // CHECK: %[[HOST_OUTPUT1:[0-9]*]] = "tf._XlaHostComputeMlir"(%[[A_OUTPUT]]) - // CHECK-SAME: recv_key = "host_compute_channel_cluster1_retvals" + // CHECK-SAME: recv_key = "host_compute_channel_cluster1_0_retvals" // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[HOST_OUTPUT1]]) // CHECK: %[[HOST_OUTPUT2:[0-9]*]] = "tf._XlaHostComputeMlir"(%[[C_OUTPUT]]) - // CHECK-SAME: recv_key = "host_compute_channel_cluster2_retvals" + // CHECK-SAME: recv_key = "host_compute_channel_cluster2_0_retvals" // CHECK: "tf.E"(%[[HOST_OUTPUT2]]) %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { %2 = "tf_device.cluster"() ( { @@ -346,12 +408,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK-NEXT: "tf_device.launch" // CHECK: %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey" // CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]]) - // CHECK-SAME: key = "host_compute_channel_cluster1_args" + // CHECK-SAME: key = "host_compute_channel_cluster1_0_args" // CHECK: "tf.B"(%arg0, %[[RECV_OUTPUT]]) // CHECK: "tf_device.cluster" // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" // CHECK: "tf._XlaHostComputeMlir"(%[[A_OUTPUT]]) - // CHECK-SAME: send_key = "host_compute_channel_cluster1_args" + // CHECK-SAME: send_key = "host_compute_channel_cluster1_0_args" %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { %2 = "tf_device.cluster"() ( { %3 = "tf.A"() : () -> (tensor) @@ -375,20 +437,20 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK-NEXT: "tf_device.launch" // CHECK: %[[PROGRAM_OUTPUT_2:[a-z_0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey" // CHECK: %[[RECV_OUTPUT_2:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT_2]]) - // CHECK-SAME: key = "host_compute_channel_cluster2_args" + // CHECK-SAME: key = "host_compute_channel_cluster2_0_args" // CHECK: "tf.D"(%[[RECV_OUTPUT_2]]) // CHECK: "tf_device.launch" // CHECK: %[[PROGRAM_OUTPUT_1:[a-z_0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey" // CHECK: %[[RECV_OUTPUT_1:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT_1]]) - // CHECK-SAME: key = "host_compute_channel_cluster1_args" + // CHECK-SAME: key = "host_compute_channel_cluster1_0_args" // CHECK: "tf.B"(%[[RECV_OUTPUT_1]]) // CHECK: "tf_device.cluster" // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" // CHECK: "tf._XlaHostComputeMlir"(%[[A_OUTPUT]]) - // CHECK-SAME: send_key = "host_compute_channel_cluster1_args" + // CHECK-SAME: send_key = "host_compute_channel_cluster1_0_args" // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C" // CHECK: "tf._XlaHostComputeMlir"(%[[C_OUTPUT]]) - // CHECK-SAME: send_key = "host_compute_channel_cluster2_args" + // CHECK-SAME: send_key = "host_compute_channel_cluster2_0_args" %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { %2 = "tf_device.cluster"() ( { %3 = "tf.A"() : () -> (tensor) @@ -413,14 +475,14 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK-NEXT: "tf_device.launch" // CHECK: %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey" // CHECK: %[[RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]]) - // CHECK-SAME: key = "host_compute_channel_cluster1_args" + // CHECK-SAME: key = "host_compute_channel_cluster1_0_args" // CHECK: "tf.C"(%[[RECV_OUTPUT]]#0) // CHECK: "tf.D"(%[[RECV_OUTPUT]]#1, %[[RECV_OUTPUT]]#0) // CHECK: "tf_device.cluster" // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B" // CHECK: "tf._XlaHostComputeMlir"(%[[A_OUTPUT]], %[[B_OUTPUT]]) - // CHECK-SAME: send_key = "host_compute_channel_cluster1_args" + // CHECK-SAME: send_key = "host_compute_channel_cluster1_0_args" %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { %2 = "tf_device.cluster"() ( { %3 = "tf.A"() : () -> (tensor) @@ -469,25 +531,25 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK-NEXT: %[[PLACEHOLDER_KEY:[0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"() // CHECK-NEXT: %[[PREDICATE_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) // CHECK-SAME: device_ordinal = 0 - // CHECK-SAME: key = "if_predicate_channel_cluster1_0" + // CHECK-SAME: key = "if_predicate_channel_cluster1_0_0" // CHECK-NEXT: tf.IfRegion"(%[[PREDICATE_RECV_OUTPUT]]) // CHECK-NEXT: %[[ARG_RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) // CHECK-SAME: device_ordinal = 0 - // CHECK-SAME: key = "host_compute_channel_cluster1_args" + // CHECK-SAME: key = "host_compute_channel_cluster1_0_args" // CHECK: "tf.D"(%[[ARG_RECV_OUTPUT]]#0, %[[ARG_RECV_OUTPUT]]#1) // CHECK: "tf._XlaSendFromHost"(%[[PLACEHOLDER_KEY]]) // CHECK-SAME: device_ordinal = 0 - // CHECK-SAME: key = "host_compute_channel_cluster1_retvals" + // CHECK-SAME: key = "host_compute_channel_cluster1_0_retvals" // CHECK-NEXT: "tf.Yield"() : () -> () // CHECK: "tf_device.cluster" // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B" // CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G" - // CHECK: "tf.XlaSendToHost"(%6) {key = "if_predicate_channel_cluster1_0"} + // CHECK: "tf.XlaSendToHost"(%6) {key = "if_predicate_channel_cluster1_0_0"} // CHECK-NEXT: tf.IfRegion"(%[[G_OUTPUT]]) // CHECK: "tf._XlaHostComputeMlir"(%[[B_OUTPUT]], %[[A_OUTPUT]]) - // CHECK-SAME: recv_key = "host_compute_channel_cluster1_retvals" - // CHECK-SAME: send_key = "host_compute_channel_cluster1_args" + // CHECK-SAME: recv_key = "host_compute_channel_cluster1_0_retvals" + // CHECK-SAME: send_key = "host_compute_channel_cluster1_0_args" // CHECK-SAME: tpu_core = 0 // CHECK-NEXT: "tf.Yield"() : () -> () %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { @@ -525,20 +587,20 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK-NEXT: %[[PLACEHOLDER_KEY:[0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"() // CHECK-NEXT: %[[RECV_OUTPUT:[0-9]*]]:3 = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) // CHECK-SAME: device_ordinal = 0 - // CHECK-SAME: key = "host_compute_channel_cluster1_args" + // CHECK-SAME: key = "host_compute_channel_cluster1_0_args" // CHECK-SAME: (tensor<2x!tf.string>) -> (tensor, tensor, tensor) // CHECK-NEXT: tf.IfRegion"(%[[RECV_OUTPUT]]#2) // CHECK: "tf.D"(%[[RECV_OUTPUT]]#0, %[[RECV_OUTPUT]]#1, %[[F_OUT]]) // CHECK: "tf._XlaSendFromHost"(%[[PLACEHOLDER_KEY]]) // CHECK-SAME: device_ordinal = 0 - // CHECK-SAME: key = "host_compute_channel_cluster1_retvals" + // CHECK-SAME: key = "host_compute_channel_cluster1_0_retvals" // CHECK: "tf_device.cluster" // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B" // CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G" // CHECK: "tf._XlaHostComputeMlir"(%[[B_OUTPUT]], %[[A_OUTPUT]], %[[G_OUTPUT]]) - // CHECK-SAME: recv_key = "host_compute_channel_cluster1_retvals" - // CHECK-SAME: send_key = "host_compute_channel_cluster1_args" + // CHECK-SAME: recv_key = "host_compute_channel_cluster1_0_retvals" + // CHECK-SAME: send_key = "host_compute_channel_cluster1_0_args" // CHECK-SAME: tpu_core = 0 %0 = "tf.A"(%arg0) : (tensor) -> tensor %7 = "tf.F"() : () -> tensor @@ -579,12 +641,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK-NEXT: %[[PLACEHOLDER_KEY:[0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"() // CHECK-NEXT: %[[RECV_OUTPUT_PREDICATE:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) // CHECK-SAME: device_ordinal = 0 - // CHECK-SAME: key = "if_predicate_channel_cluster1_0" + // CHECK-SAME: key = "if_predicate_channel_cluster1_0_0" // CHECK-SAME: (tensor<2x!tf.string>) -> tensor // CHECK-NEXT: tf.IfRegion"(%[[RECV_OUTPUT_PREDICATE]]) // CHECK-NEXT: %[[RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) // CHECK-SAME: device_ordinal = 0 - // CHECK-SAME: key = "host_compute_channel_cluster1_args" + // CHECK-SAME: key = "host_compute_channel_cluster1_0_args" // CHECK-SAME: (tensor<2x!tf.string>) -> (tensor, tensor) // CHECK-NEXT: tf.IfRegion"(%[[RECV_OUTPUT]]#1) // CHECK-NEXT: "tf.H"(%[[RECV_OUTPUT]]#0, %[[F_OUT]]) @@ -592,20 +654,20 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK: "tf.Yield"() : () -> () // CHECK: "tf._XlaSendFromHost"(%[[PLACEHOLDER_KEY]]) // CHECK-SAME: device_ordinal = 0 - // CHECK-SAME: key = "host_compute_channel_cluster1_retvals" + // CHECK-SAME: key = "host_compute_channel_cluster1_0_retvals" // CHECK: "tf_device.cluster" // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B" // CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G" // CHECK: "tf.XlaSendToHost"(%[[G_OUTPUT]]) - // CHECK-SAME: key = "if_predicate_channel_cluster1_0" + // CHECK-SAME: key = "if_predicate_channel_cluster1_0_0" // CHECK-SAME: (tensor) -> () // CHECK-NEXT: "tf.IfRegion"(%[[G_OUTPUT]]) // CHECK: %[[D_OUT:[0-9]*]] = "tf.D" // CHECK-NEXT: %[[F_OUT:[0-9]*]] = "tf.F" // CHECK: "tf._XlaHostComputeMlir"(%[[D_OUT]], %[[F_OUT]]) - // CHECK-SAME: recv_key = "host_compute_channel_cluster1_retvals" - // CHECK-SAME: send_key = "host_compute_channel_cluster1_args" + // CHECK-SAME: recv_key = "host_compute_channel_cluster1_0_retvals" + // CHECK-SAME: send_key = "host_compute_channel_cluster1_0_args" // CHECK-SAME: tpu_core = 0 // CHECK: "tf.Yield"() : () -> () // CHECK: "tf.Yield"() : () -> () @@ -657,25 +719,25 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK-NEXT: %[[PLACEHOLDER_KEY:[0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"() // CHECK-NEXT: %[[PREDICATE_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) // CHECK-SAME: device_ordinal = 0 - // CHECK-SAME: key = "if_predicate_channel_cluster1_0" + // CHECK-SAME: key = "if_predicate_channel_cluster1_0_0" // CHECK-NEXT: tf.IfRegion"(%[[PREDICATE_RECV_OUTPUT]]) // CHECK-NEXT: %[[ARG_RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) // CHECK-SAME: device_ordinal = 0 - // CHECK-SAME: key = "host_compute_channel_cluster1_args" + // CHECK-SAME: key = "host_compute_channel_cluster1_0_args" // CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[ARG_RECV_OUTPUT]]#0, %[[ARG_RECV_OUTPUT]]#1) // CHECK: "tf._XlaSendFromHost"(%[[D_OUTPUT]], %[[PLACEHOLDER_KEY]]) // CHECK-SAME: device_ordinal = 0 - // CHECK-SAME: key = "host_compute_channel_cluster1_retvals" + // CHECK-SAME: key = "host_compute_channel_cluster1_0_retvals" // CHECK-NEXT: "tf.Yield"() : () -> () // CHECK: "tf_device.cluster" // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B" // CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G" - // CHECK: "tf.XlaSendToHost"(%6) {key = "if_predicate_channel_cluster1_0"} + // CHECK: "tf.XlaSendToHost"(%6) {key = "if_predicate_channel_cluster1_0_0"} // CHECK-NEXT: tf.IfRegion"(%[[G_OUTPUT]]) // CHECK: %[[HOST_COMPUTE_OUT:[0-9]*]] = "tf._XlaHostComputeMlir"(%[[B_OUTPUT]], %[[A_OUTPUT]]) - // CHECK-SAME: recv_key = "host_compute_channel_cluster1_retvals" - // CHECK-SAME: send_key = "host_compute_channel_cluster1_args" + // CHECK-SAME: recv_key = "host_compute_channel_cluster1_0_retvals" + // CHECK-SAME: send_key = "host_compute_channel_cluster1_0_args" // CHECK-SAME: tpu_core = 0 // CHECK-NEXT: "tf.Yield"(%[[HOST_COMPUTE_OUT]]) %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { @@ -714,7 +776,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK-NEXT: %[[PLACEHOLDER_KEY:[0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"() // CHECK-NEXT: %[[PREDICATE_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) // CHECK-SAME: device_ordinal = 0 - // CHECK-SAME: key = "if_predicate_channel_cluster1_0" + // CHECK-SAME: key = "if_predicate_channel_cluster1_0_0" // CHECK-NEXT: tf.IfRegion"(%[[PREDICATE_RECV_OUTPUT]]) // CHECK: "tf.D" // CHECK-NEXT: "tf.Yield"() : () -> () @@ -722,7 +784,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B" // CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G" - // CHECK: "tf.XlaSendToHost"(%6) {key = "if_predicate_channel_cluster1_0"} + // CHECK: "tf.XlaSendToHost"(%6) {key = "if_predicate_channel_cluster1_0_0"} // CHECK-NEXT: tf.IfRegion"(%[[G_OUTPUT]]) // CHECK-NEXT: "tf.Yield"() : () -> () %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { @@ -759,30 +821,30 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK-NEXT: %[[PLACEHOLDER_KEY:[0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"() // CHECK-NEXT: %[[PREDICATE_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) // CHECK-SAME: device_ordinal = 0 - // CHECK-SAME: key = "if_predicate_channel_cluster1_0" + // CHECK-SAME: key = "if_predicate_channel_cluster1_0_0" // CHECK-NEXT: tf.IfRegion"(%[[PREDICATE_RECV_OUTPUT]]) // CHECK-NEXT: %[[PREDICATE2_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) // CHECK-SAME: device_ordinal = 0 - // CHECK-SAME: key = "if_predicate_channel_cluster1_1" + // CHECK-SAME: key = "if_predicate_channel_cluster1_0_1" // CHECK-NEXT: tf.IfRegion"(%[[PREDICATE2_RECV_OUTPUT]]) // CHECK-NEXT: "tf.Yield"() : () -> () // CHECK: %[[ARG_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) // CHECK-SAME: device_ordinal = 0 - // CHECK-SAME: key = "host_compute_channel_cluster1_args" + // CHECK-SAME: key = "host_compute_channel_cluster1_0_args" // CHECK: "tf.D"(%[[ARG_RECV_OUTPUT]]) // CHECK: "tf._XlaSendFromHost"(%[[PLACEHOLDER_KEY]]) // CHECK-SAME: device_ordinal = 0 - // CHECK-SAME: key = "host_compute_channel_cluster1_retvals" + // CHECK-SAME: key = "host_compute_channel_cluster1_0_retvals" // CHECK-NEXT: "tf.Yield"() : () -> () // CHECK: "tf_device.cluster" // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B" // CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G" - // CHECK: "tf.XlaSendToHost"(%[[G_OUTPUT]]) {key = "if_predicate_channel_cluster1_0"} + // CHECK: "tf.XlaSendToHost"(%[[G_OUTPUT]]) {key = "if_predicate_channel_cluster1_0_0"} // CHECK-NEXT: tf.IfRegion"(%[[G_OUTPUT]]) // CHECK: %[[H_OUTPUT:[0-9]*]] = "tf.H"(%[[B_OUTPUT]]) - // CHECK: "tf.XlaSendToHost"(%[[H_OUTPUT]]) {key = "if_predicate_channel_cluster1_1"} + // CHECK: "tf.XlaSendToHost"(%[[H_OUTPUT]]) {key = "if_predicate_channel_cluster1_0_1"} // CHECK-NEXT: tf.IfRegion"(%[[H_OUTPUT]]) // CHECK-NEXT: "tf.Yield"() : () -> () // CHECK: %[[I_OUTPUT:[0-9]*]] = "tf.I"(%[[H_OUTPUT]]) @@ -819,4 +881,442 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor return %1 : tensor } + + // Tests extraction of a single outside compiled cluster inside a tf.WhileRegion op body. + + // CHECK-LABEL: func @outside_compiled_ops_inside_tf_while_body + func @outside_compiled_ops_inside_tf_while_body(%arg0: tensor) -> tensor { + %0 = "tf.A"(%arg0) : (tensor) -> tensor + + // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute" + // CHECK-NEXT: "tf_device.launch" + // CHECK-NEXT: %[[PLACEHOLDER_KEY:[0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"() + // CHECK-NEXT: tf.WhileRegion" + // CHECK-NEXT: %[[COND_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-SAME: device_ordinal = 0 + // CHECK-SAME: key = "while_condition_channel_cluster1_0_0" + // CHECK: "tf.Yield"(%[[COND_RECV_OUTPUT]]) + // CHECK: %[[BODY_RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D" + // CHECK: "tf._XlaSendFromHost"(%[[D_OUTPUT]], %[[PLACEHOLDER_KEY]]) + // CHECK_NEXT: "tf.Yield" + // CHECK: "tf_device.cluster" + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" + // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B" + // CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G" + // CHECK-NEXT: tf.WhileRegion"(%[[B_OUTPUT]], %[[A_OUTPUT]]) + // CHECK: %[[H_OUTPUT:[0-9]*]] = "tf.H" + // CHECK-NEXT: "tf.XlaSendToHost"(%[[H_OUTPUT]]) + // CHECK-NEXT: "tf.Yield"(%[[H_OUTPUT]]) + // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C" + // CHECK-NEXT: %[[HOST_COMPUTE_OUTPUT:[0-9]*]] = "tf._XlaHostComputeMlir" + // CHECK-NEXT "tf.Yield"(%[[C_OUTPUT]], %[[HOST_COMPUTE_OUTPUT]]) + %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { + %2 = "tf_device.cluster"() ( { + %3 = "tf.A"() : () -> (tensor) + %4 = "tf.B"() : () -> (tensor) + %6 = "tf.G"() : () -> (tensor) + + "tf.WhileRegion"(%4, %3) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %7 = "tf.H"(%arg1) : (tensor) -> tensor + "tf.Yield"(%7) : (tensor) -> () + }, { + ^bb0(%arg1: tensor, %arg2: tensor): + %8 = "tf.C"(%arg1) : (tensor) -> tensor + %9 = "tf.D"(%arg1, %arg2) {_xla_outside_compilation = "cluster1"} : (tensor, tensor) -> tensor + "tf.Yield"(%8, %9) : (tensor, tensor) -> () + }) { is_stateless = false} : (tensor, tensor) -> (tensor, tensor) + + %5 = "tf.E"() : () -> tensor + tf_device.return %5 : tensor + }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor + tf_device.return %2 : tensor + } + + return %1 : tensor + } + + // Tests extraction of a single outside compiled cluster inside a tf.WhileRegion op cond. + + // CHECK-LABEL: func @outside_compiled_ops_inside_tf_while_cond + func @outside_compiled_ops_inside_tf_while_cond(%arg0: tensor) -> tensor { + %0 = "tf.A"(%arg0) : (tensor) -> tensor + + // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute" + // CHECK-NEXT: "tf_device.launch" + // CHECK-NEXT: %[[PLACEHOLDER_KEY:[0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"() + // CHECK-NEXT: tf.WhileRegion" + // CHECK-NEXT: %[[COND_RECV_OUTPUT1:[0-9]*]]:2 = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-NEXT: %[[I_OUTPUT:[0-9]*]] = "tf.I"(%[[COND_RECV_OUTPUT1]]#0, %[[COND_RECV_OUTPUT1]]#1) + // CHECK-NEXT: "tf._XlaSendFromHost"(%[[I_OUTPUT]], %[[PLACEHOLDER_KEY]]) + // CHECK-NEXT: %[[COND_RECV_OUTPUT2:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-SAME: device_ordinal = 0 + // CHECK-SAME: key = "while_condition_channel_cluster1_0_0" + // CHECK: "tf.Yield"(%[[COND_RECV_OUTPUT2]]) + // CHECK_NEXT: "tf.Yield" + // CHECK: "tf_device.cluster" + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" + // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B" + // CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G" + // CHECK-NEXT: tf.WhileRegion"(%[[B_OUTPUT]], %[[A_OUTPUT]]) + // CHECK "tf.XlaHostCompute" + // CHECK: %[[H_OUTPUT:[0-9]*]] = "tf.H" + // CHECK-NEXT: "tf.XlaSendToHost"(%[[H_OUTPUT]]) + // CHECK-NEXT: "tf.Yield"(%[[H_OUTPUT]]) + // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C" + // CHECK-NEXT: "tf.D" + // CHECK-NEXT "tf.Yield"(%[[C_OUTPUT]], %[[HOST_COMPUTE_OUTPUT]]) + %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { + %2 = "tf_device.cluster"() ( { + %3 = "tf.A"() : () -> (tensor) + %4 = "tf.B"() : () -> (tensor) + %6 = "tf.G"() : () -> (tensor) + + "tf.WhileRegion"(%4, %3) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %7 = "tf.I"(%arg1, %arg2) {_xla_outside_compilation = "cluster1"} : (tensor, tensor) -> tensor + %8 = "tf.H"(%7) : (tensor) -> tensor + "tf.Yield"(%8) : (tensor) -> () + }, { + ^bb0(%arg1: tensor, %arg2: tensor): + %7 = "tf.C"(%arg1) : (tensor) -> tensor + %8 = "tf.D"(%arg1, %arg2) : (tensor, tensor) -> tensor + "tf.Yield"(%7, %8) : (tensor, tensor) -> () + }) { is_stateless = false} : (tensor, tensor) -> (tensor, tensor) + + %5 = "tf.E"() : () -> tensor + tf_device.return %5 : tensor + }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor + tf_device.return %2 : tensor + } + + return %1 : tensor + } + + // Tests extraction of a single outside compiled cluster inside a tf.WhileRegion op cond and body. + + // CHECK-LABEL: func @outside_compiled_ops_inside_tf_while_cond_body + func @outside_compiled_ops_inside_tf_while_cond_body(%arg0: tensor) -> tensor { + %0 = "tf.A"(%arg0) : (tensor) -> tensor + + // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute" + // CHECK-NEXT: "tf_device.launch" + // CHECK-NEXT: %[[PLACEHOLDER_KEY:[0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"() + // CHECK-NEXT: tf.WhileRegion" + // CHECK-NEXT: %[[COND_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-SAME: device_ordinal = 0 + // CHECK-SAME: key = "while_condition_channel_cluster2_0_0" + // CHECK: "tf.Yield"(%[[COND_RECV_OUTPUT]]) + // CHECK: %[[BODY_RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D" + // CHECK: "tf._XlaSendFromHost"(%[[D_OUTPUT]], %[[PLACEHOLDER_KEY]]) + // CHECK_NEXT: "tf.Yield" + // CHECK: "tf_device.launch" + // CHECK-NEXT: %[[PLACEHOLDER_KEY:[0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"() + // CHECK-NEXT: tf.WhileRegion" + // CHECK-NEXT: %[[COND_RECV_OUTPUT1:[0-9]*]]:2 = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-NEXT: %[[I_OUTPUT:[0-9]*]] = "tf.I"(%[[COND_RECV_OUTPUT1]]#0, %[[COND_RECV_OUTPUT1]]#1) + // CHECK-NEXT: "tf._XlaSendFromHost"(%[[I_OUTPUT]], %[[PLACEHOLDER_KEY]]) + // CHECK-NEXT: %[[COND_RECV_OUTPUT2:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-SAME: device_ordinal = 0 + // CHECK-SAME: key = "while_condition_channel_cluster1_0_0" + // CHECK: "tf.Yield"(%[[COND_RECV_OUTPUT2]]) + // CHECK_NEXT: "tf.Yield" + // CHECK: "tf_device.cluster" + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" + // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B" + // CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G" + // CHECK-NEXT: tf.WhileRegion"(%[[B_OUTPUT]], %[[A_OUTPUT]]) + // CHECK "tf.XlaHostCompute" + // CHECK: %[[H_OUTPUT:[0-9]*]] = "tf.H" + // CHECK-NEXT: "tf.XlaSendToHost"(%[[H_OUTPUT]]) + // CHECK-NEXT: "tf.XlaSendToHost"(%[[H_OUTPUT]]) + // CHECK-NEXT: "tf.Yield"(%[[H_OUTPUT]]) + // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C" + // CHECK-NEXT "tf.Yield"(%[[C_OUTPUT]], %[[HOST_COMPUTE_OUTPUT]]) + %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { + %2 = "tf_device.cluster"() ( { + %3 = "tf.A"() : () -> (tensor) + %4 = "tf.B"() : () -> (tensor) + %6 = "tf.G"() : () -> (tensor) + + "tf.WhileRegion"(%4, %3) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %7 = "tf.I"(%arg1, %arg2) {_xla_outside_compilation = "cluster1"} : (tensor, tensor) -> tensor + %8 = "tf.H"(%7) : (tensor) -> tensor + "tf.Yield"(%8) : (tensor) -> () + }, { + ^bb0(%arg1: tensor, %arg2: tensor): + %7 = "tf.C"(%arg1) : (tensor) -> tensor + %8 = "tf.D"(%arg1, %arg2) {_xla_outside_compilation = "cluster2"} : (tensor, tensor) -> tensor + "tf.Yield"(%7, %8) : (tensor, tensor) -> () + }) { is_stateless = false} : (tensor, tensor) -> (tensor, tensor) + + %5 = "tf.E"() : () -> tensor + tf_device.return %5 : tensor + }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor + tf_device.return %2 : tensor + } + + return %1 : tensor + } + + // Tests extraction of a single outside compiled cluster inside a tf.IfRegion op + // nested in a tf.WhileRegion. + + // CHECK-LABEL: func @outside_compiled_ops_inside_tf_while_if + func @outside_compiled_ops_inside_tf_while_if(%arg0: tensor) -> tensor { + %0 = "tf.A"(%arg0) : (tensor) -> tensor + + // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute" + // CHECK-NEXT: "tf_device.launch" + // CHECK-NEXT: %[[PLACEHOLDER_KEY:[0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"() + // CHECK-NEXT: tf.WhileRegion" + // CHECK-NEXT: %[[COND_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-SAME: device_ordinal = 0 + // CHECK-SAME: key = "while_condition_channel_cluster1_0_0" + // CHECK: "tf.Yield"(%[[COND_RECV_OUTPUT]]) + // CHECK: %[[PREDICATE_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-NEXT: tf.IfRegion"(%[[PREDICATE_RECV_OUTPUT]]) + // CHECK: "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D" + // CHECK: "tf._XlaSendFromHost"(%[[D_OUTPUT]], %[[PLACEHOLDER_KEY]]) + // CHECK_NEXT: "tf.Yield" + // CHECK: "tf_device.cluster" + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" + // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B" + // CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G" + // CHECK-NEXT: tf.WhileRegion"(%[[B_OUTPUT]], %[[A_OUTPUT]]) + // CHECK: %[[H_OUTPUT:[0-9]*]] = "tf.H" + // CHECK-NEXT: "tf.XlaSendToHost"(%[[H_OUTPUT]]) + // CHECK-NEXT: "tf.Yield"(%[[H_OUTPUT]]) + // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C" + // CHECK-NEXT: "tf.XlaSendToHost"(%[[G_OUTPUT]]) + // CHECK-NEXT: tf.IfRegion"(%[[G_OUTPUT]]) + // CHECK-NEXT: %[[HOST_COMPUTE_OUTPUT:[0-9]*]] = "tf._XlaHostComputeMlir" + // CHECK-NEXT "tf.Yield"(%[[HOST_COMPUTE_OUTPUT]]) + %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { + %2 = "tf_device.cluster"() ( { + %3 = "tf.A"() : () -> (tensor) + %4 = "tf.B"() : () -> (tensor) + %6 = "tf.G"() : () -> (tensor) + + "tf.WhileRegion"(%4, %3) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %7 = "tf.H"(%arg1) : (tensor) -> tensor + "tf.Yield"(%7) : (tensor) -> () + }, { + ^bb0(%arg1: tensor, %arg2: tensor): + %8 = "tf.C"(%arg1) : (tensor) -> tensor + %10 = "tf.IfRegion"(%6) ({ + %9 = "tf.D"() {_xla_outside_compilation = "cluster1"} : () -> tensor + "tf.Yield"(%9) : (tensor) -> () + }, { + "tf.Yield"(%arg2) : (tensor) -> () + }) { is_stateless = false} : (tensor) -> tensor + "tf.Yield"(%8, %10) : (tensor, tensor) -> () + }) { is_stateless = false} : (tensor, tensor) -> (tensor, tensor) + + %5 = "tf.E"() : () -> tensor + tf_device.return %5 : tensor + }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor + tf_device.return %2 : tensor + } + + return %1 : tensor + } + + // Tests extraction of an outside compiled tf.IfRegion op where the entirety + // of tf.IfRegion op is outside compiled with a nested tf.WhileRegion op. + + // CHECK-LABEL: func @outside_compiled_tf_if_nested_while + func @outside_compiled_tf_if_nested_while(%arg0: tensor) -> tensor { + // CHECK: %[[A_OUT:[0-9]*]] = "tf.A" + // CHECK: %[[F_OUT:[0-9]*]] = "tf.F" + // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute" + // CHECK-NEXT: "tf_device.launch" + // CHECK-NEXT: %[[PLACEHOLDER_KEY:[0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"() + // CHECK-NEXT: %[[RECV_OUTPUT:[0-9]*]]:3 = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-SAME: device_ordinal = 0 + // CHECK-SAME: key = "host_compute_channel_cluster1_0_args" + // CHECK-SAME: (tensor<2x!tf.string>) -> (tensor, tensor, tensor) + // CHECK-NEXT: tf.IfRegion"(%[[RECV_OUTPUT]]#2) + // CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[RECV_OUTPUT]]#0, %[[RECV_OUTPUT]]#1, %[[F_OUT]]) + // CHECK-NEXT: %[[J_OUTPUT:[0-9]*]] = "tf.J" + // CHECK-NEXT: %[[K_OUTPUT:[0-9]*]] = "tf.K" + // CHECK-NEXT: tf.WhileRegion"(%[[J_OUTPUT]], %[[D_OUTPUT]]) + // CHECK: %[[H_OUTPUT:[0-9]*]] = "tf.H"(%[[K_OUTPUT]]) + // CHECK: "tf._XlaSendFromHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-SAME: device_ordinal = 0 + // CHECK-SAME: key = "host_compute_channel_cluster1_0_retvals" + // CHECK: "tf_device.cluster" + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" + // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B" + // CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G" + // CHECK: "tf._XlaHostComputeMlir"(%[[B_OUTPUT]], %[[A_OUTPUT]], %[[G_OUTPUT]]) + // CHECK-SAME: recv_key = "host_compute_channel_cluster1_0_retvals" + // CHECK-SAME: send_key = "host_compute_channel_cluster1_0_args" + // CHECK-SAME: tpu_core = 0 + %0 = "tf.A"(%arg0) : (tensor) -> tensor + %7 = "tf.F"() : () -> tensor + + %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { + %2 = "tf_device.cluster"() ( { + %3 = "tf.A"() : () -> (tensor) + %4 = "tf.B"() : () -> (tensor) + %6 = "tf.G"() : () -> (tensor) + + "tf.IfRegion"(%6) ({ + %8 = "tf.D"(%4, %3, %7) {} : (tensor, tensor, tensor) -> (tensor) + %9 = "tf.J"() : () -> (tensor) + %10 = "tf.K"() : () -> (tensor) + "tf.WhileRegion"(%9, %8) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %11 = "tf.I"(%arg1, %arg2) : (tensor, tensor) -> tensor + %12 = "tf.H"(%10) : (tensor) -> tensor + "tf.Yield"(%12) : (tensor) -> () + }, { + ^bb0(%arg1: tensor, %arg2: tensor): + %11 = "tf.C"(%arg1) : (tensor) -> tensor + %12 = "tf.D"(%arg1, %arg2) : (tensor, tensor) -> tensor + "tf.Yield"(%11, %12) : (tensor, tensor) -> () + }) { is_stateless = false} : (tensor, tensor) -> (tensor, tensor) + "tf.Yield"() : () -> () + }, { + "tf.Yield"() : () -> () + }) {_xla_outside_compilation = "cluster1", is_stateless = false} : (tensor) -> () + + %5 = "tf.E"() : () -> tensor + tf_device.return %5 : tensor + }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor + tf_device.return %2 : tensor + } + + return %1 : tensor + } + + // Tests extraction of an outside compiled tf.WhileRegion where the entire + // tf.WhileRegion op is outside compiled with a nested tf.IfRegion. + + // CHECK-LABEL: func @outside_compiled_ops_tf_while_nested_if + func @outside_compiled_ops_tf_while_nested_if(%arg0: tensor) -> tensor { + %0 = "tf.A"(%arg0) : (tensor) -> tensor + + // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute" + // CHECK-NEXT: "tf_device.launch" + // CHECK-NEXT: %[[PLACEHOLDER_KEY:[0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"() + // CHECK-NEXT: %[[HOST_RECV_OUTPUT:[0-9]*]]:3 = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK: "tf.WhileRegion"(%[[HOST_RECV_OUTPUT]]#1, %[[HOST_RECV_OUTPUT]]#2) + // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C" + // CHECK: "tf.IfRegion"(%[[HOST_RECV_OUTPUT]]#0) + // CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[C_OUTPUT]]) + // CHECK_NEXT: "tf.Yield" + // CHECK: "tf_device.cluster" + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" + // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B" + // CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G" + // CHECK: "tf._XlaHostComputeMlir"(%[[G_OUTPUT]], %[[B_OUTPUT]], %[[A_OUTPUT]]) + %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { + %2 = "tf_device.cluster"() ( { + %3 = "tf.A"() : () -> (tensor) + %4 = "tf.B"() : () -> (tensor) + %6 = "tf.G"() : () -> (tensor) + + "tf.WhileRegion"(%4, %3) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %7 = "tf.H"(%arg1) : (tensor) -> tensor + "tf.Yield"(%7) : (tensor) -> () + }, { + ^bb0(%arg1: tensor, %arg2: tensor): + %8 = "tf.C"(%arg1) : (tensor) -> tensor + %10 = "tf.IfRegion"(%6) ({ + %9 = "tf.D"(%8) : (tensor) -> tensor + "tf.Yield"(%9) : (tensor) -> () + }, { + "tf.Yield"(%arg2) : (tensor) -> () + }) { is_stateless = false} : (tensor) -> tensor + "tf.Yield"(%8, %10) : (tensor, tensor) -> () + }) {_xla_outside_compilation = "cluster1", is_stateless = false} : (tensor, tensor) -> (tensor, tensor) + + %5 = "tf.E"() : () -> tensor + tf_device.return %5 : tensor + }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor + tf_device.return %2 : tensor + } + + return %1 : tensor + } + + // Tests extraction of an outside compiled cluster that contains ops wrapped + // inside multiple regions of nested tf.IfRegion and tf.WhileRegion. + + // CHECK-LABEL: func @outside_compiled_ops_with_multiple_region_single_cluster + func @outside_compiled_ops_with_multiple_region_single_cluster(%arg0: tensor) -> tensor { + %0 = "tf.A"(%arg0) : (tensor) -> tensor + + // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute" + // CHECK-NEXT: "tf_device.launch" + // CHECK-NEXT: %[[PLACEHOLDER_KEY:[0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"() + // CHECK-NEXT: "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-NEXT: %[[B_OUT:.*]] = "tf.B" + // CHECK-NEXT: "tf._XlaSendFromHost"(%[[B_OUT]], %[[PLACEHOLDER_KEY]]) + // CHECK-NEXT: "tf.WhileRegion"() + // CHECK-NEXT: %[[WHILE_COND:.*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-NEXT: "tf.Yield"(%[[WHILE_COND]]) + // CHECK: "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-NEXT: %[[C_OUT:.*]] = "tf.C"(%[[B_OUT]]) + // CHECK-NEXT: "tf._XlaSendFromHost"(%[[C_OUT]], %[[PLACEHOLDER_KEY]]) + // CHECK-NEXT: %[[IF_COND:.*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-NEXT: "tf.IfRegion"(%[[IF_COND]]) + // CHECK-NEXT: "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-NEXT: %[[D_OUT:.*]] = "tf.D"(%[[C_OUT]]) + // CHECK: "tf_device.cluster" + // CHECK-NEXT: %[[A_OUT:.*]] = "tf.A" + // CHECK-NEXT: %[[B_OUT_DEVICE:.*]] = "tf._XlaHostComputeMlir"() + // CHECK-NEXT: %[[G_OUT:.*]] = "tf.G" + // CHECK-NEXT: "tf.WhileRegion"(%[[B_OUT_DEVICE]], %[[A_OUT]]) + // CHECK: %[[H_OUT:.*]] = "tf.H" + // CHECK-NEXT: "tf.XlaSendToHost"(%[[H_OUT]]) + // CHECK-NEXT: "tf.Yield"(%[[H_OUT]]) + // CHECK: %[[C_OUT_DEVICE:.*]] = "tf._XlaHostComputeMlir"() + // CHECK-NEXT: "tf.XlaSendToHost"(%[[G_OUT]]) + // CHECK-NEXT: "tf.IfRegion"(%[[G_OUT]]) + // CHECK-NEXT: "tf._XlaHostComputeMlir"() + %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { + %2 = "tf_device.cluster"() ( { + %3 = "tf.A"() : () -> (tensor) + %4 = "tf.B"() {_xla_outside_compilation="cluster0"} : () -> (tensor) + %6 = "tf.G"() : () -> (tensor) + "tf.WhileRegion"(%4, %3) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %7 = "tf.H"(%arg1) : (tensor) -> tensor + "tf.Yield"(%7) : (tensor) -> () + }, { + ^bb0(%arg1: tensor, %arg2: tensor): + %8 = "tf.C"(%4) {_xla_outside_compilation="cluster0"} : (tensor) -> tensor + %10 = "tf.IfRegion"(%6) ({ + %9 = "tf.D"(%8) {_xla_outside_compilation="cluster0"} : (tensor) -> tensor + "tf.Yield"(%9) : (tensor) -> () + }, { + "tf.Yield"(%arg2) : (tensor) -> () + }) { is_stateless = false} : (tensor) -> tensor + "tf.Yield"(%8, %10) : (tensor, tensor) -> () + }) {is_stateless = false} : (tensor, tensor) -> (tensor, tensor) + + %5 = "tf.E"() : () -> tensor + tf_device.return %5 : tensor + }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor + tf_device.return %2 : tensor + } + + return %1 : tensor + } } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_outside_compilation_cluster.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_outside_compilation_cluster.mlir index 1394bd22dc8..183c7c34d41 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_outside_compilation_cluster.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_outside_compilation_cluster.mlir @@ -75,7 +75,7 @@ func @two_clusters_no_dependencies() { // CHECK: "tf.opB" // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER4:[a-zA-Z_0-9]+]]" // CHECK: "tf.opC" - // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER4]]" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER5:[a-zA-Z_0-9]+]]" // CHECK: "tf.opD" "tf_device.cluster"() ( { "tf.opA"() : () -> () @@ -95,7 +95,6 @@ func @two_clusters_with_one_op_each() { // CHECK-NEXT: "tf.opC" // CHECK-NEXT: "tf.opD" // CHECK-NOT: _xla_outside_compilation = "[[CLUSTER6]]" - // CHECK-SAME: _xla_outside_compilation = "{{[a-zA-Z_0-9]+}}" // CHECK-NEXT: "tf.opE" "tf_device.cluster"() ( { %a = "tf.opA"() : () -> tensor @@ -118,9 +117,8 @@ func @two_clusters_with_two_ops_each() { // CHECK-NEXT: "tf.opD" // CHECK-NEXT: "tf.opE" // CHECK-NOT: _xla_outside_compilation = "[[CLUSTER8]]" - // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER9:[a-zA-Z_0-9]+]]" // CHECK-NEXT: "tf.opF" - // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER9]]" + // CHECK-NOT: _xla_outside_compilation = "[[CLUSTER8]]" // CHECK-NEXT: "tf.opG" "tf_device.cluster"() ( { %a = "tf.opA"() : () -> tensor @@ -135,6 +133,27 @@ func @two_clusters_with_two_ops_each() { return } +// CHECK-LABEL: func @resource_side_effect_cycle +func @resource_side_effect_cycle(%arg0: tensor>>, %arg1: tensor>>) { + // CHECK: "tf.ReadVariableOp" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER1:[a-zA-Z_0-9]+]]" + // CHECK-NEXT: "tf.Identity" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER1]]" + // CHECK-NEXT: "tf.AssignVariableOp" + // CHECK-NOT: {_xla_outside_compilation = "[[CLUSTER1]]" + "tf_device.cluster"() ( { + %read0 = "tf.ReadVariableOp"(%arg0) {_xla_outside_compilation = "0"} : (tensor>>) -> tensor + %idet0 = "tf.Identity"(%read0) {_xla_outside_compilation = "0"} : (tensor) -> tensor + "tf.AssignVariableOp"(%arg1, %idet0) : (tensor>>, tensor) -> () + %read1 = "tf.ReadVariableOp"(%arg1) {_xla_outside_compilation = "0"} : (tensor>>) -> tensor + %idet1 = "tf.Identity"(%read1) {_xla_outside_compilation = "0"} : (tensor) -> tensor + %add0 = "tf.AddV2"(%idet0, %idet1) {_xla_outside_compilation = "0"} : (tensor, tensor) -> tensor + "tf.AssignVariableOp"(%arg0, %add0) {_xla_outside_compilation = "0"} : (tensor>>, tensor) -> () + tf_device.return + }) {cluster_attr = "cluster_attr"} : () -> () + return +} + // CHECK-LABEL: func @two_clusters_with_same_parent func @two_clusters_with_same_parent() { // CHECK: "tf.opA" @@ -142,12 +161,11 @@ func @two_clusters_with_same_parent() { // CHECK-NEXT: "tf.opB" // CHECK-NEXT: "tf.opC" // CHECK-NOT: _xla_outside_compilation = "[[CLUSTER10]]" - // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER11:[a-zA-Z_0-9]+]]" // CHECK-NEXT: "tf.opD" - // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER10]]" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER12:[a-zA-Z_0-9]+]]" // CHECK-NEXT: "tf.opE" // CHECK-NEXT: "tf.opF" - // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER11]]" + // CHECK-NOT: _xla_outside_compilation = "[[CLUSTER12]]" // CHECK-NEXT: "tf.opG" "tf_device.cluster"() ( { %a = "tf.opA"() {_xla_outside_compilation = "0"} : () -> tensor @@ -168,11 +186,11 @@ func @two_clusters_with_same_outside_compiled_parent() { // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER12:[a-zA-Z_0-9]+]]" // CHECK-NEXT: "tf.opB" // CHECK-NEXT: "tf.opC" - // CHECK-NOT: _xla_outside_compilation = "[[CLUSTER12]]" // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER13:[a-zA-Z_0-9]+]]" // CHECK-NEXT: "tf.opD" - // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER12]]" - // CHECK-NEXT: "tf.opE" + // CHECK-NOT: _xla_outside_compilation = "[[CLUSTER12]]" + // CHECK-NOT: _xla_outside_compilation = "[[CLUSTER13]]" + // CHECK-NEXT: "tf.Identity" // CHECK-NEXT: "tf.opF" // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER13]]" // CHECK-NEXT: "tf.opG" @@ -182,7 +200,7 @@ func @two_clusters_with_same_outside_compiled_parent() { %b = "tf.opB"(%a) : (tensor) -> tensor %c = "tf.opC"(%b) {_xla_outside_compilation = "0"} : (tensor) -> tensor %d = "tf.opD"() {_xla_outside_compilation = "0"} : () -> tensor - %e = "tf.opE"(%d) : (tensor) -> tensor + %e = "tf.Identity"(%d) : (tensor) -> tensor %f = "tf.opF"(%e) {_xla_outside_compilation = "0"} : (tensor) -> tensor %g = "tf.opG"(%c, %f) {_xla_outside_compilation = "0"} : (tensor, tensor) -> tensor tf_device.return @@ -213,14 +231,14 @@ func @outside_compile_with_block() { // CHECK-NEXT: "tf.opB" // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER15]]" // CHECK: "tf.opC" - // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER15]]" + // CHECK-NOT: _xla_outside_compilation = "[[CLUSTER15]]" "tf_device.cluster"() ( { %a = "tf.opA"() {_xla_outside_compilation = "0"} : () -> tensor - %b = "tf.opB"() {_xla_outside_compilation = "0"} : () -> tensor + %b = "tf.opB"(%a) {_xla_outside_compilation = "0"} : (tensor) -> tensor "tf_device.cluster" () ( { tf_device.return }) {cluster_attr = "cluster_attr"} : () -> () - %c = "tf.opC"() {_xla_outside_compilation = "0"} : () -> tensor + %c = "tf.opC"(%b) {_xla_outside_compilation = "0"} : (tensor) -> tensor tf_device.return }) {cluster_attr = "cluster_attr"} : () -> () return @@ -235,7 +253,6 @@ func @two_clusters_with_one_op_each_with_indirect_dependency() { // CHECK-NEXT: "tf.opD" // CHECK-NEXT: "tf.opE" // CHECK-NOT: _xla_outside_compilation = "[[CLUSTER16]]" - // CHECK-SAME: _xla_outside_compilation = "{{[a-zA-Z_0-9]+}}" // CHECK-NEXT: "tf.opF" "tf_device.cluster"() ( { %a = "tf.opA"() : () -> tensor @@ -248,3 +265,277 @@ func @two_clusters_with_one_op_each_with_indirect_dependency() { }) {cluster_attr = "cluster_attr"} : () -> () return } + +// CHECK-LABEL: func @check_ops_with_data_dependency_added_as_host_cluster +func @check_ops_with_data_dependency_added_as_host_cluster() { + // CHECK: "tf.opA" + // CHECK-NEXT: "tf.opB" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER16:[a-zA-Z_0-9]+]]" + // CHECK-NEXT: "tf.Identity" + // CHECK-NEXT: "tf.Identity" + // CHECK-NEXT: "tf.opE" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER16]]" + // CHECK-NEXT: "tf.opF" + "tf_device.cluster"() ( { + %a = "tf.opA"() : () -> tensor + %b = "tf.opB"(%a) {_xla_outside_compilation = "0"} : (tensor) -> tensor + %c = "tf.Identity"(%b) : (tensor) -> tensor + %d = "tf.Identity"(%c) : (tensor) -> tensor + %e = "tf.opE"(%d, %b, %c) {_xla_outside_compilation = "0"} : (tensor, tensor, tensor) -> tensor + "tf.opF"(%e) : (tensor) -> () + tf_device.return + }) {cluster_attr = "cluster_attr"} : () -> () + return +} + +// CHECK-LABEL: func @check_op_inside_nested_region_clustered +func @check_op_inside_nested_region_clustered(%arg0 : tensor<*x!tf.resource>) { + // CHECK: tf_device.cluster + // CHECK: "tf.IfRegion" + // CHECK-NEXT: "tf.Const" + // CHECK-NEXT: "tf.B" + // CHECK-NEXT: "tf.C" + // CHECK-NEXT: "tf.Const" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER17:[a-zA-Z_0-9]+]]" + // CHECK-NEXT: "tf.Const" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER17]]" + // CHECK-NEXT: "tf.WriteSummary" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER17]]" + "tf_device.cluster"() ( { + %0 = "tf.Const"() {value = dense : tensor} : () -> tensor + "tf.IfRegion"(%0) ( { + %1 = "tf.Const"() {value = dense : tensor} : () -> tensor + %2 = "tf.B"() : () -> (tensor) + %3 = "tf.C"() : () -> (tensor) + %4 = "tf.Const"() {_xla_outside_compilation = "auto0", value = dense<"logits"> : tensor} : () -> tensor + %5 = "tf.Const"() {_xla_outside_compilation = "auto1", value = dense<"\0A\09\0A\07scalars"> : tensor} : () -> tensor + "tf.WriteSummary"(%arg0, %2, %3, %4, %5) {_xla_outside_compilation = "auto2", device = "/device:CPU:0"} : (tensor<*x!tf.resource>, tensor, tensor, tensor, tensor) -> () + "tf.Yield"(%1) : (tensor) -> () + }, { + %1 = "tf.Const"() {value = dense : tensor} : () -> tensor + "tf.Yield"(%1) : (tensor) -> () + }) { is_stateless = true } : (tensor) -> tensor + + tf_device.return + }) {cluster_attr = "cluster_attr"} : () -> () + return +} + +// CHECK-LABEL: func @check_ops_inside_different_block_clustered +func @check_ops_inside_different_block_clustered(%arg0 : tensor<*x!tf.resource>) { + // CHECK: tf_device.cluster + // CHECK-NEXT: "tf.Const" + // CHECK-NEXT: "tf.B" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER17:[a-zA-Z_0-9]+]]" + // CHECK-NEXT: "tf.C" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER18:[a-zA-Z_0-9]+]]" + // CHECK: "tf.IfRegion" + // CHECK-NEXT: "tf.Const" + // CHECK-NEXT: "tf.Const" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER17]]" + // CHECK-NEXT: "tf.Const" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER17]]" + // CHECK-NEXT: "tf.WriteSummary" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER17]]" + // CHECK: "tf.Const" + // CHECK-NEXT: "tf.Const" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER18]]" + // CHECK-NEXT: "tf.D" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER18]]" + "tf_device.cluster"() ( { + %0 = "tf.Const"() {value = dense : tensor} : () -> tensor + %2 = "tf.B"() {_xla_outside_compilation = "auto1"} : () -> (tensor) + %3 = "tf.C"() {_xla_outside_compilation = "auto2"} : () -> (tensor) + "tf.IfRegion"(%0) ( { + %1 = "tf.Const"() {value = dense : tensor} : () -> tensor + %4 = "tf.Const"() {_xla_outside_compilation = "auto3", value = dense<"logits"> : tensor} : () -> tensor + %5 = "tf.Const"() {_xla_outside_compilation = "auto4", value = dense<"\0A\09\0A\07scalars"> : tensor} : () -> tensor + "tf.WriteSummary"(%arg0, %2, %3, %4, %5) {_xla_outside_compilation = "auto2", device = "/device:CPU:0"} : (tensor<*x!tf.resource>, tensor, tensor, tensor, tensor) -> () + "tf.Yield"(%1) : (tensor) -> () + }, { + %1 = "tf.Const"() {value = dense : tensor} : () -> tensor + %4 = "tf.Const"() {_xla_outside_compilation = "auto5", value = dense<"a"> : tensor} : () -> tensor + "tf.D"(%3, %4, %1) {_xla_outside_compilation = "auto6"} : (tensor, tensor, tensor) -> () + "tf.Yield"(%1) : (tensor) -> () + }) { is_stateless = true } : (tensor) -> tensor + + tf_device.return + }) {cluster_attr = "cluster_attr"} : () -> () + return +} + +// CHECK-LABEL: func @check_clustering_ops_inside_nested_control_flow +func @check_clustering_ops_inside_nested_control_flow(%arg0 : tensor<*x!tf.resource>) { + // CHECK: tf_device.cluster + // CHECK-NEXT: "tf.Const" + // CHECK-NEXT: "tf.B" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER17:[a-zA-Z_0-9]+]]" + // CHECK-NEXT: "tf.C" + // CHECK: _xla_outside_compilation = "[[CLUSTER17]]" + // CHECK: "tf.IfRegion" + // CHECK: "tf.IfRegion" + // CHECK-NEXT: "tf.Const" + // CHECK-NEXT: "tf.Const" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER17]]" + // CHECK-NEXT: "tf.Const" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER17]]" + // CHECK-NEXT: "tf.WriteSummary" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER17]]" + "tf_device.cluster"() ( { + %0 = "tf.Const"() {value = dense : tensor} : () -> tensor + %2 = "tf.B"() {_xla_outside_compilation = "auto1"} : () -> (tensor) + %3 = "tf.C"() {_xla_outside_compilation = "auto2"} : () -> (tensor) + "tf.IfRegion"(%0) ( { + %6 = "tf.Const"() {value = dense : tensor} : () -> tensor + "tf.IfRegion"(%6) ( { + %1 = "tf.Const"() {value = dense : tensor} : () -> tensor + %4 = "tf.Const"() {_xla_outside_compilation = "auto3", value = dense<"logits"> : tensor} : () -> tensor + %5 = "tf.Const"() {_xla_outside_compilation = "auto4", value = dense<"\0A\09\0A\07scalars"> : tensor} : () -> tensor + "tf.WriteSummary"(%arg0, %2, %3, %4, %5) {_xla_outside_compilation = "auto2", device = "/device:CPU:0"} : (tensor<*x!tf.resource>, tensor, tensor, tensor, tensor) -> () + "tf.Yield"(%1) : (tensor) -> () + }, { + %1 = "tf.Const"() {value = dense : tensor} : () -> tensor + "tf.Yield"(%1) : (tensor) -> () + }) { is_stateless = true } : (tensor) -> tensor + "tf.Yield"(%6) : (tensor) -> () + }, { + %7 = "tf.Const"() {value = dense : tensor} : () -> tensor + "tf.Yield"(%7) : (tensor) -> () + }) { is_stateless = true } : (tensor) -> tensor + tf_device.return + }) {cluster_attr = "cluster_attr"} : () -> () + return +} + +// CHECK-LABEL: func @single_variant_input +func @single_variant_input() { + // CHECK: "tf.opA" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER1:[a-zA-Z_0-9]+]]" + // CHECK: "tf.opB" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER1]]" + // CHECK: "tf.opC" + "tf_device.cluster"() ( { + %1= "tf.opA"() : () -> tensor>> + "tf.opB"(%1) {_xla_outside_compilation = "0"} : (tensor>>) -> () + "tf.opC"() : () -> () + tf_device.return + }) {cluster_attr = "cluster_attr"} : () -> () + return +} + +// CHECK-LABEL: func @chained_variant_input +func @chained_variant_input() { + // CHECK: "tf.opA" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER1:[a-zA-Z_0-9]+]]" + // CHECK: "tf.opB" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER1]]" + // CHECK: "tf.opC" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER1]]" + "tf_device.cluster"() ( { + %1 = "tf.opA"() : () -> tensor>> + %2 = "tf.opB"(%1) : (tensor>>) -> (tensor>>) + "tf.opC"(%2) {_xla_outside_compilation = "0"} : (tensor>>) -> () + tf_device.return + }) {cluster_attr = "cluster_attr"} : () -> () + return +} + +// CHECK-LABEL: func @single_variant_output +func @single_variant_output() { + // CHECK: "tf.opA" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER1:[a-zA-Z_0-9]+]]" + // CHECK: "tf.opB" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER1]]" + // CHECK: "tf.opC" + "tf_device.cluster"() ( { + %1= "tf.opA"() {_xla_outside_compilation = "0"} : () -> tensor>> + "tf.opB"(%1) : (tensor>>) -> () + "tf.opC"() : () -> () + tf_device.return + }) {cluster_attr = "cluster_attr"} : () -> () + return +} + +// CHECK-LABEL: func @chained_variant_output +func @chained_variant_output() { + // CHECK: "tf.opA" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER1:[a-zA-Z_0-9]+]]" + // CHECK: "tf.opB" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER1]]" + // CHECK: "tf.opC" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER1]]" + "tf_device.cluster"() ( { + %1 = "tf.opA"() {_xla_outside_compilation = "0"} : () -> tensor>> + %2 = "tf.opB"(%1) : (tensor>>) -> (tensor>>) + "tf.opC"(%2) : (tensor>>) -> () + tf_device.return + }) {cluster_attr = "cluster_attr"} : () -> () + return +} + +// CHECK-LABEL: func @variant_input_output +func @variant_input_output() { + // CHECK: "tf.opA" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER1:[a-zA-Z_0-9]+]]" + // CHECK: "tf.opB" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER1]]" + // CHECK: "tf.opC" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER1]]" + "tf_device.cluster"() ( { + %1 = "tf.opA"() : () -> tensor>> + %2 = "tf.opB"(%1) {_xla_outside_compilation = "0"} : (tensor>>) -> (tensor>>) + "tf.opC"(%2) : (tensor>>) -> () + tf_device.return + }) {cluster_attr = "cluster_attr"} : () -> () + return +} + +// CHECK-LABEL: func @variant_input_nested +func @variant_input_nested(%arg0 : tensor<*x!tf.resource>) { + // CHECK: tf_device.cluster + // CHECK-NEXT: "tf.Const" + // CHECK-NEXT: "tf.C" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER1:[a-zA-Z_0-9]+]]" + // CHECK: "tf.IfRegion" + // CHECK: "tf.opD" + // CHECK-NOT: _xla_outside_compilation = "[[CLUSTER1]]" + "tf_device.cluster"() ( { + %0 = "tf.Const"() {value = dense : tensor} : () -> tensor + %2 = "tf.C"() {_xla_outside_compilation = "auto0"} : () -> (tensor>>) + "tf.IfRegion"(%0) ( { + %1 = "tf.Const"() {value = dense : tensor} : () -> tensor + "tf.opD"(%2) : (tensor>>) -> () + "tf.Yield"(%1) : (tensor) -> () + }, { + %1 = "tf.Const"() {value = dense : tensor} : () -> tensor + "tf.Yield"(%1) : (tensor) -> () + }) { is_stateless = true, _xla_outside_compilation = "auto1" } : (tensor) -> tensor + tf_device.return + }) {cluster_attr = "cluster_attr"} : () -> () + return +} + +// CHECK-LABEL: func @variant_output_nested +func @variant_output_nested(%arg0 : tensor<*x!tf.resource>) { + // CHECK: tf_device.cluster + // CHECK: "tf.IfRegion" + // CHECK: "tf.C" + // CHECK-NOT: _xla_outside_compilation + // CHECK: "tf.D" + // CHECK-NOT: _xla_outside_compilation + // CHECK: "tf.Yield" + // CHECK: _xla_outside_compilation + "tf_device.cluster"() ( { + %0 = "tf.Const"() {value = dense : tensor} : () -> tensor + %1 = "tf.IfRegion"(%0) ( { + %2 = "tf.C"() : () -> (tensor>>) + "tf.Yield"(%2) : (tensor>>) -> () + }, { + %2 = "tf.D"() : () -> (tensor>>) + "tf.Yield"(%2) : (tensor>>) -> () + }) { is_stateless = true, _xla_outside_compilation = "auto1" } : (tensor) -> tensor>> + "tf.E"(%1) {_xla_outside_compilation = "auto0"} : (tensor>>) -> () + tf_device.return + }) {cluster_attr = "cluster_attr"} : () -> () + return +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_parallel_execute_sink_resource_write.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_parallel_execute_sink_resource_write.mlir new file mode 100644 index 00000000000..ad4433c1d20 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_parallel_execute_sink_resource_write.mlir @@ -0,0 +1,137 @@ +// RUN: tf-opt %s -tf-tpu-parallel-execute-sink-resource-write | FILECHECK_OPTS="" FileCheck %s + +// CHECK-LABEL: func @multiple_uses +// CHECK-SAME: ({{.+}}: tensor, [[ARG1:%.+]]: tensor) +func @multiple_uses(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: [[PARALLEL_EXECUTE:%.+]]:2 = "tf_device.parallel_execute" + %0:2 = "tf_device.parallel_execute"() ( { + tf_device.return %arg0 : tensor + }, { + tf_device.return %arg0 : tensor + // CHECK: }) : () -> (tensor, tensor) + }) : () -> (tensor, tensor) + // CHECK-NEXT: "tf.AssignVariableOp"([[ARG1]], [[PARALLEL_EXECUTE]]#0) + "tf.AssignVariableOp"(%arg1, %0#0) : (tensor, tensor) -> () + // CHECK-NEXT: return [[PARALLEL_EXECUTE]]#0 + return %0#0 : tensor +} + +// CHECK-LABEL: func @not_assign_var +// CHECK-SAME: ({{.+}}: tensor, [[ARG1:%.+]]: tensor) +func @not_assign_var(%arg0: tensor, %arg1: tensor) { + // CHECK: [[PARALLEL_EXECUTE:%.+]]:2 = "tf_device.parallel_execute" + %0:2 = "tf_device.parallel_execute"() ( { + tf_device.return %arg0 : tensor + }, { + tf_device.return %arg0 : tensor + // CHECK: }) : () -> (tensor, tensor) + }) : () -> (tensor, tensor) + // CHECK-NEXT: "tf.AssignAddVariableOp"([[ARG1]], [[PARALLEL_EXECUTE]]#0) + "tf.AssignAddVariableOp"(%arg1, %0#0) : (tensor, tensor) -> () + return +} + +// CHECK-LABEL: func @resource_handle_output +// CHECK-SAME: ([[ARG0:%.+]]: tensor, {{.+}}: tensor) +func @resource_handle_output(%arg0: tensor, %arg1: tensor) { + // CHECK: [[PARALLEL_EXECUTE:%.+]]:2 = "tf_device.parallel_execute" + %0:2 = "tf_device.parallel_execute"() ( { + tf_device.return %arg1 : tensor + }, { + tf_device.return %arg1 : tensor + // CHECK: }) : () -> (tensor, tensor) + }) : () -> (tensor, tensor) + // CHECK-NEXT: "tf.AssignVariableOp"([[PARALLEL_EXECUTE]]#0, [[ARG0]]) + "tf.AssignVariableOp"(%0#0, %arg0) : (tensor, tensor) -> () + return +} + +// CHECK-LABEL: func @resource_handle_and_value_output +func @resource_handle_and_value_output(%arg0: tensor, %arg1: tensor) { + // CHECK: [[PARALLEL_EXECUTE:%.+]]:2 = "tf_device.parallel_execute" + %0:2 = "tf_device.parallel_execute"() ( { + tf_device.return %arg0, %arg1 : tensor, tensor + }, { + tf_device.return + }) : () -> (tensor, tensor) + // CHECK: "tf.AssignVariableOp"([[PARALLEL_EXECUTE]]#1, [[PARALLEL_EXECUTE]]#0) + "tf.AssignVariableOp"(%0#1, %0#0) : (tensor, tensor) -> () + return +} + +// CHECK-LABEL: func @resource_handle_after_parallel_execute +func @resource_handle_after_parallel_execute(%arg0: tensor) { + // CHECK: [[PARALLEL_EXECUTE:%.+]]:2 = "tf_device.parallel_execute" + %0:2 = "tf_device.parallel_execute"() ( { + tf_device.return %arg0 : tensor + }, { + tf_device.return %arg0 : tensor + // CHECK: }) : () -> (tensor, tensor) + }) : () -> (tensor, tensor) + // CHECK-NEXT: [[VAR:%.+]] = "tf.VarHandleOp" + %1 = "tf.VarHandleOp"() {container = "", shape = #tf.shape<>, shared_name = "x"} : () -> tensor>> + // CHECK-NEXT: "tf.AssignVariableOp"([[VAR]], [[PARALLEL_EXECUTE]]#0) + "tf.AssignVariableOp"(%1, %0#0) : (tensor>>, tensor) -> () + return +} + +// CHECK-LABEL: func @replace_single_output +// CHECK-SAME: ([[ARG0:%.+]]: tensor, [[ARG1:%.+]]: tensor, [[ARG2:%.+]]: tensor, [[ARG3:%.+]]: tensor) +func @replace_single_output(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor) { + // CHECK: {{%.+}}:2 = "tf_device.parallel_execute" + %0:3 = "tf_device.parallel_execute"() ( { + // CHECK-NEXT: "tf.AssignVariableOp"([[ARG3]], [[ARG1]]) + // CHECK-NEXT: tf_device.return [[ARG0]], [[ARG2]] : tensor, tensor + tf_device.return %arg0, %arg1, %arg2 : tensor, tensor, tensor + // CHECK-NEXT: }, { + }, { + // CHECK-NEXT: tf_device.return + tf_device.return + // CHECK-NEXT: }) : () -> (tensor, tensor) + }) : () -> (tensor, tensor, tensor) + "tf.AssignVariableOp"(%arg3, %0#1) : (tensor, tensor) -> () + // CHECK-NEXT: return + return +} + +// CHECK-LABEL: func @replace_multiple_outputs +// CHECK-SAME: ([[ARG0:%.+]]: tensor, [[ARG1:%.+]]: tensor, [[ARG2:%.+]]: tensor, [[ARG3:%.+]]: tensor, [[ARG4:%.+]]: tensor, [[ARG5:%.+]]: tensor, [[ARG6:%.+]]: tensor) +func @replace_multiple_outputs(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor, %arg6: tensor) { + // CHECK: {{%.+}}:3 = "tf_device.parallel_execute" + %0:5 = "tf_device.parallel_execute"() ( { + // CHECK-NEXT: "tf.AssignVariableOp"([[ARG5]], [[ARG1]]) + // CHECK-NEXT: "tf.AssignVariableOp"([[ARG6]], [[ARG3]]) + // CHECK-NEXT: tf_device.return [[ARG0]], [[ARG2]], [[ARG4]] : tensor, tensor, tensor + tf_device.return %arg0, %arg1, %arg2, %arg3, %arg4 : tensor, tensor, tensor, tensor, tensor + // CHECK-NEXT: }, { + }, { + // CHECK-NEXT: tf_device.return + tf_device.return + // CHECK-NEXT: }) : () -> (tensor, tensor, tensor) + }) : () -> (tensor, tensor, tensor, tensor, tensor) + "tf.AssignVariableOp"(%arg5, %0#1) : (tensor, tensor) -> () + "tf.AssignVariableOp"(%arg6, %0#3) : (tensor, tensor) -> () + // CHECK-NEXT: return + return +} + +// CHECK-LABEL: func @replace_multiple_outputs_regions +// CHECK-SAME: ([[ARG0:%.+]]: tensor, [[ARG1:%.+]]: tensor, [[ARG2:%.+]]: tensor, [[ARG3:%.+]]: tensor, [[ARG4:%.+]]: tensor, [[ARG5:%.+]]: tensor, [[ARG6:%.+]]: tensor, [[ARG7:%.+]]: tensor) +func @replace_multiple_outputs_regions(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor) { + // CHECK: {{%.+}}:4 = "tf_device.parallel_execute" + %0:6 = "tf_device.parallel_execute"() ( { + // CHECK-NEXT: "tf.AssignVariableOp"([[ARG6]], [[ARG1]]) + // CHECK-NEXT: tf_device.return [[ARG0]], [[ARG2]] : tensor, tensor + tf_device.return %arg0, %arg1, %arg2 : tensor, tensor, tensor + // CHECK-NEXT: }, { + }, { + // CHECK-NEXT: "tf.AssignVariableOp"([[ARG7]], [[ARG4]]) + // CHECK-NEXT: tf_device.return [[ARG3]], [[ARG5]] : tensor, tensor + tf_device.return %arg3, %arg4, %arg5 : tensor, tensor, tensor + // CHECK-NEXT: }) : () -> (tensor, tensor, tensor, tensor) + }) : () -> (tensor, tensor, tensor, tensor, tensor, tensor) + "tf.AssignVariableOp"(%arg6, %0#1) : (tensor, tensor) -> () + "tf.AssignVariableOp"(%arg7, %0#4) : (tensor, tensor) -> () + // CHECK-NEXT: return + return +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_sharding_identification.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_sharding_identification.mlir index 2e3e38c7004..a3d5a43a214 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_sharding_identification.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_sharding_identification.mlir @@ -227,3 +227,28 @@ func @pcall_func_body(%arg0: tensor<*xi1>) -> tensor { %2 = "tf.D"(%1) : (tensor<*xi1>) -> (tensor) return %2 : tensor } + +// ----- + +// Tests that output sharding inside a functional op is parsed correctly. + +// CHECK-LABEL: func @check_sharding_inside_functional_op +func @check_sharding_inside_functional_op(%arg0: tensor<*xi32>) { + "tf_device.cluster_func"(%arg0) {func = @cluster_func, step_marker_location = ""} : (tensor<*xi32>) -> tensor<*xi32> + // CHECK: input_sharding_configuration + // CHECK-SAME: ["\01\02\03"] + // CHECK: output_sharding_configuration + // CHECK-SAME: ["\01\02\03"] + return +} + +func @cluster_func(%arg0: tensor<*xi32>) -> tensor<*xi32> { + %0 = "tf.PartitionedCall"(%arg0) {f= @func_body, config="", config_proto="", executor_type=""} : (tensor<*xi32>) -> tensor<*xi32> + return %0 : tensor<*xi32> +} + +func @func_body(%arg0: tensor<*xi32>)-> tensor<*xi32> { + %0 = "tf.XlaSharding"(%arg0) { _XlaSharding = "\01\02\03" } : (tensor<*xi32>) -> tensor<*xi32> + %1 = "tf.Identity"(%0) : (tensor<*xi32>) -> (tensor<*xi32>) + return %1 : tensor<*xi32> +} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.cc b/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.cc index de73dff8b0b..fe0c5bea44e 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.cc @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.h" - #include #include #include @@ -41,7 +39,48 @@ namespace mlir { namespace TF { namespace { -// Replace TF BatchMatMul by TF Einsum + +// Replace TF BatchMatMul by TF Einsum op +template +class ConvertTFBatchMatMulToEinsumOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(BatchMatMulOpType op, + PatternRewriter& rewriter) const override { + Value input_lhs = op.x(); + Value input_rhs = op.y(); + + // LHS and RHS must be a ranked tensor type + auto lhs_type = input_lhs.getType().dyn_cast(); + auto rhs_type = input_rhs.getType().dyn_cast(); + + if (!lhs_type || !rhs_type) return failure(); + + auto lhs_shape = lhs_type.getShape(); + auto rhs_shape = rhs_type.getShape(); + + // Ensure that input ranks are at least 2. + const int dims_a = lhs_shape.size(); + const int dims_b = rhs_shape.size(); + if (dims_a < 2 || dims_b < 2) { + return failure(); + } + + // einsum equation for batchmatmul + std::string equation("...mk,...kn->...mn"); + if (op.adj_x()) std::swap(equation[3], equation[4]); + if (op.adj_y()) std::swap(equation[6 + 3], equation[6 + 4]); + + rewriter.replaceOpWithNewOp( + op, op.getType(), + /*inputs=*/ValueRange({input_lhs, input_rhs}), + /*equation=*/equation); + + return success(); + } +}; + struct BatchMatMulToEinsumPass : public PassWrapper { void runOnFunction() override; @@ -57,65 +96,10 @@ void BatchMatMulToEinsumPass::runOnFunction() { applyPatternsAndFoldGreedily(func, patterns); } -} // namespace - -template -LogicalResult -ConvertTFBatchMatMulToEinsumOp::matchAndRewrite( - BatchMatMulOpType op, PatternRewriter& rewriter) const { - Value input_lhs = op.x(); - Value input_rhs = op.y(); - - if (!input_lhs.getType().isa()) { - // LHS must be a ranked tensor type - return failure(); - } - if (!input_rhs.getType().isa()) { - // RHS must be a ranked tensor type - return failure(); - } - - auto lhs_type = input_lhs.getType().dyn_cast(); - auto rhs_type = input_rhs.getType().dyn_cast(); - - if (!lhs_type || !rhs_type) { - return failure(); - } - - auto lhs_shape = lhs_type.getShape(); - auto rhs_shape = rhs_type.getShape(); - - Location loc = op.getLoc(); - - // Ensure that input ranks are at least 2. - const int dims_a = lhs_shape.size(); - const int dims_b = rhs_shape.size(); - if (dims_a < 2 || dims_b < 2) { - // Both inputs must have rank >= 2 - return failure(); - } - - // einsum equation for batchmatmul - std::string equation("...mk,...kn->...mn"); - - if (op.adj_x()) { - std::swap(equation[3], equation[4]); - } - if (op.adj_y()) { - std::swap(equation[6 + 3], equation[6 + 4]); - } - - llvm::SmallVector inputs = {input_lhs, input_rhs}; - rewriter.replaceOpWithNewOp(op, op.getType(), - /*inputs=*/ValueRange(inputs), - /*equation=*/equation); - - return success(); -} - -static PassRegistration pass( +PassRegistration pass( "tf-batch-matmul-to-tf-einsum", "Replace TF BatchMatMul op by TF Einsum op."); +} // namespace std::unique_ptr> CreateBatchMatMulToEinsumPass() { return std::make_unique(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.h b/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.h deleted file mode 100644 index d39f3575b4a..00000000000 --- a/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.h +++ /dev/null @@ -1,43 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_BATCHMATMUL_TO_EINSUM_H_ -#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_BATCHMATMUL_TO_EINSUM_H_ - -#include "llvm/ADT/ArrayRef.h" -#include "mlir/IR/Location.h" // from @llvm-project -#include "mlir/IR/PatternMatch.h" // from @llvm-project -#include "mlir/IR/TypeUtilities.h" // from @llvm-project -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" -#include "tensorflow/core/util/matmul_bcast.h" - -namespace mlir { -namespace TF { - -// Replace TF BatchMatMul by TF Einsum op -template -class ConvertTFBatchMatMulToEinsumOp - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite( - BatchMatMulOpType op, - PatternRewriter& rewriter) const override; // NOLINT -}; - -} // namespace TF -} // namespace mlir - -#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_BATCHMATMUL_TO_EINSUM_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc index 0c21078b0ad..eccbe5feaec 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc @@ -22,6 +22,7 @@ limitations under the License. #include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" namespace mlir { @@ -57,6 +58,7 @@ tensorflow::Status RunTPUBridge( ModuleOp module, bool enable_logging, llvm::function_ref pipeline_builder) { PassManager bridge(module.getContext()); + ::tensorflow::applyTensorflowAndCLOptions(bridge); if (enable_logging) EnableLogging(&bridge); // Populate a passmanager with the list of passes that implement the bridge. @@ -98,18 +100,20 @@ void CreateTPUBridgePipeline(OpPassManager &pm) { // Run another shape inference pass because resource decomposition might have // created new partial types. pm.addPass(TF::CreateTFShapeInferencePass()); - pm.addPass(TFDevice::CreateResourceOpLiftingPass()); pm.addPass(TF::CreateTFFunctionalControlFlowToRegions()); pm.addPass(mlir::createInlinerPass()); + pm.addPass(CreateTPUClusterCleanupAttributesPass()); + pm.addPass(TFDevice::CreateResourceOpLiftingPass()); pm.addPass(TFDevice::CreateMarkOpsForOutsideCompilationPass()); pm.addPass(CreateTPUExtractHeadTailOutsideCompilationPass()); + pm.addPass(CreateTPUOutsideCompilationClusterPass()); pm.addPass(CreateTPUExtractOutsideCompilationPass()); - pm.addPass(TF::CreateTFRegionControlFlowToFunctional()); pm.addNestedPass(tf_executor::CreateTFExecutorConstantSinkingPass()); pm.addPass(TF::CreateResourceDeviceInferencePass()); pm.addPass(TFDevice::CreateClusterOutliningPass()); pm.addPass(CreateTPUDynamicPaddingMapperPass()); + pm.addPass(CreateTPUResourceReadForWritePass()); pm.addPass(CreateTPUShardingIdentificationPass()); pm.addPass(TFDevice::CreateAnnotateParameterReplicationPass()); pm.addPass(CreateTPURewritePass()); @@ -117,7 +121,9 @@ void CreateTPUBridgePipeline(OpPassManager &pm) { pm.addNestedPass(TFDevice::CreateReplicateInvariantOpHoistingPass()); pm.addNestedPass(CreateTPUDynamicLayoutPass()); pm.addNestedPass(CreateTPUMergeVariablesWithExecutePass()); + pm.addNestedPass(CreateTPUColocateCompositeResourceOps()); pm.addPass(CreateTPUVariableReformattingPass()); + pm.addPass(TF::CreateTFRegionControlFlowToFunctional()); } void CreateTPUBridgePipelineV1(OpPassManager &pm) { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_formation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_formation.cc index 2b8ab85be38..e85058a1964 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_formation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_formation.cc @@ -39,6 +39,10 @@ namespace { struct ClusterFormationPass : public PassWrapper { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + void runOnFunction() override; }; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.cc b/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.cc index 57a5cd888a1..cde07503e75 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.cc @@ -181,14 +181,14 @@ llvm::Optional GetElementTypeFromAccess( llvm::function_ref(Operation*)> infer_from_op) { for (auto& use : collection.getUses()) { if (auto while_op = llvm::dyn_cast(use.getOwner())) { - auto body = while_op.body_func(); + auto body = while_op.body_function(); assert(body); auto type_from_body = GetElementTypeFromAccess( body.getArgument(use.getOperandNumber()), module, infer_from_op); if (type_from_body.hasValue()) return type_from_body; } else if (auto if_op = llvm::dyn_cast(use.getOwner())) { - auto then_branch = if_op.then_func(); - auto else_branch = if_op.else_func(); + auto then_branch = if_op.then_function(); + auto else_branch = if_op.else_function(); assert(then_branch && else_branch); auto type_from_then = GetElementTypeFromAccess( then_branch.getArgument(use.getOperandNumber() - 1), module, diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc index 3005c78c54f..31cfc5ebf9c 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/tf_status.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/utils/eval_util.h" #include "tensorflow/core/platform/mutex.h" @@ -72,7 +73,8 @@ LogicalResult ConstantFoldFallbackHook( SmallVectorImpl& results) { // NOLINT // Instructions with side effects should not be constant folded to preserve // the original semantics. - if (inst->getNumRegions() != 0 || !MemoryEffectOpInterface::hasNoEffect(inst)) + if (inst->hasTrait() || + inst->getNumRegions() != 0 || !MemoryEffectOpInterface::hasNoEffect(inst)) return failure(); // If any of the result types are variants, don't try to constant fold them. @@ -87,7 +89,7 @@ LogicalResult ConstantFoldFallbackHook( } // Do not execute function calls. - if (llvm::isa(inst)) { + if (llvm::isa(inst)) { return failure(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/contraction_fusion.cc b/tensorflow/compiler/mlir/tensorflow/transforms/contraction_fusion.cc new file mode 100644 index 00000000000..b5d09f7a794 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/contraction_fusion.cc @@ -0,0 +1,162 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/UseDefLists.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace mlir { +namespace TF { +namespace { + +// -------------------------------------------------------------------------- // +// Fuse ContractionFusableInterface operations into contraction operation. +// -------------------------------------------------------------------------- // + +template +class FuseIntoContractionOp : public RewritePattern { + public: + FuseIntoContractionOp() + : RewritePattern(PatternBenefit(1), MatchAnyOpTypeTag()) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + auto fusable = dyn_cast(op); + if (!fusable) return failure(); + + auto failed = [&](Twine message) -> LogicalResult { + return rewriter.notifyMatchFailure(op, message); + }; + + // Check if the operation can be fused. + Optional fusion = fusable.GetContractionFusion(); + if (!fusion.hasValue()) { + return failed("returned empty contraction fusion specification"); + } + + // Check if preceeding operation is a BaseOp or FusedOp that we can use for + // fusion. + Operation *fuse_into = nullptr; + Value operand = op->getOperand(0); + + if (BaseOp base_op = operand.getDefiningOp()) { + fuse_into = base_op.getOperation(); + } else if (FusedOp fused_op = operand.getDefiningOp()) { + fuse_into = fused_op.getOperation(); + } else { + return failed("input to the fusable op must be a " + + BaseOp::getOperationName() + " or a " + + FusedOp::getOperationName()); + } + + // Operand result must have one use, because we do not want to compute + // tensor contraction twice. + if (!fuse_into->getResult(0).hasOneUse()) { + return failed("fused into op result must have one use"); + } + + MLIRContext *ctx = op->getContext(); + + // Build a fused MatMul operation from a base MatMul and a fusion. + SmallVector locations = {fuse_into->getLoc(), op->getLoc()}; + Location loc = rewriter.getFusedLoc(locations); + + // Fusion can't change the type of a fused operation. + Type result_ty = fuse_into->getResult(0).getType(); + + // Copy all operands from a base op and add additional fusion arguments. + SmallVector operands(fuse_into->getOperands()); + for (int idx : fusion->additional_arguments) { + operands.push_back(op->getOperand(idx)); + } + + // Copy attributes from a base op that we fuse into (e.g. copy all + // MatMul or Conv attributes to the fused operation). + SmallVector attrs(fuse_into->getAttrs().begin(), + fuse_into->getAttrs().end()); + + // Add fusion specific additional attributes. + for (auto attr : fusion->additional_attributes) { + attrs.push_back(attr); + } + + // Add a fused output kernel name to the list of fusions. + Identifier fusion_id = Identifier::get("fusion", ctx); + StringAttr fusion_name = StringAttr::get(fusion->output_kernel, ctx); + + auto is_fusion = [&](const NamedAttribute &attr) -> bool { + return attr.first == fusion_id; + }; + + if (isa(fuse_into)) { + NamedAttribute fusion_attr(fusion_id, ArrayAttr::get({fusion_name}, ctx)); + attrs.push_back(fusion_attr); + + } else { + ArrayAttr arr = + llvm::find_if(attrs, is_fusion)->second.template cast(); + llvm::erase_if(attrs, is_fusion); + + auto rng = arr.getAsRange(); + SmallVector updated(rng.begin(), rng.end()); + updated.push_back(fusion_name); + + attrs.push_back(NamedAttribute(fusion_id, ArrayAttr::get(updated, ctx))); + } + + // Update all uses of a fusable op with a new fused operation. + Value fused = rewriter.create(loc, result_ty, operands, attrs); + rewriter.replaceOp(op, {fused}); + + return failure(); + } +}; + +// -------------------------------------------------------------------------- // + +using FuseIntoMatMulOp = FuseIntoContractionOp; + +struct ContractionFusionPass + : public PassWrapper { + void runOnFunction() override; +}; + +void ContractionFusionPass::runOnFunction() { + FuncOp func = getFunction(); + + OwningRewritePatternList patterns; + patterns.insert(); + applyPatternsAndFoldGreedily(func, patterns); +} + +} // namespace + +std::unique_ptr> CreateContractionFusionPass() { + return std::make_unique(); +} + +static PassRegistration pass( + "tf-contraction-fusion", + "Fuses operations implementing ContractionFusionInterface into the " + "contraction operations"); + +} // namespace TF +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.cc b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.cc index 4737f44ae1e..28a5c583919 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.cc @@ -73,7 +73,7 @@ static Type GetResourceSubtype(Value resource) { void PopulateDecomposeResourceOpsPatterns(MLIRContext *context, OwningRewritePatternList *patterns) { - populateWithGenerated(context, patterns); + populateWithGenerated(context, *patterns); } } // namespace TF diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.td b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.td index 40339cebd31..4ed0307e2ef 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.td @@ -85,7 +85,7 @@ def DecomposeResourceApplyMomentumOpNonNesterov : $var_resource, $accum_resource, $lr, $grad, $momentum, BoolAttr:$_, ConstBoolAttrFalse:$use_nesterov ), - [(TF_AddOp:$accum_new + [(TF_AddV2Op:$accum_new (TF_MulOp (CreateTFReadVariableOp $src_op, $grad, $accum_resource), $momentum @@ -107,7 +107,7 @@ def DecomposeResourceApplyMomentumOpNesterov : $var_resource, $accum_resource, $lr, $grad, $momentum, BoolAttr:$_, ConstBoolAttrTrue:$use_nesterov ), - [(TF_AddOp:$accum_new + [(TF_AddV2Op:$accum_new (TF_MulOp (CreateTFReadVariableOp $src_op, $grad, $accum_resource), $momentum @@ -117,7 +117,7 @@ def DecomposeResourceApplyMomentumOpNesterov : (TF_AssignVariableOp $accum_resource, $accum_new), (TF_AssignSubVariableOp $var_resource, - (TF_AddOp + (TF_AddV2Op (TF_MulOp $grad, $lr), (TF_MulOp $accum_new, (TF_MulOp $momentum, $lr)) ) @@ -175,7 +175,7 @@ def DecomposeResourceApplyKerasMomentumOpNesterov : ] >; -// Pattern to Decompose ResourceApplyAdagrad. +// Pattern to Decompose ResourceApplyAdagradV2. // This decomposition is only correct inside XLA as it ignores use_locking // attribute. // accum <- accum + grad * grad @@ -201,6 +201,21 @@ def DecomposeResourceApplyAdagradV2 : ] >; +// ResourceApplyAdagrad op can be canonicalized to ResourceApplyAdagradV2 with +// zero epsilon and then decomposed using DecomposeResourceApplyAdagradV2 +// pattern. +def DecomposeResourceApplyAdagrad : + Pattern< + (TF_ResourceApplyAdagradOp $var_resource, $accum_resource, $lr, $grad, + $use_locking, $update_slots), + [ + (TF_ConstOp:$zero_epsilon (GetScalarOfType<0> $grad)), + (TF_ResourceApplyAdagradV2Op $var_resource, $accum_resource, $lr, + $zero_epsilon, $grad, $use_locking, $update_slots + ) + ]>; + + // Pattern to Decompose ResourceApplyAdam without Nesterov momentum. // This decomposition is only correct inside XLA as it ignores use_locking // attribute. @@ -342,7 +357,7 @@ def DecomposeResourceApplyCenteredRMSProp : ), [(TF_ConstOp:$one (GetScalarOfType<1> $grad)), (CreateTFReadVariableOp $src_op, $grad, $ms_resource), - (TF_AddOp:$ms_new + (TF_AddV2Op:$ms_new (TF_MulOp (TF_MulOp $grad, $grad), (TF_SubOp $one, $rho) @@ -354,7 +369,7 @@ def DecomposeResourceApplyCenteredRMSProp : ), (TF_AssignVariableOp $ms_resource, $ms_new), // mg = grad * (one - rho) + mg * rho; - (TF_AddOp:$mg_new + (TF_AddV2Op:$mg_new (TF_MulOp $grad, (TF_SubOp $one, $rho) @@ -366,7 +381,7 @@ def DecomposeResourceApplyCenteredRMSProp : ), (TF_AssignVariableOp $mg_resource, $mg_new), // mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms - mg * mg + epsilon) - (TF_AddOp:$mom_new + (TF_AddV2Op:$mom_new (TF_MulOp $momentum, (CreateTFReadVariableOp $src_op, $grad, $mom_resource)), (TF_DivOp @@ -374,7 +389,7 @@ def DecomposeResourceApplyCenteredRMSProp : (TF_SqrtOp (TF_SubOp $ms_new, - (TF_AddOp + (TF_AddV2Op (TF_MulOp $mg_new, $mg_new @@ -390,3 +405,45 @@ def DecomposeResourceApplyCenteredRMSProp : (TF_AssignSubVariableOp $var_resource, $mom_new) ] >; + +// This decomposition is only correct inside XLA as it ignores use_locking +// attribute. +// ms <- rho * ms_{t-1} + (1-rho) * grad * grad +// mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon) +// var <- var - mom +def DecomposeResourceApplyRMSProp : + Pattern< + (TF_ResourceApplyRMSPropOp:$src_op + $var_resource, $ms_resource, $mom_resource, $lr, $rho, $momentum, $epsilon, + $grad, ConstBoolAttrFalse:$use_locking + ), + [(TF_ConstOp:$one (GetScalarOfType<1> $grad)), + (CreateTFReadVariableOp $src_op, $grad, $ms_resource), + // ms <- rho * ms_{t-1} + (1-rho) * grad * grad + (TF_AddV2Op:$ms_new + (TF_MulOp + (CreateTFReadVariableOp $src_op, $grad, $ms_resource), + $rho + ), + (TF_MulOp + (TF_SquareOp $grad), + (TF_SubOp $one, $rho) + ) + ), + (TF_AssignVariableOp $ms_resource, $ms_new), + // mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon) + (TF_AddV2Op:$mom_new + (TF_MulOp $momentum, + (CreateTFReadVariableOp $src_op, $grad, $mom_resource)), + (TF_DivOp + (TF_MulOp $lr, $grad), + (TF_SqrtOp + (TF_AddV2Op $ms_new, $epsilon) + ) + ) + ), + (TF_AssignVariableOp $mom_resource, $mom_new), + // var <- var - mom + (TF_AssignSubVariableOp $var_resource, $mom_new) + ] + >; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc b/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc index 69dab58c3f5..c3d43c27ac5 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc @@ -16,12 +16,16 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/transforms/einsum.h" #include +#include #include #include #include "absl/memory/memory.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" @@ -43,130 +47,6 @@ namespace TF { namespace { -// All supported Einsum equations. -enum EinsumEquation { - BatchMatMul, - FourDMatrixDotProd, - ThreeDReshapeTail, - FourDBatchMatMul, - BroadcastMatMul, - ReduceSum, - TransposeMatMul, - BatchMatMulReducedDim, - TransposeReducedDim, - FourDReduceLast, - FourDTransposeAll, - UnsupportedEquation -}; - -// Tokens for parsing the given equation string. -enum EquationToken { - A, - B, - C, - D, - E, - COMMA, - ARROW, -}; -constexpr int kNumSupportedEquationVariables = 5; // A - E for now. - -bool tokenizeEquation(const llvm::StringRef& equation, - std::vector* tokens) { - std::map label_axis_mapping; - size_t index = 0; - int variable_count = 0; - llvm::Regex r("[[:alpha:]]"); - while (index < equation.size()) { - if (r.match(equation.substr(index, 1))) { - const char ltr = equation[index]; - auto itr = label_axis_mapping.find(ltr); - if (itr == label_axis_mapping.end() && - variable_count < kNumSupportedEquationVariables) { - label_axis_mapping[ltr] = EquationToken(variable_count); - tokens->push_back(EquationToken(variable_count)); - variable_count++; - } else if (itr != label_axis_mapping.end()) { - tokens->push_back(itr->second); - } else { - // Ran out of equation variables. - return false; - } - } else if (equation.substr(index, 1).contains(",")) { - tokens->push_back(COMMA); - } else if ((index < (equation.size() - 1)) && - (equation.substr(index, 2).contains("->"))) { - tokens->push_back(ARROW); - index++; - } else { - // Unallowed character encountered. - return false; - } - index++; - } - return true; -} - -EinsumEquation parseEquation(const std::vector& eqn) { - auto is_equal = [](const std::vector& eqn1, - const std::initializer_list& eqn2) { - return std::equal(eqn1.begin(), eqn1.end(), eqn2.begin(), eqn2.end()); - }; - // IJK,IKM->IJM - if (is_equal(eqn, {A, B, C, COMMA, A, C, D, ARROW, A, B, D})) { - return EinsumEquation::BatchMatMul; - } - // BFND,NDH->BFH - if (is_equal(eqn, {A, B, C, D, COMMA, C, D, E, ARROW, A, B, E})) { - return EinsumEquation::FourDMatrixDotProd; - } - // BFNH,BTNH->BNFT - if (is_equal(eqn, {A, B, C, D, COMMA, A, E, C, D, ARROW, A, C, B, E})) { - return EinsumEquation::FourDBatchMatMul; - } - // BFD,DNH->BFNH - if (is_equal(eqn, {A, B, C, COMMA, C, D, E, ARROW, A, B, D, E})) { - return EinsumEquation::ThreeDReshapeTail; - } - // BFH,HO->BFO - if (is_equal(eqn, {A, B, C, COMMA, C, D, ARROW, A, B, D})) { - return EinsumEquation::BroadcastMatMul; - } - // LBH,BL->BH - if (is_equal(eqn, {A, B, C, COMMA, B, A, ARROW, B, C})) { - return EinsumEquation::ReduceSum; - } - // LBH,BKL->BKH - if (is_equal(eqn, {A, B, C, COMMA, B, D, A, ARROW, B, D, C})) { - return EinsumEquation::TransposeMatMul; - } - // BIN,BINJ->BIJ - if (is_equal(eqn, {A, B, C, COMMA, A, B, C, D, ARROW, A, B, D})) { - return EinsumEquation::BatchMatMulReducedDim; - } - // BIJ,BINJ->BIN - if (is_equal(eqn, {A, B, C, COMMA, A, B, D, C, ARROW, A, B, D})) { - return EinsumEquation::TransposeReducedDim; - } - // ABCD,ADBE->ACBE - if (is_equal(eqn, {A, B, C, D, COMMA, A, D, B, E, ARROW, A, C, B, E})) { - return EinsumEquation::FourDReduceLast; - } - // ABCD,AECD->ACEB - if (is_equal(eqn, {A, B, C, D, COMMA, A, E, C, D, ARROW, A, C, E, B})) { - return EinsumEquation::FourDTransposeAll; - } - return EinsumEquation::UnsupportedEquation; -} - -EinsumEquation tokenizeAndParse(const llvm::StringRef& equation) { - std::vector tokens; - if (tokenizeEquation(equation, &tokens)) { - return parseEquation(tokens); - } - return EinsumEquation::UnsupportedEquation; -} - TF::TransposeOp createTransposeOp(Value value, Location loc, llvm::ArrayRef permutation, PatternRewriter* rewriter) { @@ -186,28 +66,6 @@ TF::TransposeOp createTransposeOp(Value value, Location loc, perm_op); } -TF::SumOp createSumOp(Value value, Location loc, - llvm::ArrayRef redux_axes, - PatternRewriter* rewriter) { - auto value_type = value.getType().cast(); - auto shape = value_type.getShape(); - auto redux_type = RankedTensorType::get( - {static_cast(redux_axes.size())}, rewriter->getIntegerType(32)); - auto redux_attr = DenseElementsAttr::get(redux_type, redux_axes); - auto redux_op = rewriter->create(loc, redux_type, redux_attr); - std::vector sum_shape(shape.size() - redux_axes.size()); - int count = 0; - for (int i = 0, end = shape.size(); i < end; ++i) { - if (std::find(redux_axes.begin(), redux_axes.end(), i) == - redux_axes.end()) { - sum_shape[count] = shape[i]; - count++; - } - } - auto sum_type = RankedTensorType::get(sum_shape, value_type.getElementType()); - return rewriter->create(loc, sum_type, value, redux_op); -} - TF::ReshapeOp createReshapeOp(Value value, ArrayRef shape, Type element_type, Location loc, PatternRewriter* rewriter) { @@ -222,241 +80,277 @@ TF::ReshapeOp createReshapeOp(Value value, ArrayRef shape, /*shape=*/shape_tensor); } +struct EinsumDimensionNumbers { + // Each field contains the list of dimensions appearing only in the specifed + // arguments of the einsum op with natural ordering. For example `rhs_out` + // contains the dimensions appearing in the RHS and the OUTPUT of the einsum + // but not in the LHS. + std::vector lhs; + std::vector rhs; + std::vector> lhs_rhs; + std::vector> lhs_out; + std::vector> rhs_out; + std::vector> lhs_rhs_out; +}; + +llvm::Optional> EquationToMap( + llvm::StringRef equation) { + llvm::SmallDenseMap map; + for (int64_t i = 0; i < equation.size(); ++i) { + if (!std::isalpha(equation[i])) { + // Unsupported character in the equation. + return llvm::None; + } + if (map.count(equation[i])) { + // Duplicate character in the equation. + return llvm::None; + } + map.try_emplace(equation[i], i); + } + return map; +} + +llvm::Optional GetEinsumDimensionNumbers( + llvm::StringRef equation) { + llvm::StringRef lhs_rhs; + llvm::StringRef out; + std::tie(lhs_rhs, out) = equation.split("->"); + if (lhs_rhs.empty() || out.empty()) return llvm::None; + + llvm::StringRef lhs; + llvm::StringRef rhs; + std::tie(lhs, rhs) = lhs_rhs.split(','); + if (lhs.empty() || rhs.empty()) return llvm::None; + + auto lhs_map_or = EquationToMap(lhs); + if (!lhs_map_or.hasValue()) return llvm::None; + auto lhs_map = lhs_map_or.getValue(); + + auto rhs_map_or = EquationToMap(rhs); + if (!rhs_map_or.hasValue()) return llvm::None; + auto rhs_map = rhs_map_or.getValue(); + + auto out_map_or = EquationToMap(out); + if (!out_map_or.hasValue()) return llvm::None; + auto out_map = out_map_or.getValue(); + + EinsumDimensionNumbers dnums; + for (int64_t i = 0, e = lhs.size(); i < e; ++i) { + auto rhs_index = rhs_map.find(lhs[i]); + auto out_index = out_map.find(lhs[i]); + if (rhs_index == rhs_map.end() && out_index == out_map.end()) { + dnums.lhs.emplace_back(i); + } else if (rhs_index == rhs_map.end()) { + dnums.lhs_out.emplace_back(i, out_index->second); + } else if (out_index == out_map.end()) { + dnums.lhs_rhs.emplace_back(i, rhs_index->second); + } else { + dnums.lhs_rhs_out.emplace_back(i, rhs_index->second, out_index->second); + } + } + for (int64_t i = 0, e = rhs.size(); i < e; ++i) { + auto lhs_index = lhs_map.find(rhs[i]); + auto out_index = out_map.find(rhs[i]); + if (lhs_index == lhs_map.end()) { + if (out_index == out_map.end()) { + dnums.rhs.emplace_back(i); + } else { + dnums.rhs_out.emplace_back(i, out_index->second); + } + } + } + for (int64_t i = 0, e = out.size(); i < e; ++i) { + auto lhs_index = lhs_map.find(out[i]); + auto rhs_index = rhs_map.find(out[i]); + if (lhs_index == lhs_map.end() && rhs_index == rhs_map.end()) { + // out only isn't supported + return llvm::None; + } + } + return dnums; +} + +std::vector inverseTransposeVector( + llvm::ArrayRef input, llvm::ArrayRef permutation) { + std::vector output(input.size()); + for (int64_t i = 0; i < input.size(); ++i) { + output[permutation[i]] = input[i]; + } + return output; +} + +// Computes the transpositions required to convert dnums to one supported by +// tf.BatchMatmulV2 and returns the new set of dimension numbers with them. +LogicalResult transposeForBatchMatmul( + const Location& loc, EinsumDimensionNumbers& dnums, Value* lhs, Value* rhs, + std::vector* out_inverse_transpose, PatternRewriter* rewriter) { + std::vector lhs_transpose; + std::vector rhs_transpose; + std::vector out_transpose; + lhs_transpose.reserve(dnums.lhs_rhs_out.size() + dnums.lhs_out.size() + + dnums.lhs_rhs.size()); + rhs_transpose.reserve(dnums.lhs_rhs_out.size() + dnums.rhs_out.size() + + dnums.lhs_rhs.size()); + out_transpose.reserve(dnums.lhs_rhs_out.size() + dnums.lhs_out.size() + + dnums.rhs_out.size()); + for (int64_t i = 0, e = dnums.lhs_rhs_out.size(); i < e; ++i) { + lhs_transpose.push_back(std::get<0>(dnums.lhs_rhs_out[i])); + rhs_transpose.push_back(std::get<1>(dnums.lhs_rhs_out[i])); + out_transpose.push_back(std::get<2>(dnums.lhs_rhs_out[i])); + dnums.lhs_rhs_out[i] = std::make_tuple(i, i, i); + } + + for (int64_t i = 0, e = dnums.lhs_out.size(); i < e; ++i) { + lhs_transpose.push_back(std::get<0>(dnums.lhs_out[i])); + out_transpose.push_back(std::get<1>(dnums.lhs_out[i])); + dnums.lhs_out[i] = + std::make_tuple(lhs_transpose.size() - 1, out_transpose.size() - 1); + } + for (int64_t i = 0, e = dnums.lhs_rhs.size(); i < e; ++i) { + lhs_transpose.push_back(std::get<0>(dnums.lhs_rhs[i])); + rhs_transpose.push_back(std::get<1>(dnums.lhs_rhs[i])); + dnums.lhs_rhs[i] = + std::make_tuple(lhs_transpose.size() - 1, rhs_transpose.size() - 1); + } + for (int64_t i = 0, e = dnums.rhs_out.size(); i < e; ++i) { + rhs_transpose.push_back(std::get<0>(dnums.rhs_out[i])); + out_transpose.push_back(std::get<1>(dnums.rhs_out[i])); + dnums.rhs_out[i] = + std::make_tuple(rhs_transpose.size() - 1, out_transpose.size() - 1); + } + + out_inverse_transpose->resize(out_transpose.size()); + for (int64_t i = 0, e = out_transpose.size(); i < e; ++i) { + out_inverse_transpose->at(out_transpose[i]) = i; + } + + *lhs = createTransposeOp(*lhs, loc, lhs_transpose, rewriter); + *rhs = createTransposeOp(*rhs, loc, rhs_transpose, rewriter); + return success(); +} + +// Reshapes LHS and RHS to have B0,...,Bn,L,C and B0,...,Bn,C,R shape +// respectively while assuming that the initial shape for them is +// B0,...,Bn,L0,...,Ln,C0,...,Cn and B0,...,Bn,C0,...,Cn,R0,...,Rn respectively. +LogicalResult reshapeForBatchMatmul(const Location& loc, + EinsumDimensionNumbers& dnums, Value* lhs, + Value* rhs, std::vector* out_shape, + PatternRewriter* rewriter) { + RankedTensorType lhs_type = lhs->getType().cast(); + RankedTensorType rhs_type = rhs->getType().cast(); + + std::vector lhs_shape; + std::vector rhs_shape; + lhs_shape.reserve(dnums.lhs_rhs_out.size() + dnums.lhs_out.size() + 1); + rhs_shape.reserve(dnums.lhs_rhs_out.size() + 2); + for (auto i : dnums.lhs_rhs_out) { + int64_t b = lhs_type.getShape()[std::get<0>(i)]; + lhs_shape.push_back(b); + rhs_shape.push_back(b); + out_shape->push_back(b); + } + + if (dnums.lhs_out.empty()) { + lhs_shape.push_back(1); + out_shape->push_back(1); + dnums.lhs_out.emplace_back(lhs_shape.size() - 1, out_shape->size() - 1); + } else if (dnums.lhs_rhs_out.empty()) { + for (auto i : dnums.lhs_out) { + int64_t b = lhs_type.getShape()[std::get<0>(i)]; + lhs_shape.push_back(b); + out_shape->push_back(b); + } + } else { + int64_t lhs_out_size = 1; + for (auto i : dnums.lhs_out) { + lhs_out_size *= lhs_type.getShape()[std::get<0>(i)]; + } + lhs_shape.push_back(lhs_out_size); + out_shape->push_back(lhs_out_size); + } + + int64_t lhs_rhs_size = 1; + for (auto i : dnums.lhs_rhs) { + lhs_rhs_size *= lhs_type.getShape()[std::get<0>(i)]; + } + lhs_shape.push_back(lhs_rhs_size); + rhs_shape.push_back(lhs_rhs_size); + + int64_t rhs_size = 1; + for (auto i : dnums.rhs_out) { + rhs_size *= rhs_type.getShape()[std::get<0>(i)]; + } + rhs_shape.push_back(rhs_size); + out_shape->push_back(rhs_size); + + *lhs = createReshapeOp(*lhs, lhs_shape, lhs_type.getElementType(), loc, + rewriter); + *rhs = createReshapeOp(*rhs, rhs_shape, rhs_type.getElementType(), loc, + rewriter); + + dnums.lhs_rhs.assign( + {std::make_tuple(dnums.lhs_rhs_out.size() + dnums.lhs_out.size(), + dnums.lhs_rhs_out.size())}); + dnums.rhs_out.assign( + {std::make_tuple(dnums.lhs_rhs_out.size() + dnums.lhs_out.size(), + dnums.lhs_rhs_out.size() + dnums.lhs_out.size())}); + return success(); +} + +LogicalResult rewriteToBatchMatmul(TF::EinsumOp op, + EinsumDimensionNumbers dnums, + PatternRewriter& rewriter) { + if (!dnums.lhs.empty() || !dnums.rhs.empty()) return failure(); + + auto inputs = op.inputs(); + if (inputs.size() != 2) return failure(); + Value lhs = inputs.front(); + Value rhs = inputs.back(); + + RankedTensorType original_type = + op.getResult().getType().dyn_cast_or_null(); + if (!original_type) return failure(); + + std::vector out_transpose; + if (failed(transposeForBatchMatmul(op.getLoc(), dnums, &lhs, &rhs, + &out_transpose, &rewriter))) + return failure(); + + std::vector matmul_shape; + if (failed(reshapeForBatchMatmul(op.getLoc(), dnums, &lhs, &rhs, + &matmul_shape, &rewriter))) + return failure(); + + auto matmul_type = + RankedTensorType::get(matmul_shape, original_type.getElementType()); + Value out = rewriter.create( + op.getLoc(), matmul_type, lhs, rhs, rewriter.getBoolAttr(false), + rewriter.getBoolAttr(false)); + + out = createReshapeOp( + out, inverseTransposeVector(original_type.getShape(), out_transpose), + original_type.getElementType(), op.getLoc(), &rewriter); + out = createTransposeOp(out, op.getLoc(), out_transpose, &rewriter); + + rewriter.replaceOp(op, out); + return success(); +} + } // namespace LogicalResult ConvertTFEinsumOp::matchAndRewrite( TF::EinsumOp op, PatternRewriter& rewriter) const { - Type output_type = op.getResult().getType(); - Value lhs = op.getOperand(0); - Value rhs = op.getOperand(1); - Location loc = op.getLoc(); - if (!lhs.getType().isa()) { - // LHS must be a ranked tensor type - return failure(); - } - if (!rhs.getType().isa()) { - // RHS must be a ranked tensor type - return failure(); - } + const auto dnums_or = GetEinsumDimensionNumbers(op.equation()); + if (!dnums_or.hasValue()) return failure(); + const auto& dnums = dnums_or.getValue(); - auto lhs_type = lhs.getType().cast(); - auto rhs_type = rhs.getType().cast(); - auto lhs_shape = lhs_type.getShape(); - auto rhs_shape = rhs_type.getShape(); + RankedTensorType lhs = + op.getOperand(0).getType().dyn_cast_or_null(); + RankedTensorType rhs = + op.getOperand(1).getType().dyn_cast_or_null(); + if (!lhs || !rhs) return failure(); - // Currently only support static shapes. - if (!(lhs_type.hasStaticShape() && rhs_type.hasStaticShape())) { - return failure(); - } - - // Currently support use cases of LHS dims \in {3,4} RHS dims \in {2, 3, 4} - const int dims_lhs = lhs_shape.size(); - const int dims_rhs = rhs_shape.size(); - if (dims_lhs < 3 || dims_lhs > 4 || dims_rhs < 2 || dims_rhs > 4) { - return failure(); - } - - EinsumEquation einsum_eqn = tokenizeAndParse(op.equation()); - if (einsum_eqn == EinsumEquation::BatchMatMul) { - // Case "IJK,IKM->IJM" - auto bmm_op = rewriter.create( - loc, ArrayRef{output_type}, lhs, rhs, rewriter.getBoolAttr(false), - rewriter.getBoolAttr(false)); - rewriter.replaceOp(op, bmm_op.getResult()); - return success(); - } - if (einsum_eqn == EinsumEquation::BroadcastMatMul) { - // Case "BFH,HO->BFO" - auto bmm_op = rewriter.create( - loc, ArrayRef{output_type}, lhs, rhs, rewriter.getBoolAttr(false), - rewriter.getBoolAttr(false)); - rewriter.replaceOp(op, bmm_op.getResult()); - return success(); - } - if (einsum_eqn == EinsumEquation::ReduceSum) { - // Case "LBH,BL->BH" - // Transpose LHS - lhs = createTransposeOp(lhs, loc, {1, 2, 0}, &rewriter); - // Reshape RHS - auto rhs_element_type = rhs_type.getElementType(); - const int rhs_dim0 = rhs_shape[0]; - const int rhs_dim1 = rhs_shape[1]; - auto reshaped_rhs = createReshapeOp(rhs, {rhs_dim0, 1, rhs_dim1}, - rhs_element_type, loc, &rewriter); - auto mul_op = rewriter.create(loc, lhs, reshaped_rhs); - - auto sum_op = createSumOp(mul_op, loc, {2}, &rewriter); - rewriter.replaceOp(op, {sum_op.getResult()}); - return success(); - } - if (einsum_eqn == EinsumEquation::TransposeMatMul) { - // Case "LBH,BKL->BKH" - // Transpose LHS - lhs = createTransposeOp(lhs, loc, {1, 2, 0}, &rewriter); - // Transpose RHS - rhs = createTransposeOp(rhs, loc, {0, 2, 1}, &rewriter); - std::vector bmm_shape = {lhs_shape[1], lhs_shape[2], rhs_shape[1]}; - auto bmm_type = RankedTensorType::get(bmm_shape, rhs_type.getElementType()); - auto bmm_op = rewriter.create( - loc, ArrayRef{bmm_type}, lhs, rhs, rewriter.getBoolAttr(false), - rewriter.getBoolAttr(false)); - - auto trans_bmm = createTransposeOp(bmm_op, loc, {0, 2, 1}, &rewriter); - rewriter.replaceOp(op, {trans_bmm.getResult()}); - return success(); - } - if (einsum_eqn == EinsumEquation::ThreeDReshapeTail) { - // Case "BFD,DNH->BFNH" - auto lhs_type = lhs.getType().cast(); - auto lhs_shape = lhs_type.getShape(); - const int lhs_dim0 = lhs_shape[0]; - const int lhs_dim1 = lhs_shape[1]; - // Reshape RHS - auto rhs_type = rhs.getType().cast(); - auto rhs_shape = rhs_type.getShape(); - auto rhs_element_type = rhs_type.getElementType(); - const int rhs_dim0 = rhs_shape[0]; - const int rhs_dim1 = rhs_shape[1]; - const int rhs_dim2 = rhs_shape[2]; - auto reshaped_rhs = createReshapeOp(rhs, {rhs_dim0, rhs_dim1 * rhs_dim2}, - rhs_element_type, loc, &rewriter); - - std::vector bmm_shape = {lhs_dim0, lhs_dim1, rhs_dim1 * rhs_dim2}; - auto bmm_type = RankedTensorType::get(bmm_shape, rhs_type.getElementType()); - auto bmm_op = rewriter.create( - loc, ArrayRef{bmm_type}, lhs, reshaped_rhs, - rewriter.getBoolAttr(false), rewriter.getBoolAttr(false)); - auto bmm_element_type = bmm_type.getElementType(); - auto final_reshape = - createReshapeOp(bmm_op, {lhs_dim0, lhs_dim1, rhs_dim1, rhs_dim2}, - bmm_element_type, loc, &rewriter); - rewriter.replaceOp(op, {final_reshape.getResult()}); - return success(); - } - if (einsum_eqn == EinsumEquation::FourDMatrixDotProd) { - // Case "BFND,NDH->BFH" - // Reshape LHS - auto lhs_element_type = lhs_type.getElementType(); - const int lhs_dim0 = lhs_shape[0]; - const int lhs_dim1 = lhs_shape[1]; - const int lhs_dim2 = lhs_shape[2]; - const int lhs_dim3 = lhs_shape[3]; - auto reshaped_lhs = - createReshapeOp(lhs, {lhs_dim0, lhs_dim1, lhs_dim2 * lhs_dim3}, - lhs_element_type, loc, &rewriter); - // Reshape RHS - auto rhs_element_type = rhs_type.getElementType(); - const int rhs_dim0 = rhs_shape[0]; - const int rhs_dim1 = rhs_shape[1]; - const int rhs_dim2 = rhs_shape[2]; - auto reshaped_rhs = createReshapeOp(rhs, {rhs_dim0 * rhs_dim1, rhs_dim2}, - rhs_element_type, loc, &rewriter); - auto bmm_op = rewriter.create( - loc, ArrayRef{output_type}, reshaped_lhs, reshaped_rhs, - rewriter.getBoolAttr(false), rewriter.getBoolAttr(false)); - rewriter.replaceOp(op, {bmm_op.getResult()}); - return success(); - } - if (einsum_eqn == EinsumEquation::FourDBatchMatMul) { - // Case "BFNH,BTNH->BNFT" - // Transpose LHS - lhs = createTransposeOp(lhs, loc, {0, 2, 1, 3}, &rewriter); - // Transpose RHS - rhs = createTransposeOp(rhs, loc, {0, 2, 3, 1}, &rewriter); - auto bmm_op = rewriter.create( - loc, ArrayRef{output_type}, lhs, rhs, rewriter.getBoolAttr(false), - rewriter.getBoolAttr(false)); - rewriter.replaceOp(op, {bmm_op.getResult()}); - return success(); - } - if (einsum_eqn == EinsumEquation::BatchMatMulReducedDim) { - // Case "BIN,BINJ->BIJ" - // Reshape LHS - auto lhs_element_type = lhs_type.getElementType(); - const int lhs_dim0 = lhs_shape[0]; - const int lhs_dim1 = lhs_shape[1]; - const int lhs_dim2 = lhs_shape[2]; - const int rhs_dim3 = rhs_shape[3]; - - auto reshaped_lhs = createReshapeOp(lhs, {lhs_dim0, lhs_dim1, 1, lhs_dim2}, - lhs_element_type, loc, &rewriter); - std::vector bmm_shape = {lhs_dim0, lhs_dim1, 1, rhs_dim3}; - auto bmm_type = RankedTensorType::get(bmm_shape, rhs_type.getElementType()); - auto bmm_op = rewriter.create( - loc, ArrayRef{bmm_type}, reshaped_lhs, rhs, - rewriter.getBoolAttr(false), rewriter.getBoolAttr(false)); - - auto bmm_element_type = bmm_type.getElementType(); - auto final_reshape = createReshapeOp(bmm_op, {lhs_dim0, lhs_dim1, rhs_dim3}, - bmm_element_type, loc, &rewriter); - rewriter.replaceOp(op, {final_reshape.getResult()}); - return success(); - } - if (einsum_eqn == EinsumEquation::TransposeReducedDim) { - // Case "BIJ,BINJ->BIN" - // Reshape LHS - auto lhs_element_type = lhs_type.getElementType(); - const int lhs_dim0 = lhs_shape[0]; - const int lhs_dim1 = lhs_shape[1]; - const int lhs_dim2 = lhs_shape[2]; - const int rhs_dim2 = rhs_shape[2]; - - auto reshaped_lhs = createReshapeOp(lhs, {lhs_dim0, lhs_dim1, 1, lhs_dim2}, - lhs_element_type, loc, &rewriter); - // Transpose RHS - rhs = createTransposeOp(rhs, loc, {0, 1, 3, 2}, &rewriter); - std::vector bmm_shape = {lhs_dim0, lhs_dim1, 1, rhs_dim2}; - auto bmm_type = RankedTensorType::get(bmm_shape, rhs_type.getElementType()); - auto bmm_op = rewriter.create( - loc, ArrayRef{bmm_type}, reshaped_lhs, rhs, - rewriter.getBoolAttr(false), rewriter.getBoolAttr(false)); - - auto bmm_element_type = bmm_type.getElementType(); - auto final_reshape = createReshapeOp(bmm_op, {lhs_dim0, lhs_dim1, rhs_dim2}, - bmm_element_type, loc, &rewriter); - rewriter.replaceOp(op, {final_reshape.getResult()}); - return success(); - } - if (einsum_eqn == EinsumEquation::FourDReduceLast) { - // Case "acbe,aecd->abcd" - const int lhs_dim2 = lhs_shape[2]; - const int rhs_dim0 = rhs_shape[0]; - const int rhs_dim2 = rhs_shape[2]; - const int rhs_dim3 = rhs_shape[3]; - // Transpose RHS - rhs = createTransposeOp(rhs, loc, {0, 2, 1, 3}, &rewriter); - std::vector bmm_shape = {rhs_dim0, rhs_dim2, lhs_dim2, rhs_dim3}; - auto bmm_type = RankedTensorType::get(bmm_shape, rhs_type.getElementType()); - auto bmm_op = rewriter.create( - loc, ArrayRef{bmm_type}, lhs, rhs, rewriter.getBoolAttr(false), - rewriter.getBoolAttr(false)); - - auto trans_bmm = createTransposeOp(bmm_op, loc, {0, 2, 1, 3}, &rewriter); - rewriter.replaceOp(op, {trans_bmm.getResult()}); - return success(); - } - if (einsum_eqn == EinsumEquation::FourDTransposeAll) { - // Case "aecd,abcd->acbe" - const int lhs_dim0 = lhs_shape[0]; - const int lhs_dim1 = lhs_shape[1]; - const int lhs_dim2 = lhs_shape[2]; - const int rhs_dim1 = rhs_shape[1]; - // Transpose LHS - lhs = createTransposeOp(lhs, loc, {0, 2, 1, 3}, &rewriter); - // Transpose RHS - rhs = createTransposeOp(rhs, loc, {0, 2, 3, 1}, &rewriter); - std::vector bmm_shape = {lhs_dim0, lhs_dim2, lhs_dim1, rhs_dim1}; - auto bmm_type = RankedTensorType::get(bmm_shape, rhs_type.getElementType()); - auto bmm_op = rewriter.create( - loc, ArrayRef{bmm_type}, lhs, rhs, rewriter.getBoolAttr(false), - rewriter.getBoolAttr(false)); - - auto trans_bmm = createTransposeOp(bmm_op, loc, {0, 1, 3, 2}, &rewriter); - rewriter.replaceOp(op, {trans_bmm.getResult()}); - return success(); - } - - return failure(); + return rewriteToBatchMatmul(op, dnums, rewriter); } // Transform Einsum to other TF Ops for the supported variants. diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc index d8678e620f4..a5d76619416 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc @@ -157,14 +157,14 @@ static LogicalResult LowerIfOp(IfOp op) { // Set up the 'then' block. Block* then_block = builder.createBlock(merge_block); - Operation* call_op = CallFn(loc, get_operand, op.then_func(), &builder); + Operation* call_op = CallFn(loc, get_operand, op.then_function(), &builder); auto get_then_result = [&](int i) { return call_op->getResult(i); }; JumpToBlock(loc, get_then_result, merge_block, &builder); // Set up the 'else' block. Block* else_block = builder.createBlock(merge_block); - call_op = CallFn(loc, get_operand, op.else_func(), &builder); + call_op = CallFn(loc, get_operand, op.else_function(), &builder); auto get_else_result = [&](int i) { return call_op->getResult(i); }; JumpToBlock(loc, get_else_result, merge_block, &builder); @@ -190,8 +190,8 @@ static LogicalResult LowerWhileOp(WhileOp op) { OpBuilder builder(op_inst); - auto cond_fn = op.cond_func(); - auto body_fn = op.body_func(); + auto cond_fn = op.cond_function(); + auto body_fn = op.body_function(); // Split the block containing the While op into two blocks. One containing // operations before the While op and other containing the rest. Create two diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc index 11d74e87f96..87733bbbf3f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc @@ -98,10 +98,10 @@ LogicalResult ConvertIfOp(IfOp if_op) { if_op.getLoc(), if_op.getResultTypes(), cond, if_op.is_stateless()); CopyDeviceAndUnderscoredAttributes(if_op, if_region); - CreateCall(if_op, if_op.then_func(), + CreateCall(if_op, if_op.then_function(), /*caller_region=*/if_region.then_branch(), if_op.input(), /*use_region_args=*/false); - CreateCall(if_op, if_op.else_func(), + CreateCall(if_op, if_op.else_function(), /*caller_region=*/if_region.else_branch(), if_op.input(), /*use_region_args=*/false); if_op.replaceAllUsesWith(if_region.getResults()); @@ -116,14 +116,14 @@ LogicalResult ConvertWhileOp(WhileOp while_op) { CopyDeviceAndUnderscoredAttributes(while_op, while_region); YieldOp cond_yield = - CreateCall(while_op, while_op.cond_func(), + CreateCall(while_op, while_op.cond_function(), /*caller_region=*/while_region.cond(), while_op.input(), /*use_region_args=*/true); Value i1_cond = ConvertConditionToBoolean(cond_yield, cond_yield.getOperand(0)); cond_yield.setOperand(0, i1_cond); - CreateCall(while_op, while_op.body_func(), + CreateCall(while_op, while_op.body_function(), /*caller_region=*/while_region.body(), while_op.input(), /*use_region_args=*/true); while_op.replaceAllUsesWith(while_region.getResults()); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.cc index 7563f606434..a18d893fac7 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.cc @@ -20,6 +20,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" namespace mlir { @@ -39,6 +40,7 @@ Status MlirGraphOptimizationPass::Run(const ConfigProto& config_proto, VLOG(1) << "Run MLIR Graph Optimization Passes"; PassManager pm(module.getContext()); + ::tensorflow::applyTensorflowAndCLOptions(pm); // Run island coarsening before shape inference to allow more exact shape // inference using constant folding within islands. diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/layout_optimization.cc b/tensorflow/compiler/mlir/tensorflow/transforms/layout_optimization.cc index e76a8da0b29..8123f50757e 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/layout_optimization.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/layout_optimization.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/FormatVariadic.h" #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/Function.h" // from @llvm-project @@ -33,6 +35,34 @@ namespace mlir { namespace TF { namespace { +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_helpers.inc" + +// Helper method that returns an op from 'transpose_ops' that match criteria +// for an 'operand' and 'permutation' +TransposeOp ReuseExistingTranspose(const OpOperand* operand, + const SmallVector& permutation, + Operation* op, ConstOp permutation_op, + SmallVector* transpose_ops) { + for (auto it = transpose_ops->begin(); it != transpose_ops->end(); ++it) { + auto tranpose_op = *it; + for (auto tranpose_operand : tranpose_op.getOperands()) { + auto ranked_tranpose_type = + tranpose_operand.getType().dyn_cast_or_null(); + if (!ranked_tranpose_type) continue; + if (ranked_tranpose_type.getRank() == permutation.size() && + operand->get().getType() == + ShuffleRankedTensorType(ranked_tranpose_type, permutation)) { + TransposeOp transpose = tranpose_op; + transpose.getOperation()->moveBefore(op); + transpose.setOperand(0, operand->get()); + transpose.setOperand(1, permutation_op); + transpose_ops->erase(it); + return transpose; + } + } + } + return nullptr; +} // LayoutAssignmentPass assigns optimal data layout (data format) for all // layout sensitive operations. @@ -79,18 +109,7 @@ class MoveTransposesPass clEnumValN(Direction::kEnd, "end", "end of the block"))}; }; -using Permutation = SmallVector; - -Permutation GetDataFormatPermutation(StringRef from_data_format, - StringRef to_data_format) { - if (from_data_format == "NHWC" && to_data_format == "NCHW") { - return {0, 3, 1, 2}; - } else if (from_data_format == "NCHW" && to_data_format == "NHWC") { - return {0, 2, 3, 1}; - } else { - llvm_unreachable("Unknown data format combination"); - } -} +using Permutation = SmallVector; void LayoutAssignmentPass::runOnFunction() { FuncOp func = getFunction(); @@ -131,7 +150,7 @@ void LayoutAssignmentPass::runOnFunction() { OpBuilder builder = OpBuilder::atBlockEnd(op->getBlock()); auto perm_attr = [&](Permutation permutation) -> DenseIntElementsAttr { - auto perm_ty = RankedTensorType::get({4}, builder.getIntegerType(32)); + auto perm_ty = RankedTensorType::get({4}, builder.getIntegerType(64)); return DenseIntElementsAttr::get(perm_ty, permutation); }; @@ -202,6 +221,27 @@ void MoveTransposeBefore(Operation* op, SmallVector* work_list) { // Nothing to do here. if (!permutation_op || transpose_ops.empty()) return; + SmallVector permutation; + auto perm_attr = permutation_op.value().cast(); + for (const auto& value : perm_attr.getIntValues()) + permutation.push_back(value.getSExtValue()); + + // We want to make sure the shape of the operand equals the transposed shape. + // mismatch can happen if 'op' supports broadcasting and the operands have + // different ranks. + if (op->hasTrait()) { + auto transpose_op = *transpose_ops.begin(); + auto result_type = + transpose_op.getResult().getType().dyn_cast_or_null(); + auto is_valid_move = + llvm::all_of(op->getOperands(), [result_type](Value operand) -> bool { + auto operand_type = operand.getType().dyn_cast_or_null(); + return result_type && operand_type && result_type.hasRank() && + operand_type.hasRank() && + result_type.getRank() == operand_type.getRank(); + }); + if (!is_valid_move) return; + } // At this point we checked that we can safely move Transpose node before // `op`, and bypass all result transposes. @@ -228,16 +268,12 @@ void MoveTransposeBefore(Operation* op, SmallVector* work_list) { work_list->push_back(operand_op); // Try to reuse result transposes. - TransposeOp transpose; - if (!transpose_ops.empty()) { - transpose = transpose_ops.pop_back_val(); - transpose.getOperation()->moveBefore(op); - transpose.setOperand(0, operand.get()); - transpose.setOperand(1, permutation_op); - } else { + TransposeOp transpose = ReuseExistingTranspose( + &operand, permutation, op, permutation_op, &transpose_ops); + // If no transpose available for using, create new one. + if (!transpose) transpose = builder.create(loc, operand.get(), permutation_op); - } operand.set(transpose); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc index ad241ef9488..8ab348c1e5b 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc @@ -88,7 +88,7 @@ class ConvertConvOp : public OpConversionPattern { const int input_channels = conv_op.lhs().getType().cast().getDimSize( input_feature_dimension); - int feature_group_count = conv_op.feature_group_count().getSExtValue(); + int feature_group_count = conv_op.feature_group_count(); const bool is_depthwise_conv = input_channels == feature_group_count; std::string padding; @@ -250,7 +250,7 @@ class ConvertSliceOp : public OpConversionPattern { strides.getSplatValue().cast().getInt() != 1) return failure(); - rewriter.setInsertionPointAfter(slice_op); + rewriter.setInsertionPointAfter(slice_op.getOperation()); auto start_indices = slice_op.start_indices(); auto limit_indices = slice_op.limit_indices(); std::vector size_values; @@ -614,7 +614,65 @@ class ConvertReduceOpToTfMin : public OpConversionPattern { }; }; +class ConvertIotaOpToTfRange : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + mhlo::IotaOp iota_op, ArrayRef args, + ConversionPatternRewriter &rewriter) const final { + RankedTensorType type = + iota_op.getType().dyn_cast_or_null(); + if (!type) return failure(); + + const uint64_t dimension = iota_op.iota_dimension(); + Type element_type = type.getElementType(); + Attribute start, limit, delta; + if (element_type.isa()) { + start = rewriter.getFloatAttr(element_type, 0.0); + limit = rewriter.getFloatAttr(element_type, type.getShape()[dimension]); + delta = rewriter.getFloatAttr(element_type, 1.0); + } else if (element_type.isa()) { + start = rewriter.getIntegerAttr(element_type, 0); + limit = rewriter.getIntegerAttr(element_type, type.getShape()[dimension]); + delta = rewriter.getIntegerAttr(element_type, 1); + } else { + return failure(); + } + + auto range_type = + RankedTensorType::get({type.getShape()[dimension]}, element_type); + Value start_op = rewriter.create(iota_op.getLoc(), start); + Value limit_op = rewriter.create(iota_op.getLoc(), limit); + Value delta_op = rewriter.create(iota_op.getLoc(), delta); + Value result = rewriter.create(iota_op.getLoc(), range_type, + start_op, limit_op, delta_op); + + if (type.getRank() > 1) { + std::vector reshape_shape(type.getRank(), 1); + reshape_shape[iota_op.iota_dimension()] = type.getShape()[dimension]; + auto reshape_type = RankedTensorType::get(reshape_shape, element_type); + Value reshape_shape_op = rewriter.create( + iota_op.getLoc(), rewriter.getI64TensorAttr(reshape_shape)); + result = rewriter.create(iota_op.getLoc(), reshape_type, + result, reshape_shape_op); + + Value broadcast_shape_op = rewriter.create( + iota_op.getLoc(), rewriter.getI64TensorAttr(type.getShape())); + result = rewriter.create(iota_op.getLoc(), type, + result, broadcast_shape_op); + } + + rewriter.replaceOp(iota_op, result); + return success(); + } +}; + class LegalizeHloToTf : public PassWrapper { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + public: LegalizeHloToTf() = default; LegalizeHloToTf(const LegalizeHloToTf &) {} @@ -763,9 +821,10 @@ static PassRegistration pass( void PopulateLegalizeHloToTfPatterns(OwningRewritePatternList *patterns, MLIRContext *context) { - populateWithGenerated(context, patterns); + populateWithGenerated(context, *patterns); patterns->insert(context); + ConvertReduceOpToTfMin, ConvertReduceOpToTfSum, + ConvertIotaOpToTfRange>(context); } std::unique_ptr> CreateLegalizeHloToTfPass() { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.cc b/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.cc index 6686b340be9..6c1e6a827c7 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.cc @@ -176,7 +176,48 @@ LogicalResult LiftVariables(ModuleOp module, Session* session) { if (resource_names.empty()) return success(); - return LiftVariablesFromSession(module, session, resource_names); + if (failed(LiftVariablesFromSession(module, session, resource_names))) + return failure(); + + // Now that we have all global tensors created, we set the corresponding + // bound_inputs' types correctly. + SymbolTable symbol_table(module); + for (auto func : module.getOps()) { + for (auto arg : func.getArguments()) { + unsigned arg_number = arg.getArgNumber(); + auto global_tensor = LookupBoundInputOfType( + func, arg_number, symbol_table); + if (!global_tensor) continue; + + auto arg_type = arg.getType().cast(); + assert(arg_type.getRank() == 0); + llvm::ArrayRef underlying_type = + arg_type.getElementType().cast().getSubtypes(); + + // If the arg type already matches the global_tensor type, we don't need + // to do anything. + if (!underlying_type.empty() && + underlying_type[0] == global_tensor.type()) { + assert(underlying_type.size() == 1); + continue; + } + + // Otherwise, set this argument's type to the global_tensor's type. + auto new_arg_type = mlir::RankedTensorType::get( + /*shape=*/{}, + mlir::TF::ResourceType::get( + /*subtypes=*/{global_tensor.type().cast()}, + module.getContext())); + + arg.setType(new_arg_type); + } + + // Update the function type. + func.setType(mlir::FunctionType::get(func.getArgumentTypes(), + func.getType().getResults(), + module.getContext())); + } + return success(); } } // namespace tf_saved_model diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc index 6946dc65104..a462f967bef 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc @@ -56,6 +56,14 @@ static DenseIntElementsAttr GetI64ElementsAttrForSeq(int start, int end, return DenseIntElementsAttr::get(ty, vals); } +static APFloat ConvertToAPFloat(double val, Type type) { + if (type.getIntOrFloatBitWidth() == 32) { + return APFloat(static_cast(val)); + } + + return APFloat(val); +} + // Returns int, float, or complex DenseElementsAttr with scalar shape with the // given element type and the integer value. static DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value) { @@ -121,6 +129,17 @@ Type InferExpandDimsType(Type ty, int64_t axis, Builder *builder) { return RankedTensorType::get(shape, ranked_ty.getElementType()); } +// Converts individual Values to a tensor of rank 1. Each input Value has rank 1 +// and size 1. +Value ValuesToRank1(PatternRewriter &rewriter, Location loc, Type dtype, + ArrayRef vals) { + int64_t length = vals.size(); + auto type = RankedTensorType::get({length}, dtype); + auto axis = rewriter.create( + loc, GetScalarOfType(rewriter.getIntegerType(64), 0)); + return rewriter.create(loc, type, ValueRange(vals), axis); +} + // Lowers AddN op to a sequence of AddV2 ops to accumulate operands. // // Note that to improve the parallelism, AddN op uses tree-based reduction. @@ -160,34 +179,37 @@ Type InferExpandDimsType(Type ty, int64_t axis, Builder *builder) { // %sum2 = "tf.AddV2"(%sum0, %sum1) // %result = "tf.AddV2"(%sum2, %4) // -class LowerAddNOp : public OpRewritePattern { +class LowerAddNOp : public RewritePattern { public: explicit LowerAddNOp(MLIRContext *context) - : OpRewritePattern(context) {} + : RewritePattern(TF::AddNOp::getOperationName(), + {TF::AddV2Op::getOperationName()}, 1, context) {} - LogicalResult matchAndRewrite(TF::AddNOp op, + LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { + auto addn_op = cast(op); + // TODO(hinsu): Support variant with TensorList type. tf.AddV2 doesn't // support variant type so variant types require special handling. - if (getElementTypeOrSelf(op.getType()).isa()) return failure(); - - llvm::SmallVector operands(op.inputs().begin(), - op.inputs().end()); + if (getElementTypeOrSelf(addn_op.getType()).isa()) + return failure(); + llvm::SmallVector operands(addn_op.inputs().begin(), + addn_op.inputs().end()); int64_t n = operands.size(); // Keep doing tree-based reduction when there are more than one operand. while (n > 1) { for (int64_t i = 0; i < n; i += 2) { // Add two adjacent operands if applicable. - operands[i / 2] = (i + 1 < n) - ? rewriter.create( - op.getLoc(), operands[i], operands[i + 1]) - : operands[i]; + operands[i / 2] = + (i + 1 < n) ? rewriter.create( + addn_op.getLoc(), operands[i], operands[i + 1]) + : operands[i]; } n = (n + 1) / 2; } - rewriter.replaceOp(op, operands[0]); + rewriter.replaceOp(addn_op, operands[0]); return success(); } }; @@ -273,7 +295,7 @@ class LowerDynamicStitchOp : public OpRewritePattern { reshaped_data.getType().cast().getShape()[0]; auto items = rewriter.create( loc, SmallVector(num_items, item_ty), reshaped_data, - /*axis=*/APInt(64, 0)); + /*axis=*/0); for (auto index_item : llvm::zip(index_attr, items.getResults())) { int64_t output_index = std::get<0>(index_item).getSExtValue(); Value item = std::get<1>(index_item); @@ -287,6 +309,114 @@ class LowerDynamicStitchOp : public OpRewritePattern { } }; +// This pass performs a manual conversion with FakeQuant, converting between +// floating point and quantized space. It is designed to reproduce TF's +// implementation, mirroring the previous XLA implementation. +// +// 1. Computing proper quantized bounds. This involves nudging the input bounds. +// 2. Converting the input bounds to quantized space, rounding values. +// 3. Convert back into floating point space. +class ConvertFakeQuantWithMinMaxVarsOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::FakeQuantWithMinMaxVarsOp op, + PatternRewriter &rewriter) const override { + auto input = op.inputs(); + auto input_ty = input.getType().cast(); + auto element_ty = input_ty.getElementType(); + auto scalar_ty = RankedTensorType::get({}, element_ty); + + auto num_bits = op.num_bits(); + auto narrow_range = op.narrow_range(); + const double bits_min = narrow_range ? 1 : 0; + const double bits_max = (1 << num_bits) - 1; + + auto float_min = op.min(); + auto float_max = op.max(); + + auto float_diff = + rewriter.create(op.getLoc(), float_max, float_min); + + // Compute the range when quantized. + auto quant_min = rewriter.create( + op.getLoc(), DenseElementsAttr::get( + scalar_ty, ConvertToAPFloat(bits_min, element_ty))); + + auto quant_max = rewriter.create( + op.getLoc(), DenseElementsAttr::get( + scalar_ty, ConvertToAPFloat(bits_max, element_ty))); + + auto quant_diff = rewriter.create( + op.getLoc(), + DenseElementsAttr::get( + scalar_ty, ConvertToAPFloat(bits_max - bits_min, element_ty))); + + auto quant_to_float = + rewriter.create(op.getLoc(), float_diff, quant_diff); + + auto float_to_quant = + rewriter.create(op.getLoc(), quant_diff, float_diff); + + // During quantization, the quantized min/max values may not line up + // perfectly with the specified min/max. Nudge them into the right range. + auto min_scaled = + rewriter.create(op.getLoc(), float_min, quant_to_float); + auto min_scaled_sub = + rewriter.create(op.getLoc(), quant_min, min_scaled); + + auto mid_rounded = + rewriter.create(op.getLoc(), scalar_ty, min_scaled_sub); + + auto nudged_zero_point_val = rewriter.create( + op.getLoc(), scalar_ty, mid_rounded, quant_min, quant_max); + + auto quant_min_sub = rewriter.create(op.getLoc(), quant_min, + nudged_zero_point_val); + auto quant_max_sub = rewriter.create(op.getLoc(), quant_max, + nudged_zero_point_val); + + auto nudged_float_min = + rewriter.create(op.getLoc(), quant_min_sub, quant_to_float); + + auto nudged_float_max = + rewriter.create(op.getLoc(), quant_max_sub, quant_to_float); + + // Now quantize the input value with the approximated min/max values. + + // Move the input value into quantized space + Value quantized_input = rewriter.create( + op.getLoc(), input_ty, input, nudged_float_min, nudged_float_max); + + quantized_input = rewriter.create( + op.getLoc(), input_ty, quantized_input, nudged_float_min); + + quantized_input = rewriter.create( + op.getLoc(), input_ty, quantized_input, float_to_quant); + + // Round the quantized input always to the positive direction. + auto half_val = rewriter.create( + op.getLoc(), + DenseElementsAttr::get(scalar_ty, ConvertToAPFloat(0.5, element_ty))); + + quantized_input = rewriter.create(op.getLoc(), input_ty, + quantized_input, half_val); + + quantized_input = + rewriter.create(op.getLoc(), quantized_input); + + // Convert back into floating point spae. + Value output = rewriter.create(op.getLoc(), input_ty, + quantized_input, quant_to_float); + + output = rewriter.create(op.getLoc(), input_ty, output, + nudged_float_min); + + rewriter.replaceOp(op, {output}); + return success(); + } +}; + // Lowers InvertPermutation op to TensorScatterUpdate op. // // Example: @@ -347,6 +477,210 @@ class LowerInvertPermutationOp } }; +// Approximates lgamma using Lanczos' approximation from +// "A Precision Approximation of the Gamma Function". SIAM Journal on Numerical +// Analysis series B. Vol. 1: +// lgamma(z + 1) = (log(2) + log(pi)) / 2 + (z + 1/2) * log(t(z)) - t(z) + A(z) +// t(z) = z + kLanczosGamma + 1/2 +// A(z) = kBaseLanczosCoeff +// + sigma(k = 1, n, kLanczosCoefficients[i] / (z + k)) +// +// Coefficients for the Lanczos approximation of the gamma function. The +// coefficients are uniquely determined by the choice of g and n +// (kLanczosGamma and kLanczosCoefficients.size() + 1). The coefficients below +// correspond to [7, 9]. [5, 7], [7, 9], [9, 10], and [607/128.0, 15] were +// evaluated and [7, 9] seemed to be the least sensitive to the quality of the +// log function. In particular, [5, 7] is the only choice where -1.5e-5 <= +// lgamma(2) <= 1.5e-5 for a particularly inaccurate log function. +static constexpr double kLanczosGamma = 7; // aka g +static constexpr double kBaseLanczosCoeff = 0.99999999999980993227684700473478; +static constexpr std::array kLanczosCoefficients = { + 676.520368121885098567009190444019, -1259.13921672240287047156078755283, + 771.3234287776530788486528258894, -176.61502916214059906584551354, + 12.507343278686904814458936853, -0.13857109526572011689554707, + 9.984369578019570859563e-6, 1.50563273514931155834e-7}; + +class LowerLgammaOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::LgammaOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value input = op.x(); + TensorType original_tensor_type = op.x().getType().cast(); + + // The approximation is not precise enough for float16. Do the computation + // in float32 for that case. + TensorType tensor_type = original_tensor_type; + FloatType float_type = tensor_type.getElementType().cast(); + bool needs_cast = float_type.getWidth() < 32; + if (needs_cast) { + MLIRContext *context = rewriter.getContext(); + float_type = FloatType::getF32(context); + if (original_tensor_type.hasRank()) { + tensor_type = + RankedTensorType::get(original_tensor_type.getShape(), float_type); + } else { + tensor_type = UnrankedTensorType::get(float_type); + } + input = rewriter.create(loc, tensor_type, input); + } + + // Helper lambda function for creating a ConstOp for a tensor filled with + // the given constant float value. + auto create_const_op = [&rewriter, loc, tensor_type, + float_type](double value) { + return rewriter.create( + loc, DenseElementsAttr::get(tensor_type, + FloatAttr::get(float_type, value))); + }; + + Value one_half = create_const_op(0.5); + Value one = create_const_op(1.0); + Value infinity = create_const_op(std::numeric_limits::infinity()); + Value pi = create_const_op(M_PI); + Value log_pi = create_const_op(std::log(M_PI)); + Value log_sqrt_two_pi = create_const_op((std::log(2) + std::log(M_PI)) / 2); + Value lanczos_gamma_plus_one_half = create_const_op(kLanczosGamma + 0.5); + Value log_lanczos_gamma_plus_one_half = + create_const_op(std::log(kLanczosGamma + 0.5)); + Value base_lanczos_coeff = create_const_op(kBaseLanczosCoeff); + + Value minus_input = rewriter.create(loc, input); + Value input_minus_one = rewriter.create(loc, input, one); + + // If the input is less than 0.5 use Euler's reflection formula: + // gamma(x) = pi / (sin(pi * x) * gamma(1 - x)) + Value need_to_reflect = rewriter.create(loc, input, one_half); + Type tensor_bool_type = need_to_reflect.getType(); + Value z = rewriter.create(loc, need_to_reflect, minus_input, + input_minus_one); + + Value x = base_lanczos_coeff; + for (int i = 0, end = kLanczosCoefficients.size(); i < end; ++i) { + Value lanczos_coefficient = create_const_op(kLanczosCoefficients[i]); + Value index = create_const_op(static_cast(i)); + Value z_plus_index = rewriter.create(loc, z, index); + Value z_plus_index_plus_one = + rewriter.create(loc, z_plus_index, one); + Value incr = rewriter.create(loc, lanczos_coefficient, + z_plus_index_plus_one); + x = rewriter.create(loc, x, incr); + } + + // To improve accuracy on platforms with less-precise log implementations, + // compute log(lanczos_gamma_plus_one_half) at compile time and use log1p on + // the device. + // log(t) = log(kLanczosGamma + 0.5 + z) + // = log(kLanczosGamma + 0.5) + log1p(z / (kLanczosGamma + 0.5)) + Value t = rewriter.create(loc, lanczos_gamma_plus_one_half, z); + Value z_div_lanczos_gamma_plus_one_half = + rewriter.create(loc, z, lanczos_gamma_plus_one_half); + Value log1p_z_div_lanczos_gamma_plus_one_half = + rewriter.create(loc, z_div_lanczos_gamma_plus_one_half); + Value log_t = + rewriter.create(loc, log_lanczos_gamma_plus_one_half, + log1p_z_div_lanczos_gamma_plus_one_half); + + // Compute the final result (modulo reflection). t(z) may be large, and we + // need to be careful not to overflow to infinity in the first term of + // + // (z + 1/2) * log(t(z)) - t(z). + // + // Therefore we compute this as + // + // (z + 1/2 - t(z) / log(t(z))) * log(t(z)). + // + // log_y = log_sqrt_two_pi + (z + one_half - t / log_t) * log_t + Log(x); + Value t_div_log_t = rewriter.create(loc, t, log_t); + Value one_half_minus_t_div_log_t = + rewriter.create(loc, one_half, t_div_log_t); + Value z_plus_one_half_minus_t_div_log_t = + rewriter.create(loc, z, one_half_minus_t_div_log_t); + Value z_plus_one_half_minus_t_div_log_t_mul_log_t = + rewriter.create(loc, z_plus_one_half_minus_t_div_log_t, + log_t); + Value log_x = rewriter.create(loc, x); + Value log_y_rhs = rewriter.create( + loc, z_plus_one_half_minus_t_div_log_t_mul_log_t, log_x); + Value log_y = rewriter.create(loc, log_sqrt_two_pi, log_y_rhs); + + // Compute the reflected value, used when x < 0.5: + // + // lgamma(x) = log(pi) - lgamma(1-x) - log(abs(sin(pi * x))). + // + // (The abs is because lgamma is the log of the absolute value of the gamma + // function.) + // + // We have to be careful when computing the final term above. gamma(x) goes + // to +/-inf at every integer x < 0, and this is controlled by the + // sin(pi * x) term. The slope is large, so precision is particularly + // important. + // + // Because abs(sin(pi * x)) has period 1, we can equivalently use + // abs(sin(pi * frac(x))), where frac(x) is the fractional part of x. This + // is more numerically accurate: It doesn't overflow to inf like pi * x can, + // and if x is an integer, it evaluates to 0 exactly, which is significant + // because we then take the log of this value, and log(0) is inf. + // + // We don't have a frac(x) primitive in XLA and computing it is tricky, but + // because abs(sin(pi * x)) = abs(sin(pi * abs(x))), it's good enough for + // our purposes to use abs(frac(x)) = abs(x) - floor(abs(x)). + // + // Furthermore, pi * abs(frac(x)) loses precision when abs(frac(x)) is close + // to 1. To remedy this, we can use the fact that sin(pi * x) in the domain + // [0, 1] is symmetric across the line Y=0.5. + Value abs_input = rewriter.create(loc, input); + Value abs_input_floor = rewriter.create(loc, abs_input); + Value abs_frac_input = + rewriter.create(loc, abs_input, abs_input_floor); + + // Convert values of abs_frac_input > 0.5 to (1 - frac_input) to improve + // precision of pi * abs_frac_input for values of abs_frac_input close to 1. + Value one_minus_abs_frac_input = + rewriter.create(loc, one, abs_frac_input); + Value abs_frac_input_gt_one_half = + rewriter.create(loc, abs_frac_input, one_half); + Value reduced_frac_input = rewriter.create( + loc, abs_frac_input_gt_one_half, one_minus_abs_frac_input, + abs_frac_input); + Value pi_mul_reduced_frac_input = + rewriter.create(loc, pi, reduced_frac_input); + Value sin_pi_mul_reduced_frac_input = + rewriter.create(loc, pi_mul_reduced_frac_input); + Value reflection_denom = + rewriter.create(loc, sin_pi_mul_reduced_frac_input); + + // Avoid computing -inf - inf, which is nan. If reflection_denom is +/-inf, + // then it "wins" and the result is +/-inf. + Value is_finite = rewriter.create(loc, tensor_bool_type, + reflection_denom); + Value neg_reflection_denom = + rewriter.create(loc, reflection_denom); + Value log_pi_minus_reflection_denom = + rewriter.create(loc, log_pi, reflection_denom); + Value reflection_if_finite = + rewriter.create(loc, log_pi_minus_reflection_denom, log_y); + Value reflection = rewriter.create( + loc, is_finite, reflection_if_finite, neg_reflection_denom); + + Value result = rewriter.create(loc, need_to_reflect, + reflection, log_y); + + // lgamma(+/-inf) = +inf. + Value is_inf = rewriter.create(loc, tensor_bool_type, input); + result = rewriter.create(loc, is_inf, infinity, result); + + if (needs_cast) { + result = rewriter.create(loc, original_tensor_type, result); + } + + rewriter.replaceOp(op, result); + return success(); + } +}; + // Lowers Pack op to ConcatV2 op after changing shape of the inputs with // ExpandDims op. // @@ -369,7 +703,7 @@ class LowerPackOp : public OpRewritePattern { loc, DenseElementsAttr::get( RankedTensorType::get({}, rewriter.getIntegerType(64)), op.axis())); - int64_t axis = op.axis().getSExtValue(); + int64_t axis = op.axis(); Type prev_input_ty, inferred_ty; SmallVector expanded_inputs; @@ -393,6 +727,187 @@ class LowerPackOp : public OpRewritePattern { } }; +// Lowers SpaceToBatchND by reducing to reshape(transpose(reshape(pad(input)))). +// +// Before rewrite: +// output = SpaceToBatchND(input, block_shape, paddings) +// Let: +// [batch] + spatial_shape + remaining_shape = input.shape +// M = spatial_shape.rank +// After rewrite: +// padded = zero-pad input with paddings +// The spatial_shape component of input.shape pads with paddings[*, 0] +// before each dimension, and paddings[*, 1] after each dimension. +// reshaped = reshape padded to: +// [batch] +// + [padded.shape[1]/block_shape[0], block_shape[0], ..., +// padded.shape[M]/block_shape[M-1], block_shape[M-1]] +// + remaining_shape +// permuted = transpose reshaped to: +// block_shape +// + [batch] +// + [padded.shape[1]/block_shape[0], ..., padded.shape[M]/block_shape[M-1]] +// + remaining_shape +// result = reshape permuted to: +// [batch * product(block_shape)] +// + [padded.shape[1]/block_shape[0], ..., padded.shape[M]/block_shape[M-1]] +// + remaining_shape +class LowerSpaceToBatchNDOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::SpaceToBatchNDOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto input_type = op.input().getType().cast(); + if (!input_type.hasStaticShape()) { + return failure(); + } + ArrayRef input_shape = input_type.getShape(); + auto block_shape_type = op.block_shape().getType().cast(); + if (!block_shape_type.hasStaticShape()) { + return failure(); + } + auto paddings_type = op.paddings().getType().cast(); + + int64_t input_rank = input_type.getRank(); + int64_t block_rank = block_shape_type.getNumElements(); + int64_t remaining_rank = input_rank - 1 - block_rank; + if (remaining_rank < 0) { + // TODO(b/157475606): Move this check to ::Verify + return failure(); + } + + auto block_shape_i64_type = RankedTensorType::get( + block_shape_type.getShape(), rewriter.getIntegerType(64)); + auto block_shape_i64 = rewriter.create( + loc, block_shape_i64_type, op.block_shape()); + + auto paddings_i64_type = RankedTensorType::get(paddings_type.getShape(), + rewriter.getIntegerType(64)); + auto paddings_i64 = + rewriter.create(loc, paddings_i64_type, op.paddings()); + + auto pad00 = rewriter.create( + loc, DenseElementsAttr::get( + RankedTensorType::get({1, 2}, rewriter.getIntegerType(64)), + {0, 0})); + SmallVector full_paddings_list{pad00, paddings_i64}; + full_paddings_list.append(remaining_rank, pad00); + auto full_paddings_type = + RankedTensorType::get({input_rank, 2}, rewriter.getIntegerType(64)); + auto zero_i64 = rewriter.create( + loc, GetScalarOfType(rewriter.getIntegerType(64), 0)); + // Extends paddings to all dimensions of input by adding 0s to non-block + // dimensions. + auto full_paddings = rewriter.create( + loc, full_paddings_type, full_paddings_list, zero_i64); + + SmallVector padded_shape(input_rank, ShapedType::kDynamicSize); + auto padded_type = + RankedTensorType::get(padded_shape, rewriter.getF32Type()); + // padded = pad(input, full_paddings) + auto padded = + rewriter.create(loc, padded_type, op.input(), full_paddings); + + auto paddings_sum_type = + RankedTensorType::get({input_rank}, rewriter.getIntegerType(64)); + auto one_i64 = rewriter.create( + loc, GetScalarOfType(rewriter.getIntegerType(64), 1)); + // paddings_sum = paddings[*,0] + paddings[*,1] + auto paddings_sum = rewriter.create(loc, paddings_sum_type, + full_paddings, one_i64); + + // input_shape_tensor = input.shape + auto input_shape_tensor = rewriter.create( + loc, + DenseElementsAttr::get( + RankedTensorType::get({input_rank}, rewriter.getIntegerType(64)), + input_shape)); + + // padded_shape_tensor is the shape of padded. + auto padded_shape_tensor = + rewriter.create(loc, paddings_sum, input_shape_tensor); + + auto zero_i32 = rewriter.create( + loc, GetScalarOfType(rewriter.getIntegerType(32), 0)); + SmallVector padded_shape_splits_types( + input_rank, RankedTensorType::get({1}, rewriter.getIntegerType(64))); + SmallVector padded_shape_splits( + rewriter + .create(loc, padded_shape_splits_types, zero_i32, + padded_shape_tensor) + .output()); + + SmallVector block_shape_splits_types( + block_rank, RankedTensorType::get({1}, rewriter.getIntegerType(64))); + SmallVector block_shape_splits( + rewriter + .create(loc, block_shape_splits_types, zero_i32, + block_shape_i64) + .output()); + + SmallVector outer_shape_vals; + for (int64_t i = 0; i < block_rank; ++i) { + // TODO(b/157475606): Insert tf.Assert that the following division has + // remainder 0. + outer_shape_vals.push_back(rewriter.create( + loc, padded_shape_splits[1 + i], block_shape_splits[i])); + } + + SmallVector reshaped_shape_vals{padded_shape_splits[0]}; + for (int64_t i = 0; i < block_rank; ++i) { + reshaped_shape_vals.push_back(outer_shape_vals[i]); + reshaped_shape_vals.push_back(block_shape_splits[i]); + } + for (int64_t i = 1 + block_rank; i < input_rank; ++i) { + reshaped_shape_vals.push_back(padded_shape_splits[i]); + } + auto reshaped_shape = ValuesToRank1( + rewriter, loc, rewriter.getIntegerType(64), reshaped_shape_vals); + + SmallVector permutation_vals; + for (int64_t i = 0; i < block_rank; ++i) { + permutation_vals.push_back(2 + 2 * i); + } + permutation_vals.push_back(0); + for (int64_t i = 0; i < block_rank; ++i) { + permutation_vals.push_back(1 + 2 * i); + } + for (int64_t i = 1 + block_rank; i < input_rank; ++i) { + permutation_vals.push_back(block_rank + i); + } + auto permutation = rewriter.create( + loc, GetI64ElementsAttr(permutation_vals, &rewriter)); + + auto output_batch = padded_shape_splits[0]; + for (int64_t i = 0; i < block_rank; ++i) { + output_batch = + rewriter.create(loc, output_batch, block_shape_splits[i]); + } + SmallVector output_shape_vals{output_batch}; + for (int64_t i = 0; i < block_rank; ++i) { + output_shape_vals.push_back(outer_shape_vals[i]); + } + for (int64_t i = 1 + block_rank; i < input_rank; ++i) { + output_shape_vals.push_back(padded_shape_splits[i]); + } + auto output_shape = ValuesToRank1( + rewriter, loc, rewriter.getIntegerType(64), output_shape_vals); + auto reshaped = rewriter.create(loc, padded, reshaped_shape); + auto permuted = + rewriter.create(loc, reshaped, permutation); + + // Sometimes the result type is more specific than what the reshape builder + // can infer. + auto result_type = op.getResult().getType(); + rewriter.replaceOpWithNewOp(op, result_type, permuted, + output_shape); + + return success(); + } +}; + // Lowers `TF::SparseMatMulOp` to `TF::MatMulOp`, ignoring the sparseness hints, // since we currently don't have an implementation that can use this // information. Adds appropriate casts where necessary to align element types @@ -447,8 +962,7 @@ class Lower_UnaryOpsComposition LogicalResult matchAndRewrite(TF::_UnaryOpsCompositionOp op, PatternRewriter &rewriter) const override { Value result = op.x(); - for (StringRef op_name : - op.op_names().getAsRange()) { + for (StringRef op_name : op.op_names().getAsValueRange()) { std::string full_name = "tf." + op_name.str(); // All ops in the sequences have the same result type as the original // result type. @@ -466,10 +980,11 @@ class Lower_UnaryOpsComposition void PopulateLoweringTFPatterns(MLIRContext *context, OwningRewritePatternList *patterns) { - patterns->insert( - context); - populateWithGenerated(context, patterns); + patterns->insert(context); + populateWithGenerated(context, *patterns); } } // namespace TF diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td index f7a867f3130..bddc863ee60 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td @@ -24,6 +24,10 @@ class GetScalarOfType : NativeCodeCall< class GetScalarOfFloatType : NativeCodeCall< "GetScalarOfFloatType(getElementTypeOrSelf($0)," # value # ")">; +def GetScalarInfOfType : NativeCodeCall< + "GetScalarOfFloatType(getElementTypeOrSelf($0), " + "std::numeric_limits::infinity())">; + def GetScalarNanOfType : NativeCodeCall< "GetScalarOfFloatType(getElementTypeOrSelf($0), " "std::numeric_limits::quiet_NaN())">; @@ -154,13 +158,22 @@ foreach fromToBinPair = [[TF_DivNoNanOp, TF_DivOp], def LowerFillOp : Pat<(TF_FillOp $dims, $value), (TF_BroadcastToOp $value, $dims)>; +//===----------------------------------------------------------------------===// +// Inf op patterns. +//===----------------------------------------------------------------------===// + +def LowerIsInfOp : Pat<(TF_IsInfOp $x), + (TF_EqualOp (TF_AbsOp:$abs $x), + (TF_ConstOp:$inf (GetScalarInfOfType $x)), + /*incompatible_shape_error*/ConstBoolAttrTrue)>; + //===----------------------------------------------------------------------===// // NaN op patterns. //===----------------------------------------------------------------------===// def LowerIsNanOp : Pat<(TF_IsNanOp $x), - (TF_EqualOp $x, (TF_ConstOp:$nan (GetScalarNanOfType $x)), - /*incompatible_shape_error*/ConstBoolAttrTrue)>; + (TF_NotEqualOp $x, $x, + /*incompatible_shape_error*/ConstBoolAttrTrue)>; //===----------------------------------------------------------------------===// // L2Loss op patterns. @@ -198,6 +211,25 @@ def : Pat<(TF_PadOp TensorOf<[AnySignlessInteger, AnyFloat]>:$input, $paddings), def LowerReciprocal : Pat<(TF_ReciprocalOp $x), (TF_DivOp (TF_ConstOp (GetScalarOfType<1> $x)), $x)>; +//===----------------------------------------------------------------------===// +// Round op patterns. +//===----------------------------------------------------------------------===// + + +// Rounds on integers should just be bypassed. +def : Pat<(TF_RoundOp:$res TF_IntTensor:$input), (TF_IdentityOp $input)>; + +// Implements TF Round on floats using basic operations. +def : Pat<(TF_RoundOp:$res TF_FloatTensor:$input), + (TF_SelectOp + (TF_LessOp + (TF_SubOp $input, (TF_FloorOp:$floor $input)), + (TF_ConstOp (GetScalarOfFloatType<"0.5"> $input))), + $floor, + (TF_AddOp + (TF_ConstOp (GetScalarOfType<1> $input)), $floor))>; + + //===----------------------------------------------------------------------===// // Rsqrt op patterns. //===----------------------------------------------------------------------===// @@ -217,12 +249,22 @@ def : Pat<(TF_RsqrtGradOp $lhs, $rhs), // TODO(hinsu): Support complex input types. def LowerTanhGradOp : - Pat<(TF_TanhGradOp TF_FpTensor:$y, TF_FpTensor:$dy), + Pat<(TF_TanhGradOp TF_FloatTensor:$y, TF_FloatTensor:$dy), (TF_MulOp $dy, (TF_SubOp (TF_ConstOp (GetScalarOfType<1> $y)), (TF_SquareOp $y)))>; + //===----------------------------------------------------------------------===// +// LowerFakeQuantWithMinMaxArgs op patterns. +//===----------------------------------------------------------------------===// + +def LowerFakeQuantWithMinMaxArgs : + Pat<(TF_FakeQuantWithMinMaxArgsOp TF_FloatTensor: $input, + $min, $max, $bits, $narrow_range), + (TF_FakeQuantWithMinMaxVarsOp $input, + (TF_ConstOp $min), (TF_ConstOp $max), $bits, $narrow_range)>; + // ZerosLike op patterns. //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc index 4438f19bb74..ac844b925ce 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/xla/transforms/passes.h" +#include "tensorflow/core/lib/monitoring/gauge.h" namespace mlir { namespace TFDevice { @@ -37,6 +38,11 @@ namespace { constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation"; constexpr char kAllowSoftPlacementAttr[] = "allow_soft_placement"; +auto* auto_outside_compilation_gauge = + tensorflow::monitoring::Gauge::New( + "/tensorflow/core/use_auto_outside_compilation", + "Tracks if auto outside compilation is enabled"); + // This pass marks unsupported ops in a device cluster with // `_xla_outside_compilation` attribute so the operations will run on the host // instead of the device. Unsupported ops are ops that can not be code @@ -47,6 +53,15 @@ struct MarkOpsForOutsideCompilation void runOnOperation() override; }; +// Adds any canonicalization patterns to list of supported `patterns`. +// TODO(b/161726307): Move or import the relevant patterns to LowerTF pass and +// remove this. +void AddCanonicalizationPatterns(MLIRContext* context, + OwningRewritePatternList* patterns) { + for (auto* op : context->getRegisteredOperations()) + op->getCanonicalizationPatterns(*patterns, context); +} + // TODO(b/159128666): Check the control flow legalization passes instead once // added. void AddSupportedControlFlowOps(MLIRContext* context, @@ -68,16 +83,71 @@ void AddRewrittenEmbeddingOps(MLIRContext* context, TF::SendTPUEmbeddingGradientsOp::getOperationName(), context)); } +// Stack, TensorList and TensorArray ops are rewritten during the second phase +// of the bridge (compilation of TPUCompile op). They would not match any +// legalization/canonicalization pattern and have to be manually added to the +// list of supported ops. +void AddRewrittenCompositeOps(MLIRContext* context, + llvm::DenseSet* supported_ops) { +#define GET_OPERATION_NAME(op) OperationName(op::getOperationName(), context) + llvm::SmallDenseSet allowlist_ops = { + // Stack ops. + GET_OPERATION_NAME(TF::StackV2Op), + GET_OPERATION_NAME(TF::StackPushV2Op), + GET_OPERATION_NAME(TF::StackPopV2Op), + // Tensor Array ops. + GET_OPERATION_NAME(TF::TensorArrayV3Op), + GET_OPERATION_NAME(TF::TensorArrayReadV3Op), + GET_OPERATION_NAME(TF::TensorArrayWriteV3Op), + GET_OPERATION_NAME(TF::TensorArrayConcatV3Op), + GET_OPERATION_NAME(TF::TensorArraySplitV3Op), + GET_OPERATION_NAME(TF::TensorArraySizeV3Op), + GET_OPERATION_NAME(TF::TensorArrayGradV3Op), + GET_OPERATION_NAME(TF::TensorArrayGatherV3Op), + GET_OPERATION_NAME(TF::TensorArrayScatterV3Op), + GET_OPERATION_NAME(TF::TensorListFromTensorOp), + // Tensor List Ops. + GET_OPERATION_NAME(TF::EmptyTensorListOp), + GET_OPERATION_NAME(TF::TensorListReserveOp), + GET_OPERATION_NAME(TF::TensorListFromTensorOp), + GET_OPERATION_NAME(TF::TensorListPushBackOp), + GET_OPERATION_NAME(TF::TensorListPopBackOp), + GET_OPERATION_NAME(TF::TensorListGetItemOp), + GET_OPERATION_NAME(TF::TensorListSetItemOp), + GET_OPERATION_NAME(TF::TensorListLengthOp), + GET_OPERATION_NAME(TF::TensorListElementShapeOp), + GET_OPERATION_NAME(TF::TensorListGatherOp), + GET_OPERATION_NAME(TF::TensorListScatterIntoExistingListOp), + }; +#undef GET_OPERATION_NAME + + supported_ops->insert(allowlist_ops.begin(), allowlist_ops.end()); +} + +bool IsStringType(Type type) { + if (type.isa()) return true; + + auto sub_type = type.dyn_cast(); + if (!sub_type) return false; + + bool has_string = llvm::any_of(sub_type.GetSubtypes(), [](TensorType type) { + return type.getElementType().isa(); + }); + return has_string; +} + bool HasStringOperand(Operation& op) { for (auto operand : op.getOperands()) { - if (getElementTypeOrSelf(operand).isa()) return true; + auto operand_type = getElementTypeOrSelf(operand); + if (IsStringType(operand_type)) return true; } return false; } bool HasStringResult(Operation& op) { for (auto result : op.getResults()) { - if (getElementTypeOrSelf(result).isa()) return true; + auto result_type = getElementTypeOrSelf(result); + if (IsStringType(result_type)) return true; } return false; } @@ -135,18 +205,10 @@ LogicalResult MarkUncompilableOps( op->getContext())); outside_compiled_cluster_counter++; } - if (llvm::isa(op)) { - if (HasCapturedStringOperand(op)) { - op->setAttr( - kXlaOutsideCompilationAttr, - StringAttr::get( - llvm::formatv("auto{0}", outside_compiled_cluster_counter) - .str(), - op->getContext())); - outside_compiled_cluster_counter++; - } - } }); + if (outside_compiled_cluster_counter > 0) { + auto_outside_compilation_gauge->GetCell()->Set(true); + } return success(); } @@ -179,6 +241,7 @@ void MarkOpsForOutsideCompilation::runOnOperation() { OwningRewritePatternList patterns; mhlo::PopulateLegalizeTfPatterns(module.getContext(), &patterns); TF::PopulateLoweringTFPatterns(module.getContext(), &patterns); + AddCanonicalizationPatterns(module.getContext(), &patterns); // `supported_ops` contains the name of all of the ops that can potentially be // lowered into HLO on the device. This doesn't always mean that the op can @@ -186,10 +249,12 @@ void MarkOpsForOutsideCompilation::runOnOperation() { // be lowered in a subsequent pass. llvm::DenseSet supported_ops; for (auto& pattern : patterns) { - supported_ops.insert(*pattern->getRootKind()); + Optional root_kind = pattern->getRootKind(); + if (root_kind.hasValue()) supported_ops.insert(root_kind.getValue()); } AddSupportedControlFlowOps(module.getContext(), &supported_ops); AddRewrittenEmbeddingOps(module.getContext(), &supported_ops); + AddRewrittenCompositeOps(module.getContext(), &supported_ops); auto result = module.walk([&](tf_device::ClusterOp cluster) { // Only if `allow_soft_placement` attribute is true should we mark ops diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc index 24e77d31e7c..29ecc38de0b 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc @@ -37,7 +37,7 @@ struct TFOptimizePass : public PassWrapper { void runOnFunction() override { OwningRewritePatternList patterns; auto func = getFunction(); - populateWithGenerated(&getContext(), &patterns); + populateWithGenerated(&getContext(), patterns); applyPatternsAndFoldGreedily(func, patterns); } }; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc b/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc index 6fee693554e..b81e390580d 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc @@ -109,13 +109,14 @@ class ResourceAnalyzer { return; } if (auto if_op = dyn_cast(op)) { - for (auto callee : {if_op.then_func(), if_op.else_func()}) { + for (auto callee : {if_op.then_function(), if_op.else_function()}) { PropagatePotentiallyWrittenUpFromCallee(callee, if_op.input()); } return; } if (auto while_op = dyn_cast(op)) { - for (auto callee : {while_op.cond_func(), while_op.body_func()}) { + for (auto callee : + {while_op.cond_function(), while_op.body_function()}) { PropagatePotentiallyWrittenUpFromCallee(callee, while_op.input()); } return; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/parallel_execute_to_islands.cc b/tensorflow/compiler/mlir/tensorflow/transforms/parallel_execute_to_islands.cc index 1332c8b6e59..86eea50d744 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/parallel_execute_to_islands.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/parallel_execute_to_islands.cc @@ -16,66 +16,62 @@ limitations under the License. // This pass forms `tf_executor.island` per region of // `tf_device.parallel_execute`. // -// For example: +// For example, the following: +// +// %0 = tf_executor.island { +// tf_executor.yield +// } // %1:2 = tf_executor.island { // %2 = "tf.opA"(%arg0) : (tensor) -> tensor // tf_executor.yield %2 : tensor // } -// tf_executor.island() { -// "tf_device.parallel_execute"() ({ -// %3 = "tf.opB"() : () -> tensor -// tf_device.return %3 : tensor -// }, -// { +// %3:2 = tf_executor.island(%0) { +// %4 = "tf_device.parallel_execute"() ( { +// %5 = "tf.opB"() : () -> tensor +// tf_device.return %5 : tensor +// }, { // %5 = "tf.opC"(%1#0) : (tensor) -> tensor // tf_device.return // }) {} : () -> (tensor) +// tf_executor.yield %4 : tensor +// } +// tf_executor.fetch %3#0 : tensor +// +// gets lowered to: +// +// %0 = tf_executor.island { // tf_executor.yield // } -// tf_executor.fetch +// %1:2 = tf_executor.island { +// %2 = "tf.opA"(%arg0) : (tensor) -> tensor +// tf_executor.yield %2 : tensor +// } // -// Would become: -// %1:2 = tf_executor.island { -// %2 = "tf.opA"(%arg0) : (tensor) -> tensor -// tf_executor.yield %2 : tensor -// } +// // Island for the first region of above parallel_execute. +// %3:2 = tf_executor.island(%0) { +// %4 = "tf.opB"() : () -> tensor +// tf_executor.yield %4 : tensor +// } // -// // Input barrier sink island that forwards all inputs. -// %output_0, %control_1 = tf_executor.island { -// tf_executor.yield %1#0: tensor -// } +// // Island for the second region of above parallel_execute. +// %5 = tf_executor.island(%0) { +// %6 = "tf.opC"(%1#0) : (tensor) -> tensor +// tf_executor.yield +// } // -// // Island for the first region of above parallel_execute. -// %output_2, %control_3 = tf_executor.island(%control_1) { -// %3 = "tf.opB"() : () -> tensor -// tf_executor.yield %3 : tensor -// } -// -// // Island for the second region of above parallel_execute. -// %control_5 = tf_executor.island { -// %5 = "tf.opC"(%output_0) : (tensor) -> tensor -// tf_executor.yield -// } -// -// // Output barrier sink island that forwards all outputs. -// %output_5, %control_6 = tf_executor.island(%control_5) { -// tf_executor.yield %output_2 -// } +// tf_executor.fetch %3#0, %5 : tensor, !tf_executor.control // // When tf_device.parallel_execute op is enclosed after tf_device.replicate, // then this pass will run following `replicate-to-island` pass and // `tf-executor-break-up-islands` pass. #include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" #include "mlir/IR/Block.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project -#include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "mlir/Transforms/RegionUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" @@ -89,175 +85,117 @@ struct ParallelExecuteToIslandsPass }; // Convert parallel_execute op to a set of islands where each region of -// parallel_execute op becomes a separate island. This ensures that -// regions of parallel_execute op gets executed concurrently. -LogicalResult ExpandParallelExecuteToIslands( - tf_executor::IslandOp island_op, tf_executor::IslandOp input_sink_island, +// parallel_execute op becomes a separate island. This ensures that the regions +// of the parallel_execute op gets executed concurrently. +void ExpandParallelExecuteToIslands( + tf_executor::IslandOp island_op, tf_device::ParallelExecuteOp parallel_execute_op, OpBuilder* builder, - llvm::SmallVector* islands) { - const int num_executions = - parallel_execute_op.getOperation()->getNumRegions(); - llvm::SmallVector executions; - executions.reserve(num_executions); - builder->setInsertionPoint(island_op); + llvm::SmallVectorImpl& executes) { + const int num_regions = parallel_execute_op.getOperation()->getNumRegions(); + executes.reserve(num_regions); - auto control_type = tf_executor::ControlType::get(island_op.getContext()); - for (int i : llvm::seq(0, num_executions)) { - auto execute_region = - parallel_execute_op.GetRegionBlockWithIndex(i).getParent(); + for (int i : llvm::seq(0, num_regions)) { + Block& execute_block = parallel_execute_op.GetRegionBlockWithIndex(i); - // If region does not have any inputs, then add explicit control dependency - // from the input sink island. This guarantees that all inputs of - // parallel_execute op must be materialized before any of the islands are - // executed. - llvm::SetVector region_inputs; - getUsedValuesDefinedAbove(*execute_region, region_inputs); - llvm::SmallVector execution_control_inputs; - if (region_inputs.empty() && input_sink_island) - execution_control_inputs.emplace_back(input_sink_island.control()); - - // Collect result types and operands. - Operation* terminator = execute_region->front().getTerminator(); - llvm::SmallVector output_types(terminator->getOperandTypes()); - - // Replace terminator with YieldOp as island op always ends with yield op. + // Replace terminator with tf_executor.YieldOp. + Operation* terminator = execute_block.getTerminator(); builder->setInsertionPoint(terminator); - builder->create(terminator->getLoc(), - terminator->getOperands()); + auto yield = builder->create( + terminator->getLoc(), terminator->getOperands()); terminator->erase(); // Create new island for each region. builder->setInsertionPoint(island_op); - auto execution_island = builder->create( - island_op.getLoc(), output_types, control_type, - execution_control_inputs); + auto execute_island = builder->create( + island_op.getLoc(), yield.getOperandTypes(), + island_op.control().getType(), island_op.controlInputs()); - // Move over tf_device.parallel_execute body region into newly a - // created island. - execution_island.body().takeBody(*execute_region); - islands->push_back(execution_island); + // Move over tf_device.parallel_execute body region into newly the created + // island. + execute_island.body().takeBody(*execute_block.getParent()); + executes.push_back(execute_island); } - - return success(); } -// Creates an island that works as input sync point for islands. This guarantees -// that all (implicitly captured) inputs of parallel_execute are materialized -// before any of the islands are executed. -tf_executor::IslandOp CreateInputBarrierIsland( - OpBuilder* builder, tf_executor::IslandOp island_op) { - builder->setInsertionPoint(island_op); - - llvm::SetVector all_inputs; - getUsedValuesDefinedAbove(island_op.body(), all_inputs); - - // Filter out values that are arguments and doesn't need to be part of the - // entry barrier. - llvm::SmallVector island_inputs; - llvm::SmallVector input_types; - island_inputs.reserve(all_inputs.size()); - input_types.reserve(all_inputs.size()); - for (Value val : all_inputs) { - if (!val.isa()) { - island_inputs.push_back(val); - input_types.push_back(val.getType()); - } - } - if (island_inputs.empty() && island_op.controlInputs().empty()) return {}; - - // Create new island for that forwards all inputs. - auto control_type = tf_executor::ControlType::get(island_op.getContext()); - auto input_sink_island = builder->create( - island_op.getLoc(), input_types, control_type, island_op.controlInputs()); - input_sink_island.body().push_back(new Block); - - for (auto input_index_and_value : llvm::enumerate(island_inputs)) { - int index = input_index_and_value.index(); - Value input_value = input_index_and_value.value(); - replaceAllUsesInRegionWith(input_value, input_sink_island.getResult(index), - island_op.body()); - } - - // Create YieldOp for the new input sink island. - builder->setInsertionPointToEnd(&input_sink_island.GetBody()); - builder->create(island_op.getLoc(), - llvm::to_vector<8>(island_inputs)); - return input_sink_island; -} - -// Creates an islands that works as output sync point. This guarantees that -// execution of all islands must be completed before op following -// parallel_execute runs. -tf_executor::IslandOp CreateOutputBarrierIsland( - OpBuilder* builder, tf_executor::IslandOp island_op, - llvm::SmallVectorImpl* islands) { - // Add control dependency to island operand if island output has no uses. - llvm::SmallVector island_operands; - for (auto& island : *islands) - if (island.use_empty()) island_operands.push_back(island.control()); - - // Create single island forwarding all island results. - builder->setInsertionPoint(island_op); - auto island_output_sink = builder->create( - island_op.getLoc(), llvm::to_vector<8>(island_op.getResultTypes()), - island_operands); - island_output_sink.body().push_back(new Block); - return island_output_sink; -} - -LogicalResult CreateIslandsFromParallelExecute( +void CreateIslandsFromParallelExecute( tf_executor::IslandOp island_op, tf_device::ParallelExecuteOp parallel_execute_op) { OpBuilder builder(island_op); - auto input_sink_island = CreateInputBarrierIsland(&builder, island_op); - // Create N islands where N is the number of regions inside parallel_execute - // op. - llvm::SmallVector islands; - auto result = ExpandParallelExecuteToIslands( - island_op, input_sink_island, parallel_execute_op, &builder, &islands); - if (failed(result)) return result; + // Create islands for each region of the parallel_execute op. + llvm::SmallVector executes; + ExpandParallelExecuteToIslands(island_op, parallel_execute_op, &builder, + executes); - // Remap all results of parallel_execute op with outputs from newly - // created islands. + // Remap all results of parallel_execute op with outputs from newly created + // islands. llvm::SmallVector parallel_execute_outputs; parallel_execute_outputs.reserve( parallel_execute_op.getOperation()->getNumResults()); - for (auto island : islands) - for (auto output_value : island.outputs()) - parallel_execute_outputs.emplace_back(output_value); + for (auto& execute : executes) + parallel_execute_outputs.append(execute.outputs().begin(), + execute.outputs().end()); - parallel_execute_op.getOperation()->replaceAllUsesWith( - parallel_execute_outputs); + for (auto result : llvm::zip(island_op.outputs(), parallel_execute_outputs)) + std::get<0>(result).replaceAllUsesWith(std::get<1>(result)); - auto island_output_sink = - CreateOutputBarrierIsland(&builder, island_op, &islands); + // Add sink island to pin all islands as a control dependency if there is a + // control dependency leading from the parallel_execute originally. + if (!island_op.control().use_empty()) { + llvm::SmallVector island_operands; + for (auto& execute : executes) island_operands.push_back(execute.control()); + + builder.setInsertionPoint(island_op); + auto island_sink = builder.create( + island_op.getLoc(), llvm::ArrayRef{}, + island_op.control().getType(), island_operands); + island_sink.body().push_back(new Block); + builder.setInsertionPointToEnd(&island_sink.GetBody()); + builder.create(island_op.getLoc(), + llvm::ArrayRef{}); + island_op.control().replaceAllUsesWith(island_sink.control()); + } + + // Islands with no uses should be pinned to a graph fetch so they still + // execute. + llvm::SmallVector unused_execute_controls; + for (auto& execute : executes) + if (execute.use_empty()) + unused_execute_controls.push_back(execute.control()); + + if (!unused_execute_controls.empty()) { + auto graph_op = island_op.getParentOfType(); + tf_executor::FetchOp fetch = graph_op.GetFetch(); + auto fetches = llvm::to_vector<8>(fetch.getOperands()); + fetches.append(unused_execute_controls.begin(), + unused_execute_controls.end()); + builder.setInsertionPoint(fetch); + builder.create(fetch.getLoc(), fetches); + fetch.erase(); + } - // Move island YieldOp over to new single island and remap island results. - island_op.GetYield().getOperation()->moveBefore( - &island_output_sink.GetBody(), island_output_sink.GetBody().begin()); - island_op.replaceAllUsesWith(island_output_sink); island_op.erase(); - - return success(); -} - -// Finds islands with a single `tf_device.parallel_execute` and create -// individual islands per region of parallel_execute. -void LowerSingleIslandParallelExecuteToIslands( - tf_executor::IslandOp island_op) { - if (!hasSingleElement(island_op.GetBody().without_terminator())) return; - - if (auto parallel_execute_op = llvm::dyn_cast( - &island_op.GetBody().front())) - CreateIslandsFromParallelExecute(island_op, parallel_execute_op); } void ParallelExecuteToIslandsPass::runOnFunction() { - getFunction().walk([&](tf_executor::IslandOp island_op) { - LowerSingleIslandParallelExecuteToIslands(island_op); + // Find islands with a single `tf_device.parallel_execute` and create + // individual islands per execute region of the parallel_execute. + llvm::SmallVector parallel_execute_op_islands; + getFunction().walk([&](tf_executor::GraphOp graph_op) { + for (auto island_op : graph_op.getOps()) { + if (!island_op.WrapsSingleOp()) continue; + + if (isa(&island_op.GetBody().front())) + parallel_execute_op_islands.push_back(island_op); + } }); + + for (tf_executor::IslandOp island_op : parallel_execute_op_islands) { + auto parallel_execute_op = + cast(island_op.GetBody().front()); + CreateIslandsFromParallelExecute(island_op, parallel_execute_op); + } } } // anonymous namespace diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/parallelize_embedding_params_ops_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/parallelize_embedding_params_ops_pass.cc index 527af0934ea..352604955c0 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/parallelize_embedding_params_ops_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/parallelize_embedding_params_ops_pass.cc @@ -39,6 +39,10 @@ namespace { struct ParallelizeEmbeddingParamsOpsPass : public PassWrapper { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + void runOnFunction() override; }; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h index d93d9ddccaf..a4ddb713ec0 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h @@ -167,6 +167,12 @@ void PopulateLegalizeHloToTfPatterns(OwningRewritePatternList* patterns, // future these fusions may be codegen'd automatically. std::unique_ptr> CreateFusedKernelMatcherPass(); +// Fuses operations defining `ContractionFusableInterface` interface into the +// contraction operations (MatMul, Conv2D, etc...). This is a more general +// version of `CreateFusedKernelMatcherPass` that relies on codegen to compose +// contraction fusions together. +std::unique_ptr> CreateContractionFusionPass(); + // Creates function pass to select device index/fold tf.DeviceIndex. std::unique_ptr> CreateDeviceIndexSelectorPass(); @@ -276,6 +282,11 @@ namespace TFTPU { // `_tpu_replicate` attribute. std::unique_ptr> CreateTPUClusterFormationPass(); +// Creates a pass that cleans up `_tpu_replicate` attribute on operations +// that are inside a cluster. +std::unique_ptr> +CreateTPUClusterCleanupAttributesPass(); + // Creates a pass that removes Identity/IdentityN ops from a cluster. std::unique_ptr> CreateTPUIdentityPruningPass(); @@ -287,6 +298,10 @@ std::unique_ptr> CreateTPUDynamicLayoutPass(); // `tf_device.launch_func` `padding_map` attribute to its encapsulated function. std::unique_ptr> CreateTPUDynamicPaddingMapperPass(); +// Creates a pass that adds `tf.ReadVariableOp` to a TPU cluster for resources +// the cluster only writes to. +std::unique_ptr> CreateTPUResourceReadForWritePass(); + // Creates a pass that rewrites `tf_device.launch_func` on TPUs into TPU runtime // ops. std::unique_ptr> CreateTPURewritePass(); @@ -295,18 +310,29 @@ std::unique_ptr> CreateTPURewritePass(); // computation. std::unique_ptr> CreateTPUShardingIdentificationPass(); +// Creates a pass that moves `tf.AssignVariableOp` into a +// `tf_device.parallel_execute` region if the `tf.AssignVariableOp` is the +// only consumer of a `tf_device.parallel_execute` result. +std::unique_ptr> +CreateTPUParallelExecuteSinkResourceWritePass(); + // Creates a pass that merges device variable reads/updates into the surrounded // TPUExecute node. This allows the execute node to perform in-place variable // updates. std::unique_ptr> CreateTPUMergeVariablesWithExecutePass(); +// Creates a pass that wraps ReadVariableOp/AssignVariable op that consumes a +// packed tensor to have same device placement as underlying TPU device. +std::unique_ptr> CreateTPUColocateCompositeResourceOps(); + // Creates a pass that adds ops which perform formatting on variables at // run-time according to compilation result. std::unique_ptr> CreateTPUVariableReformattingPass(); // Creates a pass that groups outside compiled operations (CPU ops inside TPU // cluster) into clusters that can be extracted and run on the CPU. -std::unique_ptr> CreateTPUOutsideCompilationClusterPass(); +std::unique_ptr> +CreateTPUOutsideCompilationClusterPass(); // Creates a pass that extracts outside compilation (CPU ops inside TPU cluster) // at head/tail of TPU cluster to run before/after TPU computation. @@ -329,6 +355,7 @@ std::unique_ptr> CreateTPUExtractOutsideCompilationPass(); // Populates the supplied passmanager with the passes required to run the +// bridge. void CreateTPUBridgePipeline(OpPassManager& pm); // Populates the supplied passmanager with the passes required to run the diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_invariant_op_hoisting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_invariant_op_hoisting.cc index 031d57e99ba..96ff2890558 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_invariant_op_hoisting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_invariant_op_hoisting.cc @@ -151,7 +151,7 @@ bool IsOpReplicateInvariant(Region* replicate_region, Operation* op) { // invariant. Shape ops are rewritten to be invariant when possible, prior to // hoisting ops. void HoistReplicateInvariantOps(tf_device::ReplicateOp replicate_op) { - const int num_replicas = replicate_op.n().getLimitedValue(); + const int num_replicas = replicate_op.n(); Block* replicate_block = &replicate_op.GetBody(); replicate_op.walk([&](TF::ShapeOp shape_op) { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc index d99279c0014..5b70729ee80 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc @@ -376,7 +376,7 @@ LogicalResult CreateIslandsFromReplicate(const Dialect* tf_dialect, tf_executor::IslandOp island_op, tf_device::ReplicateOp replicate_op) { OpBuilder builder(island_op); - const int num_replicas = replicate_op.n().getLimitedValue(); + const int num_replicas = replicate_op.n(); // Create islands per replica. llvm::SmallVector replicas; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_device_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_device_inference.cc index c1ca98bf1f1..648805febfe 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_device_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_device_inference.cc @@ -283,12 +283,13 @@ void ResourceDeviceInference::runOnOperation() { if (auto while_op = dyn_cast(op)) { if (failed(propagate_operands_to_callee_arguments( while_op, while_op.getOperands(), - {while_op.body_func(), while_op.cond_func()}, func_res))) + {while_op.body_function(), while_op.cond_function()}, + func_res))) return WalkResult::interrupt(); } else if (auto if_op = dyn_cast(op)) { if (failed(propagate_operands_to_callee_arguments( - if_op, if_op.input(), {if_op.then_func(), if_op.else_func()}, - func_res))) + if_op, if_op.input(), + {if_op.then_function(), if_op.else_function()}, func_res))) return WalkResult::interrupt(); } else if (auto call = dyn_cast(op)) { auto func = dyn_cast(call.resolveCallable()); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc index 77f672f5ee4..c357abd10da 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc @@ -15,11 +15,15 @@ limitations under the License. // This pass lifts resource variable operations outside of device computation. +#include #include +#include "llvm/ADT/BitVector.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" @@ -32,20 +36,24 @@ limitations under the License. #include "mlir/IR/Function.h" // from @llvm-project #include "mlir/IR/Module.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Region.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/Verifier.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/RegionUtils.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h" @@ -137,14 +145,37 @@ struct ResourceOpLiftingPass void runOnOperation() override; }; -// Removes identity nodes in the block. The device computation does not need -// such nodes to carry information. -void RemoveIdentity(Block* block) { - for (auto& op : llvm::make_early_inc_range(*block)) { - if (isa(&op)) { - op.replaceAllUsesWith(op.getOperands()); - op.erase(); - } +bool IsResource(Value value) { + return getElementTypeOrSelf(value.getType()).isa(); +} + +// Get the type of the data contained in a resource. Returns null if there is +// no single type in the resource. +Type GetResourceSubtype(Value value) { + auto resource_type = + getElementTypeOrSelf(value.getType()).dyn_cast(); + auto subtypes = resource_type.getSubtypes(); + if (subtypes.size() == 1) return subtypes[0]; + return nullptr; +} + +// Replaces all `tf.VarIsInitializedOp` in a block with a constant true. +// TODO(b/171039585): Replace this with proper analysis of +// `tf.VarIsInitializedOp` in regards to resource writes and control flow. +void SetAllVarIsInitializedToTrue(Block* block) { + auto builder = OpBuilder::atBlockBegin(block); + TF::ConstOp const_true = nullptr; + for (auto op : + llvm::make_early_inc_range(block->getOps())) { + builder.setInsertionPoint(op); + if (!const_true) + const_true = builder.create( + op.getLoc(), + DenseIntElementsAttr::get( + RankedTensorType::get(/*shape=*/{}, builder.getI1Type()), true)); + + op.is_initialized().replaceAllUsesWith(const_true); + op.erase(); } } @@ -187,157 +218,447 @@ void ForwardStoreToLoad(Block* block) { } } -// Moves resource load operations with the provided `move_load` function. This -// assumes load-store forwarding has been performed on this block such that -// all loads of same resource are on its initial values. A `skip_load` functions -// is used to indicate whether a load should be skipped. If there are multiple -// loads on the same resource, only the first one will be moved, and the later -// ones will be removed and replaced with the first one. -void HoistResourceLoads( - Block* block, llvm::function_ref skip_load, - llvm::function_ref move_load) { - llvm::SmallDenseMap resource_to_read_ops; +//===----------------------------------------------------------------------===// +// RegionResourceHoister +//===----------------------------------------------------------------------===// +// Helper class to hoist resource ops out of regions attached to an op. +class RegionResourceHoister { + public: + explicit RegionResourceHoister(Operation* op) : op_(op) {} + + // Analyzes attached regions to record resources read and written. + LogicalResult Analyze(); + + // Returns all resources accessed by the regions attached the op. + auto& GetResources() { return resources_; } + + // Returns if the given value is a resouce that needs lifting. + bool Contains(Value resource) const { + return resources_.find(resource) != resources_.end(); + } + + // Drops the given resource from lifting. + void DropResource(Value resource) { + resources_.erase(resource); + written_resources_.remove(resource); + } + + // Replaces all resource loads in all regions attached to the op. + void ReplaceResourceLoads(bool read_only) { + llvm::for_each(op_->getRegions(), [&](Region& region) { + ReplaceResourceLoads(region, read_only); + }); + } + + static LogicalResult ReplaceOpWithNewOp(Operation* op); + + private: + // Returns if any resources need lifting. + bool NeedsLifting() const { return !resources_.empty(); } + + // Returns the number of results generated by the lifted op. + int GetLiftedNumResults() const { return num_new_results_; } + + // Generates hoisted reads for resources that need them before the op. + void GenerateHoistedReads(); + + // Replaces all resource loads in the given region with hoisted loads. If + // `read_only` is true, limit this replacement to read only resources. + void ReplaceResourceLoads(Region& region, bool read_only); + + // Appends final values writte to resources to the region returns for the + // given set of regions. + void AppendResourceStoreValueToReturn(RegionRange regions); + + // Performs the final replacement of the op. + void ReplaceOpWithNewOp(); + + // Returns is this resource was written to in any of the regions. + bool IsWritten(Value resource) const { + return written_resources_.contains(resource); + } + + static LogicalResult HoistResourcesOutOfIfCaseCluster(Operation* op); + static LogicalResult HoistResourcesOutOfWhileRegion(TF::WhileRegionOp op); + + Operation* op_; + + // Per resource information about accesses to that resource. + struct ResourceInfo { + // Is this resource read in any of the regions? + bool is_read; + // Is this resource written in any of the regions? + bool is_written; + // Is this resource written in all of the regions? + bool is_written_all; + // The hoisted read used to replace region reads. + Value hoisted_read; + // the type of the data held by the resource. + Type data_type; + // For written resources, the result # of the lifted op which will hold the + // value of the resource. This result will be used to generates writes to + // the resource after the lifted op. + int result_index; + // Attributes on the read operation. + DictionaryAttr read_attrs; + // Attributes on the write operation. + DictionaryAttr write_attrs; + + ResourceInfo() + : is_read(false), + is_written(false), + is_written_all(false), + hoisted_read(nullptr), + data_type(nullptr), + result_index(-1) {} + + bool IsResultIndexAssigned() { return result_index != -1; } + + // Refine the resource type using the given type `type`. + void RefineType(Type type) { + if (!data_type) { + data_type = type; + } else { + data_type = TF::GetCastCompatibleType(data_type, type, + /*may_ignore_ref_type_a=*/false); + assert(data_type != nullptr && "Resource used with incompatible types"); + } + } + }; + llvm::MapVector resources_; + llvm::SetVector written_resources_; + // number of new results after lifting. + int num_new_results_; +}; + +// Analyzes resources that are read or written within attached regions. +LogicalResult RegionResourceHoister::Analyze() { + // Hoisting of child regions might have created opportunity for store-load + // forwarding. + for (Region& region : op_->getRegions()) { + ForwardStoreToLoad(®ion.front()); + } + + llvm::SetVector all_resources; + bool is_func = false; + // For functions, the resources to analyze are the function arguments. + // Otherwise, its the region captures. + if (FuncOp func = dyn_cast(op_)) { + is_func = true; + Region& body = func.getBody(); + for (BlockArgument arg : body.getArguments()) { + if (IsResource(arg)) all_resources.insert(arg); + } + } else { + getUsedValuesDefinedAbove(op_->getRegions(), all_resources); + all_resources.remove_if([](Value value) { return !IsResource(value); }); + } + + num_new_results_ = op_->getNumResults(); + + for (auto resource : all_resources) { + ResourceInfo info; + info.data_type = GetResourceSubtype(resource); + llvm::BitVector written_regions(op_->getNumRegions()); + bool unsupported_use = false; + for (OpOperand& use : resource.getUses()) { + Operation* user = use.getOwner(); + // If the user is not in one of the regions, we are not interested in it. + // Since all the sub-regions within this region (i.e., regions attached to + // op's in this region) have themselves gone through lifting, all resource + // users are expected to be operations in this region and and not embedded + // within other sub-regions attached to op's in this region. So the check + // for whether a user is in one of the regions attached to this op is + // straightforward. + if (user->getParentRegion()->getParentOp() != op_) continue; + + // For functions, if the resource is used as a return operand, use that + // as its result index. + if (is_func && isa(user)) { + assert(!info.IsResultIndexAssigned() && + "Expect resource argument to returned no more than once"); + info.result_index = use.getOperandNumber(); + continue; + } + + auto read = dyn_cast(user); + auto write = dyn_cast(user); + if (!read && !write) { + unsupported_use = true; + break; + } + + if (read && !info.is_read) { + info.is_read = true; + info.RefineType(read.value().getType()); + info.read_attrs = user->getAttrDictionary(); + } + + if (write) { + info.is_written = true; + info.RefineType(write.value().getType()); + info.write_attrs = user->getAttrDictionary(); + written_regions.set(user->getParentRegion()->getRegionNumber()); + } + } + + // If the resource is used in an op that we do not understand, skip + // lifting for that resource. + if (unsupported_use) continue; + + info.is_written_all = written_regions.count() == op_->getNumRegions(); + + // If the resource is written in some but not all regions, we would need + // a read for the value before these regions. Note that this is applicable + // only to multi-region ops: + // If/Case: If not all regions write to the resource, post hoisting the read + // value need to be routed through all paths that don't write. + // While: since while condition cannot write, any resource written in the + // while body will need to be read as well in case the while body is never + // executed. + // Both cases are handled by the condition below. + if (info.is_written && !info.is_written_all) info.is_read = true; + + // Allocate a result index for written resources that don't have one. + if (info.is_written) { + written_resources_.insert(resource); + if (!info.IsResultIndexAssigned()) info.result_index = num_new_results_++; + } + + resources_.insert({resource, info}); + } + return success(); +} + +// Generates hoisted reads for all resources that need them just before the op. +void RegionResourceHoister::GenerateHoistedReads() { + OpBuilder builder(op_); + for (auto& resource_it : GetResources()) { + Value resource = resource_it.first; + auto& info = resource_it.second; + + if (info.is_read) { + Operation* read = builder.create( + op_->getLoc(), info.data_type, resource); + read->setAttrs(info.read_attrs); + info.hoisted_read = read->getResult(0); + } + } +} + +// Replaces all resource reads with the hoisted read. +void RegionResourceHoister::ReplaceResourceLoads(Region& region, + bool read_only) { + assert(llvm::hasSingleElement(region) && "Expected single block region"); // Only iterate through ops directly in the body as we can't handle // ops nested deeper in regions. - for (Operation& op : llvm::make_early_inc_range(*block)) { - auto read_variable_op = dyn_cast(&op); - if (!read_variable_op) continue; - if (skip_load(read_variable_op)) continue; + auto all_reads = region.front().getOps(); + for (auto read_op : llvm::make_early_inc_range(all_reads)) { + Value resource = read_op.resource(); + if (!Contains(resource)) continue; - Value resource = read_variable_op.resource(); - auto p = resource_to_read_ops.insert({resource, read_variable_op}); - if (p.second) { - move_load(read_variable_op); - continue; + ResourceInfo& info = resources_[resource]; + // If replacing loads for read only resources, skip if the resource + // was written to. + if (read_only && info.is_written) continue; + + read_op.replaceAllUsesWith(info.hoisted_read); + read_op.erase(); + } +} + +// For written resources, add its value at the end of each region to that +// regions return value. For a region, its value at the end may be a value +// written to that resource in that region, or its hoisted read value if the +// resource is not written in that region. The return value can be vended out +// either as an existing return value, or a newly allocated return value. +void RegionResourceHoister::AppendResourceStoreValueToReturn( + RegionRange regions) { + for (Region* region : regions) { + assert(llvm::hasSingleElement(*region) && "Expected single block region"); + Block& front = region->front(); + auto old_return = front.getTerminator(); + assert(old_return->getNumOperands() == op_->getNumResults()); + auto new_return_operands = llvm::to_vector<4>(old_return->getOperands()); + new_return_operands.resize(num_new_results_); + + // initialize return values for written resources to be the hosited reads. + for (Value resource : written_resources_) { + const ResourceInfo& info = resources_[resource]; + new_return_operands[info.result_index] = info.hoisted_read; } - // Getting here means a load operation of this resource has been hoisted out - // before. Use hoisted load result to replace all uses of current op result - // and erase op. - op.replaceAllUsesWith(p.first->second); - op.erase(); - } -} + // Only iterate through ops directly in the body as op's embedded in child + // regions should have been lifted out. + auto assign_ops = front.getOps(); + for (auto assign_variable_op : llvm::make_early_inc_range(assign_ops)) { + Value resource = assign_variable_op.resource(); + if (!IsWritten(resource)) continue; -// If there are any stores to resource defined outside of the block then the -// stored values must be returned so that new values can be used by sunk -// resource stores. -// Returns true if any resource variable stored values are appended, otherwise -// false. -bool AppendResourceStoreValueToReturn(Block* body) { - bool has_resource_store = false; - auto old_return = body->getTerminator(); - - llvm::SmallVector new_return_operands(old_return->getOperands()); - - // Only iterate through ops directly in the body as we can't handle ops nested - // deeper in regions. - for (auto assign_variable_op : body->getOps()) { - Value resource = assign_variable_op.resource(); - if (!resource) continue; - - // Skip resources created inside of the body. - if (resource.getParentRegion() == body->getParent()) continue; - - // TODO(ycao): Prevent same value from being returned multiple times. - // TODO(ycao): Do not return resource store value if it is defined outside - // of cluster. - new_return_operands.push_back(assign_variable_op.value()); - has_resource_store = true; - } - - // If no resource stores are found, no need to update return op. - if (!has_resource_store) return false; - - OpBuilder builder(old_return); - builder.create(old_return->getLoc(), - new_return_operands); - old_return->erase(); - return true; -} - -// Moves resource store operations to after cluster. This assumes load-store -// forwarding has been performed on this cluster such that there is at most one -// resource store operation carrying its final value. -tf_device::ClusterOp SinkResourceStores(tf_device::ClusterOp cluster, - OpBuilder* builder) { - // Update ReturnOp inside cluster's body to output final values of updated - // external resources. - if (!AppendResourceStoreValueToReturn(&cluster.GetBody())) return cluster; - - auto new_return_op = cluster.GetBody().getTerminator(); - llvm::SmallVector new_return_types(new_return_op->getOperandTypes()); - - builder->setInsertionPoint(cluster); - auto new_cluster = builder->create( - cluster.getLoc(), new_return_types, - /*operands=*/llvm::SmallVector(), cluster.getAttrs()); - new_cluster.body().takeBody(cluster.body()); - - // Replace uses of old cluster results with those of new_cluster. - for (auto result : llvm::zip(cluster.getResults(), new_cluster.getResults())) - std::get<0>(result).replaceAllUsesWith(std::get<1>(result)); - - // Create a mapping from operands of new_return_op operands to new_cluster - // results. - BlockAndValueMapping mapper; - for (auto operand_result : - llvm::zip(new_return_op->getOperands(), new_cluster.getResults())) - mapper.map(std::get<0>(operand_result), std::get<1>(operand_result)); - - // Clone all resource store ops and map their operands to values returned from - // new_cluster. - for (Operation& op : llvm::make_early_inc_range(new_cluster.GetBody())) { - if (isa(op)) { - builder->clone(op, mapper); - op.erase(); + // TODO(ycao): Prevent same value from being returned multiple times. + // TODO(ycao): Do not return resource store value if it is defined outside + // of cluster. Both of these can be post-resource-op-lifting cleanup + // passes. + int result_index = resources_[resource].result_index; + new_return_operands[result_index] = assign_variable_op.value(); + assign_variable_op.erase(); } + old_return->setOperands(new_return_operands); } - - cluster.erase(); - return new_cluster; } -// Hoists resource variable loads and sinks stores from cluster. -LogicalResult HoistResourceOpsFromCluster(tf_device::ClusterOp cluster, - ModuleOp module) { - OpBuilder builder(module); +// Replace the old op with a new op (with potentially additional results), and +// add stores to written resources after the new op. +void RegionResourceHoister::ReplaceOpWithNewOp() { + auto new_result_types = llvm::to_vector<4>(op_->getResultTypes()); + int result_region = isa(op_) ? 1 : 0; + Operation* terminator = op_->getRegion(result_region).front().getTerminator(); + auto extra_result_types = + terminator->getOperands().drop_front(op_->getNumResults()).getTypes(); + new_result_types.insert(new_result_types.end(), extra_result_types.begin(), + extra_result_types.end()); + OpBuilder builder(op_); + // Clone ths old operation but with new result types. + Operation* new_op = Operation::create( + op_->getLoc(), op_->getName(), new_result_types, op_->getOperands(), + op_->getAttrs(), op_->getSuccessors(), op_->getNumRegions()); + builder.insert(new_op); - // Remove identity nodes to avoid aliasing. - RemoveIdentity(&cluster.GetBody()); + // Move regions to the new op. + for (auto it : llvm::zip(op_->getRegions(), new_op->getRegions())) { + Region& old_region = std::get<0>(it); + Region& new_region = std::get<1>(it); + new_region.takeBody(old_region); + } - // Perform store-load forwarding. So that each resource is only loaded with - // its initial value and is only stored with its final value. - ForwardStoreToLoad(&cluster.GetBody()); + // Insert stores to all written resources. + for (Value resource : written_resources_) { + ResourceInfo& info = resources_[resource]; + Value value_to_write = new_op->getResult(info.result_index); + Operation* write = builder.create( + op_->getLoc(), resource, value_to_write); + write->setAttrs(info.write_attrs); + } - // Move loads of external resources, if any, to before cluster. - // (Skipping resources created inside of cluster.) - HoistResourceLoads( - &cluster.GetBody(), - /*skip_load=*/ - [&](TF::ReadVariableOp read) { - return read.resource().getParentRegion() == &cluster.body(); - }, - /*move_load=*/ - [&](TF::ReadVariableOp read) { - read.getOperation()->moveBefore(cluster); - }); + // As a part of lifting, we either reuse an existing slot for resource type + // results or add a new slot. Resource type results should not have any uses + // to begin with. So we can safely replace each old op result with the + // corresponding new op result. + int old_num_results = op_->getNumResults(); + op_->replaceAllUsesWith(new_op->getResults().take_front(old_num_results)); + op_->erase(); + op_ = nullptr; +} - // Move stores of external resources, if any, to after cluster. - auto new_cluster = SinkResourceStores(cluster, &builder); +// Lift resource load and stores out of regions attached to `op`, where op is +// an If/case/cluster op. +LogicalResult RegionResourceHoister::HoistResourcesOutOfIfCaseCluster( + Operation* op) { + RegionResourceHoister hoister(op); + if (failed(hoister.Analyze())) return failure(); - llvm::SetVector captured_values; - getUsedValuesDefinedAbove(new_cluster.body(), new_cluster.body(), - captured_values); + // If there are no resource region captures, then nothing to do. + if (!hoister.NeedsLifting()) return success(); + // Start the transformation. For each region, replace the resource read with + // the value read before the op. + hoister.GenerateHoistedReads(); + hoister.ReplaceResourceLoads(/*read_only=*/false); + hoister.AppendResourceStoreValueToReturn(op->getRegions()); + hoister.ReplaceOpWithNewOp(); return success(); } +// Lift resource loads and stores out of WhileRegion +LogicalResult RegionResourceHoister::HoistResourcesOutOfWhileRegion( + TF::WhileRegionOp op) { + // For WhileRegion, post canonicalization all resource used within the + // body and condition regions are replaced with captured values, so we do not + // need to take into account the body and condition region arguments. + RegionResourceHoister hoister(op); + + if (failed(hoister.Analyze())) return failure(); + + // If there are no resource region captures, then nothing to do. + if (!hoister.NeedsLifting()) return success(); + + // The resources captured for While loop fall into two categories: + // (a) read-only. These reads can be replaced by a hoisted read created + // before the WhileOp (similar to if and case). + // (b) written: since the value is written in the loop (which can only in + // loop body, all these will become loop variables. Since all resource + // variables are removed from the loop variabled during + // canonicalizationW, we need to create new operand/result slots. The + // input operands for these slots are the read values + // prior to the op, and all references to these are replaced by the + // corresponding slot argument. We need to generate writes following + // the while for these resources. + // + // Note that for WhileRegion ops, if a resource is written, it will be written + // only in the body and not the condition, so the hoister analysis will infer + // it as needing a read as well. + + // Generate hoisted reads before the while. + hoister.GenerateHoistedReads(); + + // Replace just the read-only resources with the hoisted reads. + hoister.ReplaceResourceLoads(/*read_only=*/true); + + // For written resources, add additional operands to the while op. + int num_old_results = op.getNumResults(); + int num_new_results = hoister.GetLiftedNumResults(); + int num_extra_results = num_new_results - num_old_results; + + SmallVector new_result_types; + SmallVector new_while_operands; + new_result_types.resize(num_extra_results); + new_while_operands.resize(num_extra_results); + + for (auto& it : hoister.GetResources()) { + if (!it.second.is_written) continue; + int index = it.second.result_index - num_old_results; + new_result_types[index] = it.second.data_type; + new_while_operands[index] = it.second.hoisted_read; + } + op.getOperation()->insertOperands(op.getNumOperands(), new_while_operands); + + // Patch the cond and body regions to have additional arguments, and replace + // the remaining resource reads (which will be resource reads for written + // resources) with these arguments. + for (Region* region : op.getRegions()) { + region->addArguments(new_result_types); + // Point hoisted read for written resources to the region's arguments. + for (auto& it : hoister.GetResources()) { + if (!it.second.is_written) continue; + it.second.hoisted_read = region->getArgument(it.second.result_index); + } + hoister.ReplaceResourceLoads(*region, /*read_only=*/false); + } + + // Add additional return values to body return. These correspond to values + // written to resources in the body region. + hoister.AppendResourceStoreValueToReturn(op.getRegions().drop_front()); + + // Finally, create a new while with additional return values. + hoister.ReplaceOpWithNewOp(); + return success(); +} + +// Lift resources out of the regions attached to `op` +LogicalResult RegionResourceHoister::ReplaceOpWithNewOp(Operation* op) { + if (auto while_op = dyn_cast(op)) + return HoistResourcesOutOfWhileRegion(while_op); + return HoistResourcesOutOfIfCaseCluster(op); +} + // Holds information about a function's use of a resource argument. struct ResourceArgUseInfo { + // Data type of the data contained in the resource. Type data_type; + // Is the resource argument used in an assign op? bool updated; + // Is the resource argument used in a read or assign op? bool used; }; @@ -348,12 +669,12 @@ struct ResourceArgUseInfo { LogicalResult FindResourceArgUseInfo( FuncOp func_op, llvm::SmallDenseMap* result) { auto return_op = func_op.front().getTerminator(); - for (auto arg : func_op.getArguments()) { - if (!getElementTypeOrSelf(arg.getType()).isa()) continue; + for (auto arg : TF::filter_resources(func_op.getArguments())) { ResourceArgUseInfo info; info.used = false; info.updated = false; bool read_or_assigned = false; + bool used_in_unsupported_op = false; for (auto user : arg.getUsers()) { if (user == return_op) continue; info.used = true; @@ -362,14 +683,21 @@ LogicalResult FindResourceArgUseInfo( info.data_type = read.getType(); continue; } + if (auto assign = llvm::dyn_cast(user)) { read_or_assigned = true; info.updated = true; info.data_type = assign.value().getType(); continue; } + + used_in_unsupported_op = true; + break; } - if (!info.used || read_or_assigned) (*result)[arg.getArgNumber()] = info; + + // If the arg is used in an unsupported op, skip lifting it. + if (used_in_unsupported_op) continue; + (*result)[arg.getArgNumber()] = info; } return success(); } @@ -455,59 +783,59 @@ void RemoveUnusedResourceArgumentsAndForwardedRetvals( // signature. resource_data_types is the (index, data type) pair for each // resource argument. handle_updated_arg_value is a caller-provided function // that handles the updated value for an resource argument. -void LiftArgRetResourcesForFunction( +LogicalResult LiftArgRetResourcesForFunction( FuncOp func_op, const llvm::SmallDenseMap& resource_data_types, llvm::function_ref handle_updated_arg_value) { - ForwardStoreToLoad(&func_op.front()); - // Maps a resource argument to the first read. - llvm::SmallDenseMap resource_arg_read; - // Maps a resource argument to the last write. - llvm::SmallDenseMap resource_arg_write; - // Use HoistResourceLoads to CSE loads and the `move_load` function only - // records the remaining load to resource_arg_read. - HoistResourceLoads( - &func_op.front(), - /*skip_load=*/ - [&](TF::ReadVariableOp read) { - return !read.resource().isa(); - }, - /*move_load=*/ - [&](TF::ReadVariableOp read) { - resource_arg_read[read.resource()] = read; - }); - // Record the stores in resource_arg_read. - for (auto& op : llvm::make_early_inc_range(func_op.front())) { - auto write = llvm::dyn_cast(&op); - if (!write) continue; - auto arg = write.resource().dyn_cast(); - if (!arg) continue; - // After ForwardStoreToLoad(), there should be just one store for each - // resource. - resource_arg_write[arg] = write; - } - // Now change the input types to non-resource and remove the internal loads. - auto new_types = llvm::to_vector<8>(func_op.getType().getInputs()); - for (auto& entry : resource_data_types) { - auto arg = func_op.getArgument(entry.getFirst()); - auto read_it = resource_arg_read.find(arg); - auto write_it = resource_arg_write.find(arg); - arg.setType(entry.getSecond()); - new_types[arg.getArgNumber()] = entry.getSecond(); - if (read_it != resource_arg_read.end()) { - read_it->getSecond().replaceAllUsesWith(arg); - read_it->getSecond().erase(); - } - if (write_it != resource_arg_write.end()) { - handle_updated_arg_value(arg.getArgNumber(), - write_it->getSecond().value()); - write_it->getSecond().erase(); + RegionResourceHoister hoister(func_op); + if (failed(hoister.Analyze())) return failure(); + + // Each of these resources could be read or written in the function. If its + // read, we need to replace the resource arg with a value arg to get the + // read value. If its written, we need to replace the write with an additional + // value to be written. + + // Now create read values that will be used to replace each resource that + // is read in the function body. These read vaulues are just the same argument + // with type replaced. + llvm::SmallVector skipped_args; + for (auto& it : hoister.GetResources()) { + BlockArgument arg = it.first.dyn_cast(); + assert(arg && "Expect resources for FuncOp to be its arguments"); + auto type_iter = resource_data_types.find(arg.getArgNumber()); + if (type_iter == resource_data_types.end()) { + // Skip lifting the resource if it's not present in the data type map. + // This indicates that the resource is not to be lifted because it is used + // in an unsupported op in some other function. + skipped_args.push_back(arg); + } else { + arg.setType(type_iter->second); + it.second.hoisted_read = arg; } } - func_op.setType(FunctionType::get( - new_types, - llvm::to_vector<4>(func_op.front().getTerminator()->getOperandTypes()), - func_op.getContext())); + + // Drop all the args that have to be skipped. + for (Value arg : skipped_args) hoister.DropResource(arg); + + hoister.ReplaceResourceLoads(/*read_only=*/false); + + // For writes, invoke the callback and then erase the write. + auto assign_ops = func_op.front().getOps(); + for (auto assign_variable_op : llvm::make_early_inc_range(assign_ops)) { + Value resource = assign_variable_op.resource(); + if (!hoister.Contains(resource)) continue; + + auto arg = resource.dyn_cast(); + handle_updated_arg_value(arg.getArgNumber(), assign_variable_op.value()); + assign_variable_op.erase(); + } + + func_op.setType( + FunctionType::get(func_op.front().getArgumentTypes(), + func_op.front().getTerminator()->getOperandTypes(), + func_op.getContext())); + + return success(); } // Returns a vector filtered from range where the unused elements (specified by @@ -556,29 +884,7 @@ void AddLoadsStoresOutsideControlFlowOp( // Lifts loads/stores from while loop's body and cond functions. LogicalResult HandleWhileLoop(TF::WhileOp while_op, FuncOp body, FuncOp cond) { - // Remove identity nodes to avoid aliasing. - RemoveIdentity(&body.front()); - RemoveIdentity(&cond.front()); auto return_op = body.front().getTerminator(); - // Sanity check: body resource input/output should alias each other. - for (auto arg : body.getArguments()) { - if (!getElementTypeOrSelf(arg.getType()).isa()) continue; - if (return_op->getOperand(arg.getArgNumber()) != arg) { - return return_op->emitOpError( - "resource used in while loop is only supported when the ") - << "resource input and output alias each other in the loop body."; - } - } - // FindResourceArgUseInfo will check supported resource ops (read and assign), - // but loop condition has additional requirement that it cannot write - // resources. - if (cond.walk([&](TF::AssignVariableOp assign) { - assign.emitOpError("found resource write in loop condition."); - return WalkResult::interrupt(); - }) - .wasInterrupted()) { - return failure(); - } llvm::SmallDenseMap body_use_info; llvm::SmallDenseMap cond_use_info; if (failed(FindResourceArgUseInfo(body, &body_use_info)) || @@ -589,12 +895,7 @@ LogicalResult HandleWhileLoop(TF::WhileOp while_op, FuncOp body, FuncOp cond) { auto resource_arg_uses = MergeArgResourceUseInfo(body_use_info, cond_use_info); if (resource_arg_uses.empty()) return success(); - for (const auto& entry : resource_arg_uses) { - // Replace output resource uses with the input, so that we can later freely - // change the output type. - while_op.getResult(entry.getFirst()) - .replaceAllUsesWith(while_op.getOperand(entry.getFirst())); - } + // Remove unused resources in functions. llvm::SmallVector old_to_new_indices; llvm::SmallDenseMap remaining_resource_data_types; @@ -647,50 +948,8 @@ LogicalResult HandleWhileLoop(TF::WhileOp while_op, FuncOp body, FuncOp cond) { // Lifts loads/stores from an IfOp or CaseOp's branches. template LogicalResult HandleCaseOrIfOp(CaseOrIfOp op, ArrayRef branches) { - // Remove identity nodes to avoid aliasing. - for (auto func : branches) RemoveIdentity(&func.front()); - - // Sanity check: branch return of resources should be aliases of inputs. If - // so, replace the output uses with the input so that we can remove these - // outputs. - for (OpResult result : op.getResults()) { - if (!getElementTypeOrSelf(result.getType()).isa()) - continue; - unsigned result_index = result.getResultNumber(); - constexpr unsigned kUnassigned = -1; - unsigned common_aliasing_arg_num = kUnassigned; - for (auto func : branches) { - auto retval = func.front().getTerminator()->getOperand(result_index); - assert(result.getType() == retval.getType()); - auto aliasing_arg = retval.dyn_cast(); - if (!aliasing_arg) - return op.emitOpError("unsupported output: ") - << "resource does not alias input"; - if (common_aliasing_arg_num == kUnassigned) - common_aliasing_arg_num = aliasing_arg.getArgNumber(); - if (aliasing_arg.getArgNumber() != common_aliasing_arg_num) - return op.emitOpError("unsupported output: ") - << "resource does not alias a single input"; - } - assert(common_aliasing_arg_num != kUnassigned); - result.replaceAllUsesWith(op.getOperand(common_aliasing_arg_num + 1)); - } - - // Erase the resource outputs from the branches. - int64_t non_resource_results = 0; - llvm::SmallVector old_to_new_output_indices; - bool output_removed = false; - for (auto result : op.getResults()) { - if (!getElementTypeOrSelf(result.getType()) - .template isa()) { - old_to_new_output_indices.push_back(non_resource_results++); - continue; - } - old_to_new_output_indices.push_back(-1); - for (auto func : branches) - func.front().getTerminator()->eraseOperand(non_resource_results); - output_removed = true; - } + // For canonicalized If/Case, there should not be any resource outputs + int64_t non_resource_results = op.getNumResults(); llvm::SmallDenseMap resource_arg_uses; if (failed(FindResourceArgUseInfo(branches.front(), &resource_arg_uses))) @@ -705,7 +964,7 @@ LogicalResult HandleCaseOrIfOp(CaseOrIfOp op, ArrayRef branches) { MergeArgResourceUseInfo(resource_arg_uses, branch_use_info); } - if (resource_arg_uses.empty() && !output_removed) return success(); + if (resource_arg_uses.empty()) return success(); // Remove unused resources in functions. llvm::SmallDenseMap remaining_resource_data_types; RemoveUnusedResourceArgumentsAndForwardedRetvals( @@ -780,12 +1039,7 @@ LogicalResult HandleCaseOrIfOp(CaseOrIfOp op, ArrayRef branches) { AddLoadsStoresOutsideControlFlowOp(new_op, arg_data_type_and_updated_output_index); // Replace uses. - for (int64_t i = 0, end = old_to_new_output_indices.size(); i < end; ++i) { - if (old_to_new_output_indices[i] >= 0) { - op.getResult(i).replaceAllUsesWith( - new_op.getResult(old_to_new_output_indices[i])); - } - } + op.replaceAllUsesWith(new_op.getResults().take_front(op.getNumResults())); op.erase(); return success(); } @@ -811,8 +1065,6 @@ struct PartitionedCallLiftingInfo { // happens on a clone, which will be stored in `result`. LogicalResult HandlePartitionedCallOpCallee( FuncOp callee, PartitionedCallLiftingInfo* result) { - // Remove identity nodes to avoid aliasing. - RemoveIdentity(&callee.front()); // Sanity check: return of resources should be aliases of inputs. Such outputs // will be removed later. int64_t non_resource_results = 0; @@ -932,8 +1184,8 @@ void UpdatePartitionedCallOpWithNewCallee( call_op.erase(); } -LogicalResult HoistForFunctionalControlFlow( - Block*, ModuleOp, +LogicalResult HoistForControlFlow( + Block*, ModuleOp, bool, llvm::SmallDenseMap*); // A templated routine for handling both PartitionedCallOp and @@ -942,14 +1194,17 @@ LogicalResult HoistForFunctionalControlFlow( // flow, then performs lifting on the callee. template LogicalResult HandlePartitionedCallOp( - CallOpType call_op, FuncOp callee, ModuleOp module, + CallOpType call_op, FuncOp callee, ModuleOp module, bool vars_initialized, llvm::SmallDenseMap* lifted_callees) { auto emplace_res = lifted_callees->try_emplace(callee.getName(), PartitionedCallLiftingInfo()); if (emplace_res.second) { // Unseen callee. Perform resource lifting on it. - HoistForFunctionalControlFlow(&callee.front(), module, lifted_callees); + if (failed(HoistForControlFlow(&callee.front(), module, vars_initialized, + lifted_callees))) + return failure(); + if (failed(HandlePartitionedCallOpCallee( callee, &emplace_res.first->getSecond()))) { return failure(); @@ -961,50 +1216,49 @@ LogicalResult HandlePartitionedCallOp( // Hoists resource loads/stores from control flow ops in `block` outside the // body/cond/branch/callee functions. -LogicalResult HoistForFunctionalControlFlow( - Block* block, ModuleOp module, +LogicalResult HoistForControlFlow( + Block* block, ModuleOp module, bool vars_initialized, llvm::SmallDenseMap* lifted_partitioned_call_callees) { - // Remove identity nodes to avoid aliasing. - RemoveIdentity(block); + if (vars_initialized) SetAllVarIsInitializedToTrue(block); + for (Operation& op : llvm::make_early_inc_range(*block)) { if (auto while_op = llvm::dyn_cast(&op)) { - auto body = while_op.body_func(); - auto cond = while_op.cond_func(); + auto body = while_op.body_function(); + auto cond = while_op.cond_function(); // Recursively handle the nested control flow. - HoistForFunctionalControlFlow(&body.front(), module, - lifted_partitioned_call_callees); - HoistForFunctionalControlFlow(&cond.front(), module, - lifted_partitioned_call_callees); + HoistForControlFlow(&body.front(), module, vars_initialized, + lifted_partitioned_call_callees); + HoistForControlFlow(&cond.front(), module, vars_initialized, + lifted_partitioned_call_callees); if (failed(HandleWhileLoop(while_op, body, cond))) return failure(); } else if (auto if_op = llvm::dyn_cast(&op)) { - auto then_branch = if_op.then_func(); - auto else_branch = if_op.else_func(); + auto then_branch = if_op.then_function(); + auto else_branch = if_op.else_function(); // Recursively handle the nested control flow. - HoistForFunctionalControlFlow(&then_branch.front(), module, - lifted_partitioned_call_callees); - HoistForFunctionalControlFlow(&else_branch.front(), module, - lifted_partitioned_call_callees); + HoistForControlFlow(&then_branch.front(), module, vars_initialized, + lifted_partitioned_call_callees); + HoistForControlFlow(&else_branch.front(), module, vars_initialized, + lifted_partitioned_call_callees); if (failed(HandleCaseOrIfOp(if_op, {then_branch, else_branch}))) return failure(); } else if (auto case_op = llvm::dyn_cast(&op)) { SmallVector branch_functions; - branch_functions.reserve(case_op.branches().size()); - for (const Attribute& branch : case_op.branches()) { - FuncOp func = - module.lookupSymbol(branch.cast()); + case_op.get_branch_functions(branch_functions); + for (FuncOp func : branch_functions) { // Recursively handle the nested control flow. - HoistForFunctionalControlFlow(&func.front(), module, - lifted_partitioned_call_callees); - branch_functions.push_back(func); + HoistForControlFlow(&func.front(), module, vars_initialized, + lifted_partitioned_call_callees); } if (failed(HandleCaseOrIfOp(case_op, branch_functions))) return failure(); } else if (auto call_op = llvm::dyn_cast(&op)) { auto callee = call_op.func(); - if (!callee) + if (!callee) { return call_op.emitOpError( "resource lifting does not support call with nested references."); + } if (failed(HandlePartitionedCallOp(call_op, callee, module, + vars_initialized, lifted_partitioned_call_callees))) { // Nested control flow handling is done in HandlePartitionedCallOp(). return failure(); @@ -1012,29 +1266,23 @@ LogicalResult HoistForFunctionalControlFlow( } else if (auto call_op = llvm::dyn_cast(&op)) { if (failed(HandlePartitionedCallOp(call_op, call_op.func(), module, + vars_initialized, lifted_partitioned_call_callees))) { return failure(); } + } else if (isa(op)) { + for (Region& region : op.getRegions()) + HoistForControlFlow(®ion.front(), module, vars_initialized, + lifted_partitioned_call_callees); + LogicalResult result = RegionResourceHoister::ReplaceOpWithNewOp(&op); + if (failed(result)) return failure(); } } - // Remove unused local variables. + // After we have hoisted operations in the block, we may have added new read + // and writes of resources to this block. Clean them up by doing store-load + // forwarding. ForwardStoreToLoad(block); - llvm::SmallVector local_vars; - for (Operation& op : *block) { - if (auto local_var = llvm::dyn_cast(&op)) { - local_vars.push_back(local_var); - } - } - for (auto local_var : local_vars) { - if (llvm::all_of(local_var.resource().getUsers(), - [](const Operation* user) { - return isa(user); - })) { - for (auto user : local_var.resource().getUsers()) user->erase(); - local_var.erase(); - } - } return success(); } @@ -1045,19 +1293,23 @@ void ResourceOpLiftingPass::runOnOperation() { llvm::SmallDenseMap lifted_partitioned_call_callees; ModuleOp module = getOperation(); - auto result = module.walk([&](FuncOp func_op) { + + if (failed(TF::CleanupAndCanonicalizeForResourceOpLifting(module))) + return signalPassFailure(); + + auto walk_result = module.walk([&](FuncOp func_op) { return func_op.walk([&](tf_device::ClusterOp cluster) { - if (failed(HoistForFunctionalControlFlow( - &cluster.GetBody(), module, &lifted_partitioned_call_callees)) || - failed(HoistResourceOpsFromCluster(cluster, module))) { - return WalkResult::interrupt(); - } + LogicalResult result = HoistForControlFlow( + &cluster.GetBody(), module, /*vars_initialized=*/true, + &lifted_partitioned_call_callees); + if (failed(result)) return WalkResult::interrupt(); + result = RegionResourceHoister::ReplaceOpWithNewOp(cluster); + if (failed(result)) return WalkResult::interrupt(); return WalkResult::advance(); }); }); - if (result.wasInterrupted()) { - signalPassFailure(); - } + + if (walk_result.wasInterrupted()) return signalPassFailure(); } struct ResourceOpLiftingForMainFunctionPass @@ -1107,11 +1359,20 @@ LogicalResult ResourceLiftingForFunctionalControlFlow(FuncOp function) { << function.getBlocks().size(); } + if (failed(TF::CleanupAndCanonicalizeForResourceOpLifting(function))) + return failure(); + llvm::SmallDenseMap lifted_partitioned_call_callees; - return HoistForFunctionalControlFlow(&function.front(), - cast(function.getParentOp()), - &lifted_partitioned_call_callees); + if (failed(HoistForControlFlow( + &function.front(), cast(function.getParentOp()), + /*vars_initialized=*/false, &lifted_partitioned_call_callees))) + return failure(); + + // Clean up and canonicalize to remove dead local variables as some local + // variables might be dead after hoisting resource loads/stores from control + // flow ops. + return TF::CleanupAndCanonicalizeForResourceOpLifting(function); } } // namespace TF diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.cc new file mode 100644 index 00000000000..b635096cc9b --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.cc @@ -0,0 +1,459 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.h" + +#include "llvm/ADT/BitVector.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" + +namespace mlir { +namespace { + +bool IsResource(Value value) { + return getElementTypeOrSelf(value.getType()).isa(); +} + +// Removes identity nodes in the block. The device computation does not need +// such nodes to carry information. +void RemoveIdentity(Block &block) { + for (auto &op : llvm::make_early_inc_range(block)) { + if (isa(&op)) { + op.replaceAllUsesWith(op.getOperands()); + op.erase(); + } + } +} + +// Eliminate local variables that are only assigned to but never read, and thus +// are dead. +void RemoveDeadLocalVariables(Block &block) { + llvm::SmallVector local_vars; + for (Operation &op : block) { + if (auto local_var = llvm::dyn_cast(&op)) { + local_vars.push_back(local_var); + } + } + for (auto local_var : local_vars) { + auto users = local_var.resource().getUsers(); + if (llvm::all_of(users, [](const Operation *user) { + return isa(user); + })) { + for (auto user : llvm::make_early_inc_range(users)) user->erase(); + local_var.erase(); + } + } +} + +LogicalResult CleanupAndCanonicalize(Operation *parent_op); + +// Eliminates unusued results from an operation `op` by cloning it with reduced +// result types and doing appropriate use replacements. `results_to_eliminate` +// is a bitvector of result positions to eliminate. If its null, then all unused +// results of the operation will be eliminated. +void EliminateUnusedResults( + Operation *op, const llvm::BitVector *results_to_eliminate = nullptr) { + auto can_eliminate = [&](OpResult &result) -> bool { + if (!result.use_empty()) return false; + if (results_to_eliminate) + return results_to_eliminate->test(result.getResultNumber()); + else + return true; + }; + SmallVector new_result_types; + for (OpResult result : op->getResults()) { + if (can_eliminate(result)) continue; + new_result_types.push_back(result.getType()); + } + + // Rebuild the new operation with lesser number of results. + OpBuilder builder(op); + Operation *new_op = Operation::create( + op->getLoc(), op->getName(), new_result_types, op->getOperands(), + op->getAttrs(), op->getSuccessors(), op->getNumRegions()); + builder.insert(new_op); + + // Move region bodies to the new operation. + for (auto it : llvm::zip(op->getRegions(), new_op->getRegions())) { + Region &old_region = std::get<0>(it); + Region &new_region = std::get<1>(it); + new_region.takeBody(old_region); + } + + // Replace used results and erase the old op. + int next_result_idx = 0; + for (OpResult result : op->getResults()) { + if (can_eliminate(result)) continue; + result.replaceAllUsesWith(new_op->getResult(next_result_idx++)); + } + op->erase(); +} + +// Clones a function if it cannot be patched in place. Clone if there are +// multiple uses or unknown uses (for external functions). The cloned function +// will be marked as private. +FuncOp CloneFunctionIfNeeded(FuncOp func) { + ModuleOp module = func.getParentOfType(); + auto func_uses = SymbolTable::getSymbolUses(func, &module.getBodyRegion()); + if (func_uses.hasValue() && llvm::hasSingleElement(func_uses.getValue())) + return func; + FuncOp cloned = func.clone(); + cloned.setVisibility(SymbolTable::Visibility::Private); + cloned.setName(func.getName().str() + "_lifted"); + SymbolTable(module).insert(cloned); + return cloned; +} + +// Eliminates unused results for If/Case operations. Also patches up the +// branch functions to (a) drop the ununsed return values, and (b) as a result +// if some argument becomes unused in all branches, drop that argument and the +// corresponding if/case input operand. +void EliminateUnusedResultsForIfCase(Operation *op, ArrayRef branches) { + // Clone branch functions if needed since we will be mutating them. + SmallVector cloned_branches; + cloned_branches.reserve(branches.size()); + for (FuncOp func : branches) { + FuncOp cloned = CloneFunctionIfNeeded(func); + cloned_branches.push_back(cloned); + if (cloned == func) continue; + // Patch up the op attribute to point to the new function. + for (NamedAttribute attr : op->getAttrs()) { + auto symref = attr.second.dyn_cast(); + if (!symref) continue; + if (symref.getValue() != func.getName()) continue; + op->setAttr(attr.first, + FlatSymbolRefAttr::get(cloned.getName(), op->getContext())); + break; + } + } + + // Traverse results backward so that indices to be deleted stay unchanged. + for (OpResult result : llvm::reverse(op->getResults())) { + if (!result.use_empty()) continue; + int result_idx = result.getResultNumber(); + for (FuncOp func : cloned_branches) + func.front().getTerminator()->eraseOperand(result_idx); + } + + // Check which function arguments are unused in all branches. We can drop + // those as well. + int num_args = cloned_branches[0].getNumArguments(); + llvm::BitVector used_args(num_args); + for (FuncOp func : branches) { + for (BlockArgument arg : func.getArguments()) { + if (!arg.use_empty()) used_args.set(arg.getArgNumber()); + } + } + + // There are some unused args that we can drop. Also drop the corresponding + // input operand. + if (used_args.count() != num_args) { + // Traverse arguments backward so that indices to be deleted stay unchanged. + for (int idx = num_args - 1; idx >= 0; --idx) { + if (used_args.test(idx)) continue; + for (FuncOp func : cloned_branches) func.eraseArgument(idx); + // For if/case, arg #i of attached function corresponds to operand #i+1 + op->eraseOperand(idx + 1); + } + } + + // Patch up function types (with less number of return values and potentially + // less number of arguments) + for (FuncOp func : cloned_branches) { + func.setType(FunctionType::get( + func.front().getArgumentTypes(), + func.front().getTerminator()->getOperandTypes(), func.getContext())); + } + + EliminateUnusedResults(op); +} + +// Eliminated unused results from a functional while. +void EliminateUnusedResultsForWhile(TF::WhileOp op) { + FuncOp cond = op.cond_function(); + FuncOp body = op.body_function(); + + llvm::BitVector can_eliminate(op.getNumResults()); + for (OpResult result : llvm::reverse(op.getResults())) { + if (!result.use_empty()) continue; + int result_idx = result.getResultNumber(); + BlockArgument cond_arg = cond.getArgument(result_idx); + BlockArgument body_arg = cond.getArgument(result_idx); + Operation *body_ret = body.front().getTerminator(); + // We can eliminate a result if its unused and the corresponding argument + // is unused in cond and the only use in body is use it as a return value. + if (cond_arg.use_empty() && body_arg.hasOneUse() && + body_arg.use_begin()->getOperandNumber() == result_idx && + body_arg.use_begin()->getOwner() == body_ret) { + can_eliminate.set(result_idx); + } + } + + if (can_eliminate.empty()) return; + + FuncOp cloned_cond = CloneFunctionIfNeeded(cond); + FuncOp cloned_body = CloneFunctionIfNeeded(body); + op.condAttr(FlatSymbolRefAttr::get(cloned_cond.getName(), op.getContext())); + op.bodyAttr(FlatSymbolRefAttr::get(cloned_body.getName(), op.getContext())); + + // Drop cond/body args and return value. WhileOp result will be dropped later + // in EliminateUnusedResults. Traverse in reverse order so that indices to be + // deleted stay unchanged. + for (int idx = op.getNumResults() - 1; idx >= 0; --idx) { + if (!can_eliminate.test(idx)) continue; + cloned_cond.eraseArgument(idx); + cloned_body.front().getTerminator()->eraseOperand(idx); + cloned_body.eraseArgument(idx); + } + + // Patch up branch function types. + for (FuncOp func : {cloned_cond, cloned_body}) { + func.setType(FunctionType::get( + func.front().getArgumentTypes(), + func.front().getTerminator()->getOperandTypes(), func.getContext())); + } + EliminateUnusedResults(op, &can_eliminate); +} + +// For resource results, replace all uses with the resource input to which the +// result is tied to. After this, resource outputs of this op are expected to be +// unused. +LogicalResult ForwardCommonArgToOutput(Operation *op, ArrayRef branches, + ValueRange branch_args, + bool &has_resource_result) { + // For while, the branch inputs and outputs need to match. + bool io_match = isa(op); + + has_resource_result = false; + // Check if the same input argument number is passed through all functions. + for (OpResult result : op->getResults()) { + if (!IsResource(result)) continue; + + has_resource_result = true; + int result_idx = result.getResultNumber(); + Optional common_arg_index; + for (FuncOp func : branches) { + auto ret = func.front().getTerminator(); + auto block_arg = ret->getOperand(result_idx).dyn_cast(); + if (!block_arg) { + return op->emitOpError("result #") + << result_idx << " not tied to function argument for branch @" + << func.getName(); + } + if (!common_arg_index.hasValue()) { + common_arg_index = block_arg.getArgNumber(); + } else if (common_arg_index.getValue() != block_arg.getArgNumber()) { + return op->emitError("result #") + << result_idx + << " is not tied to the same argument across all branches"; + } + } + + if (io_match && result_idx != common_arg_index.getValue()) { + return op->emitOpError("Result #") + << result_idx << " is tied to argument #" + << common_arg_index.getValue(); + } + + // Forward the corresponding input to the output + result.replaceAllUsesWith(branch_args[common_arg_index.getValue()]); + } + return success(); +} + +// Canonicalizes a function if. Forwards input argument to resource results and +// then deletes the resource results. +LogicalResult CanonicalizeFunctionalIfCase(Operation *op, + ArrayRef branches, + ValueRange branch_args) { + for (FuncOp func : branches) { + if (failed(CleanupAndCanonicalize(func))) return failure(); + } + + bool has_resource_result = false; + if (failed(ForwardCommonArgToOutput(op, branches, branch_args, + has_resource_result))) + return failure(); + + // If no resource type results were found, no further cleanup needed. + if (!has_resource_result) return success(); + + // Drop unused results. + EliminateUnusedResultsForIfCase(op, branches); + return success(); +} + +// Canonicalizes a functional while. Forwards common argument to results and +// drop resource results if posible. +LogicalResult CanonicalizeFunctionalWhile(TF::WhileOp op) { + for (FuncOp func : {op.cond_function(), op.body_function()}) { + if (failed(CleanupAndCanonicalize(func))) return failure(); + } + + // For while, just use the body function to forward operand to result. + bool has_resource_result = false; + if (failed(ForwardCommonArgToOutput(op, {op.body_function()}, + op.getOperands(), has_resource_result))) + return failure(); + // If no resource type results were found, no further cleanup needed. + if (!has_resource_result) return success(); + + // Drop unused results. + EliminateUnusedResultsForWhile(op); + return success(); +} + +// Canonicalizes region based if/case and cluster operations. If the same +// captured resource typed value is used for all region results, then that value +// is forwared to the result and the result is dropped. +LogicalResult CanonicalizeRegionIfCaseCluster(Operation *op) { + // Check if the same value is used for all region results for this output. + bool has_resource_result = false; + for (OpResult result : op->getResults()) { + if (!IsResource(result)) continue; + has_resource_result = true; + int result_idx = result.getResultNumber(); + + Value ret0 = + op->getRegion(0).front().getTerminator()->getOperand(result_idx); + for (Region ®ion : op->getRegions().drop_front()) { + Value ret = region.front().getTerminator()->getOperand(result_idx); + if (ret != ret0) { + return op->emitError("Result #") + << result_idx + << " not tied to the same capture across all regions"; + } + } + result.replaceAllUsesWith(ret0); + } + + if (!has_resource_result) return success(); + + // Eliminate unused region results. Traverse in reverse order so that + // indices to be deleted stay unchanged. + for (OpResult result : llvm::reverse(op->getResults())) { + if (!result.use_empty()) continue; + int result_idx = result.getResultNumber(); + for (Region ®ion : op->getRegions()) + region.front().getTerminator()->eraseOperand(result_idx); + } + EliminateUnusedResults(op); + return success(); +} + +// Canonicalizes a region based while. If the same value is passed through +// the body, the result is replaced with the operand and all argument/results +// and retuns values corresponding to that result are dropped. +LogicalResult CanonicalizeWhileRegion(TF::WhileRegionOp op) { + Region &body = op.body(); + Region &cond = op.cond(); + llvm::BitVector can_eliminate(op.getNumResults()); + + // Traverse in reverse order so that indices to be deleted stay unchanged. + for (OpResult result : llvm::reverse(op.getResults())) { + if (!IsResource(result)) continue; + int result_idx = result.getResultNumber(); + auto body_arg = body.front() + .getTerminator() + ->getOperand(result_idx) + .dyn_cast(); + if (!body_arg || body_arg.getArgNumber() != result_idx) { + return op.emitOpError("Result #") << result_idx << " is not tied to arg #" + << result_idx << " of the body"; + } + body.getArgument(result_idx).replaceAllUsesWith(op.getOperand(result_idx)); + cond.getArgument(result_idx).replaceAllUsesWith(op.getOperand(result_idx)); + body.front().getTerminator()->eraseOperand(result_idx); + body.eraseArgument(result_idx); + cond.eraseArgument(result_idx); + result.replaceAllUsesWith(op.getOperand(result_idx)); + op.getOperation()->eraseOperand(result_idx); + can_eliminate.set(result_idx); + } + EliminateUnusedResults(op, &can_eliminate); + return success(); +} + +// Removes identities and canonicalizes all operations within `parent_op`. +LogicalResult CleanupAndCanonicalize(Operation *parent_op) { + auto walk_result = parent_op->walk([](Operation *op) { + // Cleanup code in attached regions. + for (Region ®ion : op->getRegions()) { + if (!llvm::hasSingleElement(region)) return WalkResult::interrupt(); + RemoveIdentity(region.front()); + RemoveDeadLocalVariables(region.front()); + } + + LogicalResult result = success(); + + // While condition cannot write to resource variables. + auto check_while_cond = [&](TF::AssignVariableOp assign) { + op->emitOpError("found resource write in loop condition."); + return WalkResult::interrupt(); + }; + + if (auto if_op = dyn_cast(op)) { + result = CanonicalizeFunctionalIfCase( + op, {if_op.then_function(), if_op.else_function()}, if_op.input()); + } else if (auto case_op = dyn_cast(op)) { + SmallVector branches; + case_op.get_branch_functions(branches); + result = CanonicalizeFunctionalIfCase(case_op, branches, case_op.input()); + } else if (auto while_op = dyn_cast(op)) { + if (while_op.cond_function().walk(check_while_cond).wasInterrupted()) + return WalkResult::interrupt(); + result = CanonicalizeFunctionalWhile(while_op); + } else if (isa( + op)) { + result = CanonicalizeRegionIfCaseCluster(op); + } else if (auto while_region = dyn_cast(op)) { + if (while_region.cond().walk(check_while_cond).wasInterrupted()) + return WalkResult::interrupt(); + // For while region, the body input and output arg should match. + CanonicalizeWhileRegion(while_region); + } else if (auto call = dyn_cast(op)) { + FuncOp func = dyn_cast(call.resolveCallable()); + if (!func) return WalkResult::interrupt(); + result = CleanupAndCanonicalize(func); + } + return failed(result) ? WalkResult::interrupt() : WalkResult::advance(); + }); + + return failure(walk_result.wasInterrupted()); +} + +} // anonymous namespace + +namespace TF { + +LogicalResult CleanupAndCanonicalizeForResourceOpLifting(FuncOp func) { + return CleanupAndCanonicalize(func); +} + +LogicalResult CleanupAndCanonicalizeForResourceOpLifting(ModuleOp module) { + auto walk_result = module.walk([](tf_device::ClusterOp cluster) { + if (failed(CleanupAndCanonicalize(cluster))) return WalkResult::interrupt(); + return WalkResult::advance(); + }); + return failure(walk_result.wasInterrupted()); +} + +} // namespace TF +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.h b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.h new file mode 100644 index 00000000000..626ef91bcf6 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.h @@ -0,0 +1,47 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_RESOURCE_OP_LIFTING_CLEANUP_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_RESOURCE_OP_LIFTING_CLEANUP_H_ + +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project + +// Performs IR cleanup and canonicalization in preparation for Resource Op +// Lifting pass. It does several things: +// - Eliminate identity nodes to remove (most) of resource aliasing +// - Canonicalize functional control flow. For functional control flow we +// expect that any resource output of these ops matches the corresponding +// input, and then forward that input to the output. Fails if this is not the +// case. If successful, the following invariants will hold true: +// (a) For if/case, any resource type results will be deleted. +// (b) For while, any resource type results will be unused. +// - Canonicalize region based control flow. Again, any resource outputs are +// expected to be resolved to be one of the captured resource inputs. Fails +// if this is not the case. If successful, the following invariants will hold +// true: +// (a) For if/case, any resource type results will be deleted. +// (b) For while, any resource type results will be unused. +namespace mlir { +namespace TF { +LogicalResult CleanupAndCanonicalizeForResourceOpLifting(ModuleOp module); +LogicalResult CleanupAndCanonicalizeForResourceOpLifting(FuncOp func); + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_RESOURCE_OP_LIFTING_CLEANUP_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc index 88ad787df3e..e802353b84c 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include "llvm/ADT/Hashing.h" +#include "llvm/ADT/None.h" #include "llvm/ADT/PointerUnion.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" @@ -41,6 +42,7 @@ limitations under the License. #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project #include "mlir/Interfaces/FoldInterfaces.h" // from @llvm-project +#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project @@ -51,10 +53,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" -#include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" -#include "tensorflow/core/framework/op.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/types.pb.h" @@ -68,7 +67,7 @@ using tensorflow::shape_inference::ShapeHandle; namespace mlir { namespace TF { namespace { -Optional> InferShapeForFunctionReturnType(FuncOp func) { +Optional InferShapeForFunctionReturnType(FuncOp func) { // Find any return ops. SmallVector return_ops; for (Block& block : func) { @@ -111,17 +110,17 @@ Optional> InferShapeForFunctionReturnType(FuncOp func) { } } - return llvm::to_vector<4>(return_op.getOperandTypes()); + return TypeRange(return_op.getOperandTypes()); } // Returns if the shape inference pass supports an op outside the TF dialect. bool IsSupportedNonTFOp(Operation* op) { - return isa(op); + return isa(op); } // Returns whether a cast back would need to be inserted, e.g., whether the @@ -156,57 +155,6 @@ void UpdateTypeAndInsertIncompatibleUseCasts(Dialect* tf_dialect, Type new_type, result.setType(new_type); } -// Extracts a PartialTensorShape from the MLIR type. -Optional GetShapeFromMlirType(Type t) { - if (auto ranked_type = t.dyn_cast()) { - // Convert the MLIR shape indices (int64_t) to TensorFlow indices - // (int64). - ArrayRef shape = ranked_type.getShape(); - SmallVector tf_shape(shape.begin(), shape.end()); - return tensorflow::PartialTensorShape({tf_shape.data(), tf_shape.size()}); - } - return None; -} - -// Gets the subtype's shape and data type for `type`. Templated to support both -// ResourceType and VariantType. -template -std::unique_ptr>> -GetSubtypesHelper(Type type) { - auto type_with_subtypes = - type.cast().getElementType().dyn_cast(); - if (!type_with_subtypes || type_with_subtypes.getSubtypes().empty()) { - return nullptr; - } - auto shapes_and_types = absl::make_unique>>(); - for (auto subtype : type_with_subtypes.getSubtypes()) { - auto shape = GetShapeFromMlirType(subtype); - // handle_shapes_and_types requires all shapes to be known. So if any - // subtype is unknown, clear the vector. - if (!shape) { - shapes_and_types = nullptr; - break; - } - tensorflow::DataType dtype; - auto status = - tensorflow::ConvertToDataType(subtype.getElementType(), &dtype); - assert(status.ok() && "Unknown element type"); - shapes_and_types->emplace_back(*shape, dtype); - } - return shapes_and_types; -} - -// Gets the subtype's shape and data type for `type`. -std::unique_ptr>> -GetSubtypes(Type type) { - auto subclasses = GetSubtypesHelper(type); - if (subclasses) return subclasses; - return GetSubtypesHelper(type); -} - // Returns whether type can be further refined. bool CanBeRefined(Type type) { auto shape_type = type.dyn_cast(); @@ -293,8 +241,8 @@ bool InferShapeForCast(CastOp op, Dialect* tf_dialect) { // function result types. bool InferShapeForIf(IfOp op) { bool changed = false; - auto then_results = op.then_func().getType().getResults(); - auto else_results = op.else_func().getType().getResults(); + auto then_results = op.then_function().getType().getResults(); + auto else_results = op.else_function().getType().getResults(); for (auto it : llvm::zip(op.getResults(), then_results, else_results)) { // If then and else types do not match, skip refinement for that result. if (std::get<1>(it) != std::get<2>(it)) continue; @@ -745,6 +693,11 @@ bool ShapeInference::InferShapeForNonTFDialectOperation(Operation* op) { return RefineTypeForPassThroughOperands(op, terminator->getOperands(), op->getResults()); } + if (auto cluster_op = dyn_cast(op)) { + auto terminator = cluster_op.GetBody().getTerminator(); + return RefineTypeForPassThroughOperands(op, terminator->getOperands(), + op->getResults()); + } if (op->hasTrait()) { return RefineShapeForPassThroughOps(op); } @@ -794,182 +747,54 @@ bool ShapeInference::InferShapeForSingleOperation(Operation* op) { if (auto if_region = dyn_cast(op)) return InferShapeForIfRegion(if_region); - StringRef op_name = op->getName().getStringRef(); - // Drop the `tf.` prefix to query TF registry. - auto node_name = - op_name.drop_front(TensorFlowDialect::getDialectNamespace().size() + 1); - - // Get information from the registry and check if we have a shape function for - // this op. - const tensorflow::OpRegistrationData* op_reg_data = - tensorflow::OpRegistry::Global()->LookUp(node_name.data()); - if (!op_reg_data) { - LLVM_DEBUG(llvm::dbgs() << "Skipping inference for unregistered op '" - << op->getName() << "'.\n"); - return false; - } - if (op_reg_data->shape_inference_fn == nullptr) { - LLVM_DEBUG(llvm::dbgs() - << "Skipping inference for op without shape function '" - << op->getName() << "'.\n"); - return false; - } - - // Convert the operation to a NodeDef to be able to use the InferenceContext - // and the TensorFlow shape function. - auto node_def_or = tensorflow::ConvertTFDialectOpToNodeDef( - op, node_name, /*ignore_unregistered_attrs=*/true); - if (!node_def_or.ok()) { - LLVM_DEBUG(llvm::dbgs() - << "Error converting op '" << *op << "' to NodeDef: " - << node_def_or.status().error_message() << "\n"); - return false; - } - std::unique_ptr node_def = - std::move(node_def_or).ValueOrDie(); - - // Collect an array with input values for constant operands and input shapes - // for all the operands. - std::vector input_tensors(op->getNumOperands()); - std::vector input_shapes( - op->getNumOperands()); - std::vector tensors(op->getNumOperands()); - std::vector>>> - handle_shapes_and_types(op->getNumOperands()); - for (auto it : llvm::enumerate(op->getOperands())) { - Value operand = it.value(); - size_t index = it.index(); - - // If the operand is constant, then convert it to Tensor. + // Return operand as a constant attribute. + auto operand_as_constant_fn = [&](Value operand) { ValuePort vp(operand); Attribute attr = ComputeOutputComponent(vp); if (!attr && matchPattern(operand, m_Constant(&attr))) RecordValue(vp, attr); - if (attr) { - tensorflow::Tensor* input_tensor = &tensors[index]; - auto status = - tensorflow::ConvertToTensor(attr.cast(), input_tensor); - if (status.ok()) { - input_tensors[index] = input_tensor; - } else { - LLVM_DEBUG(llvm::dbgs() - << "Error converting input " << index << " of op '" << *op - << "' to Tensor: " << status.error_message() << "\n"); - } - } + return attr; + }; - Type operand_type = operand.getType(); - if (auto shape = GetShapeFromMlirType(operand_type)) { - input_shapes[index] = *shape; - } - // Collect the handle shapes and types for a resource/variant. - handle_shapes_and_types[index] = GetSubtypes(operand_type); - } + // Return op result as a shape. + auto op_result_as_shape_fn = [&](InferenceContext& context, + OpResult op_result) { + return ComputeOutputAsShape(op_result, &context); + }; - // Perform the shape inference using an InferenceContext with the input - // shapes. This object is abstracting the information that the ShapeInference - // function operates on. - InferenceContext c(graph_version_, *node_def, op_reg_data->op_def, - input_shapes, input_tensors, - /*input_tensors_as_shapes=*/{}, handle_shapes_and_types); - auto status = c.Run(op_reg_data->shape_inference_fn); - if (!status.ok()) { - LLVM_DEBUG(llvm::dbgs() << "Shape inference error for '" << *op - << "': " << status.error_message() << "\n"); + // Return result element type at `index`. + auto result_element_type_fn = [&](int index) { + return op->getResult(index).getType().cast().getElementType(); + }; + + llvm::SmallVector inferred_return_shapes; + if (failed(InferReturnTypeComponentsForTFOp( + /*location=*/None, op, graph_version_, operand_as_constant_fn, + op_result_as_shape_fn, result_element_type_fn, + inferred_return_shapes))) return false; - } - - // Determine if, during shape computation, the shape functions attempted to - // query an input operand as shape where the input was not known/constant. - bool requires_inputs = - any_of(llvm::seq(0, c.num_inputs()), [&](int input) { - return c.requested_input_tensor_as_partial_shape(input) && - !input_tensors[input]; - }); - if (requires_inputs) { - LLVM_DEBUG(llvm::dbgs() << "\trequired input\n"); - std::vector input_tensors_as_shapes; - for (int input : llvm::seq(0, c.num_inputs())) { - if (c.requested_input_tensor_as_partial_shape(input) && - !input_tensors[input]) { - LLVM_DEBUG(llvm::dbgs() << "Requesting " << input << " as shape\n"); - auto op_result = op->getOperand(input).dyn_cast(); - if (!op_result) continue; - // Resize on first valid shape computed. - input_tensors_as_shapes.resize(c.num_inputs()); - auto handle = ComputeOutputAsShape(op_result, &c); - LLVM_DEBUG(llvm::dbgs() << "Requested " << input << " as shape " - << (handle.Handle() ? "found" : "not found")); - if (handle.Handle()) input_tensors_as_shapes[input] = handle; - } - } - - // Attempt to compute the unknown operands as shapes. - // Note: in the case where no partial outputs could be computed, this would - // be empty. - if (!input_tensors_as_shapes.empty()) { - c.set_input_tensors_as_shapes(input_tensors_as_shapes); - auto status = c.Run(op_reg_data->shape_inference_fn); - if (!status.ok()) { - LLVM_DEBUG(llvm::dbgs() << "Shape inference error for '" << *op - << "': " << status.error_message() << "\n"); - return false; - } - } - } - - assert(c.num_outputs() == op->getNumResults() && - "inference context matches the MLIR number of results."); // Update the shape for each of the operation result if the InferenceContext // has more precise shapes recorded. bool changed = false; - for (int output : llvm::seq(0, c.num_outputs())) { - // Skip already statically shaped results. - Value result = op->getResult(output); - if (!CanBeRefined(result.getType())) continue; - auto shaped_type = result.getType().cast(); + for (auto result : llvm::zip(op->getResults(), inferred_return_shapes)) { + Value op_result = std::get<0>(result); + if (!CanBeRefined(op_result.getType())) continue; - ShapeHandle shape_handle = c.output(output); - LLVM_DEBUG(llvm::dbgs() << "Inferred output " << output << " : " - << c.DebugString(shape_handle) << "\n"); - auto get_tensor_type = [&c](const ShapeHandle& sh, - Type element_type) -> TensorType { - if (!c.RankKnown(sh)) return UnrankedTensorType::get(element_type); - // Convert the shape from TensorFlow (int64) to MLIR (int64_t). - SmallVector shape; - for (int dim : llvm::seq(0, c.Rank(sh))) - shape.push_back(c.Value(c.Dim(sh, dim))); - return RankedTensorType::get(shape, element_type); - }; - auto new_element_type = shaped_type.getElementType(); - // Populate the handle shapes for a resource/variant. - if (new_element_type.isa()) { - auto handle_shapes_types = c.output_handle_shapes_and_types(output); - if (handle_shapes_types) { - SmallVector subtypes; - OpBuilder b(op); - for (const auto& shape_n_type : *handle_shapes_types) { - Type element_type; - auto status = - tensorflow::ConvertDataType(shape_n_type.dtype, b, &element_type); - assert(status.ok() && "Unknown element type"); - subtypes.push_back(get_tensor_type(shape_n_type.shape, element_type)); - } - if (new_element_type.isa()) { - new_element_type = TF::ResourceType::get(subtypes, op->getContext()); - } else { - new_element_type = TF::VariantType::get(subtypes, op->getContext()); - } - } - } - auto new_type = get_tensor_type(shape_handle, new_element_type); - if (result.getType() == new_type) continue; + ShapedTypeComponents inferred = std::get<1>(result); + TensorType inferred_type; + if (inferred.hasRank()) + inferred_type = + RankedTensorType::get(inferred.getDims(), inferred.getElementType()); + else + inferred_type = UnrankedTensorType::get(inferred.getElementType()); - UpdateTypeAndInsertIncompatibleUseCasts(tf_dialect_, new_type, op, result); + if (op_result.getType() == inferred_type) continue; + UpdateTypeAndInsertIncompatibleUseCasts(tf_dialect_, inferred_type, op, + op_result); changed = true; } + if (changed) LLVM_DEBUG(llvm::dbgs() << "Modified after shape inference: '" << *op << "'\n"); @@ -980,7 +805,6 @@ LogicalResult ShapeInference::PropagateShapeToFunctions( ModuleOp module, Operation::operand_type_range input_types, ArrayRef functions, int64_t max_iteration) { bool all_succeeded = true; - auto types = llvm::to_vector<4>(input_types); // If shape propagation fails for one function, return failure, but do not // early exit and attempt to propagate shapes for all provided functions to // have a best-effort propagation. @@ -997,8 +821,8 @@ LogicalResult ShapeInference::PropagateShapeToFunctions( } FunctionType func_type = func.getType(); - func.setType( - FunctionType::get(types, func_type.getResults(), func.getContext())); + func.setType(FunctionType::get(input_types, func_type.getResults(), + func.getContext())); auto res = PropagateShapeToRegions(input_types, {&func.getBody()}, max_iteration); @@ -1009,7 +833,7 @@ LogicalResult ShapeInference::PropagateShapeToFunctions( auto new_return_types = InferShapeForFunctionReturnType(func); if (new_return_types) - func.setType(FunctionType::get(types, new_return_types.getValue(), + func.setType(FunctionType::get(input_types, new_return_types.getValue(), func.getContext())); } return success(all_succeeded); @@ -1019,16 +843,17 @@ LogicalResult ShapeInference::PropagateShapeToRegions( Operation::operand_type_range input_types, ArrayRef regions, int64_t max_iteration) { bool all_succeeded = true; - auto types = llvm::to_vector<4>(input_types); // If shape propagation fails for one region, return failure, but do not // early exit and attempt to propagate shapes for all provided regions to // have a best-effort propagation. for (auto region : regions) { // Refine region arguments. Block& entry = region->front(); - assert(types.size() == entry.getNumArguments()); - for (auto arg_and_idx : llvm::enumerate(entry.getArguments())) { - arg_and_idx.value().setType(types[arg_and_idx.index()]); + assert(llvm::size(input_types) == entry.getNumArguments()); + for (auto it : llvm::zip(entry.getArguments(), input_types)) { + BlockArgument arg = std::get<0>(it); + Type type = std::get<1>(it); + arg.setType(type); } // Propagate shapes into the region. @@ -1099,20 +924,17 @@ LogicalResult ShapeInference::PropagateShapeIntoAttachedFunctions( if (auto if_op = dyn_cast(op)) { return PropagateShapeToFunctions( module, drop_begin(if_op.getOperandTypes(), 1), - {if_op.then_func(), if_op.else_func()}, max_iteration); + {if_op.then_function(), if_op.else_function()}, max_iteration); } else if (auto case_op = dyn_cast(op)) { SmallVector branches; - for (Attribute branch : case_op.branches()) { - auto sym = branch.cast(); - branches.push_back(SymbolTable::lookupNearestSymbolFrom(op, sym)); - } + case_op.get_branch_functions(branches); return PropagateShapeToFunctions(module, drop_begin(case_op.getOperandTypes(), 1), branches, max_iteration); } else if (auto while_op = dyn_cast(op)) { return PropagateShapeToFunctions( module, while_op.getOperandTypes(), - {while_op.cond_func(), while_op.body_func()}, max_iteration); + {while_op.cond_function(), while_op.body_function()}, max_iteration); } else if (auto call_op = dyn_cast(op)) { if (auto func = dyn_cast(call_op.resolveCallable())) { PropagateConstantToCallee(call_op, func, module); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc b/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc index d3755a4a7d0..05eef4d5045 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc @@ -139,8 +139,7 @@ void ModifyFunctionSignature( handle_new_size_vars(func.getArguments().drop_front(original_arg_count)); } func.setType(FunctionType::get( - new_input_types, - llvm::to_vector<8>(func.front().getTerminator()->getOperandTypes()), + new_input_types, func.front().getTerminator()->getOperandTypes(), func.getContext())); } @@ -163,7 +162,7 @@ LogicalResult HandleWhileOp( const llvm::SmallDenseMap& data_var_to_size_var, llvm::StringMap* decomposed_partitioned_call_callees) { - auto body = while_op.body_func(); + auto body = while_op.body_function(); llvm::SmallDenseMap body_map; auto find_arg_stack_type = [&](int64_t index) -> llvm::Optional { auto it = data_var_to_size_var.find(while_op.getOperand(index)); @@ -187,7 +186,7 @@ LogicalResult HandleWhileOp( return failure(); } // Cond should not change stacks in the arguments, so use an empty map. - auto cond = while_op.cond_func(); + auto cond = while_op.cond_function(); ModifyFunctionSignature(cond, nullptr, find_arg_stack_type); llvm::SmallDenseMap empty_map; if (failed(DecomposeStackOpsInternal(&cond.front(), module, &empty_map, @@ -231,8 +230,8 @@ LogicalResult HandleIfOp( const llvm::SmallDenseMap& data_var_to_size_var, llvm::StringMap* decomposed_partitioned_call_callees) { - auto then_func = if_op.then_func(); - auto else_func = if_op.else_func(); + auto then_func = if_op.then_function(); + auto else_func = if_op.else_function(); llvm::SmallDenseMap then_map; llvm::SmallDenseMap else_map; @@ -465,6 +464,38 @@ LogicalResult HandleStackPopV2Op( return success(); } +LogicalResult HandleRegionControlFlowOps( + Operation& op, ModuleOp module, + llvm::SmallDenseMap* data_var_to_size_var, + llvm::StringMap* + decomposed_partitioned_call_callees) { + for (OpOperand& operand : op.getOpOperands()) { + if (getElementTypeOrSelf(operand.get().getType()).isa()) { + return op.emitOpError() + << "found unexpected type " << operand.get().getType() + << " of operand #" << operand.getOperandNumber() + << ", resource type operands are expected to have been " + "canonicalized away for region based control flow ops"; + } + } + for (OpResult result : op.getResults()) { + if (getElementTypeOrSelf(result.getType()).isa()) { + return op.emitOpError() + << "found unexpected type " << result.getType() << " of result #" + << result.getResultNumber() + << ", resource type results are expected to have been " + "canonicalized away for region based control flow ops"; + } + } + for (Region& region : op.getRegions()) { + if (failed(DecomposeStackOpsInternal(®ion.front(), module, + data_var_to_size_var, + decomposed_partitioned_call_callees))) + return failure(); + } + return success(); +} + // Decomposes stack ops on a region and recursively decomposes called functions. // data_var_to_size_var: a mapping from stacks' buffer local variables to size // local variables. @@ -506,6 +537,13 @@ LogicalResult DecomposeStackOpsInternal( decomposed_partitioned_call_callees))) { return failure(); } + } else if (llvm::isa(op) || + llvm::isa(op) || + llvm::isa(op)) { + if (failed( + HandleRegionControlFlowOps(op, module, data_var_to_size_var, + decomposed_partitioned_call_callees))) + return failure(); } else if (auto pcall = llvm::dyn_cast(&op)) { if (!pcall.func()) { return pcall.emitOpError( diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc index b3a05c06a67..680d5334ceb 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc @@ -132,6 +132,15 @@ llvm::Optional> GetTensorArrayElementShape( return llvm::None; } return t; + } else if (auto scatter = + llvm::dyn_cast(user)) { + // TensorArrayScatter writes vector of tensors to TensorArray. We can + // deduce the shape of TensorArray by dropping the 0th dim of + // TensorArrayScatter `value`. + auto t = scatter.value().getType().dyn_cast(); + if (!t || t.getShape().empty()) return llvm::None; + return RankedTensorType::get(t.getShape().drop_front(), + t.getElementType()); } return llvm::None; }); @@ -139,6 +148,26 @@ llvm::Optional> GetTensorArrayElementShape( return llvm::to_vector<8>(elem_type->getShape()); } +void ReplaceAllUsesWithCast(Value old_val, Value new_val) { + if (old_val.use_empty()) return; + auto cast_op = + OpBuilder(old_val.getDefiningOp()) + .create(old_val.getLoc(), old_val.getType(), new_val); + old_val.replaceAllUsesWith(cast_op); +} + +void ReplaceAllUsesExceptTerminator(Value old_val, Value new_val) { + if (old_val.getType() == new_val.getType()) { + old_val.replaceAllUsesWith(new_val); + return; + } + Operation* old_op = old_val.getDefiningOp(); + Operation* terminator_op = + old_op->getParentOfType().front().getTerminator(); + llvm::SmallPtrSet exceptions = {terminator_op}; + old_val.replaceAllUsesExcept(new_val, exceptions); +} + struct TensorArrayStats { // Whether a write op should accumulate with the old value. Set to true if // this is a gradient. @@ -195,7 +224,8 @@ LogicalResult HandleTensorArrayReadV3Op( auto index_reshape = cutil::ReshapeScalarToSizeType(builder, read.index(), read.getLoc()); auto elem = cutil::GetElement(index_reshape, buffer, builder, read.getLoc()); - read.value().replaceAllUsesWith(elem); + ReplaceAllUsesExceptTerminator(read.value(), elem); + ReplaceAllUsesWithCast(read.value(), elem); read.erase(); // The clear_after_read attribute does not mean setting the tensor to 0 after // read; instead it does not allow a second read before the next write. We @@ -260,7 +290,8 @@ LogicalResult HandleTensorArrayConcatV3Op( RankedTensorType::get(shape, buffer_type.getElementType())}, ArrayRef{buffer, cutil::GetR1Const(shape, builder, concat.getLoc())}); - concat.value().replaceAllUsesWith(buffer); + ReplaceAllUsesExceptTerminator(concat.value(), buffer); + ReplaceAllUsesWithCast(concat.value(), buffer); // Create the lengths as a list of the same value (element size). tensorflow::Tensor lengths_tensor(tensorflow::DT_INT64, @@ -389,7 +420,8 @@ LogicalResult HandleTensorArrayGatherV3Op( auto buffer = cutil::ReadLocalVariable(local_var, builder, gather.getLoc()); auto result = cutil::GatherElements(gather.indices(), buffer, builder, gather.getLoc()); - gather.value().replaceAllUsesWith(result); + ReplaceAllUsesExceptTerminator(gather.value(), result); + ReplaceAllUsesWithCast(gather.value(), result); gather.erase(); return success(); } @@ -443,12 +475,12 @@ llvm::SmallDenseMap> AccessedGradients( insert(grad.handle(), grad.source().str()); } else if (auto while_op = llvm::dyn_cast(&op)) { for (const auto& entry : AccessedGradients( - {while_op.body_func(), while_op.cond_func()}, module)) + {while_op.body_function(), while_op.cond_function()}, module)) for (const string& source : entry.getSecond()) insert(while_op.getOperand(entry.getFirst()), source); } else if (auto if_op = llvm::dyn_cast(&op)) { - for (const auto& entry : - AccessedGradients({if_op.then_func(), if_op.else_func()}, module)) + for (const auto& entry : AccessedGradients( + {if_op.then_function(), if_op.else_function()}, module)) for (const string& source : entry.getSecond()) insert(if_op.getOperand(entry.getFirst() + 1), source); } else if (auto call = llvm::dyn_cast(&op)) { @@ -509,8 +541,8 @@ LogicalResult HandleWhileOp(TF::WhileOp while_op, ModuleOp module, llvm::SmallDenseMap* stats, llvm::StringMap* decomposed_partitioned_call_callees) { - auto body = while_op.body_func(); - auto cond = while_op.cond_func(); + auto body = while_op.body_function(); + auto cond = while_op.cond_function(); auto grads = AccessedGradients({body, cond}, module); auto ta_arg_buffer_type = [&](int64_t index) -> Type { auto it = stats->find(while_op.getOperand(index)); @@ -570,6 +602,7 @@ LogicalResult HandleWhileOp(TF::WhileOp while_op, ModuleOp module, } stat.grads[source] = grad_var; operands.push_back(grad_var); + (*stats)[grad_var].accumulate_on_write = true; } } } @@ -592,8 +625,8 @@ LogicalResult HandleIfOp(TF::IfOp if_op, ModuleOp module, llvm::SmallDenseMap* stats, llvm::StringMap* decomposed_partitioned_call_callees) { - auto then_branch = if_op.then_func(); - auto else_branch = if_op.else_func(); + auto then_branch = if_op.then_function(); + auto else_branch = if_op.else_function(); auto grads = AccessedGradients({then_branch, else_branch}, module); auto ta_arg_buffer_type = [&](int64_t index) -> Type { auto it = stats->find(if_op.getOperand(index + 1)); @@ -636,6 +669,7 @@ LogicalResult HandleIfOp(TF::IfOp if_op, ModuleOp module, } stat.grads[source] = grad_var; operands.push_back(grad_var); + (*stats)[grad_var].accumulate_on_write = true; } } } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc index 9634e4a8be3..f7c0357a212 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc @@ -73,8 +73,7 @@ void UpdateFuncType(FuncOp func) { llvm::SmallVector arg_types; for (auto arg : func.getArguments()) arg_types.push_back(arg.getType()); func.setType(FunctionType::get( - arg_types, - llvm::to_vector<8>(func.front().getTerminator()->getOperandTypes()), + arg_types, func.front().getTerminator()->getOperandTypes(), func.getContext())); } @@ -125,26 +124,39 @@ LogicalResult DecomposeTensorListOpsInternal( Block*, ModuleOp, llvm::SmallDenseMap*, llvm::StringMap*); +// Adds the corresponding sizes of tensor list buffers in block's terminator +// to the list of return values. Returns the mapping from the buffer +// indices to the added size indices, which is a list of tuples +// (buffer_return_index, size_return_index, fixed_size). +template +llvm::SmallVector, 8> +AddTensorListSizesToTerminator( + Block& block, const llvm::SmallDenseMap& buffer_to_size) { + auto old_terminator = block.getTerminator(); + auto new_outputs = llvm::to_vector<8>(old_terminator->getOperands()); + llvm::SmallVector, 8> + output_buffer_to_size; + for (auto retval : llvm::enumerate(old_terminator->getOperands())) { + auto it = buffer_to_size.find(retval.value()); + if (it == buffer_to_size.end()) continue; + output_buffer_to_size.emplace_back(retval.index(), new_outputs.size(), + it->getSecond().fixed); + new_outputs.push_back(it->getSecond().size); + } + OpBuilder(old_terminator) + .create(old_terminator->getLoc(), new_outputs); + old_terminator->erase(); + return output_buffer_to_size; +} + // Adds the corresponding sizes of tensor list buffers in func's return values // to the list of return values. Returns the mapping from the buffer indices to // the added size indices, which is a list of tuples (buffer_return_index, // size_return_index, fixed_size). -llvm::SmallVector, 8> -AddTensorListSizesToReturn( +llvm::SmallVector, 8> ModifyFunctionReturn( FuncOp func, const llvm::SmallDenseMap& buffer_to_size) { - auto old_return = func.front().getTerminator(); - auto new_returns = llvm::to_vector<8>(old_return->getOperands()); - llvm::SmallVector, 8> - output_buffer_to_size; - for (auto retval : llvm::enumerate(old_return->getOperands())) { - auto it = buffer_to_size.find(retval.value()); - if (it == buffer_to_size.end()) continue; - output_buffer_to_size.emplace_back(retval.index(), new_returns.size(), - it->getSecond().fixed); - new_returns.push_back(it->getSecond().size); - } - OpBuilder(old_return).create(old_return->getLoc(), new_returns); - old_return->erase(); + auto output_buffer_to_size = + AddTensorListSizesToTerminator(func.front(), buffer_to_size); UpdateFuncType(func); return output_buffer_to_size; } @@ -155,7 +167,7 @@ LogicalResult HandleWhileOp( llvm::StringMap* decomposed_partitioned_call_callees) { // Rewrite body. - auto body = while_op.body_func(); + auto body = while_op.body_function(); llvm::SmallDenseMap body_map; auto find_arg_tensor_list_type = [&](int64_t index) -> llvm::Optional { auto it = buffer_to_size->find(while_op.getOperand(index)); @@ -173,10 +185,10 @@ LogicalResult HandleWhileOp( decomposed_partitioned_call_callees))) { return failure(); } - auto output_buffer_to_size = AddTensorListSizesToReturn(body, body_map); + auto output_buffer_to_size = ModifyFunctionReturn(body, body_map); // Rewrite cond. - auto cond = while_op.cond_func(); + auto cond = while_op.cond_function(); llvm::SmallDenseMap cond_map; ModifyFunctionSignature(cond, cutil::GetSizeType(builder), &cond_map, find_arg_tensor_list_type, arg_buffer_size_is_fixed); @@ -241,9 +253,9 @@ LogicalResult HandleCaseOrIfOp( const bool arg_no_changed = branch_maps.front().empty(); auto output_buffer_to_size = - AddTensorListSizesToReturn(branches.front(), branch_maps.front()); + ModifyFunctionReturn(branches.front(), branch_maps.front()); for (const auto& pair : llvm::drop_begin(llvm::zip(branches, branch_maps), 1)) - AddTensorListSizesToReturn(std::get<0>(pair), std::get<1>(pair)); + ModifyFunctionReturn(std::get<0>(pair), std::get<1>(pair)); if (output_buffer_to_size.empty() && arg_no_changed) return success(); @@ -267,6 +279,158 @@ LogicalResult HandleCaseOrIfOp( return success(); } +LogicalResult HandleWhileRegionOp( + TF::WhileRegionOp while_op, ModuleOp module, + llvm::SmallDenseMap* buffer_to_size, + llvm::StringMap* + decomposed_partitioned_call_callees) { + OpBuilder builder(while_op); + auto modify_region_arguments = [&](Region& region) { + int64_t original_arg_count = region.getNumArguments(); + for (int64_t i = 0; i < original_arg_count; ++i) { + auto operand = while_op.getOperand(i); + auto it = buffer_to_size->find(operand); + if (it == buffer_to_size->end()) continue; + auto buffer_type = it->getFirst().getType(); + region.getArgument(i).setType(buffer_type); + auto size_arg = region.addArgument(cutil::GetSizeType(builder)); + (*buffer_to_size)[region.getArgument(i)] = {size_arg, + it->getSecond().fixed}; + } + }; + + // Rewrite body. + Region& body_region = while_op.body(); + modify_region_arguments(body_region); + if (failed(DecomposeTensorListOpsInternal( + &body_region.front(), module, buffer_to_size, + decomposed_partitioned_call_callees))) { + return failure(); + } + auto output_buffer_to_size = AddTensorListSizesToTerminator( + body_region.front(), *buffer_to_size); + + // Rewrite cond. + Region& cond_region = while_op.cond(); + modify_region_arguments(cond_region); + if (failed(DecomposeTensorListOpsInternal( + &cond_region.front(), module, buffer_to_size, + decomposed_partitioned_call_callees))) { + return failure(); + } + + if (output_buffer_to_size.empty()) return success(); + + // Create the new while op. + auto new_while_operands = llvm::to_vector<8>(while_op.getOperands()); + for (int64_t i = 0; i < while_op.getNumResults(); ++i) { + auto it = buffer_to_size->find(while_op.getOperand(i)); + if (it == buffer_to_size->end()) continue; + new_while_operands.push_back(it->getSecond().size); + } + auto new_while = builder.create( + while_op.getLoc(), body_region.front().getTerminator()->getOperandTypes(), + new_while_operands, while_op.getAttrs()); + new_while.body().takeBody(body_region); + new_while.cond().takeBody(cond_region); + for (const auto& entry : output_buffer_to_size) { + (*buffer_to_size)[new_while.getResult(std::get<0>(entry))] = { + new_while.getResult(std::get<1>(entry)), std::get<2>(entry)}; + } + while_op.replaceAllUsesWith( + new_while.getResults().take_front(while_op.getNumResults())); + while_op.erase(); + return success(); +} + +LogicalResult HandleIfRegionOp( + TF::IfRegionOp if_op, ModuleOp module, + llvm::SmallDenseMap* buffer_to_size, + llvm::StringMap* + decomposed_partitioned_call_callees) { + // Rewrite the branches. + Region& then_branch = if_op.then_branch(); + Region& else_branch = if_op.else_branch(); + if (failed(DecomposeTensorListOpsInternal( + &then_branch.front(), module, buffer_to_size, + decomposed_partitioned_call_callees))) + return failure(); + if (failed(DecomposeTensorListOpsInternal( + &else_branch.front(), module, buffer_to_size, + decomposed_partitioned_call_callees))) + return failure(); + + auto output_buffer_to_size = AddTensorListSizesToTerminator( + then_branch.front(), *buffer_to_size); + AddTensorListSizesToTerminator(else_branch.front(), + *buffer_to_size); + + if (output_buffer_to_size.empty()) return success(); + + // Recreate the op. + auto new_op = OpBuilder(if_op).create( + if_op.getLoc(), then_branch.front().getTerminator()->getOperandTypes(), + if_op.getOperand(), if_op.getAttrs()); + for (const auto& entry : output_buffer_to_size) { + (*buffer_to_size)[new_op.getResult(std::get<0>(entry))] = { + new_op.getResult(std::get<1>(entry)), std::get<2>(entry)}; + } + + new_op.then_branch().takeBody(if_op.then_branch()); + new_op.else_branch().takeBody(if_op.else_branch()); + + if_op.replaceAllUsesWith( + new_op.getResults().take_front(if_op.getNumResults())); + if_op.erase(); + return success(); +} + +LogicalResult HandleCaseRegionOp( + TF::CaseRegionOp case_op, ModuleOp module, + llvm::SmallDenseMap* buffer_to_size, + llvm::StringMap* + decomposed_partitioned_call_callees) { + // Rewrite the branches. + RegionRange branches = case_op.getRegions(); + + for (Region* branch : branches) { + if (failed(DecomposeTensorListOpsInternal( + &branch->front(), module, buffer_to_size, + decomposed_partitioned_call_callees))) + return failure(); + } + + // Get the output buffer index to size index mapping one of the branches. It + // should be same for all the branches so we only get it for the first branch. + Region* first_branch = branches.front(); + auto output_buffer_to_size = AddTensorListSizesToTerminator( + first_branch->front(), *buffer_to_size); + for (Region* branch : branches.drop_front()) { + AddTensorListSizesToTerminator(branch->front(), + *buffer_to_size); + } + + if (output_buffer_to_size.empty()) return success(); + + // Recreate the op. + auto new_op = OpBuilder(case_op).create( + case_op.getLoc(), + first_branch->front().getTerminator()->getOperandTypes(), + case_op.getOperand(), case_op.getAttrs(), case_op.getNumRegions()); + for (const auto& entry : output_buffer_to_size) { + (*buffer_to_size)[new_op.getResult(std::get<0>(entry))] = { + new_op.getResult(std::get<1>(entry)), std::get<2>(entry)}; + } + + for (auto pair : llvm::zip(new_op.getRegions(), case_op.getRegions())) { + std::get<0>(pair)->takeBody(*std::get<1>(pair)); + } + case_op.replaceAllUsesWith( + new_op.getResults().take_front(case_op.getNumResults())); + case_op.erase(); + return success(); +} + template LogicalResult HandlePartitionedCallOp( CallOp call, FuncOp callee, ModuleOp module, @@ -337,7 +501,7 @@ LogicalResult HandlePartitionedCallOp( return failure(); } info.buffer_ret_to_size_ret = - AddTensorListSizesToReturn(lowered_callee, callee_map); + ModifyFunctionReturn(lowered_callee, callee_map); info.decomposed_callee = lowered_callee; if (args_no_changed && info.buffer_ret_to_size_ret.empty()) { // Signature is not modified. We do not need to keep two copies. @@ -701,17 +865,14 @@ LogicalResult DecomposeTensorListOpsInternal( return failure(); } } else if (auto if_op = llvm::dyn_cast(&op)) { - if (failed(HandleCaseOrIfOp(if_op, {if_op.then_func(), if_op.else_func()}, - module, buffer_to_size, - decomposed_partitioned_call_callees))) { + if (failed(HandleCaseOrIfOp( + if_op, {if_op.then_function(), if_op.else_function()}, module, + buffer_to_size, decomposed_partitioned_call_callees))) { return failure(); } } else if (auto case_op = llvm::dyn_cast(&op)) { SmallVector branches; - for (auto branch_symbol : case_op.branches()) { - branches.push_back(module.lookupSymbol( - branch_symbol.cast())); - } + case_op.get_branch_functions(branches); if (failed(HandleCaseOrIfOp(case_op, branches, module, buffer_to_size, decomposed_partitioned_call_callees))) { return failure(); @@ -734,6 +895,21 @@ LogicalResult DecomposeTensorListOpsInternal( decomposed_partitioned_call_callees))) { return failure(); } + } else if (auto while_op = llvm::dyn_cast(&op)) { + if (failed(HandleWhileRegionOp(while_op, module, buffer_to_size, + decomposed_partitioned_call_callees))) { + return failure(); + } + } else if (auto if_op = llvm::dyn_cast(&op)) { + if (failed(HandleIfRegionOp(if_op, module, buffer_to_size, + decomposed_partitioned_call_callees))) { + return failure(); + } + } else if (auto case_op = llvm::dyn_cast(&op)) { + if (failed(HandleCaseRegionOp(case_op, module, buffer_to_size, + decomposed_partitioned_call_callees))) { + return failure(); + } } } return success(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization.cc index 786c4b74b34..f2321df9823 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization.cc @@ -58,7 +58,7 @@ struct FuseParallelMapAndBatch : public OpRewritePattern { void PopulateTFDataOptimizationPatterns(MLIRContext *context, OwningRewritePatternList *patterns) { patterns->insert(context); - populateWithGenerated(context, patterns); + populateWithGenerated(context, *patterns); } } // namespace TF diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc index 1e4caaf5dd6..52ac87ecf71 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc @@ -20,6 +20,7 @@ limitations under the License. #include "mlir/IR/Identifier.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" #include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h" #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" @@ -43,6 +44,10 @@ namespace tensorflow { class GraphOptPass : public mlir::PassWrapper> { + void getDependentDialects(mlir::DialectRegistry& registry) const override { + mlir::RegisterAllTensorFlowDialects(registry); + } + public: explicit GraphOptPass(std::vector passes) : passes_(std::move(passes)) {} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_cleanup_attributes.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_cleanup_attributes.cc new file mode 100644 index 00000000000..93098acdc9d --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_cleanup_attributes.cc @@ -0,0 +1,60 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Transforms/Passes.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" + +// This pass eliminate `_tpu_replicate` and `device` attribute on operations +// that are contained in a tf_device.cluster op. + +namespace mlir { +namespace TFTPU { + +namespace { + +constexpr char kTPUReplicateAttr[] = "_tpu_replicate"; +constexpr char kDeviceAttr[] = "device"; + +class TPUCleanupClusterAttributesPass + : public PassWrapper> { + public: + void runOnOperation() override { + getOperation().walk([](tf_device::ClusterOp cluster) { + cluster.walk([](Operation *op) { + if (isa(op)) return; + for (StringRef attr : {kTPUReplicateAttr, kDeviceAttr}) + op->removeAttr(attr); + }); + }); + } +}; + +PassRegistration pass( + "tf-tpu-cleanup-cluster-attributes", + "Eliminate _tpu_replicate and other attributes from ops in a cluster"); + +} // namespace + +std::unique_ptr> +CreateTPUClusterCleanupAttributesPass() { + return std::make_unique(); +} + +} // namespace TFTPU +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc index f5bdd08d980..46bc094e5ed 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc @@ -71,13 +71,19 @@ constexpr char kBadTPUReplicateAttrMsg[] = using MetadataMap = llvm::SmallDenseMap; +// A set of operations in a cluster. +using ClusterOps = llvm::SmallSetVector; + // Mapping for `_tpu_replicate` attribute to ops of a cluster. -using ClusterMap = llvm::SmallDenseMap, 8>; +using ClusterMap = llvm::SmallDenseMap; struct TPUClusterFormation : public TF::PerFunctionAggregateAnalysisConsumerPass< TPUClusterFormation, TF::ResourceAliasAnalysis> { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + void runOnFunction( FuncOp func, const TF::ResourceAliasAnalysis::Info& resource_alias_analysis); @@ -87,42 +93,40 @@ struct TPUClusterFormation // attribute to its attributes and removes the ops. If multiple // TPUReplicateMetadata ops have the same `_tpu_replicate` attribute, an error // will be returned. -LogicalResult CollectMetadata(Operation* op, MetadataMap* metadata_map) { - auto result = - op->walk([&](TF::TPUReplicateMetadataOp metadata_op) -> WalkResult { - MutableDictionaryAttr attrs = metadata_op.getAttrs(); +LogicalResult CollectMetadata(Block* block, MetadataMap* metadata_map) { + // Just look at top-level operations in the block (not nested ones) + for (Operation& op : llvm::make_early_inc_range(*block)) { + auto metadata_op = dyn_cast(op); + if (!metadata_op) continue; - // Missing or bad `_tpu_replicate` attribute. - auto tpu_replicate_attr = attrs.get(kTPUReplicateAttr); - if (!tpu_replicate_attr) - return metadata_op.emitError() << kBadTPUReplicateAttrMsg; + MutableDictionaryAttr attrs = metadata_op.getAttrs(); - auto tpu_replicate_attr_str = tpu_replicate_attr.dyn_cast(); - if (!tpu_replicate_attr_str || - tpu_replicate_attr_str.getValue().empty()) - return metadata_op.emitError() << kBadTPUReplicateAttrMsg; + // Missing or bad `_tpu_replicate` attribute. + auto tpu_replicate_attr = attrs.get(kTPUReplicateAttr); + if (!tpu_replicate_attr) + return metadata_op.emitError() << kBadTPUReplicateAttrMsg; - // Remove `name` attribute. - attrs.remove(Identifier::get(kNameAttr, metadata_op.getContext())); + auto tpu_replicate_attr_str = tpu_replicate_attr.dyn_cast(); + if (!tpu_replicate_attr_str || tpu_replicate_attr_str.getValue().empty()) + return metadata_op.emitError() << kBadTPUReplicateAttrMsg; - auto it = metadata_map->try_emplace(tpu_replicate_attr_str.getValue(), - std::move(attrs)); + // Remove `name` attribute. + attrs.remove(Identifier::get(kNameAttr, metadata_op.getContext())); - // There are multiple TPUReplicateMetadata ops with the same - // `_tpu_replicate` attribute. - if (!it.second) { - return metadata_op.emitError() - << "multiple TPUReplicateMetadata ops with the same '" - << kTPUReplicateAttr << "' attribute '" - << tpu_replicate_attr_str.getValue() << "' found"; - } + auto it = metadata_map->try_emplace(tpu_replicate_attr_str.getValue(), + std::move(attrs)); - metadata_op.erase(); - return WalkResult::advance(); - }); - - // Return failure if the walk was interrupted. - return failure(result.wasInterrupted()); + // There are multiple TPUReplicateMetadata ops with the same + // `_tpu_replicate` attribute. + if (!it.second) { + return metadata_op.emitError() + << "multiple TPUReplicateMetadata ops with the same '" + << kTPUReplicateAttr << "' attribute '" + << tpu_replicate_attr_str.getValue() << "' found"; + } + metadata_op.erase(); + } + return success(); } // Collects and clusters ops with the same `_tpu_replicate` attribute. This will @@ -150,12 +154,12 @@ void CollectResourceIdsFromOp( op.walk([&](Operation* inner_op) { for (Value operand : TF::filter_resources(inner_op->getOperands())) { if (resource_alias_analysis.IsUnknownResource(operand)) continue; - auto ids = resource_alias_analysis.GetResourceUniqueIds(operand); + const auto& ids = resource_alias_analysis.GetResourceUniqueIds(operand); observed_resource_ids.insert(ids.begin(), ids.end()); } for (Value result : TF::filter_resources(inner_op->getResults())) { if (resource_alias_analysis.IsUnknownResource(result)) continue; - auto ids = resource_alias_analysis.GetResourceUniqueIds(result); + const auto& ids = resource_alias_analysis.GetResourceUniqueIds(result); observed_resource_ids.insert(ids.begin(), ids.end()); } }); @@ -164,13 +168,13 @@ void CollectResourceIdsFromOp( // Checks if an op should be moved after a cluster. There may be users of a // cluster interleaved among the cluster ops. bool ShouldMoveOpAfterCluster( - Block* block, Operation* op, - const llvm::SmallSetVector& cluster_ops, + Block* block, Operation* op, const ClusterOps& cluster_ops, const llvm::SmallSetVector& preceding_users, const TF::ResourceAliasAnalysis::Info& resource_alias_analysis, const llvm::SmallDenseSet& observed_resource_ids) { - auto result = op->walk([&](Operation* op) { - for (Value operand : op->getOperands()) { + const bool is_replicate = llvm::isa(op); + auto result = op->walk([&](Operation* inner_op) { + for (Value operand : inner_op->getOperands()) { Operation* def = operand.getDefiningOp(); // Operands may not have a defining op (BlockArgument) or is from a // different block. @@ -183,8 +187,13 @@ bool ShouldMoveOpAfterCluster( } } + // Don't visit replicate op inner op operands as new resource + // values/arguments may have been created but are not known in + // `resource_alias_analysis`. + if (is_replicate && inner_op != op) return WalkResult::advance(); + // Check for uses of any resource in or after cluster. - for (Value operand : TF::filter_resources(op->getOperands())) { + for (Value operand : TF::filter_resources(inner_op->getOperands())) { if (resource_alias_analysis.IsUnknownResource(operand)) continue; auto ids = resource_alias_analysis.GetResourceUniqueIds(operand); for (const auto& id : ids) @@ -204,13 +213,14 @@ bool ShouldMoveOpAfterCluster( // TODO(lyandy): Extend this to handle all side effecting ops while handling // transitive data dependencies. llvm::SmallSetVector CollectClusterPrecedingUsers( - Block* block, const llvm::SmallSetVector& cluster_ops, + Block* block, const ClusterOps& cluster_ops, const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) { llvm::SmallSetVector preceding_users; llvm::SmallDenseSet observed_resource_ids; - for (Operation& op : llvm::make_range(Block::iterator(cluster_ops.front()), - Block::iterator(cluster_ops.back()))) { + auto front = Block::iterator(cluster_ops.front()); + auto back = Block::iterator(cluster_ops.back()); + for (Operation& op : llvm::make_range(front, back)) { if (cluster_ops.contains(&op)) { CollectResourceIdsFromOp(op, resource_alias_analysis, observed_resource_ids); @@ -232,7 +242,7 @@ llvm::SmallSetVector CollectClusterPrecedingUsers( // outside of the cluster (i.e. results of ops in the cluster are only consumed // by other ops in the cluster) are pruned. llvm::SmallVector CollectClusterResults( - Block* block, const llvm::SmallSetVector& cluster_ops) { + Block* block, const ClusterOps& cluster_ops) { llvm::SmallVector results; for (Operation* op : cluster_ops) { @@ -251,61 +261,52 @@ llvm::SmallVector CollectClusterResults( } // Creates a `tf_device.cluster` to wrap cluster ops. -tf_device::ClusterOp CreateOpForCluster(Operation* last_cluster_op, - llvm::ArrayRef results) { +tf_device::ClusterOp CreateClusterOp( + Block* block, const ClusterOps& cluster_ops, llvm::ArrayRef results, + llvm::ArrayRef preceding_users) { // `tf_device.cluster` will be placed at where the last op of the cluster is. + Operation* last_cluster_op = cluster_ops.back(); OpBuilder builder(last_cluster_op); llvm::SmallVector result_types; for (Value result : results) result_types.push_back(result.getType()); - auto cluster = builder.create(last_cluster_op->getLoc(), result_types); - cluster.body().push_back(new Block); + Block* body = new Block; + cluster.body().push_back(body); + + // Move cluster ops to the cluster body. Also remove `_tpu_replicate` and + // `device` attribute from ops in the cluster as that information will be + // present in the `tf_device.cluster`. Do this for all ops including nested + // ops. + for (Operation* cluster_op : cluster_ops) { + cluster_op->moveBefore(body, body->end()); + cluster_op->walk([&](Operation* inner_op) { + inner_op->removeAttr(kTPUReplicateAttr); + inner_op->removeAttr(kDeviceAttr); + }); + } // Add terminator. - builder.setInsertionPointToEnd(&cluster.GetBody()); + builder.setInsertionPointToEnd(body); builder.create(last_cluster_op->getLoc(), results); - return cluster; -} - -// Moves cluster ops to associated `tf_device.cluster` body. -void MoveClusterOpsToCluster( - tf_device::ClusterOp cluster, - const llvm::SmallSetVector& cluster_ops) { - MLIRContext* context = cluster.getContext(); - Operation* terminator = cluster.GetBody().getTerminator(); - - for (Operation* cluster_op : cluster_ops) { - // Remove `_tpu_replicate` and `device` attribute from ops in the cluster - // as that information will be present in the `tf_device.cluster`. - cluster_op->removeAttr(Identifier::get(kTPUReplicateAttr, context)); - cluster_op->removeAttr(Identifier::get(kDeviceAttr, context)); - cluster_op->moveBefore(terminator); - } -} - -// Replaces uses of cluster ops results outside of cluster with the associated -// `tf_device.cluster` results. -void UpdateClusterResultExternalUses(tf_device::ClusterOp cluster, - llvm::ArrayRef results) { - Block& cluster_block = cluster.GetBody(); + // Replaces uses of cluster ops results outside of cluster with the associated + // `tf_device.cluster` results. for (auto ret_vals : llvm::zip(results, cluster.getResults())) { Value old_ret = std::get<0>(ret_vals); Value new_ret = std::get<1>(ret_vals); - for (auto& use : llvm::make_early_inc_range(old_ret.getUses())) - if (!cluster_block.findAncestorOpInBlock(*use.getOwner())) - use.set(new_ret); + for (auto& use : llvm::make_early_inc_range(old_ret.getUses())) { + Operation* user = use.getOwner(); + if (!body->findAncestorOpInBlock(*user)) use.set(new_ret); + } } -} -// Moves users of cluster that are before the cluster to after the cluster. -void MovePrecedingClusterUsers(tf_device::ClusterOp cluster, - llvm::ArrayRef preceding_users) { + // Move users of cluster that are before the cluster to after the cluster. Operation* op_after_cluster = cluster.getOperation()->getNextNode(); for (Operation* user : preceding_users) user->moveBefore(op_after_cluster); + return cluster; } // Sorts `tf.TPUReplicatedInput` ops by `index` attribute. Ops with an `index` @@ -318,8 +319,7 @@ LogicalResult SortTPUReplicatedInputsByIndex( llvm::SmallVectorImpl* sorted_inputs) { llvm::SmallDenseSet unique_indices; for (Operation* input : inputs) { - int64_t index = - llvm::cast(input).index().getSExtValue(); + int64_t index = llvm::cast(input).index(); if (index < -1) return input->emitOpError() << "requires index to be at least -1, but got " << index; @@ -338,10 +338,8 @@ LogicalResult SortTPUReplicatedInputsByIndex( std::stable_sort( sorted_inputs->begin(), sorted_inputs->end(), [](Operation* l, Operation* r) { - int64_t l_index = - llvm::cast(l).index().getSExtValue(); - int64_t r_index = - llvm::cast(r).index().getSExtValue(); + int64_t l_index = llvm::cast(l).index(); + int64_t r_index = llvm::cast(r).index(); if (l_index == -1 && r_index != -1) return false; if (r_index == -1 && l_index != -1) return true; return l_index < r_index; @@ -385,8 +383,7 @@ LogicalResult ReplicateCluster(tf_device::ClusterOp cluster, int num_replicas) { // Check if number of operands of each used TPUReplicatedInput op matches // `num_replicas` or 1. Collect all their operands and associated type for // creating the replicate op. - llvm::SmallVector, 8> - replicated_inputs; + llvm::SmallVector, 8> replicated_inputs; llvm::SmallVector packed_inputs; for (auto& pos_and_input : llvm::enumerate(replicated_input_ops)) { auto input = pos_and_input.value(); @@ -397,8 +394,7 @@ LogicalResult ReplicateCluster(tf_device::ClusterOp cluster, int num_replicas) { return input->emitOpError() << "requires " << num_inputs << " operands"; auto tpu_replicated_input = llvm::cast(input); - int64_t tpu_replicated_input_index = - tpu_replicated_input.index().getSExtValue(); + int64_t tpu_replicated_input_index = tpu_replicated_input.index(); if (is_packed) { packed_inputs.push_back(input->getOperand(0)); packed_input_indices.push_back(tpu_replicated_input_index); @@ -434,20 +430,24 @@ LogicalResult ReplicateCluster(tf_device::ClusterOp cluster, int num_replicas) { for (auto result_and_idx : llvm::enumerate(cluster.getResults())) { Value result = result_and_idx.value(); int idx = result_and_idx.index(); - for (auto& use : result.getUses()) { - Operation* def = use.getOwner(); - if (!def || !llvm::isa(def)) - return cluster.emitError() - << "requires output of " << cluster.getOperationName() - << " to lead to a 'tf.TPUReplicatedOutput' op"; + auto replicate_outputs = llvm::make_range( + std::next(replicate_op.result_begin(), idx * num_replicas), + std::next(replicate_op.result_begin(), (idx + 1) * num_replicas)); - const int def_NumResults = def->getNumResults(); - if (def_NumResults != num_replicas) + for (auto& use : llvm::make_early_inc_range(result.getUses())) { + Operation* def = use.getOwner(); + if (!llvm::isa(def)) { + // If user is not a `tf.TPUReplicatedOutput`, simply forward the first + // replica output. Certain Graphs under V1 create `tf.Identity` users of + // replicated ops to pin the TPU computation for execution. + use.set(*replicate_outputs.begin()); + continue; + } + + const int def_num_results = def->getNumResults(); + if (def_num_results != num_replicas) return def->emitOpError() << "requires " << num_replicas << " results"; - auto replicate_outputs = llvm::make_range( - std::next(replicate_op.result_begin(), idx * num_replicas), - std::next(replicate_op.result_begin(), (idx + 1) * num_replicas)); def->replaceAllUsesWith(replicate_outputs); } } @@ -490,10 +490,29 @@ LogicalResult ReplicateCluster(tf_device::ClusterOp cluster, int num_replicas) { // attribute `num_replicas` is greater than 1. // 9. Copy over TPUReplicateMetadata attributes to `tf_device.cluster`. LogicalResult FormClustersInBlock( - Block* block, const MetadataMap& metadata_map, + Block* block, const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) { + MetadataMap metadata_map; + LogicalResult result = CollectMetadata(block, &metadata_map); + if (failed(result)) return result; + + // If there is no TPUReplicateMetadata op in this block, process blocks in + // regions attached to the op's in the block. + if (metadata_map.empty()) { + for (Operation& op : *block) { + for (Region& region : op.getRegions()) { + if (!llvm::hasSingleElement(region)) + return op.emitOpError("Expected single block region"); + if (failed( + FormClustersInBlock(®ion.front(), resource_alias_analysis))) + return failure(); + } + } + return success(); + } + ClusterMap clusters; - LogicalResult result = CollectAndGroupClusterOps(block, &clusters); + result = CollectAndGroupClusterOps(block, &clusters); if (failed(result)) return result; for (const auto& cluster_metadata_and_ops : clusters) { @@ -518,14 +537,8 @@ LogicalResult FormClustersInBlock( llvm::SmallVector results = CollectClusterResults(block, cluster_ops); - tf_device::ClusterOp cluster = - CreateOpForCluster(cluster_ops.back(), results); - - MoveClusterOpsToCluster(cluster, cluster_ops); - - UpdateClusterResultExternalUses(cluster, results); - - MovePrecedingClusterUsers(cluster, preceding_users.getArrayRef()); + tf_device::ClusterOp cluster = CreateClusterOp( + block, cluster_ops, results, preceding_users.getArrayRef()); auto num_replicas = cluster_metadata->getSecond().get(kNumReplicasAttr); if (!num_replicas || !num_replicas.isa()) @@ -548,13 +561,13 @@ LogicalResult FormClustersInBlock( void TPUClusterFormation::runOnFunction( FuncOp func, const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) { - MetadataMap metadata_map; - if (failed(CollectMetadata(func, &metadata_map))) return signalPassFailure(); + if (!llvm::hasSingleElement(func)) { + func.emitOpError("Expecting a single block function"); + return signalPassFailure(); + } - for (Block& block : func) - if (failed( - FormClustersInBlock(&block, metadata_map, resource_alias_analysis))) - return signalPassFailure(); + if (failed(FormClustersInBlock(&func.front(), resource_alias_analysis))) + return signalPassFailure(); // Remove TPUReplicatedInput and TPUReplicatedOutput nodes. auto remove_result = func.walk([&](Operation* op) { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_colocate_composite_resource_ops.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_colocate_composite_resource_ops.cc new file mode 100644 index 00000000000..b4889f6e52c --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_colocate_composite_resource_ops.cc @@ -0,0 +1,137 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h" + +namespace mlir { +namespace TFTPU { +namespace { + +// Pass that co-locates resource ops that use composite device resources +// (packed tensors) with the underlying physical TPU device. +struct TPUColocateCompositeResourceOps + : public PassWrapper { + void runOnFunction() override; +}; + +// Wraps single op in `tf_device.launch` for explicit device assignment. +void WrapOpInLaunch(OpBuilder* builder, Location loc, Operation* op, + llvm::StringRef device) { + builder->setInsertionPoint(op); + auto launch = builder->create( + loc, builder->getStringAttr(device), op->getResultTypes()); + launch.body().push_back(new Block); + op->replaceAllUsesWith(launch); + + builder->setInsertionPointToEnd(&launch.GetBody()); + builder->create(loc, op->getResults()); + + // Move op inside cluster. + op->moveBefore(launch.GetBody().getTerminator()); +} + +llvm::SmallVector GetResourceOpsUsingCompositeArgsInReplicate( + tf_device::ReplicateOp replicate) { + llvm::SmallVector resource_users; + const auto add_resource_op_to_list = [&resource_users](Operation* op) { + if (!llvm::isa(op)) return; + + resource_users.emplace_back(op); + }; + + llvm::SmallVector resource_users_to_visit; + for (auto composite_arguments : replicate.GetPackedBlockArguments()) { + for (auto resource_user : composite_arguments.getUsers()) + resource_users_to_visit.emplace_back(resource_user); + } + + while (!resource_users_to_visit.empty()) { + llvm::SmallVector new_resource_users; + + for (auto resource_user : resource_users_to_visit) { + add_resource_op_to_list(resource_user); + + // Account for pass-through identity ops. + if (auto pass_through_identity = + llvm::dyn_cast(resource_user)) { + for (auto identity_user : pass_through_identity.output().getUsers()) { + new_resource_users.emplace_back(identity_user); + } + } + } + resource_users_to_visit.swap(new_resource_users); + } + + return resource_users; +} + +void ColocateCompositeResourceOpsInReplicate( + tf_device::ReplicateOp replicate_op, OpBuilder* builder) { + auto devices = replicate_op.devices(); + if (!devices) return; + if (!devices.getValue().get(tensorflow::GetDeviceAliasForLogicalCore(0))) + return; + + const auto composite_resource_users = + GetResourceOpsUsingCompositeArgsInReplicate(replicate_op); + for (auto resource_user : composite_resource_users) { + WrapOpInLaunch(builder, resource_user->getLoc(), resource_user, + tensorflow::GetDeviceAliasForLogicalCore(0)); + } +} + +void TPUColocateCompositeResourceOps::runOnFunction() { + // Find all the executes first, since we will mutate the nodes around each + // execute in the same tf_device.replicate op. + llvm::SmallVector execute_launches; + getFunction().walk([&](tf_device::LaunchOp op) { + if (op.WrapsSingleOp() && + llvm::isa( + op.GetBody().front())) + execute_launches.push_back(op); + }); + + OpBuilder builder(&getContext()); + for (auto execute_launch : execute_launches) { + auto replicate = execute_launch.getParentOfType(); + if (!replicate) continue; + + ColocateCompositeResourceOpsInReplicate(replicate, &builder); + } +} + +} // namespace + +std::unique_ptr> CreateTPUColocateCompositeResourceOps() { + return std::make_unique(); +} + +static PassRegistration pass( + "tf-tpu-colocate-composite-resource-ops", + "Colocate resource with composite device assignment to TPU device."); + +} // namespace TFTPU +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_layout_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_layout_pass.cc index 41362465cd9..59f36e03fbb 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_layout_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_layout_pass.cc @@ -185,7 +185,7 @@ bool HandleReplicatedInputs( const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) { // We need to know the devices to copy to. if (!replicate.devices()) return false; - int64_t num_replicas = replicate.n().getZExtValue(); + int64_t num_replicas = replicate.n(); auto inputs = replicate.getOperands() .drop_front(replicate_arg_index * num_replicas) .take_front(num_replicas); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc index fed4002bfcf..6e106b278fe 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc @@ -23,10 +23,10 @@ limitations under the License. #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" -#include "llvm/Support/FormatVariadic.h" #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Block.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project @@ -34,7 +34,9 @@ limitations under the License. #include "mlir/Pass/PassRegistry.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/RegionUtils.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h" @@ -113,12 +115,23 @@ tf_device::LaunchOp CreateLaunchForBlock(OpBuilder* builder, Operation* op, return launch; } +// Checks if an operation is a supported TPU embedding op. +bool IsEmbeddingOp(Operation* op) { + return isa(op); +} + // Returns a set of ops that are outside compiled and can be extracted to before // the TPU computation. These ops are either connected to the inputs of the TPU // computation or other ops that can be extracted, and have no operands from // other ops in the TPU computation that cannot be extracted. llvm::SmallVector FindOutsideCompiledOpsAtHead( + const TF::SideEffectAnalysis& side_effect_analysis, tf_device::ClusterOp cluster) { + const auto& analysis = side_effect_analysis.GetAnalysisForFunc( + cluster.getParentOfType()); Region* cluster_region = &cluster.body(); llvm::SmallSetVector head_outside_compiled_ops; @@ -127,6 +140,24 @@ llvm::SmallVector FindOutsideCompiledOpsAtHead( if (!HasOutsideCompilationAttribute(&cluster_op)) continue; // An outside compiled op can be extracted if its operands are not from // other ops in the cluster that cannot be extracted. + + // Check if the side effecting op right before this side effecting op, if + // it is side effecting, can be head extracted. Because of op ordering due + // to side effects, if this is not true, this op cannot be head extracted. + // TODO(lyandy): Remove special handling of embedding ops. Currently the IR + // is in a topological sort order and depending on that ordering, embedding + // ops may prevent other ops from being head extracted. + auto predecessors = analysis.DirectControlPredecessors(&cluster_op); + if (!predecessors.empty() && !IsEmbeddingOp(&cluster_op)) { + bool skip = false; + for (Operation* predecessor : llvm::reverse(predecessors)) { + if (IsEmbeddingOp(predecessor)) continue; + skip = !head_outside_compiled_ops.contains(predecessor); + break; + } + if (skip) continue; + } + auto walk_result = cluster_op.walk([&](Operation* op) { for (Value operand : op->getOperands()) { Operation* operand_op = GetOpOfValue(operand); @@ -168,11 +199,11 @@ void CreateHeadComputation(OpBuilder* builder, tf_device::ClusterOp cluster, // Extracts and move outside compiled ops that have no dependencies in the // cluster to before the cluster. mlir::LogicalResult LiftHeadOutsideCompiledOps( - OpBuilder* builder, const mlir::TF::RuntimeDevices& devices, - tf_device::ClusterOp cluster, std::string* host_device, - bool* cluster_updated) { + OpBuilder* builder, const TF::SideEffectAnalysis& side_effect_analysis, + const mlir::TF::RuntimeDevices& devices, tf_device::ClusterOp cluster, + std::string* host_device, bool* cluster_updated) { llvm::SmallVector head_outside_compiled_ops = - FindOutsideCompiledOpsAtHead(cluster); + FindOutsideCompiledOpsAtHead(side_effect_analysis, cluster); if (head_outside_compiled_ops.empty()) return success(); if (failed(tensorflow::GetHostDeviceOutsideComputation(devices, cluster, host_device))) @@ -191,9 +222,12 @@ mlir::LogicalResult LiftHeadOutsideCompiledOps( // TPU computation or other ops that can be extracted, and have no results used // by other ops in the TPU computation that cannot be extracted. void FindOutsideCompiledOpsAtTailAndClusterResults( + const TF::SideEffectAnalysis& side_effect_analysis, tf_device::ClusterOp cluster, llvm::SmallVectorImpl* tail_outside_compiled_ops, llvm::SmallVectorImpl* cluster_results) { + const auto& analysis = side_effect_analysis.GetAnalysisForFunc( + cluster.getParentOfType()); Region* cluster_region = &cluster.body(); llvm::SmallSetVector tail_outside_compiled_ops_set; Operation* terminator = cluster.GetBody().getTerminator(); @@ -205,6 +239,24 @@ void FindOutsideCompiledOpsAtTailAndClusterResults( for (Operation& cluster_op : cluster_ops) { if (!HasOutsideCompilationAttribute(&cluster_op)) continue; + // Check if the side effecting op right after this side effecting op, if + // it is side effecting, can be tail extracted. Because of op ordering due + // to side effects, if this is not true, this op cannot be tail extracted. + // TODO(lyandy): Remove special handling of embedding ops. Currently the IR + // is in a topological sort order and depending on that ordering, embedding + // ops may prevent other ops from being tail extracted. + auto successors = analysis.DirectControlSuccessors( + &cluster_op, [&terminator](Operation* op) { return op != terminator; }); + if (!successors.empty() && !IsEmbeddingOp(&cluster_op)) { + bool skip = false; + for (Operation* successor : successors) { + if (IsEmbeddingOp(successor)) continue; + skip = !tail_outside_compiled_ops_set.contains(successor); + break; + } + if (skip) continue; + } + llvm::SmallVector results_to_forward; bool can_be_extracted = llvm::all_of(cluster_op.getUsers(), [&](Operation* op) { @@ -293,13 +345,14 @@ tf_device::ClusterOp UpdateClusterResults( // Extracts and move outside compiled ops that do not create dependencies in the // cluster to after the cluster. mlir::LogicalResult LiftTailOutsideCompiledOps( - OpBuilder* builder, const mlir::TF::RuntimeDevices& devices, - std::string host_device, tf_device::ClusterOp* cluster, - bool* cluster_updated) { + OpBuilder* builder, const TF::SideEffectAnalysis& side_effect_analysis, + const mlir::TF::RuntimeDevices& devices, std::string host_device, + tf_device::ClusterOp* cluster, bool* cluster_updated) { llvm::SmallVector tail_outside_compiled_ops; llvm::SmallVector cluster_results; - FindOutsideCompiledOpsAtTailAndClusterResults( - *cluster, &tail_outside_compiled_ops, &cluster_results); + FindOutsideCompiledOpsAtTailAndClusterResults(side_effect_analysis, *cluster, + &tail_outside_compiled_ops, + &cluster_results); if (tail_outside_compiled_ops.empty()) return success(); if (host_device.empty()) @@ -365,6 +418,7 @@ struct TPUExtractHeadTailOutsideCompilation }; void TPUExtractHeadTailOutsideCompilation::runOnOperation() { + auto& side_effect_analysis = getAnalysis(); // Get runtime devices information from the closest parent module. auto module = getOperation(); mlir::TF::RuntimeDevices devices; @@ -379,10 +433,12 @@ void TPUExtractHeadTailOutsideCompilation::runOnOperation() { for (tf_device::ClusterOp cluster : clusters) { std::string host_device; bool cluster_updated = false; - if (failed(LiftHeadOutsideCompiledOps(&builder, devices, cluster, - &host_device, &cluster_updated)) || - failed(LiftTailOutsideCompiledOps(&builder, devices, host_device, - &cluster, &cluster_updated))) + if (failed(LiftHeadOutsideCompiledOps(&builder, side_effect_analysis, + devices, cluster, &host_device, + &cluster_updated)) || + failed(LiftTailOutsideCompiledOps(&builder, side_effect_analysis, + devices, host_device, &cluster, + &cluster_updated))) return signalPassFailure(); if (cluster_updated) RemoveClusterAliasedOutputs(&builder, cluster); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc index b141a7dc792..65490716cf0 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc @@ -88,22 +88,30 @@ struct TPUExtractOutsideCompilation }; // Holds information about control flow operations that wrap outside compiled -// op. Currently only tf.If op is supported. +// op. Currently only tf.IfRegion and tf.WhileRegion ops are supported. class ControlFlowStackInfo { public: - enum ControlFlowBranchType { kIfThen, kIfElse }; + enum ControlFlowBranchType { kIfThen, kIfElse, kWhileCond, kWhileBody }; explicit ControlFlowStackInfo(Operation* wrapping_op, Operation* nested_op) : callsite_op_(wrapping_op) { - // Only tf.IfRegion op is supported for now. - auto control_flow_op = llvm::cast(callsite_op_); - assert(control_flow_op); - - auto parent_region = nested_op->getParentRegion(); - if (&control_flow_op.then_branch() == parent_region) { - type_ = ControlFlowBranchType::kIfThen; + if (auto control_flow_op = llvm::dyn_cast(callsite_op_)) { + auto parent_region = nested_op->getParentRegion(); + if (&control_flow_op.then_branch() == parent_region) { + type_ = ControlFlowBranchType::kIfThen; + } else { + type_ = ControlFlowBranchType::kIfElse; + } + } else if (auto control_flow_op = + llvm::dyn_cast(callsite_op_)) { + auto parent_region = nested_op->getParentRegion(); + if (&control_flow_op.cond() == parent_region) { + type_ = ControlFlowBranchType::kWhileCond; + } else { + type_ = ControlFlowBranchType::kWhileBody; + } } else { - type_ = ControlFlowBranchType::kIfElse; + assert(false); } } @@ -116,6 +124,10 @@ class ControlFlowStackInfo { Operation* GetCallSiteOp() const { return callsite_op_; } + bool operator==(const ControlFlowStackInfo& other) const { + return type_ == other.type_ && callsite_op_ == other.callsite_op_; + } + private: ControlFlowBranchType type_; @@ -133,7 +145,7 @@ llvm::SmallVector GetControlFlowStackForOp( Operation* op_in_stack = op; while (op_in_stack != tpu_cluster.getOperation()) { auto parent_op = op_in_stack->getParentOp(); - if (llvm::isa(parent_op)) { + if (llvm::isa(parent_op)) { controlflow_stack.insert(controlflow_stack.begin(), ControlFlowStackInfo(parent_op, op_in_stack)); } @@ -166,7 +178,7 @@ TF::IfRegionOp CloneEmptyIfWithPredicate(Value predicate, bool is_stateless, // Replicates tf.IfRegion op to host side computation. Operation* ReplicateIf(const ControlFlowStackInfo& controlflow_info, - llvm::StringRef outside_cluster_name, ModuleOp module, + llvm::StringRef outside_cluster_name, Value compilation_key, OpBuilder* builder, int* send_recv_counter) { // Create XlaSendToHostOp to send predicate value from device to host. @@ -200,6 +212,64 @@ Operation* ReplicateIf(const ControlFlowStackInfo& controlflow_info, if_callsite_op.getLoc(), builder); } +// Creates a WhileRegionOp cond and body regions with yield op and +// an empty body. +TF::WhileRegionOp CloneEmptyWhile(bool is_stateless, + uint64_t parallel_iterations, Location loc, + OpBuilder* builder) { + auto host_side_while = builder->create( + loc, /*output=*/ArrayRef{}, /*input=*/ArrayRef{}, + is_stateless, parallel_iterations); + + // Create empty else branch region. + auto& body = host_side_while.body(); + body.push_back(new Block); + builder->setInsertionPointToEnd(&body.front()); + builder->create(loc, /*operands=*/ArrayRef{}); + return host_side_while; +} + +// Replicates tf.WhileRegion op to host side computation. +Operation* ReplicateWhile(const ControlFlowStackInfo& controlflow_info, + llvm::StringRef outside_cluster_name, + Value compilation_key, OpBuilder* builder, + int* send_recv_counter) { + // Create XlaSendToHostOp to send cond region output from device to host. + OpBuilder::InsertPoint insert_point = builder->saveInsertionPoint(); + auto while_callsite_op = + llvm::cast(controlflow_info.GetCallSiteOp()); + builder->setInsertionPoint(while_callsite_op.cond().front().getTerminator()); + + const auto condition_send_recv_key = + llvm::formatv("while_condition_channel_{0}_{1}", outside_cluster_name, + *send_recv_counter) + .str(); + *send_recv_counter += 1; + auto condition = + while_callsite_op.cond().front().getTerminator()->getOperand(0); + builder->create(while_callsite_op.getLoc(), condition, + condition_send_recv_key); + builder->restoreInsertionPoint(insert_point); + + auto host_side_while = CloneEmptyWhile( + while_callsite_op.is_stateless(), while_callsite_op.parallel_iterations(), + while_callsite_op.getLoc(), builder); + + // Create cond region and yield the condition from the device. + auto& cond = host_side_while.cond(); + cond.push_back(new Block); + builder->setInsertionPointToEnd(&cond.front()); + auto recv_condition_at_host = builder->create( + while_callsite_op.getLoc(), llvm::ArrayRef{condition.getType()}, + /*dynamic_key=*/compilation_key, + builder->getStringAttr(condition_send_recv_key), + /*device_ordinal=*/builder->getI64IntegerAttr(0)); + builder->create(while_callsite_op.getLoc(), + recv_condition_at_host.getResults()); + + return host_side_while; +} + // TODO(b/157054714): Use a better abstraction instead of // _TPUCompileMlirOp and _XlaRecvAtHostOp and _XlaSendFromHostOp. // Creates a compilation key as placeholder. A placeholder compilation cache key @@ -214,45 +284,97 @@ Value CreateCompilationKeyPlaceholder(Location loc, OpBuilder* builder) { loc, /*program=*/result_type, llvm::ArrayRef{}); } +// Retrieves terminator of branch specified by `control_flow_info` of replicated +// control flow op. +Operation* GetControlFlowBranchRegionTerminator( + const ControlFlowStackInfo& controlflow_info, Operation* controlflow_op) { + if (auto inner_most_if = llvm::dyn_cast(controlflow_op)) { + if (controlflow_info.GetBranchType() == ControlFlowStackInfo::kIfThen) { + return inner_most_if.then_branch().front().getTerminator(); + } else { + return inner_most_if.else_branch().front().getTerminator(); + } + } else if (auto inner_most_while = + llvm::dyn_cast(controlflow_op)) { + if (controlflow_info.GetBranchType() == ControlFlowStackInfo::kWhileCond) { + return &inner_most_while.cond().front().front(); + } else { + return inner_most_while.body().front().getTerminator(); + } + } + assert(false); + return nullptr; +} + // Replicates the control flow operations that wraps outside compiled ops to // `destination_block`. -Block* ReplicateControlFlowStack( +Operation* GetOrReplicateControlFlowStack( llvm::StringRef outside_cluster_name, const llvm::SmallVectorImpl& stack_info, tf_device::ClusterOp tpu_cluster, ModuleOp module, Value compilation_key, - Block* destination_block, int* send_recv_counter) { - assert(stack_info.size()); + Block* destination_block, int* send_recv_counter, + llvm::SmallDenseMap* replicated_controlflow_map) { + assert(!stack_info.empty()); + const auto& controlflow_info = stack_info.back(); + auto it = replicated_controlflow_map->find(controlflow_info.GetCallSiteOp()); + if (it != replicated_controlflow_map->end()) + return GetControlFlowBranchRegionTerminator(controlflow_info, it->second); + OpBuilder builder = OpBuilder::atBlockTerminator(destination_block); Operation* previous_replicated_controlflow_op = nullptr; for (const auto& controlflow_stack_info : stack_info) { + // If controlflow operation has already been created, reuse the cached + // controlflow operation. + auto it = replicated_controlflow_map->find( + controlflow_stack_info.GetCallSiteOp()); + if (it != replicated_controlflow_map->end()) { + previous_replicated_controlflow_op = it->second; + builder.setInsertionPoint(GetControlFlowBranchRegionTerminator( + controlflow_stack_info, previous_replicated_controlflow_op)); + continue; + } + // Create control flow op given provided insertion point and // ControlFlowStackInfo. - previous_replicated_controlflow_op = - ReplicateIf(controlflow_stack_info, outside_cluster_name, module, - compilation_key, &builder, send_recv_counter); - auto if_op = llvm::cast(previous_replicated_controlflow_op); - auto type = controlflow_stack_info.GetBranchType(); + if (auto control_flow_op = llvm::dyn_cast( + controlflow_stack_info.GetCallSiteOp())) { + previous_replicated_controlflow_op = + ReplicateIf(controlflow_stack_info, outside_cluster_name, + compilation_key, &builder, send_recv_counter); + auto if_op = + llvm::cast(previous_replicated_controlflow_op); + auto type = controlflow_stack_info.GetBranchType(); - // Update the insertion point to proper region inside the newly created - // control flow op. - if (type == ControlFlowStackInfo::kIfThen) { - builder.setInsertionPoint(&if_op.then_branch().front().front()); - } else { - builder.setInsertionPoint(&if_op.else_branch().front().front()); + // Update the insertion point to proper region inside the newly created + // control flow op. + if (type == ControlFlowStackInfo::kIfThen) { + builder.setInsertionPoint(&if_op.then_branch().front().front()); + } else { + builder.setInsertionPoint(&if_op.else_branch().front().front()); + } + } else if (auto control_flow_op = llvm::dyn_cast( + controlflow_stack_info.GetCallSiteOp())) { + previous_replicated_controlflow_op = + ReplicateWhile(controlflow_stack_info, outside_cluster_name, + compilation_key, &builder, send_recv_counter); + auto while_op = + llvm::cast(previous_replicated_controlflow_op); + auto type = controlflow_stack_info.GetBranchType(); + if (type == ControlFlowStackInfo::kWhileCond) { + builder.setInsertionPoint(&while_op.cond().front().front()); + } else { + builder.setInsertionPoint(&while_op.body().front().front()); + } } } - // Return the inner most branch at which outside compiled op is located. - // This block will later be used as insertion point to create send/recv ops. - auto inner_most_controlflow_stack = stack_info.back(); - auto inner_most_if = - llvm::cast(previous_replicated_controlflow_op); - if (inner_most_controlflow_stack.GetBranchType() == - ControlFlowStackInfo::kIfThen) { - return &inner_most_if.then_branch().front(); - } else { - return &inner_most_if.else_branch().front(); - } + replicated_controlflow_map->try_emplace(stack_info.back().GetCallSiteOp(), + previous_replicated_controlflow_op); + + // Return operation which should be used to as the insertion point to create + // send/recv ops. + return GetControlFlowBranchRegionTerminator( + stack_info.back(), previous_replicated_controlflow_op); } // Collects and clusters ops in `block` with the same `_xla_outside_compilation` @@ -279,18 +401,17 @@ LogicalResult CollectAndGroupOutsideClusterOps(Block* block, return failure(walk_result.wasInterrupted()); } -// Moves `cluster_ops` to associated `block`. -void MoveOutsideClusterOpsToBlock(Block& block, - llvm::ArrayRef cluster_ops, - MLIRContext* context) { - Operation* terminator = block.getTerminator(); +// Moves `cluster_ops` before `op`. +void MoveOutsideClusterOpsBeforeOp(Operation* op, + llvm::ArrayRef cluster_ops, + MLIRContext* context) { for (Operation* cluster_op : cluster_ops) { // Remove `_xla_outside_compilation` and `device` attribute from ops in the // cluster as that information will be present in the `launch_op`. cluster_op->removeAttr( Identifier::get(kXlaOutsideCompilationAttr, context)); cluster_op->removeAttr(Identifier::get(kDeviceAttr, context)); - cluster_op->moveBefore(terminator); + cluster_op->moveBefore(op); } } @@ -330,19 +451,46 @@ llvm::SmallSetVector GetExternalOperands( // in `host_cluster_ops`. for (Value v : op->getOperands()) { Operation* defining_op = v.getDefiningOp(); - if (!defining_op) continue; - bool is_external = llvm::none_of( - host_cluster_ops, - [&](Operation* cluster_op) { return defining_op == cluster_op; }); + bool is_external = false; + if (defining_op) { + if (!tpu_cluster.getOperation()->isAncestor(defining_op)) continue; + is_external = + llvm::none_of(host_cluster_ops, [&](Operation* cluster_op) { + return defining_op == cluster_op; + }); + } else { + if (auto block_arg = v.dyn_cast()) { + if (block_arg.getParentRegion() == cluster_op_parent_region) + is_external = true; + } + } if (is_external) external_values.insert(v); } } else { llvm::SetVector external_captured_inputs; visitUsedValuesDefinedAbove(*region, *region, [&](OpOperand* operand) { - Region* parent_region = operand->get().getParentRegion(); - if (!tpu_cluster.body().isAncestor(parent_region)) return; + const bool captured_value_from_host = + llvm::find(host_cluster_ops, operand->get().getDefiningOp()) != + host_cluster_ops.end(); + if (captured_value_from_host) return; + Region* operand_defined_region = operand->get().getParentRegion(); + if (!tpu_cluster.body().isAncestor(operand_defined_region)) return; + // If the host_cluster_op is regional control flow (if, while), + // then check if the operand_defined_region is an ancestor of the + // control flow regions. + if (auto if_op = llvm::dyn_cast(host_cluster_op)) { + if (if_op.then_branch().isAncestor(operand_defined_region) || + if_op.else_branch().isAncestor(operand_defined_region)) + return; + } + if (auto while_op = + llvm::dyn_cast(host_cluster_op)) { + if (while_op.cond().isAncestor(operand_defined_region) || + while_op.body().isAncestor(operand_defined_region)) + return; + } external_captured_inputs.insert(operand->get()); }); external_values.insert(external_captured_inputs.begin(), @@ -355,15 +503,21 @@ llvm::SmallSetVector GetExternalOperands( } // Extracts all externally used outputs of `cluster_ops`. -llvm::SmallVector GetExternalOutputs( +llvm::SmallSetVector GetExternalOutputs( llvm::ArrayRef cluster_ops) { llvm::SmallSetVector external_outputs; + llvm::SmallPtrSet host_cluster_ops_set; + for (auto op : cluster_ops) { + op->walk([&](Operation* host_cluster_op) { + host_cluster_ops_set.insert(host_cluster_op); + }); + } for (Operation* op : cluster_ops) { for (Operation* user : op->getUsers()) { - bool is_external = llvm::none_of(cluster_ops, [&](Operation* cluster_op) { - return user == cluster_op; - }); + bool is_external = llvm::none_of( + host_cluster_ops_set, + [&](Operation* cluster_op) { return user == cluster_op; }); if (!is_external) continue; for (Value v : user->getOperands()) { if (v.getDefiningOp() == op) external_outputs.insert(v); @@ -371,7 +525,7 @@ llvm::SmallVector GetExternalOutputs( } } - return external_outputs.takeVector(); + return external_outputs; } // Sets the insertion point on `builder` for HostCompute op. Sets insertion @@ -390,6 +544,12 @@ void SetHostComputeInsertion( } } } + + // If no operand usage can be found, this means that external input is + // implicitly captured inputs for ops inside internal regions of one of the + // `cluster_ops`. In that case, set the insertion point to the last op of the + // `cluster_ops` in the IR. + builder->setInsertionPoint(cluster_ops.back()); } // Creates the HostCompute with `inputs` and `outputs` @@ -412,53 +572,62 @@ TF::_XlaHostComputeMlirOp CreateHostCompute( return host_compute; } -void MoveOutsideCompiledOps( +// Represents a set of ops inside host computation that is wrapped inside the +// same control flow. +struct HostCluster { + // List of control flow that wraps host computation operations. May be empty. + llvm::SmallVector controlflow_stack; + + // Set of operations that will run on host wrapped around same stack of + // control flow. + llvm::SmallVector section_ops; +}; + +HostCluster* FindHostCluster( + llvm::SmallVectorImpl& host_cluster_sections, + const llvm::SmallVector& control_flows) { + for (auto& section : host_cluster_sections) + if (control_flows == section.controlflow_stack) return §ion; + return nullptr; +} + +void MoveOutsideCompiledOpsInsideControlFlow( ModuleOp module, tf_device::ClusterOp tpu_cluster, - llvm::StringRef outside_cluster_name, tf_device::LaunchOp host_launch_op, - llvm::ArrayRef cluster_ops, - const llvm::SmallSetVector& external_inputs, - llvm::ArrayRef external_outputs) { - // Since ops in `cluster_ops` do not cross function/control flow boundary, it - // is sufficient to identify the control flow that wraps `cluster_ops` by - // looking at any arbitary op inside `cluster_ops`. - auto controlflow_stack = - GetControlFlowStackForOp(tpu_cluster, cluster_ops.front()); - - Value compilation_key; - if (!controlflow_stack.empty() || !external_inputs.empty() || - !external_outputs.empty()) { - OpBuilder builder(&host_launch_op.GetBody().front()); - compilation_key = - CreateCompilationKeyPlaceholder(tpu_cluster.getLoc(), &builder); - } - - Block* block_to_move_host_cluster = nullptr; + llvm::StringRef host_cluster_section_name, + tf_device::LaunchOp host_launch_op, Value compilation_key, + llvm::ArrayRef cluster_section_ops, + const llvm::SmallVectorImpl& controlflow_stack, + const llvm::SmallSetVector& section_external_inputs, + llvm::ArrayRef section_external_outputs, + llvm::SmallDenseMap* replicated_controlflow_map) { + Operation* insertion_op = nullptr; if (controlflow_stack.empty()) { - block_to_move_host_cluster = &host_launch_op.GetBody(); + insertion_op = host_launch_op.GetBody().getTerminator(); } else { int send_recv_counter = 0; - block_to_move_host_cluster = ReplicateControlFlowStack( - outside_cluster_name, controlflow_stack, tpu_cluster, module, - compilation_key, &host_launch_op.GetBody(), &send_recv_counter); + insertion_op = GetOrReplicateControlFlowStack( + host_cluster_section_name, controlflow_stack, tpu_cluster, module, + compilation_key, &host_launch_op.GetBody(), &send_recv_counter, + replicated_controlflow_map); } MLIRContext* context = host_launch_op.getContext(); - if (external_inputs.empty() && external_outputs.empty()) { - MoveOutsideClusterOpsToBlock(*block_to_move_host_cluster, cluster_ops, - context); + if (section_external_inputs.empty() && section_external_outputs.empty()) { + MoveOutsideClusterOpsBeforeOp(insertion_op, cluster_section_ops, context); return; } - OpBuilder builder(block_to_move_host_cluster->getTerminator()); + OpBuilder builder(insertion_op); llvm::SmallVector host_output_types; - for (const auto& external_input : external_inputs) + for (const auto& external_input : section_external_inputs) host_output_types.push_back(external_input.getType()); std::string args_communication_key = - llvm::formatv("host_compute_channel_{0}_args", outside_cluster_name) + llvm::formatv("host_compute_channel_{0}_args", host_cluster_section_name) .str(); std::string retvals_communication_key = - llvm::formatv("host_compute_channel_{0}_retvals", outside_cluster_name) + llvm::formatv("host_compute_channel_{0}_retvals", + host_cluster_section_name) .str(); auto recv_at_host = builder.create( @@ -467,26 +636,105 @@ void MoveOutsideCompiledOps( builder.getStringAttr(args_communication_key), /*device_ordinal=*/builder.getI64IntegerAttr(0)); - auto host_compute = CreateHostCompute( - &builder, tpu_cluster, cluster_ops, external_inputs, external_outputs, - args_communication_key, retvals_communication_key); - MoveOutsideClusterOpsToBlock(*block_to_move_host_cluster, cluster_ops, - context); + auto host_compute = + CreateHostCompute(&builder, tpu_cluster, cluster_section_ops, + section_external_inputs, section_external_outputs, + args_communication_key, retvals_communication_key); + MoveOutsideClusterOpsBeforeOp(insertion_op, cluster_section_ops, context); - builder.setInsertionPoint(block_to_move_host_cluster->getTerminator()); + builder.setInsertionPoint(insertion_op); builder.create( - tpu_cluster.getLoc(), external_outputs, + tpu_cluster.getLoc(), section_external_outputs, /*dynamic_key=*/compilation_key, builder.getStringAttr(retvals_communication_key), /*device_ordinal=*/builder.getI64IntegerAttr(0)); - for (auto result : llvm::zip(external_inputs, recv_at_host.getResults())) + for (auto result : + llvm::zip(section_external_inputs, recv_at_host.getResults())) { mlir::replaceAllUsesInRegionWith(std::get<0>(result), std::get<1>(result), - host_launch_op.body()); + *insertion_op->getParentRegion()); + } - for (auto result : llvm::zip(external_outputs, host_compute.getResults())) - mlir::replaceAllUsesInRegionWith(std::get<0>(result), std::get<1>(result), - tpu_cluster.body()); + for (auto result : + llvm::zip(section_external_outputs, host_compute.getResults())) { + for (auto& result_use : std::get<0>(result).getUses()) { + Operation* result_using_op = result_use.getOwner(); + const bool inside_device_cluster = + tpu_cluster.body().isAncestor(result_using_op->getParentRegion()); + if (inside_device_cluster) result_use.set(std::get<1>(result)); + } + } +} + +void MoveOutsideCompiledOps( + ModuleOp module, tf_device::ClusterOp tpu_cluster, + llvm::StringRef outside_cluster_name, tf_device::LaunchOp host_launch_op, + llvm::ArrayRef cluster_ops, + const llvm::SmallSetVector& external_inputs, + const llvm::SmallSetVector& external_outputs) { + // Identify and groups ops in `cluster_ops` by ops wrapped inside the same + // control flows. + llvm::SmallVector host_cluster_sections; + for (Operation* host_cluster_op : cluster_ops) { + auto controlflow_stack = + GetControlFlowStackForOp(tpu_cluster, host_cluster_op); + auto host_cluster_section = + FindHostCluster(host_cluster_sections, controlflow_stack); + if (!host_cluster_section) { + host_cluster_sections.emplace_back( + HostCluster{controlflow_stack, {host_cluster_op}}); + } else { + host_cluster_section->section_ops.emplace_back(host_cluster_op); + } + } + + const bool has_control_flow = + llvm::any_of(host_cluster_sections, [](const auto host_cluster_section) { + return !host_cluster_section.controlflow_stack.empty(); + }); + + Value compilation_key; + if (has_control_flow || !external_inputs.empty() || + !external_outputs.empty()) { + OpBuilder builder(&host_launch_op.GetBody().front()); + compilation_key = + CreateCompilationKeyPlaceholder(tpu_cluster.getLoc(), &builder); + } + + // Maintains a map of control flow callsite operation in TPU device side + // and an replicated control flow operation on host cluster. + llvm::SmallDenseMap replicated_controlflows; + + // Move `cluster_op` to host cluster, replicating control flow if ops are + // wrapped inside a control flow. + for (const auto& host_cluster_section_and_index : + llvm::enumerate(host_cluster_sections)) { + const auto& host_cluster_section = host_cluster_section_and_index.value(); + const int index = host_cluster_section_and_index.index(); + + const auto& controlflow_stack = host_cluster_section.controlflow_stack; + const auto& cluster_section_ops = host_cluster_section.section_ops; + auto section_external_inputs = + GetExternalOperands(tpu_cluster, cluster_section_ops); + for (auto input : section_external_inputs) { + if (!external_inputs.contains(input)) + section_external_inputs.remove(input); + } + auto section_external_outputs = GetExternalOutputs(cluster_section_ops); + for (auto output : section_external_outputs) { + if (!external_outputs.contains(output)) + section_external_outputs.remove(output); + } + + const std::string host_cluster_section_name = + llvm::formatv("{0}_{1}", outside_cluster_name, index).str(); + + MoveOutsideCompiledOpsInsideControlFlow( + module, tpu_cluster, host_cluster_section_name, host_launch_op, + compilation_key, cluster_section_ops, controlflow_stack, + section_external_inputs, section_external_outputs.takeVector(), + &replicated_controlflows); + } } // Creates a `parallel_execute` op in place of launch with 'clusters` and diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_outside_compilation_cluster.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_outside_compilation_cluster.cc index be01b7644ea..63bb53f52b5 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_outside_compilation_cluster.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_outside_compilation_cluster.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" @@ -22,6 +23,7 @@ limitations under the License. #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" @@ -33,10 +35,26 @@ namespace { constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation"; struct TPUOutsideCompilationCluster - : public PassWrapper { - void runOnFunction() override; + : public TF::PerFunctionAggregateAnalysisConsumerPass< + TPUOutsideCompilationCluster, TF::SideEffectAnalysis> { + void runOnFunction(FuncOp func, + const TF::SideEffectAnalysis::Info& side_effect_analysis); }; +bool IsVariant(Value value) { + return getElementTypeOrSelf(value.getType()).isa(); +} + +bool HasOutsideCompiledAncestor(Operation* op) { + Operation* parent = op->getParentOp(); + while (parent) { + if (parent->getAttrOfType(kXlaOutsideCompilationAttr)) + return true; + parent = parent->getParentOp(); + } + return false; +} + // Represents an outside compiled cluster. All ops that are added to the same // cluster will be extracted together in a later pass. class OutsideCompiledCluster { @@ -44,81 +62,141 @@ class OutsideCompiledCluster { explicit OutsideCompiledCluster(int number) : cluster_name_(llvm::formatv("cluster{0}", number).str()) {} - // Attempts to add an op to this cluster. - // This function requires all ops to be added before their uses. - bool AddOp(Operation* op) { + // Attempts to add an op to this cluster. Ops can be grouped to the same + // cluster if they have data dependency and are inside the same block. + bool AddOp(Operation* op, + const TF::SideEffectAnalysis::Info& side_effect_analysis) { // Check if the op is safe to add before adding it. - bool add = IsSafeToAdd(op); - if (add) { - // Set the ops kXlaOutsideCompilationAttr to the cluster name. + if (IsSafeToAdd(op, side_effect_analysis)) { op->setAttr(kXlaOutsideCompilationAttr, StringAttr::get(cluster_name_, op->getContext())); - - // Since we are adding the op to the cluster, the op is no longer - // considered a user of this cluster. - users_.erase(op); + host_cluster_ops_.insert(op); + return true; } + return false; + } - // Add this op's users to the cluster users. - users_.insert(op->user_begin(), op->user_end()); - return add; + // If any tf.variants are inputs/outputs to the cluster, add them to the + // cluster unless they are already marks with outside compilation attribute. + bool AddVariantInputsOutputs() { + bool added_op = false; + llvm::SmallPtrSet expanded_cluster_ops(host_cluster_ops_); + for (Operation* cluster_op : host_cluster_ops_) { + // Walk the clustered operations to handle nested ops. + cluster_op->walk([&](Operation* op) { + // Add any operations that provide variant inputs to the cluster. + for (auto value : op->getOperands()) { + auto input_defining_op = value.getDefiningOp(); + if (IsVariant(value) && input_defining_op && + !HasOutsideCompiledAncestor(input_defining_op) && + !input_defining_op->getAttrOfType( + kXlaOutsideCompilationAttr)) { + expanded_cluster_ops.insert(input_defining_op); + input_defining_op->setAttr( + kXlaOutsideCompilationAttr, + StringAttr::get(cluster_name_, + input_defining_op->getContext())); + added_op = true; + } + } + // Add any operations that consume variant outputs to the cluster. + for (auto value : op->getResults()) { + if (IsVariant(value)) { + for (auto user : value.getUsers()) { + if (!host_cluster_ops_.contains(user) && + !HasOutsideCompiledAncestor(user) && + !user->getAttrOfType( + kXlaOutsideCompilationAttr)) { + expanded_cluster_ops.insert(user); + user->setAttr( + kXlaOutsideCompilationAttr, + StringAttr::get(cluster_name_, user->getContext())); + added_op = true; + } + } + } + } + }); + } + host_cluster_ops_.swap(expanded_cluster_ops); + + return added_op; } private: // Checks if it is safe for an op to be merged into this cluster. - bool IsSafeToAdd(Operation* op) { + bool IsSafeToAdd(Operation* op, + const TF::SideEffectAnalysis::Info& side_effect_analysis) { + if (closed_) return false; // If the op is not marked for outside compilation it doesn't belong in a // cluster. - if (!op->getAttrOfType(kXlaOutsideCompilationAttr)) + if (!op->getAttrOfType(kXlaOutsideCompilationAttr)) { + auto successors = side_effect_analysis.DirectControlSuccessors(op); + // If non outside compiled op with side effect successors is encountered, + // close this cluster to additions so that no cluster cyclic dependencies + // can be created. + if (!successors.empty()) { + closed_ = true; + } return false; - - // Checks to see if the op's operands are related to this - // clusters users. If they are related, then there is an op between this - // op and the cluster. Since ops are added before their uses, there - // is no way for the op in-between to ever be added to this cluster - // therefore there is no way this op can ever be added to the cluster. - for (const Value& value : op->getOperands()) { - Operation* op_operand = value.getDefiningOp(); - if (op_operand && users_.find(op_operand) != users_.end()) return false; } - return true; + + if (host_cluster_ops_.empty()) return true; + + // Checks to see if there is data dependency between ops in + // `host_cluster_ops_` and `op`. + const bool contains_data_dependency = llvm::any_of( + op->getUsers(), + [&](Operation* user) { return host_cluster_ops_.contains(user); }); + + return contains_data_dependency; } - // users_ stores the direct and indirect users of the outside compiled ops in - // this cluster. It does NOT store the outside compiled ops that are a part - // of this cluster that will be collectively extracted and run on the cpu. - // users_ is consulted when attempting to add a new outside compiled to the - // cluster. If the new op's operand(s) are already in users_, it means that - // the operand(s) were not added to the cluster so it is not safe to add the - // new op to the cluster either. - llvm::SmallPtrSet users_; + // `host_cluster_op_` stores a set of ops that will be grouped and computed + // on host as single XlaHostCompute op. An outside compiled op can be grouped + // to a single cluster if it has data dependency to another op already in the + // cluster. + llvm::SmallPtrSet host_cluster_ops_; std::string cluster_name_; + bool closed_ = false; // Cluster is closed to further additions. }; -void TPUOutsideCompilationCluster::runOnFunction() { +void TPUOutsideCompilationCluster::runOnFunction( + FuncOp func, const TF::SideEffectAnalysis::Info& side_effect_analysis) { llvm::SmallVector clusters; int cluster_counter = 0; - getFunction().walk([&](tf_device::ClusterOp tpu_cluster) { - for (Operation& op : tpu_cluster.GetBody()) { + func.walk([&](tf_device::ClusterOp tpu_cluster) { + llvm::SmallVector tpu_cluster_ops; + tpu_cluster_ops.reserve(tpu_cluster.getBody()->getOperations().size()); + + tpu_cluster.walk([&](Operation* op) { tpu_cluster_ops.emplace_back(op); }); + + // In order to cluster ops feeding results to the same operation, traverse + // the ops in reverse order. + for (Operation* op : llvm::reverse(tpu_cluster_ops)) { // Try to add the op to existing clusters. bool added = false; for (auto& cluster : clusters) - if ((added = cluster.AddOp(&op))) break; + if ((added = cluster.AddOp(op, side_effect_analysis))) break; // If the op cannot be added to existing clusters, create a new cluster. if (!added) { OutsideCompiledCluster new_cluster(cluster_counter++); - new_cluster.AddOp(&op); + new_cluster.AddOp(op, side_effect_analysis); clusters.push_back(new_cluster); } } }); + for (auto& cluster : clusters) { + bool variants_to_add = true; + while (variants_to_add) variants_to_add = cluster.AddVariantInputsOutputs(); + } } } // anonymous namespace -std::unique_ptr> +std::unique_ptr> CreateTPUOutsideCompilationClusterPass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_parallel_execute_sink_resource_write.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_parallel_execute_sink_resource_write.cc new file mode 100644 index 00000000000..45773a128fd --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_parallel_execute_sink_resource_write.cc @@ -0,0 +1,166 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace mlir { +namespace TFTPU { + +namespace { + +// A pass that moves `tf.AssignVariableOp` into a `tf_device.parallel_execute` +// region if the `tf.AssignVariableOp` is the only consumer of a +// `tf_device.parallel_execute` result. This will allow +// TPUMergeVariablesWithExecute to merge resource writes without special +// handling for `tf_device.parallel_execute`. +struct TPUParallelExecuteSinkResourceWrite + : public PassWrapper { + void runOnFunction() override; +}; + +// Finds an AssignVariableOp that can be moved into the parallel_execute region. +// These AssignVariableOps must be the only consumer of the respective +// parallel_execute result, and the resource handle producer must be from an op +// before or above the parallel_execute. +TF::AssignVariableOp GetSingleUseResourceWrite( + tf_device::ParallelExecuteOp parallel_execute, Value result) { + if (!result.hasOneUse()) return nullptr; + + OpOperand& use = *result.getUses().begin(); + auto assign_var = dyn_cast(use.getOwner()); + if (!assign_var) return nullptr; + + if (use.get() != assign_var.value()) return nullptr; + + auto* resource_handle_op = assign_var.resource().getDefiningOp(); + if (resource_handle_op == parallel_execute) return nullptr; + + if (resource_handle_op && + resource_handle_op->getBlock() == + parallel_execute.getOperation()->getBlock() && + parallel_execute.getOperation()->isBeforeInBlock(resource_handle_op)) + return nullptr; + + return assign_var; +} + +// Finds AssignVariableOps that can be moved into a parallel_execute region and +// moves them. Leftover parallel_execute results that were used by the +// such AssignVariableOp are also pruned. +void SinkResourceWritesIntoParallelExecute( + tf_device::ParallelExecuteOp parallel_execute) { + bool rewrite = false; + const int num_regions = parallel_execute.getNumRegions(); + llvm::SmallVector results_to_remap; + + // Go through each region and find AssignVariableOps that can be moved into + // the parallel_execute region. Result indices by region index are collected, + // so they can be removed afterwards. + llvm::SmallVector, 4> results_to_remove_by_region; + results_to_remove_by_region.resize(num_regions); + for (int i = 0; i < num_regions; ++i) { + Block& block = parallel_execute.GetRegionBlockWithIndex(i); + auto results = parallel_execute.GetRegionOutputs(i); + auto& results_to_remove = results_to_remove_by_region[i]; + results_to_remove.reserve(results.size()); + Operation* terminator = block.getTerminator(); + for (auto result : llvm::enumerate(results)) { + TF::AssignVariableOp assign_var = + GetSingleUseResourceWrite(parallel_execute, result.value()); + if (!assign_var) { + results_to_remap.push_back(result.value()); + continue; + } + + // Move AssignVariableOp and update the value to be written to the + // resource variable to be the non forwarded value from within the + // parallel_execute region. + assign_var.getOperation()->moveBefore(terminator); + assign_var.valueMutable().assign(terminator->getOperand(result.index())); + results_to_remove.push_back(result.index()); + } + + rewrite |= !results_to_remove.empty(); + } + + if (!rewrite) return; + + // Remove leftover unused results (terminator operands) from moving + // AssignVariabeOps into the parallel_execute region. + for (auto results_to_remove : llvm::enumerate(results_to_remove_by_region)) { + Block& block = + parallel_execute.GetRegionBlockWithIndex(results_to_remove.index()); + Operation* terminator = block.getTerminator(); + for (int index_to_remove : llvm::reverse(results_to_remove.value())) + terminator->eraseOperand(index_to_remove); + } + + // Replace old parallel_execute with new parallel_execute by moving the + // regions to a new parallel_execute and remapping the results. + llvm::SmallVector new_result_types; + new_result_types.reserve(results_to_remap.size()); + for (Value old_result : results_to_remap) + new_result_types.push_back(old_result.getType()); + + OpBuilder builder(parallel_execute); + auto new_parallel_execute = builder.create( + parallel_execute.getLoc(), num_regions, new_result_types); + + for (auto region : llvm::zip(new_parallel_execute.getRegions(), + parallel_execute.getRegions())) + std::get<0>(region)->takeBody(*std::get<1>(region)); + + for (auto result : + llvm::zip(results_to_remap, new_parallel_execute.getResults())) + std::get<0>(result).replaceAllUsesWith(std::get<1>(result)); + + parallel_execute.erase(); +} + +void TPUParallelExecuteSinkResourceWrite::runOnFunction() { + llvm::SmallVector parallel_executes; + getFunction().walk([&](tf_device::ParallelExecuteOp parallel_execute) { + parallel_executes.push_back(parallel_execute); + }); + + for (tf_device::ParallelExecuteOp parallel_execute : parallel_executes) + SinkResourceWritesIntoParallelExecute(parallel_execute); +} + +} // anonymous namespace + +std::unique_ptr> +CreateTPUParallelExecuteSinkResourceWritePass() { + return std::make_unique(); +} + +static PassRegistration pass( + "tf-tpu-parallel-execute-sink-resource-write", + "Moves tf.AssignVariableOp consumers of tf_device.parallel_execute into " + "tf_device.parallel_execute regions"); + +} // namespace TFTPU +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_resource_read_for_write.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_resource_read_for_write.cc new file mode 100644 index 00000000000..cccd528da1d --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_resource_read_for_write.cc @@ -0,0 +1,140 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" + +namespace mlir { +namespace TFTPU { + +// A pass that finds TPU clusters with write only resource access and adds an +// associated resource read, so the resource can later be fused into TPUExecute. +namespace { +struct TPUResourceReadForWrite + : public PassWrapper> { + void runOnOperation() override; +}; + +// Helper struct holding a resource value and its associated type. +struct ResourceValueAndSubtype { + Value resource; + Type subtype; +}; + +// Finds resource handle and type for result if result writes to a resource. +ResourceValueAndSubtype GetResourceWriteResult( + tf_device::ClusterFuncOp cluster_func, Value result) { + ResourceValueAndSubtype resource; + if (!result.hasOneUse()) return resource; + Operation* result_user = *result.getUsers().begin(); + auto assign_var = dyn_cast(result_user); + if (!assign_var) return resource; + + auto handle = assign_var.resource(); + // Skip result if cluster writes to the same variable via multiple results. + for (Operation* handle_user : handle.getUsers()) { + if (handle_user == assign_var) continue; + auto assign_var_user = dyn_cast(handle_user); + if (!assign_var_user) continue; + if (assign_var_user.value().getDefiningOp() == cluster_func) + return resource; + } + + resource.resource = assign_var.resource(); + resource.subtype = assign_var.value().getType(); + return resource; +} + +// Checks if resource is read by TPU cluster. +bool ClusterFuncHasResourceRead(tf_device::ClusterFuncOp cluster_func, + Value resource) { + for (Operation* resource_user : resource.getUsers()) + if (auto read = dyn_cast(resource_user)) + for (Operation* read_user : read.value().getUsers()) + if (read_user == cluster_func) return true; + + return false; +} + +void TPUResourceReadForWrite::runOnOperation() { + SmallVector cluster_funcs; + getOperation().walk([&](tf_device::ClusterFuncOp cluster_func) { + cluster_funcs.push_back(cluster_func); + }); + + OpBuilder builder(&getContext()); + // Add resource reads for resource writes from TPU cluster where for such + // resources the TPU cluster does not read from. + for (tf_device::ClusterFuncOp cluster_func : cluster_funcs) { + builder.setInsertionPoint(cluster_func); + + SmallVector read_operands; + for (Value result : cluster_func.getResults()) { + // TODO(lyandy): Update pass to use resource alias analysis. + auto resource_and_type = GetResourceWriteResult(cluster_func, result); + if (!resource_and_type.resource) continue; + if (ClusterFuncHasResourceRead(cluster_func, resource_and_type.resource)) + continue; + auto new_read = builder.create( + resource_and_type.resource.getLoc(), resource_and_type.subtype, + resource_and_type.resource); + read_operands.push_back(new_read.value()); + } + + if (read_operands.empty()) continue; + + // Update caller and function types with new read operands. + auto operands = llvm::to_vector<4>(cluster_func.getOperands()); + operands.append(read_operands.begin(), read_operands.end()); + + auto new_cluster_func = builder.create( + cluster_func.getLoc(), cluster_func.getResultTypes(), operands, + cluster_func.getAttrs()); + cluster_func.replaceAllUsesWith(new_cluster_func); + FuncOp func = cluster_func.getFunc(); + Block& block = func.front(); + for (Value read_operand : read_operands) + block.addArgument(read_operand.getType()); + + func.setType(FunctionType::get(block.getArgumentTypes(), + func.getCallableResults(), &getContext())); + cluster_func.erase(); + } +} + +} // namespace + +std::unique_ptr> CreateTPUResourceReadForWritePass() { + return std::make_unique(); +} + +static PassRegistration pass( + "tf-tpu-resource-read-for-write", + "Inserts tf.ReadVariableOp inputs to a TPU cluster for resource writes " + "with no reads"); + +} // namespace TFTPU +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc index 21ad457a7a6..86aeec81150 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc @@ -25,7 +25,6 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/FormatVariadic.h" -#include "llvm/Support/raw_ostream.h" #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/Module.h" // from @llvm-project @@ -42,6 +41,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h" #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h" #include "tensorflow/compiler/xla/xla.pb.h" @@ -154,11 +154,8 @@ LogicalResult EncapsulateFuncAndSerialize(FuncOp entry_func, symbol_table.insert(clone); } - // Serialize module and return. - { - llvm::raw_string_ostream os(*serialized_func_module); - module_for_func.get().print(os); - } + *serialized_func_module = + tensorflow::SerializeMlirModule(module_for_func.get()); return success(); } @@ -647,7 +644,7 @@ LogicalResult Rewrite( int num_replicas = 1; tf_device::ReplicateOp replicate = cluster_func.getParentOfType(); - if (replicate) num_replicas = replicate.n().getLimitedValue(); + if (replicate) num_replicas = replicate.n(); auto num_cores_per_replica_attr = cluster_func.getAttrOfType( tensorflow::kNumCoresPerReplicaAttr); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc index 0b9eaba8c97..35ad3d21b30 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc @@ -28,6 +28,7 @@ limitations under the License. #include "mlir/IR/Module.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" @@ -47,133 +48,47 @@ struct TPUShardingIdentificationPass void runOnOperation() override; }; -// Sets `sharding_op` if `op` is XlaShardingOp or if XlaSharding op is adjacent -// to `op`. XlaSharding op may be direct user of inputs but it may also be -// followed by an Identity op and, in the case where bfloat16 type is used, Cast -// op may be added right after the input. As so, parse the users of the -// operation to access connected XlaSharding op. +// Finds XlaSharding op connected to an argument value. If value is a resource +// type then XlaSharding op will be connected to a ReadVariable op. XlaSharding +// op may be direct user of inputs but it may also be followed by an Identity op +// and, in the case where bfloat16 type is used, Cast op may be added right +// after the input. // +// TODO(hongjunchoi): Add logic to parse XlaSharding op inside control flow (If, +// Case, While) ops and Caller return values. // TODO(hongjunchoi): Consider explicitly checking op patterns to detect sharded // inputs. -void GetAdjacentXlaShardingOp(Operation* op, - llvm::Optional* sharding_op) { - // TODO(hongjunchoi): Detect the case when sharding configuration is ambiguous - // for a single input (i.e. multiple different XlaSharding ops with different - // configuration policies are connected). - if (sharding_op->hasValue()) return; +llvm::Optional GetXlaShardingFromArg(const Value& value) { + llvm::SmallPtrSet visited_values; + llvm::SmallVector values_to_visit{value}; + while (!values_to_visit.empty()) { + llvm::SmallVector next_values_to_visit; + for (Value value_to_visit : values_to_visit) { + if (!visited_values.insert(value_to_visit).second) continue; - if (auto sharding = llvm::dyn_cast(op)) { - sharding_op->emplace(sharding); - return; - } + for (auto& use : value_to_visit.getUses()) { + Operation* owner = use.getOwner(); + if (auto sharding = llvm::dyn_cast(owner)) + return sharding._XlaSharding(); - if (llvm::isa(op)) { - for (auto user : op->getUsers()) - GetAdjacentXlaShardingOp(user, sharding_op); - } -} + if (llvm::isa(owner)) { + next_values_to_visit.push_back(use.getOwner()->getResult(0)); + continue; + } -// Parses XlaSharding op connected to input args. If Input to -// tf_device.ClusterFunc op is of resource type, then XlaSharding op will be -// connected to following ReadVariable op. -// -// TODO(hongjunchoi): Add logic to parse XlaSharding op inside a Call op or -// If/While op. -llvm::Optional ParseInputSharding(const Value& arg) { - llvm::Optional parsed_sharding_op; - for (auto user : arg.getUsers()) { - if (parsed_sharding_op) continue; - - GetAdjacentXlaShardingOp(user, &parsed_sharding_op); - if (parsed_sharding_op) continue; - - if (llvm::isa(user)) - for (auto read_variable_user : user->getUsers()) - GetAdjacentXlaShardingOp(read_variable_user, &parsed_sharding_op); - } - - if (!parsed_sharding_op) return llvm::Optional(); - return parsed_sharding_op.getValue()._XlaSharding(); -} - -// Returns the provided sharding configuration if operand of return value of -// tf_device.ClusterFunc op is directly from XlaSharding op, -llvm::Optional ParseReturnValueSharding(FuncOp func, - const int output_index, - const OpOperand& operand) { - if (auto sharding_op = llvm::dyn_cast_or_null( - operand.get().getDefiningOp())) - return sharding_op._XlaSharding(); - - return llvm::Optional(); -} - -// Includes information on Func op and argument index of the input value. This -// is used to trace Value that is fed into function call ops. -struct FunctionAndArgumentInfo { - FuncOp func; - int argument_index; -}; - -// Adds tf.PartitionedCall op or tf.StatefulPartitionedCall op to `list`. If -// `op` is a function call op, then find the func op from provided `module` and -// add the func op with `arg_index` to `list`. `list` will later be used to -// trace mlir::Value that is fed into (potentially nested) function call ops. -void AddFunctionalOpsToList( - const int arg_index, ModuleOp module, Operation* op, - llvm::SmallVectorImpl* list) { - if (auto pcall_op = llvm::dyn_cast(op)) { - if (!pcall_op.f().isa()) return; - - auto pcall_func = llvm::cast( - module.lookupSymbol(pcall_op.f().getRootReference())); - assert(pcall_func); - list->emplace_back(FunctionAndArgumentInfo{pcall_func, arg_index}); - - } else if (auto spcall_op = - llvm::dyn_cast(op)) { - auto sp_call_func = llvm::cast(module.lookupSymbol(spcall_op.f())); - assert(sp_call_func); - list->emplace_back(FunctionAndArgumentInfo{sp_call_func, arg_index}); - } -} - -// Walks the MLIR graph from `arg` and return a list of all function call ops to -// which the `arg` op is directly connected. -// -// For example: -// argument0 -> PartitionedCallOp -> StatefulPartitionedCallOp -> AddOp -// -// For above case, PartitionedCall op and StatefulPartitionedCallOp will be -// returned. -llvm::SmallVector ExtractFunctionsConnectedToArg( - BlockArgument arg, ModuleOp module) { - llvm::SmallVector functions_connected_to_arg; - for (auto& arg_use : arg.getUses()) - AddFunctionalOpsToList(arg_use.getOperandNumber(), module, - arg_use.getOwner(), &functions_connected_to_arg); - - llvm::SmallVector functions_to_parse{ - functions_connected_to_arg.begin(), functions_connected_to_arg.end()}; - - while (!functions_to_parse.empty()) { - llvm::SmallVector newly_discovered_functions; - for (auto function_info : functions_to_parse) { - Block& func_entry_block = function_info.func.front(); - auto argument = - func_entry_block.getArgument(function_info.argument_index); - - for (auto& arg_use : argument.getUses()) - AddFunctionalOpsToList(arg_use.getOperandNumber(), module, - arg_use.getOwner(), &newly_discovered_functions); + if (auto call_op = llvm::dyn_cast(owner)) { + FuncOp func = llvm::dyn_cast(call_op.resolveCallable()); + if (!func) continue; + next_values_to_visit.push_back( + func.getArgument(use.getOperandNumber())); + } + } } - functions_connected_to_arg.append(newly_discovered_functions.begin(), - newly_discovered_functions.end()); - std::swap(functions_to_parse, newly_discovered_functions); + values_to_visit.swap(next_values_to_visit); } - return functions_connected_to_arg; + return llvm::None; } // Walks the graph from the arguments of the `cluster_func_op` and extracts @@ -186,7 +101,6 @@ void IdentifyXlaShardingForComputationInputs( FuncOp cluster_function, Builder* builder) { // Look up function definition from module. Block& cluster_function_block = cluster_function.front(); - ModuleOp module = cluster_func_op.getParentOfType(); llvm::SmallVector sharding_for_args( cluster_function_block.getNumArguments(), logical_core_0_sharding); @@ -202,31 +116,17 @@ void IdentifyXlaShardingForComputationInputs( // Sharding configurations are added to the tf_device.ClusterFunc as an // attribute and the function as an argument attribute. for (auto& arg : cluster_function_block.getArguments()) { - auto arg_sharding = ParseInputSharding(arg); - const int arg_index_to_tpu_computation = arg.getArgNumber(); - - if (!arg_sharding.hasValue()) { - auto connected_functions_to_arg = - ExtractFunctionsConnectedToArg(arg, module); - for (auto& function_arg_info : connected_functions_to_arg) { - if (arg_sharding.hasValue()) break; - - const int function_argument_index = function_arg_info.argument_index; - auto& parsed_function = function_arg_info.func; - Block& parsed_function_block = parsed_function.front(); - arg_sharding = ParseInputSharding( - parsed_function_block.getArgument(function_argument_index)); - } - } + auto arg_sharding = GetXlaShardingFromArg(arg); + const int index = arg.getArgNumber(); if (arg_sharding) { - sharding_for_args[arg_index_to_tpu_computation] = arg_sharding.getValue(); + sharding_for_args[index] = arg_sharding.getValue(); cluster_function.setArgAttr( - arg_index_to_tpu_computation, kShardingAttr, + index, kShardingAttr, builder->getStringAttr(arg_sharding.getValue())); } else { cluster_function.setArgAttr( - arg_index_to_tpu_computation, kShardingAttr, + index, kShardingAttr, builder->getStringAttr(logical_core_0_sharding)); } } @@ -235,6 +135,44 @@ void IdentifyXlaShardingForComputationInputs( builder->getStrArrayAttr(sharding_for_args)); } +// Finds XlaSharding op connected to a result value. XlaSharding op may be +// direct user of inputs but it may also be followed by an Identity op and, in +// the case where bfloat16 type is used, Cast op may be added right after the +// input. +// +// TODO(hongjunchoi): Add logic to parse XlaSharding op inside control flow (If, +// Case, While) ops and Caller argument values. +// TODO(hongjunchoi): Consider explicitly checking op patterns to detect sharded +// inputs. +llvm::Optional GetXlaShardingFromRetval(const Value& value) { + llvm::SmallPtrSet visited_values; + Value value_to_visit = value; + while (value_to_visit) { + if (!visited_values.insert(value_to_visit).second) return llvm::None; + + Operation* def = value_to_visit.getDefiningOp(); + if (auto sharding = llvm::dyn_cast_or_null(def)) + return sharding._XlaSharding(); + + if (llvm::isa_and_nonnull(def)) { + value_to_visit = def->getOperand(0); + continue; + } + + if (auto call_op = llvm::dyn_cast_or_null(def)) { + FuncOp func = llvm::dyn_cast(call_op.resolveCallable()); + if (!func) continue; + value_to_visit = func.front().getTerminator()->getOperand( + value_to_visit.cast().getResultNumber()); + continue; + } + + break; + } + + return llvm::None; +} + // Parses XlaSharding op directly connected from the outputs of the // `cluster_func` and extract sharding configurations for outputs. void IdentifyXlaShardingForComputationOutputs( @@ -252,8 +190,8 @@ void IdentifyXlaShardingForComputationOutputs( // tf_device.ClusterFunc as an attribute and the function as a result // attribute. for (auto& ret : terminator->getOpOperands()) { + auto ret_sharding = GetXlaShardingFromRetval(ret.get()); const int index = ret.getOperandNumber(); - auto ret_sharding = ParseReturnValueSharding(func, index, ret); if (ret_sharding) { sharding_for_rets[index] = ret_sharding.getValue(); @@ -264,6 +202,7 @@ void IdentifyXlaShardingForComputationOutputs( builder->getStringAttr(logical_core_0_sharding)); } } + cluster_func.setAttr(tensorflow::kOutputShardingAttr, builder->getStrArrayAttr(sharding_for_rets)); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_space_to_depth_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_space_to_depth_pass.cc index 2f1db0899f7..ed4c411aae8 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_space_to_depth_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_space_to_depth_pass.cc @@ -115,9 +115,8 @@ struct TPUSpaceToDepthPass // Updates func argument type to have the updated input shape. void UpdateFuncType(FuncOp func) { - auto arg_types = llvm::to_vector<8>(func.front().getArgumentTypes()); - auto result_types = - llvm::to_vector<4>(func.front().getTerminator()->getOperandTypes()); + auto arg_types = func.front().getArgumentTypes(); + auto result_types = func.front().getTerminator()->getOperandTypes(); func.setType(FunctionType::get(arg_types, result_types, func.getContext())); } @@ -432,9 +431,8 @@ TF::SpaceToDepthOp BuildSpaceToDepth(tf_device::ClusterFuncOp cluster_func, input_shape[3] * block_size * block_size}; auto transform_result_type = RankedTensorType::get(transform_shape, getElementTypeOrSelf(input)); - return builder.create(cluster_func.getLoc(), - transform_result_type, input, - APInt(64, block_size)); + return builder.create( + cluster_func.getLoc(), transform_result_type, input, block_size); } // Performs transformation for a non-replicated input. @@ -458,7 +456,7 @@ bool HandleHostReplicatedInputs(int64_t index, int64_t replicate_arg_index = block_arg.getArgNumber(); // We need to know the devices to copy to. if (!replicate.devices()) return false; - int64_t num_replicas = replicate.n().getZExtValue(); + int64_t num_replicas = replicate.n(); // Gets inputs at replicate_arg_index for each replica. auto inputs = replicate.getOperands() .drop_front(replicate_arg_index * num_replicas) @@ -669,7 +667,6 @@ void TPUSpaceToDepthPass::runOnOperation() { if (!device_func) return; TF::Conv2DOp first_conv; - Optional> input_shape; // A map maps block argument id to the convolutions consumes them. llvm::SmallDenseMap> argnum_and_convolutions; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc index 3262b83fc94..6e5b07526d1 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc @@ -138,14 +138,17 @@ Value SkipIdentity(Value v, bool allow_other_use, // Finds the formattable arguments of `execute` and annotates the metadata of // `compile` to record these arguments. In addition, it returns a mapping from -// the formattable arguments of `execute` to the corresponding arguments of -// `while_op` (which should be passed through to `execute` via `replicate`). The +// the formattable arguments of `execute` to the corresponding operand of +// `replicate`. The // entries in the mapping are sorted in the order of operands of `execute`. llvm::SmallVector>, 4> AnnotateCompileOpAndGetExecuteArgToWhileArgsMapping( - TF::WhileOp while_op, tf_device::ReplicateOp replicate, + TF::WhileRegionOp while_op, tf_device::ReplicateOp replicate, TF::TPUExecuteAndUpdateVariablesOp execute, - tf_device::LaunchOp compile_launch, FuncOp body, FuncOp cond) { + tf_device::LaunchOp compile_launch) { + Region& body = while_op.body(); + Region& cond = while_op.cond(); + llvm::SmallVector>, 4> mapping; auto mirrored_variable_indices_attr = replicate.getAttrOfType(kMirroredVariableIndicesAttr); @@ -174,7 +177,7 @@ AnnotateCompileOpAndGetExecuteArgToWhileArgsMapping( assert(metadata_str && "Missing compilation metadata"); tensorflow::tpu::TPUCompileMetadataProto metadata; metadata.ParseFromString(std::string(metadata_str.getValue())); - int64_t num_replicas = replicate.n().getLimitedValue(); + int64_t num_replicas = replicate.n(); // Find the formattable operands of `execute`, which must be mirrored // variables (arguments of `replicate`), and must be pass-throughs from while // operands. @@ -204,39 +207,43 @@ AnnotateCompileOpAndGetExecuteArgToWhileArgsMapping( // We have found a mirrored variable which is an input to the replicated // `execute`. Now find if this mirrored variable is a pass-through of while // arguments. - llvm::SmallVector while_args; + llvm::SmallVector replicate_args; for (int64_t i = 0; i < num_inputs; ++i) { llvm::SmallPtrSet skipped_identities; auto replicate_operand = SkipIdentity( replicate.GetReplicaOperandForBlockArgument(block_arg, i), /*allow_other_use=*/false, &skipped_identities); - auto block_arg = replicate_operand.dyn_cast(); - // To qualify for a valid pass-through mirrored variable, it must satisfy - // 1) it is the body's argument; - // 2) it has no other uses than `replicate`, the skipped identitiy ops, - // or the return; - // 3) the corresponding argument in the cond function has no uses. - if (!block_arg || block_arg.getOwner() != &body.front() || - llvm::any_of(replicate_operand.getUsers(), - [&](Operation* user) { - return user != body.front().getTerminator() && - skipped_identities.count(user) == 0 && - user != replicate; - }) || - !cond.getArgument(block_arg.getArgNumber()).use_empty()) { - while_args.clear(); + // For region based control flow, the resource operand for the replicate + // should be a region capture. If this has any use other than the + // replicate op (within the body of the while) or the skipped identities, + // then do not apply the transformation to this variable. + bool is_region_capture = + replicate_operand.getParentRegion()->isProperAncestor(&body); + bool has_other_use_in_body = + llvm::any_of(replicate_operand.getUsers(), [&](Operation* user) { + // Ignore uses that are not in the while body or condition. + if (!body.isAncestor(user->getParentRegion()) && + !cond.isAncestor(user->getParentRegion())) + return false; + // Within the body or cond, only uses in replicate and the skipped + // identities is allowed. + return user != replicate && skipped_identities.count(user) == 0; + }); + + if (!is_region_capture || has_other_use_in_body) { + replicate_args.clear(); break; } - while_args.push_back(while_op.getOperand(block_arg.getArgNumber())); + replicate_args.push_back(replicate_operand); } - if (while_args.empty()) continue; + if (replicate_args.empty()) continue; // Now set the enable_xla_sharding field in the metadata to inform the // compile op. auto metadata_arg = metadata.mutable_args(it->second); metadata_arg->set_enable_xla_sharding( ::tensorflow::tpu::TPUCompileMetadataProto_Arg::ALLOWED); - mapping.emplace_back(it->second, std::move(while_args)); + mapping.emplace_back(it->second, std::move(replicate_args)); } // Sort the mapping according to execute operand order. llvm::sort(mapping, llvm::less_first()); @@ -261,10 +268,11 @@ AnnotateCompileOpAndGetExecuteArgToWhileArgsMapping( // Adds a new replicated input to the replicate op. tf_device::ReplicateOp AddInputsToReplicateOp( - tf_device::ReplicateOp replicate, ArrayRef new_inputs, + tf_device::ReplicateOp replicate, + MutableArrayRef new_inputs, const llvm::SmallDenseMap>& devices) { - int64_t num_replicas = replicate.n().getLimitedValue(); + int64_t num_replicas = replicate.n(); assert(new_inputs.size() == num_replicas); // As model parallelism is not yet supported, we assume that all ops are @@ -275,8 +283,7 @@ tf_device::ReplicateOp AddInputsToReplicateOp( ->getSecond() .size() == num_replicas); - llvm::SmallVector, Type>, 8> - new_replicated_inputs; + llvm::SmallVector, 8> new_replicated_inputs; llvm::SmallVector new_packed_inputs; llvm::SmallVector, 8> replicated_inputs; replicated_inputs.reserve(replicate.GetNumReplicatedBlockArguments()); @@ -293,13 +300,16 @@ tf_device::ReplicateOp AddInputsToReplicateOp( new_packed_inputs.emplace_back( replicate.GetReplicaOperandForBlockArgument(arg, /*replica=*/0)); } - new_replicated_inputs.emplace_back(new_inputs, new_inputs.front().getType()); + SmallVector new_input_values; + new_input_values.reserve(new_inputs.size()); + for (auto var : new_inputs) new_input_values.push_back(var.resource()); + new_replicated_inputs.emplace_back(new_input_values, + new_input_values.front().getType()); OpBuilder builder(replicate); auto new_replicate = builder.create( replicate.getLoc(), num_replicas, devices, new_replicated_inputs, new_packed_inputs, - llvm::to_vector<8>( - replicate.GetBody().getTerminator()->getOperandTypes())); + replicate.GetBody().getTerminator()->getOperandTypes()); for (auto arg : replicate.GetBody().getArguments()) { if (replicate.IsReplicatedBlockArgument(arg)) { arg.replaceAllUsesWith( @@ -319,58 +329,6 @@ tf_device::ReplicateOp AddInputsToReplicateOp( return new_replicate; } -// Adds the per-device state variables to the while-loop's inputs/outputs. -TF::WhileOp AddStateVarsToWhileOp(TF::WhileOp while_op, FuncOp body, - FuncOp cond, - ArrayRef state_vars) { - auto body_return = llvm::cast(body.front().back()); - auto new_body_return_vals = llvm::to_vector<4>(body_return.getOperands()); - auto new_while_operands = llvm::to_vector<4>(while_op.getOperands()); - auto append_types = [&](ArrayRef types) { - auto new_types = llvm::to_vector<4>(types); - for (auto state_var : state_vars) { - new_types.push_back(state_var.resource().getType()); - } - return new_types; - }; - for (auto state_var : state_vars) { - body.front().addArgument(state_var.resource().getType()); - cond.front().addArgument(state_var.resource().getType()); - auto inner_arg = body.getArgument(body.front().getNumArguments() - 1); - new_body_return_vals.push_back(inner_arg); - new_while_operands.push_back(state_var.resource()); - } - OpBuilder builder = OpBuilder::atBlockEnd(&body.front()); - // Update return values. - builder.create(body_return.getLoc(), new_body_return_vals); - body_return.erase(); - - body.setType(FunctionType::get(append_types(body.getType().getInputs()), - append_types(body.getType().getResults()), - body.getContext())); - cond.setType(FunctionType::get(append_types(cond.getType().getInputs()), - cond.getType().getResults(), - cond.getContext())); - for (int64_t i = 0, end = state_vars.size(); i < end; ++i) { - int64_t arg_index = body.getNumArguments() - state_vars.size() + i; - TF::VarHandleOp state_var = state_vars[i]; - auto device_attr = state_var.getAttr(kDeviceAttr); - if (device_attr) { - body.setArgAttr(arg_index, kFuncDeviceAttr, device_attr); - cond.setArgAttr(arg_index, kFuncDeviceAttr, device_attr); - } - } - builder.setInsertionPoint(while_op); - auto new_while_op = builder.create( - while_op.getLoc(), - append_types(llvm::to_vector<4>(while_op.getResultTypes())), - new_while_operands, while_op.getAttrs()); - while_op.replaceAllUsesWith( - new_while_op.getResults().take_front(while_op.getNumResults())); - while_op.erase(); - return new_while_op; -} - // Creates the per-device variables that represent the formatting state of each // device. llvm::SmallVector CreateStateVars( @@ -421,9 +379,9 @@ void WrapOpInLaunch(OpBuilder* builder, Location loc, Operation* op, } // Performs the transformation for a replicate op inside a while loop. -void HandleReplicateOp(TF::WhileOp while_op, tf_device::ReplicateOp replicate, - MLIRContext* context) { - int64_t num_replicas = replicate.n().getLimitedValue(); +void HandleReplicateOp(TF::WhileRegionOp while_op, + tf_device::ReplicateOp replicate) { + int64_t num_replicas = replicate.n(); if (num_replicas == 1) return; tf_device::LaunchOp execute_launch; for (auto execute_launch_op : @@ -452,13 +410,10 @@ void HandleReplicateOp(TF::WhileOp while_op, tf_device::ReplicateOp replicate, !llvm::isa(compile_launch.GetBody().front())) return; - FuncOp body = while_op.body_func(); - FuncOp cond = while_op.cond_func(); - // Analyze the formattable inputs. auto execute_arg_to_outer_args = AnnotateCompileOpAndGetExecuteArgToWhileArgsMapping( - while_op, replicate, execute, compile_launch, body, cond); + while_op, replicate, execute, compile_launch); if (execute_arg_to_outer_args.empty()) return; // Extract the replicated devices. @@ -489,16 +444,7 @@ void HandleReplicateOp(TF::WhileOp while_op, tf_device::ReplicateOp replicate, RankedTensorType::get({2}, TF::StringType::get(builder.getContext())); auto state_vars = CreateStateVars(devices, while_op.getLoc(), key_type, &builder); - while_op = AddStateVarsToWhileOp(while_op, body, cond, state_vars); - // Add the new while loop inputs to the replicate op inside the body. - int64_t new_while_operand_count = while_op.getNumOperands(); - llvm::SmallVector inner_state_vars; - for (int64_t i = new_while_operand_count - num_replicas; - i < new_while_operand_count; ++i) { - inner_state_vars.push_back(body.front().getArgument(i)); - } - - replicate = AddInputsToReplicateOp(replicate, inner_state_vars, devices); + replicate = AddInputsToReplicateOp(replicate, state_vars, devices); // Build the reformat according to the compilation. Build it inside // `replicate`. llvm::SmallVector reformat_operands; @@ -516,8 +462,7 @@ void HandleReplicateOp(TF::WhileOp while_op, tf_device::ReplicateOp replicate, // Build the replicated unformat op after the loop. First prepare building the // replicate op. - llvm::SmallVector, Type>, 8> - unformat_replicate_operands; + llvm::SmallVector, 8> unformat_replicate_operands; llvm::SmallVector unformat_packed_operands; for (const auto& entry : execute_arg_to_outer_args) { if (entry.second.size() > 1) { @@ -537,16 +482,17 @@ void HandleReplicateOp(TF::WhileOp while_op, tf_device::ReplicateOp replicate, // Build a constant default key to specify that the unformatting should // transform the variables to the original format. builder.setInsertionPointAfter(while_op); - tensorflow::Tensor default_key_tensor(tensorflow::DT_STRING, {2}); + tensorflow::Tensor default_key_tensor(tensorflow::DT_STRING, {3}); default_key_tensor.vec()(0) = kDefaultShardingValue; default_key_tensor.vec()(1) = kDefaultShardingValue; + default_key_tensor.vec()(2) = kDefaultShardingValue; auto default_state_key = builder.create( while_op.getLoc(), tensorflow::ConvertTensor(default_key_tensor, &builder).ValueOrDie()); // With all replicated inputs, now build the replicate op. auto unformat_replicate = builder.create( while_op.getLoc(), num_replicas, devices, unformat_replicate_operands, - unformat_packed_operands, ArrayRef{}); + unformat_packed_operands, TypeRange{}); // Then build the unformat op in the replicate op. builder.setInsertionPointToEnd(&unformat_replicate.GetBody()); llvm::SmallVector unformat_operands; @@ -575,10 +521,9 @@ void HandleReplicateOp(TF::WhileOp while_op, tf_device::ReplicateOp replicate, void TPUVariableRuntimeReformattingPass::runOnOperation() { auto module = getOperation(); - module.walk([&](TF::WhileOp while_op) { - auto body = llvm::cast(module.lookupSymbol(while_op.body())); + module.walk([&](TF::WhileRegionOp while_op) { tf_device::ReplicateOp replicate; - body.walk([&](tf_device::ReplicateOp replicate_op) { + while_op.body().walk([&](tf_device::ReplicateOp replicate_op) { if (replicate == nullptr) { replicate = replicate_op; return WalkResult::advance(); @@ -591,7 +536,7 @@ void TPUVariableRuntimeReformattingPass::runOnOperation() { // `tf_device.parallel_execute` op in the `tf_device.replicate` is present. if (replicate && replicate.GetBody().getOps().empty()) - HandleReplicateOp(while_op, replicate, &getContext()); + HandleReplicateOp(while_op, replicate); }); } diff --git a/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc b/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc index 0a69987deb0..e9cea13f550 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc @@ -43,6 +43,10 @@ namespace { class BreakUpIslands : public TF::PerFunctionAggregateAnalysisConsumerPass< BreakUpIslands, TF::SideEffectAnalysis> { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + public: void runOnFunction(FuncOp func, const TF::SideEffectAnalysis::Info& side_effect_analysis); @@ -126,18 +130,15 @@ void PopulateEmptyIsland(tf_executor::IslandOp island) { OpBuilder builder(&island.GetBody(), island.GetBody().begin()); tf_executor::YieldOp yield = island.GetYield(); if (yield.getNumOperands() == 0) { - builder.create(island.getLoc(), llvm::ArrayRef{}, - llvm::ArrayRef{}, - llvm::ArrayRef{}); + builder.create(island.getLoc(), TypeRange{}, ValueRange{}); } else if (yield.getNumOperands() == 1) { Value operand = yield.getOperand(0); auto identity = builder.create(island.getLoc(), operand.getType(), operand); yield.setOperand(0, identity.output()); } else { - auto types = llvm::to_vector<4>(yield.getOperandTypes()); - auto identity_n = builder.create(island.getLoc(), types, - yield.getOperands()); + auto identity_n = builder.create( + island.getLoc(), yield.getOperandTypes(), yield.getOperands()); for (auto it : llvm::enumerate(identity_n.getResults())) yield.setOperand(it.index(), it.value()); } @@ -145,8 +146,8 @@ void PopulateEmptyIsland(tf_executor::IslandOp island) { // Helper that creates an island. If `sub_op` is not nullptr, it will be moved // to the island. Otherwise a NoOp will be added to the island. -tf_executor::IslandOp CreateIsland(ArrayRef result_types, - ArrayRef control_inputs, +tf_executor::IslandOp CreateIsland(TypeRange result_types, + ValueRange control_inputs, const tf_executor::ControlType& control_type, const Location& loc, Operation* sub_op, tf_executor::IslandOp original_island) { @@ -162,10 +163,8 @@ tf_executor::IslandOp CreateIsland(ArrayRef result_types, sub_op->moveBefore(block, block->begin()); island_builder.create(loc, sub_op->getResults()); } else { - island_builder.create( - island.getLoc(), llvm::ArrayRef{}, - llvm::ArrayRef{}, llvm::ArrayRef{}); - island_builder.create(loc, ArrayRef{}); + island_builder.create(island.getLoc(), TypeRange{}, ValueRange{}); + island_builder.create(loc, ValueRange{}); } return island; } @@ -278,8 +277,8 @@ void BreakUpIslands::BreakUpIsland( ? island_control_inputs : predecessor_controls; auto new_island = - CreateIsland(llvm::to_vector<4>(sub_op.getResultTypes()), control, - control_type, sub_op.getLoc(), &sub_op, island_op); + CreateIsland(sub_op.getResultTypes(), control, control_type, + sub_op.getLoc(), &sub_op, island_op); new_control_for_sub_ops[&sub_op] = new_island.control(); if (sources_and_sinks.sinks.count(&sub_op)) { sink_island_controls.push_back(new_island.control()); diff --git a/tensorflow/compiler/mlir/tensorflow/translate/derived_attr_populator_gen.cc b/tensorflow/compiler/mlir/tensorflow/translate/derived_attr_populator_gen.cc deleted file mode 100644 index e4c965b6cb1..00000000000 --- a/tensorflow/compiler/mlir/tensorflow/translate/derived_attr_populator_gen.cc +++ /dev/null @@ -1,196 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "llvm/ADT/StringExtras.h" -#include "llvm/Support/CommandLine.h" -#include "llvm/Support/FormatVariadic.h" -#include "llvm/Support/InitLLVM.h" -#include "llvm/Support/Signals.h" -#include "llvm/Support/raw_ostream.h" -#include "llvm/TableGen/Error.h" -#include "llvm/TableGen/Main.h" -#include "llvm/TableGen/Record.h" -#include "llvm/TableGen/TableGenBackend.h" -#include "mlir/TableGen/Operator.h" // from @llvm-project - -using llvm::LessRecord; -using llvm::raw_ostream; -using llvm::Record; -using llvm::RecordKeeper; -using mlir::tblgen::Operator; - -// Helper macro that returns indented os. -#define OUT(X) os.indent((X)) - -// Emits TensorFlow derived attribute populator functions for each of the ops. -static void EmitOpAttrPopulators(const std::vector &ops, - raw_ostream *ostream) { - raw_ostream &os = *ostream; - - for (const auto &op : ops) { - // TODO(hinsu): Introduce a derived attribute property for ops with no - // type attributes. That way an error can be generated if no derived type - // attribute or the property is set. This will make sure derived type - // attributes are not omitted by mistake. - - // Emit function signature. - auto op_name = op.getCppClassName(); - OUT(0) << "static Status Populate" << op_name - << "DerivedAttrs(mlir::TF::" << op_name - << "& op, AttrValueMap *values) {\n"; - - for (const auto &named_attr : op.getAttributes()) { - auto attr_name = named_attr.name; - const auto &attr = named_attr.attr; - if (!attr.isDerivedAttr()) continue; - auto retType = attr.getReturnType(); - if (retType == "ShapedType" || retType == "mlir::TF::OperandShapeRange" || - retType == "mlir::TF::ResultShapeRange") { - OUT(2) << "TF_RETURN_IF_ERROR(SetShapeAttribute(\"" << attr_name - << "\", op." << attr_name << "(), values));\n"; - } else if (retType == "Type" || - retType == "mlir::OperandElementTypeRange" || - retType == "mlir::ResultElementTypeRange") { - OUT(2) << "TF_RETURN_IF_ERROR(SetTypeAttribute(\"" << attr_name - << "\", op." << attr_name << "(), values));\n"; - } else if (attr.isSubClassOf("TF_DerivedOperandSizeAttr") || - attr.isSubClassOf("TF_DerivedResultSizeAttr")) { - OUT(2) << "TF_RETURN_IF_ERROR(SetSizeAttribute(\"" << attr_name - << "\", op." << attr_name << "(), values));\n"; - } else { - PrintFatalError(op.getLoc(), - "NYI: attribute populator for derived attributes"); - } - } - - OUT(2) << "return Status::OK();\n"; - OUT(0) << "}\n\n"; - } -} - -// Emits TensorFlow derived attribute populator function taking an Operation -// as argument. -static void EmitInstAttrPopulator(const std::vector &ops, - raw_ostream *ostream) { - raw_ostream &os = *ostream; - - // Emit function signature. - OUT(0) << "static Status PopulateDerivedAttrs(mlir::Operation* op, " - "AttrValueMap* values) {\n"; - - for (const auto &op : ops) { - auto op_name = op.getCppClassName(); - - // Emit conditional for the op and then call populator for the op on match. - OUT(2) << "if (auto tfOp = llvm::dyn_cast(op)) {\n"; - OUT(4) << "TF_RETURN_IF_ERROR(Populate" << op_name - << "DerivedAttrs(tfOp, values));\n"; - OUT(2) << "}\n"; - } - OUT(2) << "return Status::OK();\n"; - OUT(0) << "}\n\n"; -} - -// Emits TensorFlow derived attribute name collector functions for each of the -// ops. -static void EmitOpAttrNameCollector(const std::vector &ops, - raw_ostream *ostream) { - raw_ostream &os = *ostream; - - for (const auto &op : ops) { - // Emit function signature. - auto op_name = op.getCppClassName(); - OUT(0) << "static void Collect" << op_name - << "DerivedAttrsName(mlir::TF::" << op_name - << "& op, llvm::SmallDenseSet* values) {\n"; - - // Insert the name for each derived attribute in the set. - for (const auto &named_attr : op.getAttributes()) { - auto attr_name = named_attr.name; - const auto &attr = named_attr.attr; - if (!attr.isDerivedAttr()) continue; - OUT(2) << "values->insert(\"" << attr_name << "\");\n"; - } - - OUT(2) << "return;\n"; - OUT(0) << "}\n\n"; - } -} - -// Emits TensorFlow derived attribute name collector function taking an -// Operation as argument. -static void EmitInstAttrNameCollector(const std::vector &ops, - raw_ostream *ostream) { - raw_ostream &os = *ostream; - - // Emit function signature. - OUT(0) << "static void CollectDerivedAttrsName(mlir::Operation* op, " - "llvm::SmallDenseSet* values) {\n"; - - for (const auto &op : ops) { - auto op_name = op.getCppClassName(); - - // Emit conditional for the op and then call collect for the op on match. - OUT(2) << "if (auto tf_op = llvm::dyn_cast(op)) {\n"; - OUT(4) << "Collect" << op_name << "DerivedAttrsName(tf_op, values);\n"; - OUT(2) << "}\n"; - } - OUT(2) << "return;\n"; - OUT(0) << "}\n\n"; -} - -// The function below has a non-constant reference as that is required by LLVM's -// TableGenMain. -// NOLINTNEXTLINE -static bool DerivedAttrWritersMain(raw_ostream &os, RecordKeeper &records) { - emitSourceFileHeader("MLIR Derived TensorFlow Attributes Populators", os); - - // Retrieve all the definitions derived from TF_Op and sort by record name. - std::vector defs = records.getAllDerivedDefinitions("TF_Op"); - llvm::sort(defs, LessRecord()); - - std::vector ops; - ops.reserve(defs.size()); - - // Wrap TensorFlow op definitions into tblgen Operator wrapper and verify - // them. - for (const auto *def : defs) { - ops.emplace_back(Operator(def)); - - const Operator &op = ops.back(); - if (op.getDialectName() != "tf") - PrintFatalError(op.getLoc(), - "unexpected op name format: 'TF_' prefix missing"); - if (!op.getCppClassName().endswith("Op")) - PrintFatalError(op.getLoc(), - "unexpected op name format: 'Op' suffix missing"); - } - - EmitOpAttrPopulators(ops, &os); - EmitInstAttrPopulator(ops, &os); - - EmitOpAttrNameCollector(ops, &os); - EmitInstAttrNameCollector(ops, &os); - - return false; -} - -int main(int argc, char **argv) { - llvm::InitLLVM y(argc, argv); - llvm::cl::ParseCommandLineOptions(argc, argv); - return TableGenMain(argv[0], &DerivedAttrWritersMain); -} diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc index 631553b381e..c69e802994d 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc @@ -49,6 +49,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/utils/export_utils.h" #include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h" +#include "tensorflow/compiler/mlir/utils/name_utils.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/graph_to_functiondef.h" @@ -80,46 +81,14 @@ constexpr char kInvalidExecutorGraphMsg[] = constexpr char kDeviceAttr[] = "tf.device"; constexpr char kResourceArgUniqueIdAttr[] = "tf._resource_arg_unique_id"; -bool IsLegalChar(char c, bool first_char) { - if (isalpha(c)) return true; - if (isdigit(c)) return true; - if (c == '.') return true; - if (c == '_') return true; - - // First character of a node name can only be a letter, digit, dot or - // underscore. - if (first_char) return false; - - if (c == '/') return true; - if (c == '-') return true; - - return false; -} - -// Convert characters in name that are considered illegal in TensorFlow Node -// name to '.'. -std::string LegalizeNodeName(llvm::StringRef name) { - assert(!name.empty() && "expected non-empty name"); - - std::string legalized_name; - bool first = true; - for (auto c : name) { - if (IsLegalChar(c, first)) { - legalized_name += c; - } else { - legalized_name += '.'; - } - first = false; - } - - return legalized_name; -} - // OpOrArgLocNameMapper that legalizes the returned name. class LegalizedOpOrValLocNameMapper : public OpOrArgLocNameMapper { private: std::string GetName(OpOrVal op_or_val) override { - return LegalizeNodeName(OpOrArgLocNameMapper::GetName(op_or_val)); + std::string name = OpOrArgLocNameMapper::GetName(op_or_val); + assert(!name.empty() && "expected non-empty name"); + mlir::LegalizeNodeName(name); + return name; } }; @@ -275,6 +244,7 @@ StatusOr> Exporter::GetArgumentNode( func.getArgAttrs(index); absl::flat_hash_set attrs_to_ignore = {kDeviceAttr}; TF_RETURN_IF_ERROR(ConvertAttributes(func_arg_i_attrs, attrs_to_ignore, + /*remove_ref_type=*/false, node_def->mutable_attr())); return node_def; @@ -523,13 +493,14 @@ StatusOr> Exporter::Convert( if (index >= num_data_results) break; // TODO(jpienaar): If there is a result index specified, ensure only one // and that it matches the result index of the op. - std::string orig_name(output_names[index]); - auto tensor_id = ParseTensorName(orig_name); - auto name = LegalizeNodeName( - llvm::StringRef(tensor_id.node().data(), tensor_id.node().size())); + std::string name(output_names[index]); + auto tensor_id = ParseTensorName(name); + std::string tensor_id_node(tensor_id.node()); + assert(!tensor_id_node.empty() && "expected non-empty name"); + mlir::LegalizeNodeName(tensor_id_node); // Ensure name does not get reused. - (void)exporter.op_to_name_.GetUniqueName(name); + (void)exporter.op_to_name_.GetUniqueName(tensor_id_node); } } @@ -537,8 +508,9 @@ StatusOr> Exporter::Convert( TF_RET_CHECK(input_names.size() == block.getNumArguments()); for (const auto& it : llvm::enumerate(function.getArguments())) { // TODO(lyandy): Update when changing feed/fetch import. - std::string orig_name(input_names[it.index()]); - std::string name = LegalizeNodeName(orig_name); + std::string name(input_names[it.index()]); + assert(!name.empty() && "expected non-empty name"); + mlir::LegalizeNodeName(name); auto tensor_id = ParseTensorName(name); TF_RET_CHECK(tensor_id.index() == 0) << "input port designation not supported"; @@ -690,8 +662,9 @@ Status Exporter::ConvertLibFunction(const GraphExportConfig& configs, grad_string.data(), stateful_string.data()}; llvm::SmallVector funcAttrs( function.getDialectAttrs()); - TF_RETURN_IF_ERROR( - ConvertAttributes(funcAttrs, attrs_to_ignore, func_def.mutable_attr())); + TF_RETURN_IF_ERROR(ConvertAttributes(funcAttrs, attrs_to_ignore, + /*remove_ref_type=*/false, + func_def.mutable_attr())); for (int i = 0, e = function.getNumArguments(); i < e; ++i) { if (auto resource_arg_unique_id_attr = @@ -708,6 +681,7 @@ Status Exporter::ConvertLibFunction(const GraphExportConfig& configs, kDeviceAttr, kResourceArgUniqueIdAttr}; FunctionDef::ArgAttrs func_def_arg_i_attrs; TF_RETURN_IF_ERROR(ConvertAttributes(func_arg_i_attrs, attrs_to_ignore, + /*remove_ref_type=*/false, func_def_arg_i_attrs.mutable_attr())); if (func_def_arg_i_attrs.attr().empty()) continue; (*func_def.mutable_arg_attr())[i] = std::move(func_def_arg_i_attrs); diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.cc b/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.cc index 3ca06e5efa9..0057e498cea 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.cc @@ -21,12 +21,15 @@ limitations under the License. #include "absl/strings/string_view.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/StringRef.h" -#include "llvm/ADT/StringSet.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" +#include "llvm/Support/Casting.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/utils/export_utils.h" +#include "tensorflow/compiler/mlir/utils/string_container_utils.h" #include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" @@ -34,7 +37,6 @@ limitations under the License. namespace tensorflow { namespace { -using stream_executor::port::StatusOr; // Sets type list attribute with the given `name` to the given `types`. If the // attribute already exists with a different value, returns an error. @@ -85,22 +87,12 @@ Status SetShapeAttribute(absl::string_view name, ContainerT shapes, return Status::OK(); } -// Include the auto generated derived attribute populator function taking -// TensorFlow dialect operation as an argument. This file contains the function -// definitions and isn't a header file. -#include "tensorflow/compiler/mlir/tensorflow/translate/derived_attr_populator.inc" - -// Collect all the unregistered attributes for an TF dialect operation. +// Collects all the unregistered attributes for an TF dialect operation. // Attributes "name" and "device" are not included because they are not part // of an TF op attributes. Status GetUnregisteredAttrs( - mlir::Operation* inst, + mlir::Operation* inst, const tensorflow::OpRegistrationData* op_reg_data, absl::flat_hash_set* attrs_to_ignore) { - TF_ASSIGN_OR_RETURN(auto op_name, - GetTensorFlowOpName(inst->getName().getStringRef())); - - const tensorflow::OpRegistrationData* op_reg_data = - tensorflow::OpRegistry::Global()->LookUp(std::string(op_name)); if (!op_reg_data) { // This is likely a function call node, so we should continue. return Status::OK(); @@ -123,29 +115,27 @@ Status GetUnregisteredAttrs( return Status::OK(); } -} // namespace - -StatusOr> ConvertTFDialectOpToNodeDef( - mlir::Operation* inst, llvm::StringRef name, +// Collects all attribute names to ignore in an MLIR operation when exporting to +// a TensorFlow NodeDef. +StatusOr> GetAttributesToIgnore( + mlir::Operation* inst, mlir::DictionaryAttr derived_attrs, + const tensorflow::OpRegistrationData* op_reg_data, bool ignore_unregistered_attrs) { - // Use auto generated function to populate derived attribute. - // - // Note: This only populates derived attributes for TensorFlow ops that are - // generated using the TableGen. Manually defined ops should have all the - // attributes present as native MLIR op attributes. - // The elements are owned by the MLIRContext. absl::flat_hash_set attrs_to_ignore; - if (inst->isRegistered()) { - // We ignore attributes attached to the operation when there is already a - // derived attribute defined in ODS. - llvm::SmallDenseSet derived_attrs; - CollectDerivedAttrsName(inst, &derived_attrs); - for (auto name : derived_attrs) attrs_to_ignore.insert(name.data()); + + // We ignore attributes attached to the operation when there is already a + // derived attribute defined in ODS. + if (derived_attrs) { + for (auto derived_attr : derived_attrs) { + attrs_to_ignore.insert( + mlir::StringRefToView(derived_attr.first.strref())); + } } if (ignore_unregistered_attrs) { - TF_RETURN_IF_ERROR(GetUnregisteredAttrs(inst, &attrs_to_ignore)); + TF_RETURN_IF_ERROR( + GetUnregisteredAttrs(inst, op_reg_data, &attrs_to_ignore)); } if (inst->hasTrait()) { @@ -162,15 +152,24 @@ StatusOr> ConvertTFDialectOpToNodeDef( attrs_to_ignore.insert(attr_name.data()); } - TF_ASSIGN_OR_RETURN(auto node_def, - GetOperationNodeDef(attrs_to_ignore, inst, name)); + if (llvm::isa(inst)) + attrs_to_ignore.insert("is_stateless"); - // If the operation is not registered, we won't be able to infer any attribute - if (inst->isRegistered()) { + return attrs_to_ignore; +} + +// Populates all derived attributes of a MLIR operation in a proto +// map. +Status PopulateDerivedAttributes(mlir::Operation* inst, llvm::StringRef name, + mlir::DictionaryAttr derived_attrs, + bool ignore_unregistered_attrs, + AttrValueMap* attributes) { + if (derived_attrs) { TF_RETURN_WITH_CONTEXT_IF_ERROR( - PopulateDerivedAttrs(inst, node_def->mutable_attr()), - "When populating derived attrs for ", - inst->getName().getStringRef().str()); + ConvertAttributes(derived_attrs.getValue(), /*attrs_to_ignore=*/{}, + /*remove_ref_type=*/true, attributes), + "while converting derived attributes for node: ", + mlir::StringRefToView(name)); } // Here we only add the shapes for the leading values with ShapedType, @@ -185,10 +184,46 @@ StatusOr> ConvertTFDialectOpToNodeDef( mlir::TF::ResultShapeRange output_shapes = { mlir::TF::ResultShapeIterator(begin), mlir::TF::ResultShapeIterator(end)}; - TF_RETURN_IF_ERROR(SetShapeAttribute("_output_shapes", output_shapes, - node_def->mutable_attr())); + TF_RETURN_IF_ERROR( + SetShapeAttribute("_output_shapes", output_shapes, attributes)); } } + + return Status::OK(); +} + +} // namespace + +Status GetAttrValuesFromOperation( + mlir::Operation* inst, llvm::StringRef name, + const tensorflow::OpRegistrationData* op_reg_data, + bool ignore_unregistered_attrs, AttrValueMap* attributes) { + mlir::DictionaryAttr derived_attrs = nullptr; + if (auto interface = llvm::dyn_cast(inst)) + derived_attrs = interface.materializeDerivedAttributes(); + TF_ASSIGN_OR_RETURN(auto attrs_to_ignore, + GetAttributesToIgnore(inst, derived_attrs, op_reg_data, + ignore_unregistered_attrs)); + TF_RETURN_WITH_CONTEXT_IF_ERROR( + ConvertAttributes(inst->getAttrs(), attrs_to_ignore, + /*remove_ref_type=*/false, attributes), + "while converting attributes for node: ", mlir::StringRefToView(name)); + TF_RETURN_IF_ERROR(PopulateDerivedAttributes( + inst, name, derived_attrs, ignore_unregistered_attrs, attributes)); + return Status::OK(); +} + +StatusOr> ConvertTFDialectOpToNodeDef( + mlir::Operation* inst, llvm::StringRef name, + bool ignore_unregistered_attrs) { + TF_ASSIGN_OR_RETURN(auto node_def, GetOperationNodeDef(inst, name)); + TF_ASSIGN_OR_RETURN(auto op_name, + GetTensorFlowOpName(inst->getName().getStringRef())); + const tensorflow::OpRegistrationData* op_reg_data = + tensorflow::OpRegistry::Global()->LookUp(op_name.str()); + TF_RETURN_IF_ERROR(GetAttrValuesFromOperation(inst, name, op_reg_data, + ignore_unregistered_attrs, + node_def->mutable_attr())); return node_def; } diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h b/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h index a19ad1f2940..6341b14fe7b 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h @@ -18,12 +18,24 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "mlir/IR/Operation.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/utils/export_utils.h" +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op_def_builder.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { -// Converts an MLIR operation to TensorFlow NodeDef with given node name. This +// Extracts the attributes of a MLIR operation and populates the converted +// attributes in a proto map. +Status GetAttrValuesFromOperation( + mlir::Operation* inst, llvm::StringRef name, + const tensorflow::OpRegistrationData* op_reg_data, + bool ignore_unregistered_attrs, AttrValueMap* attributes); + +// Converts a MLIR operation to TensorFlow NodeDef with given node name. This // name should be unique to the graph it is being inserted to. If the // `ignore_unregistered_attrs` argument is set to true, the attributes which are // not in the op registry will be ignored. If the `ignore_unregistered_attrs` @@ -31,9 +43,9 @@ namespace tensorflow { // ShapedType for the leading values with ShapedType in the results of the // nodes. Set it to true if the returned NodeDef will be executed by the linked // TF Eager runtime. -stream_executor::port::StatusOr> -ConvertTFDialectOpToNodeDef(mlir::Operation* inst, llvm::StringRef name, - bool ignore_unregistered_attrs); +StatusOr> ConvertTFDialectOpToNodeDef( + mlir::Operation* inst, llvm::StringRef name, + bool ignore_unregistered_attrs); } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index 692d0eaf962..42ce5c533a2 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -64,6 +64,7 @@ limitations under the License. #include "tensorflow/cc/saved_model/loader_util.h" #include "tensorflow/compiler/jit/shape_inference_helpers.h" #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" +#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -72,8 +73,11 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/convert_attr.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h" @@ -141,6 +145,13 @@ bool IsResourceOutputShapesAttribute(const AttrValue& attr_value, return false; } +void LoadImporterDialects(mlir::MLIRContext& context) { + // Load dialects involved in the conversion + mlir::DialectRegistry registry; + mlir::RegisterAllTensorFlowDialects(registry); + registry.loadAll(&context); +} + // This class is used to generate new MLIR function name strings that are both // unique in the TF function library `flib_` and unique among the name strings // generated by the class object during its lifetime. @@ -171,6 +182,8 @@ class NameUniquifier : public OpOrArgNameMapper { Status UpgradeLegacyGraph(Graph* graph, FunctionLibraryDefinition* flib_def, bool restrict_functionalization_to_tpu_nodes) { + TF_RETURN_IF_ERROR(GenerateResourceSharedNameIfEmpty(*graph, *flib_def)); + // If `restrict_functionalization_to_tpu_nodes` is true let filter function // return true for `_tpu_replicate` nodes, otherwise don't set filter. NodeFilter node_filter = @@ -298,21 +311,6 @@ class ImporterBase { return ::tensorflow::ConvertTensorProto(value, &builder_); } - // Converts the tensor shape proto into an MLIR shape attribute. - StatusOr ConvertTensorShapeProto( - const TensorShapeProto& shape) { - if (shape.unknown_rank()) - return mlir::TF::ShapeAttr::get(builder_.getContext(), llvm::None); - - llvm::SmallVector dims; - dims.reserve(shape.dim().size()); - for (const auto& dim : shape.dim()) { - dims.push_back(dim.size()); - } - return mlir::TF::ShapeAttr::get(builder_.getContext(), - llvm::makeArrayRef(dims)); - } - // Converts func name in graphdef to mlir::SymbolRefAttribute. StatusOr ConvertFunctionCallName( const std::string& func_name); @@ -1130,74 +1128,36 @@ StatusOr ImporterBase::ConvertFunctionCallName( StatusOr ImporterBase::ConvertAttributeValue( const AttrValue& value) { switch (value.value_case()) { - case AttrValue::kI: - return builder_.getI64IntegerAttr(value.i()); - case AttrValue::kS: - return builder_.getStringAttr(value.s()); - case AttrValue::kF: - return builder_.getFloatAttr(builder_.getF32Type(), value.f()); - case AttrValue::kB: - return builder_.getBoolAttr(value.b()); - case AttrValue::kType: { - mlir::Type type; - TF_RETURN_IF_ERROR(ConvertDataType(value.type(), builder_, &type)); - return mlir::TypeAttr::get(type); - } - case AttrValue::kShape: - return ConvertTensorShapeProto(value.shape()); - case AttrValue::kTensor: - return ConvertTensorProto(value.tensor()); - case AttrValue::kList: { - absl::InlinedVector attrs; - for (const auto& item : value.list().i()) - attrs.push_back(builder_.getI64IntegerAttr(item)); - for (const auto& item : value.list().s()) - attrs.push_back(builder_.getStringAttr(item)); - for (const auto& item : value.list().f()) - attrs.push_back(builder_.getFloatAttr(builder_.getF32Type(), item)); - for (const auto& item : value.list().b()) - attrs.push_back(builder_.getBoolAttr(item)); - for (const auto& item : value.list().type()) { - mlir::Type type; - TF_RETURN_IF_ERROR(ConvertDataType(DataType(item), builder_, &type)); - attrs.push_back(mlir::TypeAttr::get(type)); - } - for (const auto& item : value.list().shape()) { - TF_ASSIGN_OR_RETURN(auto attr, ConvertTensorShapeProto(item)); - attrs.push_back(attr); - } - for (const auto& item : value.list().tensor()) { - TF_ASSIGN_OR_RETURN(auto attr, ConvertTensorProto(item)); - attrs.push_back(attr); - } - for (const auto& item : value.list().func()) { - TF_ASSIGN_OR_RETURN(auto attr, ConvertFunctionCallName(item.name())); - if (item.attr_size() != 0) - return errors::Unimplemented( - "func attributes with non-zero attr.size()"); - attrs.push_back(attr); - } - return builder_.getArrayAttr( - llvm::makeArrayRef(attrs.begin(), attrs.end())); - } case AttrValue::kFunc: { // TODO(b/156546237): Unify kFunc/NameAttrList attribute representation. // Currently kFunc/NameAttrList attributes in a kList/repeated AttrValue // will not use this representation. NamedAttrList attrs; for (const auto& func_attr : value.func().attr()) { - TF_ASSIGN_OR_RETURN(auto attr, ConvertAttributeValue(func_attr.second)); + TF_ASSIGN_OR_RETURN( + auto attr, ImporterBase::ConvertAttributeValue(func_attr.second)); attrs.push_back(builder_.getNamedAttr(func_attr.first, attr)); } auto func_attrs = builder_.getDictionaryAttr(attrs); return mlir::TF::FuncAttr::get(context_, value.func().name(), func_attrs); } - case AttrValue::VALUE_NOT_SET: - return builder_.getUnitAttr(); - // kPlaceholder is not implemented. + case AttrValue::kList: { + if (!value.list().func().empty()) { + absl::InlinedVector attrs; + for (const auto& item : value.list().func()) { + TF_ASSIGN_OR_RETURN(auto attr, ConvertFunctionCallName(item.name())); + if (item.attr_size() != 0) + return errors::Unimplemented( + "func attributes with non-zero attr.size()"); + attrs.push_back(attr); + } + return builder_.getArrayAttr( + llvm::makeArrayRef(attrs.begin(), attrs.end())); + } + return ConvertNonFuncAttributeValue(value, &builder_); + } default: - return errors::Unimplemented( - absl::StrCat("Attribute ", value.DebugString())); + return ConvertNonFuncAttributeValue(value, &builder_); } } @@ -2136,11 +2096,7 @@ StatusOr GraphDefImporter::Convert( mlir::MLIRContext* context, const Graph& graph, const GraphDebugInfo& debug_info, const FunctionLibraryDefinition& flib_def, const GraphImportConfig& specs, llvm::StringRef func_name) { - // Load dialects involved in the conversion - context->loadDialect(); - context->loadDialect(); - context->loadDialect(); - + LoadImporterDialects(*context); mlir::OwningModuleRef module = mlir::ModuleOp::create(mlir::UnknownLoc::get(context)); std::unordered_map tf_name_to_mlir_name; @@ -3197,6 +3153,7 @@ Status CreateSavedModelIR( StatusOr SavedModelObjectGraphImporter::Convert( SavedModelV2Bundle* saved_model, absl::Span exported_names, mlir::MLIRContext* context, bool add_default_attributes) { + LoadImporterDialects(*context); GraphDebugInfo dummy_debug_info; const GraphDebugInfo& debug_info = saved_model->debug_info() ? *saved_model->debug_info() : dummy_debug_info; @@ -3276,6 +3233,7 @@ class SavedModelSignatureDefImporter { static StatusOr Convert( const SavedModelBundle& bundle, absl::Span exported_names, mlir::MLIRContext* context, bool upgrade_legacy) { + LoadImporterDialects(*context); SavedModelSignatureDefImporter importer(bundle, exported_names, context); TF_RETURN_IF_ERROR(importer.InitializeGraph(upgrade_legacy)); return importer.ConvertSignatures(); @@ -3562,6 +3520,7 @@ Status SavedModelSignatureDefImporter::LiftVariables() { mlir::StatusScopedDiagnosticHandler diag_handler(module_->getContext()); mlir::PassManager pm(module_->getContext()); + SetCrashReproducer(pm); pm.addPass(mlir::tf_executor::CreateTFExecutorGraphPruningPass()); pm.addPass(mlir::CreateExecutorDialectToFunctionalConversionPass()); pm.addPass( @@ -3648,6 +3607,8 @@ stream_executor::port::StatusOr ConvertFunctionToMlir( tensorflow::GraphDebugInfo dummy_debug_info; tensorflow::GraphImportConfig specs; specs.graph_as_function = true; + for (const auto* control_ret_node : fbody->control_ret_nodes) + specs.control_outputs.push_back(control_ret_node->name()); return GraphDefImporter::Convert(context, *fbody->graph, dummy_debug_info, flib_def, specs, name); } diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc index b646e14b71d..f63cb091a09 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc @@ -23,6 +23,7 @@ limitations under the License. #include "llvm/Support/MemoryBuffer.h" #include "mlir/IR/Module.h" // from @llvm-project #include "mlir/Translation.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" #include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h" @@ -86,6 +87,9 @@ static LogicalResult MlirToGraphdefTranslateFunction( } static TranslateFromMLIRRegistration mlir_to_graphdef_translate( - "mlir-to-graphdef", MlirToGraphdefTranslateFunction); + "mlir-to-graphdef", MlirToGraphdefTranslateFunction, + [](DialectRegistry& registry) { + mlir::RegisterAllTensorFlowDialects(registry); + }); } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/translate/translate_tf_dialect_op.cc b/tensorflow/compiler/mlir/tensorflow/translate/translate_tf_dialect_op.cc index 5236bdeffbf..22e6559a0f2 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/translate_tf_dialect_op.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/translate_tf_dialect_op.cc @@ -20,6 +20,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Module.h" // from @llvm-project #include "mlir/Translation.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h" namespace mlir { @@ -67,6 +68,7 @@ static LogicalResult MlirToTfNodeDef(ModuleOp module, // Test only translation to convert a simple MLIR module with a single TF // dialect op to NodeDef. static TranslateFromMLIRRegistration translate_from_mlir_registration( - "test-only-mlir-to-tf-nodedef", MlirToTfNodeDef); + "test-only-mlir-to-tf-nodedef", MlirToTfNodeDef, + mlir::RegisterAllTensorFlowDialects); } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.cc b/tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.cc new file mode 100644 index 00000000000..4792e220b17 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.cc @@ -0,0 +1,80 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.h" + +namespace tensorflow { + +Status GenerateResourceSharedNameIfEmpty(Graph& graph, + FunctionLibraryDefinition& flib_def) { + auto is_resource_op_with_empty_shared_name = [](const NodeDef& node_def, + const OpDef& op_def) { + // Only upgrade when it is a resource handle op. + if (op_def.output_arg().size() != 1 || + op_def.output_arg(0).type() != tensorflow::DT_RESOURCE) + return false; + + // If the OpDef has "use_node_name_sharing" field, then it is valid to use + // node names as shared names. + if (!std::any_of(op_def.attr().begin(), op_def.attr().end(), + [](const auto& attr_def) { + return attr_def.name() == "use_node_name_sharing" && + attr_def.type() == "bool"; + })) + return false; + + if (!std::any_of(op_def.attr().begin(), op_def.attr().end(), + [](const auto& attr_def) { + return attr_def.name() == "shared_name" && + attr_def.type() == "string"; + })) + return false; + + auto iter = node_def.attr().find("shared_name"); + if (iter == node_def.attr().end()) return true; + return iter->second.s().empty(); + }; + + // Upgrade nodes in the graph. + for (auto* node : graph.nodes()) { + if (is_resource_op_with_empty_shared_name(node->def(), node->op_def())) { + node->AddAttr("shared_name", node->name()); + } + } + + // Upgrade nodes in the functions. + auto func_names = flib_def.ListFunctionNames(); + for (const auto& func_name : func_names) { + const FunctionDef* orig = flib_def.Find(func_name); + DCHECK(orig); + auto copy = *orig; + for (auto& node_def : *copy.mutable_node_def()) { + const OpDef* op_def = nullptr; + TF_RETURN_IF_ERROR(flib_def.LookUpOpDef(node_def.op(), &op_def)); + if (is_resource_op_with_empty_shared_name(node_def, *op_def)) { + // Use the concat of function name and node name for such ops in a + // function as the shared_name. "@" is used as the separator because it + // is not allowed in the function name or the node name. + (*node_def.mutable_attr())["shared_name"].set_s( + absl::StrCat(node_def.name(), "@", func_name)); + } + } + TF_RETURN_IF_ERROR(flib_def.ReplaceFunction(func_name, copy)); + } + + return tensorflow::Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.h b/tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.h new file mode 100644 index 00000000000..3502572c410 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.h @@ -0,0 +1,32 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_UPGRADE_GRAPH_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_UPGRADE_GRAPH_H_ + +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { + +// Generate the shared_name for resource handle ops in the graph and functions +// if their shared_names are empty. Resource handle ops with empty shared_name +// may have undesired semantics. +Status GenerateResourceSharedNameIfEmpty(Graph& graph, + FunctionLibraryDefinition& flib_def); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_UPGRADE_GRAPH_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc index 0dbda2e4f9c..b55a5aa5243 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h" #include "absl/types/optional.h" +#include "absl/types/variant.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" @@ -31,9 +32,6 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/OpDefinition.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project -#include "mlir/Parser.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/register.h" @@ -51,12 +49,14 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h" #include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h" #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h" #include "tensorflow/compiler/mlir/xla/transforms/passes.h" #include "tensorflow/compiler/mlir/xla/type_to_shape.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/xla/service/hlo_sharding.h" +#include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/platform/logging.h" @@ -64,34 +64,19 @@ limitations under the License. namespace tensorflow { namespace { -// Parses the MLIR module from the mlir_module_string. -Status ParseMlirModule(llvm::StringRef mlir_module_string, - mlir::MLIRContext* mlir_context, - mlir::OwningModuleRef* mlir_module) { - TF_RET_CHECK(!mlir_module_string.empty()) - << "unexpected empty serialized MLIR module string"; - TF_RET_CHECK(mlir_module) << "unexpected null MLIR module pointer"; - - // Make sure we catch any error reported by MLIR and forward it to the TF - // error reporting system. - mlir::StatusScopedDiagnosticHandler error_handler(mlir_context); - - // Parse the module. - *mlir_module = mlir::parseSourceString(mlir_module_string, mlir_context); - if (!*mlir_module) { - return error_handler.Combine( - errors::InvalidArgument("could not parse MLIR module")); +// Extracts shape from XlaArgument as TensorShape. If shape is a xla::Shape, +// that is converted to a TensorShape. +StatusOr GetTensorShapeFromXlaArgument(const XlaArgument& arg) { + if (absl::holds_alternative(arg.shape)) { + TensorShape arg_shape; + TF_RETURN_IF_ERROR( + XLAShapeToTensorShape(absl::get(arg.shape), &arg_shape)); + return arg_shape; + } else { + return absl::get(arg.shape); } - - return Status::OK(); } -// Arguments to a computation can be either a tensor or resource. -struct TensorOrResourceShape { - TensorShape shape; - bool is_resource = false; -}; - // Converts arg_shapes to xla::Shape's and store into xla_input_shapes. Status GetXlaInputShapes( mlir::ModuleOp module, llvm::ArrayRef arg_shapes, @@ -285,52 +270,67 @@ static void RegisterDialects(mlir::DialectRegistry& registry) { } // namespace +void CreateConvertMlirToXlaHloPipeline( + mlir::OpPassManager& pm, llvm::StringRef device_type, + llvm::MutableArrayRef> + custom_legalization_passes) { + pm.addPass(mlir::TF::CreateTFFunctionalControlFlowToRegions()); + pm.addNestedPass(mlir::createCanonicalizerPass()); + pm.addPass(mlir::TF::CreateTensorListOpsDecompositionPass()); + + // TODO(b/159127949): Stack and TensorArray decomposition passes do not handle + // region based control flow yet. So convert back to functional control flow. + pm.addPass(mlir::TF::CreateTFRegionControlFlowToFunctional()); + pm.addPass(mlir::TF::CreateStackOpsDecompositionPass()); + pm.addPass(mlir::TF::CreateTensorArrayOpsDecompositionPass()); + pm.addPass(mlir::TFDevice::CreateDecomposeResourceOpsPass()); + pm.addPass(mlir::TF::CreatePromoteResourcesToArgsPass()); + pm.addPass(mlir::createSymbolDCEPass()); + // Guarantee all functions have one use, which enables shape inference. + pm.addPass(mlir::TF::CreateGuaranteeAllFuncsOneUsePass()); + pm.addPass(mlir::TF::CreateTFShapeInferencePass()); + // LegalizeTFControlFlow encapsulates arguments for control flow operations + // with a tuple argument which break the assumption of resource lifting + // inside PromoteResourcesToArgs. + pm.addPass(mlir::mhlo::createLegalizeTFControlFlowPass()); + + pm.addNestedPass(mlir::mhlo::createLegalizeTFPass( + /*allow_partial_conversion=*/true, /*legalize_chlo=*/true, + /*tf2xla_fallback_device_type=*/device_type)); + for (auto& target_pass : custom_legalization_passes) { + pm.addNestedPass(std::move(target_pass)); + } + pm.addPass(mlir::mhlo::CreateLegalizeTFCommunicationPass()); + pm.addNestedPass(mlir::createCanonicalizerPass()); + // Run shape inference pass to propagate shapes through tensor_cast operations + // from static to dynamic shapes. This could be generated if the shape + // inference was originally missing in a TF op but the corresponding HLO op + // had static shape after lowering. + pm.addPass(mlir::TF::CreateTFShapeInferencePass()); + // Run LegalizeTFPass again because the previous legalization passes can + // expose more graph pruning and canonicalization opportunities that are + // necessary for the second LegalizeTFPass(allow_partial_conversion=false) + // invocation. + pm.addNestedPass(mlir::mhlo::createLegalizeTFPass( + /*allow_partial_conversion=*/false, /*legalize_chlo=*/true, + /*tf2xla_fallback_device_type=*/device_type)); + // In order to export to XLA, we must sink constants to control flow regions, + // since XLA uses functional control flow. + pm.addNestedPass( + mlir::mhlo::createSinkConstantsToControlFlowPass()); +} + Status ConvertMLIRToXlaComputation( mlir::ModuleOp module_op, llvm::StringRef device_type, xla::XlaComputation* xla_computation, bool use_tuple_args, bool return_tuple, const XlaHelpers::ShapeRepresentationFn shape_representation_fn, - std::vector> custom_legalization_passes) { + llvm::MutableArrayRef> + custom_legalization_passes) { mlir::PassManager tf2xla(module_op.getContext()); - tf2xla.addNestedPass(mlir::createCanonicalizerPass()); - tf2xla.addPass(mlir::TF::CreateTensorListOpsDecompositionPass()); - tf2xla.addPass(mlir::TF::CreateStackOpsDecompositionPass()); - tf2xla.addPass(mlir::TF::CreateTensorArrayOpsDecompositionPass()); - tf2xla.addPass(mlir::TFDevice::CreateDecomposeResourceOpsPass()); - tf2xla.addPass(mlir::TF::CreatePromoteResourcesToArgsPass()); - tf2xla.addPass(mlir::createSymbolDCEPass()); - // Guarantee all functions have one use, which enables shape inference. - tf2xla.addPass(mlir::TF::CreateGuaranteeAllFuncsOneUsePass()); - tf2xla.addPass(mlir::TF::CreateTFShapeInferencePass()); - // LegalizeTFControlFlow encapsulates arguments for control flow operations - // with a tuple argument which break the assumption of resource lifting - // inside PromoteResourcesToArgs. - tf2xla.addPass(mlir::mhlo::createLegalizeTFControlFlowPass()); - - tf2xla.addNestedPass(mlir::mhlo::createLegalizeTFPass( - /*allow_partial_conversion=*/true, /*legalize_chlo=*/true, - /*tf2xla_fallback_device_type=*/device_type)); - for (auto& target_pass : custom_legalization_passes) { - tf2xla.addNestedPass(std::move(target_pass)); - } - tf2xla.addPass(mlir::mhlo::CreateLegalizeTFCommunicationPass()); - tf2xla.addNestedPass(mlir::createCanonicalizerPass()); - // Run shape inference pass to propagate shapes through tensor_cast operations - // from static to dynamic shapes. This could be generated if the shape - // inference was originally missing in a TF op but the corresponding HLO op - // had static shape after lowering. - tf2xla.addPass(mlir::TF::CreateTFShapeInferencePass()); - // Run LegalizeTFPass again because the previous legalization passes can - // expose more graph pruning and canonicalization opportunities that are - // necessary for the second LegalizeTFPass(allow_partial_conversion=false) - // invocation. - tf2xla.addNestedPass(mlir::mhlo::createLegalizeTFPass( - /*allow_partial_conversion=*/false, /*legalize_chlo=*/true, - /*tf2xla_fallback_device_type=*/device_type)); - // In order to export to XLA, we must sink constants to control flow regions, - // since XLA uses functional control flow. - tf2xla.addNestedPass( - mlir::mhlo::createSinkConstantsToControlFlowPass()); + applyTensorflowAndCLOptions(tf2xla); + CreateConvertMlirToXlaHloPipeline(tf2xla, device_type, + custom_legalization_passes); if (VLOG_IS_ON(1)) { // Print the whole module after each pass which requires disabling @@ -361,12 +361,13 @@ Status ConvertMLIRToXlaComputation( return Status::OK(); } -static Status CompileMlirToXlaHlo( +Status CompileMlirToXlaHlo( mlir::ModuleOp module_op, llvm::ArrayRef arg_shapes, - llvm::StringRef device_type, bool use_tuple_args, + llvm::StringRef device_type, bool use_tuple_args, bool use_return_tuple, XlaHelpers::ShapeRepresentationFn shape_representation_fn, XlaCompilationResult* compilation_result, - std::vector> custom_legalization_passes) { + llvm::MutableArrayRef> + custom_legalization_passes) { if (VLOG_IS_ON(1)) tensorflow::DumpMlirOpToFile("mlir_compile_before", module_op); @@ -383,9 +384,8 @@ static Status CompileMlirToXlaHlo( compilation_result->computation = std::make_shared(); TF_RETURN_IF_ERROR(ConvertMLIRToXlaComputation( module_op, device_type, compilation_result->computation.get(), - use_tuple_args, - /*return_tuple=*/true, shape_representation_fn, - std::move(custom_legalization_passes))); + use_tuple_args, use_return_tuple, shape_representation_fn, + custom_legalization_passes)); // Construct mapping from XlaComputation's arg to input edges of execute // node. @@ -412,21 +412,22 @@ Status CompileSerializedMlirToXlaHlo( llvm::StringRef device_type, bool use_tuple_args, const XlaHelpers::ShapeRepresentationFn shape_representation_fn, XlaCompilationResult* compilation_result, - std::vector> custom_legalization_passes) { + llvm::MutableArrayRef> + custom_legalization_passes) { mlir::MLIRContext mlir_context; RegisterDialects(mlir_context.getDialectRegistry()); mlir::OwningModuleRef mlir_module; TF_RETURN_IF_ERROR( - ParseMlirModule(mlir_module_string, &mlir_context, &mlir_module)); + DeserializeMlirModule(mlir_module_string, &mlir_context, &mlir_module)); llvm::SmallVector tensor_or_resource_shapes; tensor_or_resource_shapes.reserve(arg_shapes.size()); for (const auto& arg_shape : arg_shapes) tensor_or_resource_shapes.push_back({arg_shape}); return CompileMlirToXlaHlo(mlir_module.get(), tensor_or_resource_shapes, device_type, use_tuple_args, - shape_representation_fn, compilation_result, - std::move(custom_legalization_passes)); + /*use_return_tuple=*/true, shape_representation_fn, + compilation_result, custom_legalization_passes); } // Rewrites the given module with specified args. For each of the constant args, @@ -434,8 +435,8 @@ Status CompileSerializedMlirToXlaHlo( // removed from the signature. For resource args, their subtypes are populated. // Returns the original indices for the other arguments on success. static StatusOr> RewriteWithArgs( - mlir::ModuleOp module, llvm::ArrayRef args) { - mlir::FuncOp main_fn = module.lookupSymbol("main"); + mlir::ModuleOp module_op, llvm::ArrayRef args) { + mlir::FuncOp main_fn = module_op.lookupSymbol("main"); std::vector params; bool has_resource_args = false; @@ -447,7 +448,9 @@ static StatusOr> RewriteWithArgs( if (xla_arg.kind == XlaArgument::kResource) { mlir::Type element_type; TF_RETURN_IF_ERROR(ConvertDataType(xla_arg.type, builder, &element_type)); - auto resource_shape = absl::get(xla_arg.shape).dim_sizes(); + TF_ASSIGN_OR_RETURN(TensorShape arg_shape, + GetTensorShapeFromXlaArgument(xla_arg)); + auto resource_shape = arg_shape.dim_sizes(); llvm::SmallVector resource_subtype_shape( resource_shape.begin(), resource_shape.end()); auto resource_subtype = @@ -473,7 +476,7 @@ static StatusOr> RewriteWithArgs( ConvertTensor(xla_arg.constant_value, &builder)); // TODO(hinsu): Use the actual location of the constant. auto constant = builder.create( - mlir::UnknownLoc::get(module.getContext()), value_attr); + mlir::UnknownLoc::get(module_op.getContext()), value_attr); mlir_arg.replaceAllUsesWith(constant); args_to_erase.push_back(idx); } @@ -495,16 +498,54 @@ static StatusOr> RewriteWithArgs( } Status CompileGraphToXlaHlo( - const Graph& graph, llvm::ArrayRef args, - llvm::StringRef device_type, bool use_tuple_args, - const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info, + mlir::ModuleOp module_op, llvm::ArrayRef args, + llvm::StringRef device_type, bool use_tuple_args, bool use_return_tuple, const XlaHelpers::ShapeRepresentationFn shape_representation_fn, XlaCompilationResult* compilation_result, - std::vector> custom_legalization_passes) { + llvm::MutableArrayRef> + custom_legalization_passes) { + TF_ASSIGN_OR_RETURN(std::vector remaining_params, + RewriteWithArgs(module_op, args)); + llvm::SmallVector arg_shapes; + arg_shapes.reserve(remaining_params.size()); + for (unsigned idx : remaining_params) { + const auto& arg = args[idx]; + TF_ASSIGN_OR_RETURN(TensorShape arg_shape, + GetTensorShapeFromXlaArgument(arg)); + arg_shapes.push_back({arg_shape, + /*is_resource=*/arg.kind == XlaArgument::kResource}); + } + + mlir::PassManager pm(module_op.getContext()); + applyTensorflowAndCLOptions(pm); + mlir::TF::StandardPipelineOptions tf_options; + mlir::TF::CreateTFStandardPipeline(pm, tf_options); + { + mlir::StatusScopedDiagnosticHandler diag_handler(module_op.getContext()); + if (failed(pm.run(module_op))) return diag_handler.ConsumeStatus(); + } + + auto status = CompileMlirToXlaHlo( + module_op, arg_shapes, device_type, use_tuple_args, use_return_tuple, + shape_representation_fn, compilation_result, custom_legalization_passes); + compilation_result->input_mapping = remaining_params; + return status; +} + +Status CompileGraphToXlaHlo( + const Graph& graph, llvm::ArrayRef args, + llvm::ArrayRef control_rets, llvm::StringRef device_type, + bool use_tuple_args, const FunctionLibraryDefinition& flib_def, + const GraphDebugInfo& debug_info, + const XlaHelpers::ShapeRepresentationFn shape_representation_fn, + XlaCompilationResult* compilation_result, + llvm::MutableArrayRef> + custom_legalization_passes) { mlir::MLIRContext context; RegisterDialects(context.getDialectRegistry()); GraphImportConfig config; config.graph_as_function = true; + config.control_outputs = control_rets; // Disable shape inference during import as some TensorFlow op fails during // shape inference with dynamic shaped operands. This in turn causes the // import to fail. Shape inference during import is going to be removed and @@ -515,30 +556,11 @@ Status CompileGraphToXlaHlo( ConvertGraphToMlir(graph, debug_info, flib_def, config, &context); if (!module_or.ok()) return module_or.status(); - mlir::ModuleOp module = module_or.ValueOrDie().get(); - TF_ASSIGN_OR_RETURN(std::vector remaining_params, - RewriteWithArgs(module, {args.data(), args.size()})); - llvm::SmallVector arg_shapes; - arg_shapes.reserve(remaining_params.size()); - for (unsigned idx : remaining_params) { - const auto& arg = args[idx]; - arg_shapes.push_back({absl::get(arg.shape), - /*is_resource=*/arg.kind == XlaArgument::kResource}); - } - - mlir::PassManager pm(&context); - mlir::TF::StandardPipelineOptions tf_options; - mlir::TF::CreateTFStandardPipeline(pm, tf_options); - { - mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext()); - if (failed(pm.run(module))) return diag_handler.ConsumeStatus(); - } - - auto status = CompileMlirToXlaHlo( - module, arg_shapes, device_type, use_tuple_args, shape_representation_fn, - compilation_result, std::move(custom_legalization_passes)); - compilation_result->input_mapping = remaining_params; - return status; + mlir::ModuleOp module_op = module_or.ValueOrDie().get(); + return CompileGraphToXlaHlo(module_op, args, device_type, use_tuple_args, + /*use_return_tuple=*/true, + shape_representation_fn, compilation_result, + custom_legalization_passes); } } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h index 5c64a65ecbd..40230de406b 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h @@ -16,10 +16,13 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_COMPILE_MLIR_UTIL_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_COMPILE_MLIR_UTIL_H_ +#include + #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/StringRef.h" #include "mlir/IR/Module.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project #include "tensorflow/compiler/tf2xla/xla_argument.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/xla/client/xla_computation.h" @@ -30,6 +33,14 @@ limitations under the License. namespace tensorflow { +// Populates the supplied passmanager with the passes required to run the +// TF MLIR to XLA HLO MLIR conversion/legalization. Custom legalization passes +// can be populated in `custom_legalization_passes`. +void CreateConvertMlirToXlaHloPipeline( + mlir::OpPassManager& pm, llvm::StringRef device_type, + llvm::MutableArrayRef> + custom_legalization_passes); + // Lowers MLIR module to XLA HLO inside an XlaComputation. The input module // should only contain operations in tf dialect. If the input module contains // operation in the tf_executor dialect, for example, returns an error. @@ -61,7 +72,24 @@ Status ConvertMLIRToXlaComputation( xla::XlaComputation* xla_computation, bool use_tuple_args, bool return_tuple, const XlaHelpers::ShapeRepresentationFn shape_representation_fn = nullptr, - std::vector> custom_legalization_passes = {}); + llvm::MutableArrayRef> + custom_legalization_passes = {}); + +// Helper struct representing argument tensor or resource handle shapes. +struct TensorOrResourceShape { + TensorShape shape; + bool is_resource = false; +}; + +// Compiles a MLIR module into XLA HLO, generates all accompanying metadata and +// stores them in CompilationResult. +Status CompileMlirToXlaHlo( + mlir::ModuleOp module_op, llvm::ArrayRef arg_shapes, + llvm::StringRef device_type, bool use_tuple_args, bool use_return_tuple, + XlaHelpers::ShapeRepresentationFn shape_representation_fn, + XlaCompilationResult* compilation_result, + llvm::MutableArrayRef> + custom_legalization_passes); // Compiles a serialized MLIR module into XLA HLO, generates all accompanying // metadata and stores them in CompilationResult. @@ -70,17 +98,33 @@ Status CompileSerializedMlirToXlaHlo( llvm::StringRef device_type, bool use_tuple_args, const XlaHelpers::ShapeRepresentationFn shape_representation_fn, XlaCompilationResult* compilation_result, - std::vector> custom_legalization_passes = {}); + llvm::MutableArrayRef> + custom_legalization_passes = {}); -// Same as the above but takes input as TensorFlow Graph. -// TODO(lyandy): Allow populating of targets/control outputs. +// Compiles a TensorFlow Graph (already converted to MLIR, imported with +// tf_executor dialect still present) into XLA HLO, generates all accompanying +// metadata and stores them in CompilationResult. This will rewrite arguments +// and run the TensorFlow standard pipeline prior to invoking +// `CompileMlirToXlaHlo`. Status CompileGraphToXlaHlo( - const Graph& graph, llvm::ArrayRef args, - llvm::StringRef device_type, bool use_tuple_args, - const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info, + mlir::ModuleOp module_op, llvm::ArrayRef args, + llvm::StringRef device_type, bool use_tuple_args, bool use_return_tuple, const XlaHelpers::ShapeRepresentationFn shape_representation_fn, XlaCompilationResult* compilation_result, - std::vector> custom_legalization_passes = {}); + llvm::MutableArrayRef> + custom_legalization_passes); + +// Compiles a TensorFlow Graph into XLA HLO, generates all accompanying metadata +// and stores them in CompilationResult. +Status CompileGraphToXlaHlo( + const Graph& graph, llvm::ArrayRef args, + llvm::ArrayRef control_rets, llvm::StringRef device_type, + bool use_tuple_args, const FunctionLibraryDefinition& flib_def, + const GraphDebugInfo& debug_info, + const XlaHelpers::ShapeRepresentationFn shape_representation_fn, + XlaCompilationResult* compilation_result, + llvm::MutableArrayRef> + custom_legalization_passes = {}); } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_pass.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_pass.cc new file mode 100644 index 00000000000..57267ff027f --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_pass.cc @@ -0,0 +1,31 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h" + +namespace { +void CreateConvertMlirToXlaHloPipelineWithDefaults(mlir::OpPassManager& pm) { + tensorflow::CreateConvertMlirToXlaHloPipeline( + pm, /*device_type=*/"XLA_CPU_JIT", + /*custom_legalization_passes=*/{}); +} + +mlir::PassPipelineRegistration<> pipeline( + "tf-to-hlo-pipeline", + "Convert TF dialect to HLO dialect (used for compilation in bridge).", + CreateConvertMlirToXlaHloPipelineWithDefaults); +} // anonymous namespace diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc deleted file mode 100644 index 80e2c1132fd..00000000000 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc +++ /dev/null @@ -1,542 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h" - -#include "tensorflow/cc/framework/scope.h" -#include "tensorflow/cc/ops/function_ops.h" -#include "tensorflow/cc/ops/resource_variable_ops.h" -#include "tensorflow/compiler/tf2xla/shape_util.h" -#include "tensorflow/compiler/tf2xla/xla_compiler.h" -#include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/framework/tensor_testutil.h" -#include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/graph/testlib.h" -#include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/stream_executor/lib/statusor.h" - -namespace tensorflow { -namespace { - -// A dummy shape representation function that simply converts given shape into -// an xla::Shape without assigning any layouts. -xla::StatusOr TestShapeRepresentation(const TensorShape& shape, - DataType type, - bool use_fast_memory) { - xla::Shape xla_shape; - TF_RETURN_IF_ERROR(TensorShapeToXLAShape(type, shape, &xla_shape)); - return xla_shape; -} - -TEST(CompileSerializedMlirToXlaHloTest, InvalidSerializedMlirModule) { - constexpr char invalid_mlir_module[] = - "totally @invalid MLIR module {here} <-"; - std::vector arg_shapes; - XlaCompiler::CompilationResult compilation_result; - - Status s = CompileSerializedMlirToXlaHlo( - invalid_mlir_module, arg_shapes, "XLA_CPU_JIT", - /*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result); - EXPECT_EQ(s.code(), tensorflow::errors::Code::INVALID_ARGUMENT); - EXPECT_EQ(s.ToString(), - "Invalid argument: could not parse MLIR module-:1:1: error: " - "custom op 'totally' is unknown\n"); -} - -constexpr llvm::StringRef kBinaryAddModule = R"( - module attributes {tf.versions = {producer = 179 : i32}} { - func @main(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = "tf.AddV2"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", name = "add"} : (tensor, tensor) -> tensor - return %0 : tensor - } - } -)"; - -TEST(CompileSerializedMlirToXlaHloTest, TupleArgs) { - std::vector arg_shapes(2, TensorShape()); - XlaCompiler::CompilationResult compilation_result; - - Status s = CompileSerializedMlirToXlaHlo( - kBinaryAddModule, arg_shapes, "XLA_CPU_JIT", - /*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result); - TF_ASSERT_OK(s); - - const xla::HloModuleConfig module_config( - compilation_result.computation->GetProgramShape().ValueOrDie()); - auto status_or_hlo_module = xla::HloModule::CreateFromProto( - compilation_result.computation->proto(), module_config); - TF_ASSERT_OK(status_or_hlo_module.status()); - constexpr char expected_hlo_module_string[] = R"(HloModule main.6 - -ENTRY %main.6 (arg_tuple.1: (f32[], f32[])) -> (f32[]) { - %arg_tuple.1 = (f32[], f32[]) parameter(0) - %get-tuple-element.2 = f32[] get-tuple-element((f32[], f32[]) %arg_tuple.1), index=0 - %get-tuple-element.3 = f32[] get-tuple-element((f32[], f32[]) %arg_tuple.1), index=1 - %add.4 = f32[] add(f32[] %get-tuple-element.2, f32[] %get-tuple-element.3) - ROOT %tuple.5 = (f32[]) tuple(f32[] %add.4) -} - -)"; - EXPECT_EQ(expected_hlo_module_string, - status_or_hlo_module.ValueOrDie()->ToString()); - - // Expect an in order input mapping. - EXPECT_EQ(compilation_result.input_mapping, std::vector({0, 1})); - - // Expect a single tuple-shape, containing two F32 scalars. - EXPECT_EQ(compilation_result.xla_input_shapes.size(), 1); - xla::Shape expected_input_shape = - xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::F32, {}), - xla::ShapeUtil::MakeShape(xla::F32, {})}); - EXPECT_EQ(compilation_result.xla_input_shapes.front(), expected_input_shape); - - // Expect output shape is a tuple shape containing a single F32 Scalar type. - const xla::Shape output_shape = - xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {}); - const xla::Shape tuple_output_shape = - xla::ShapeUtil::MakeTupleShape({output_shape}); - EXPECT_EQ(compilation_result.xla_output_shape, tuple_output_shape); - - // Expect exactly 1 OutputDescription. - EXPECT_EQ(compilation_result.outputs.size(), 1); - const XlaCompiler::OutputDescription& output_desc = - compilation_result.outputs.front(); - EXPECT_EQ(output_desc.type, DataType::DT_FLOAT); - EXPECT_EQ(output_desc.shape, TensorShape()); - EXPECT_FALSE(output_desc.is_constant); - EXPECT_FALSE(output_desc.is_tensor_list); - - // Expect no resource updates from computation. - EXPECT_TRUE(compilation_result.resource_updates.empty()); -} - -TEST(CompileSerializedMlirToXlaHloTest, IndividualArgs) { - std::vector arg_shapes(2, TensorShape()); - XlaCompiler::CompilationResult compilation_result; - - Status s = CompileSerializedMlirToXlaHlo( - kBinaryAddModule, arg_shapes, "XLA_CPU_JIT", - /*use_tuple_args=*/false, TestShapeRepresentation, &compilation_result); - TF_ASSERT_OK(s); - - const xla::HloModuleConfig module_config( - compilation_result.computation->GetProgramShape().ValueOrDie()); - auto status_or_hlo_module = xla::HloModule::CreateFromProto( - compilation_result.computation->proto(), module_config); - TF_ASSERT_OK(status_or_hlo_module.status()); - constexpr char expected_hlo_module_string[] = R"(HloModule main.5 - -ENTRY %main.5 (Arg_0.1: f32[], Arg_1.2: f32[]) -> (f32[]) { - %Arg_0.1 = f32[] parameter(0) - %Arg_1.2 = f32[] parameter(1) - %add.3 = f32[] add(f32[] %Arg_0.1, f32[] %Arg_1.2) - ROOT %tuple.4 = (f32[]) tuple(f32[] %add.3) -} - -)"; - EXPECT_EQ(expected_hlo_module_string, - status_or_hlo_module.ValueOrDie()->ToString()); - - // Expect an in order input mapping. - EXPECT_EQ(compilation_result.input_mapping, std::vector({0, 1})); - - // Expect two inputs, each containing a F32 scalar. - EXPECT_EQ(compilation_result.xla_input_shapes.size(), 2); - xla::Shape expected_input_shape = xla::ShapeUtil::MakeShape(xla::F32, {}); - EXPECT_EQ(compilation_result.xla_input_shapes[0], expected_input_shape); - EXPECT_EQ(compilation_result.xla_input_shapes[1], expected_input_shape); - - // Expect output shape is a tuple shape containing a single F32 Scalar type. - const xla::Shape output_shape = - xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {}); - const xla::Shape tuple_output_shape = - xla::ShapeUtil::MakeTupleShape({output_shape}); - EXPECT_EQ(compilation_result.xla_output_shape, tuple_output_shape); - - // Expect exactly 1 OutputDescription. - EXPECT_EQ(compilation_result.outputs.size(), 1); - const XlaCompiler::OutputDescription& output_desc = - compilation_result.outputs.front(); - EXPECT_EQ(output_desc.type, DataType::DT_FLOAT); - EXPECT_EQ(output_desc.shape, TensorShape()); - EXPECT_FALSE(output_desc.is_constant); - EXPECT_FALSE(output_desc.is_tensor_list); - - // Expect no resource updates from computation. - EXPECT_TRUE(compilation_result.resource_updates.empty()); -} - -// Tests that foldable ops are constant-folded to enable legalization of ops -// that require compile time constant operand. -TEST(CompileSerializedMlirToXlaHloTest, CompileTimeConstantFoldedSuccess) { - // "tf.Shape" can only be folded away after shape inference. tf.Reshape can - // only be lowered when tf.Shape is folded into a constant. - constexpr char mlir_module[] = R"( - module attributes {tf.versions = {producer = 179 : i32}} { - func @main(%arg0: tensor<10x19xf32>, %arg1: tensor<19x10xf32> {mhlo.is_same_data_across_replicas}) -> tensor<10x19xf32> { - %0 = "tf.Shape"(%arg0) : (tensor<10x19xf32>) -> tensor<2xi64> - %1 = "tf.Reshape"(%arg1, %0) : (tensor<19x10xf32>, tensor<2xi64>) -> tensor<10x19xf32> - return %1 : tensor<10x19xf32> - } - } - )"; - - std::vector arg_shapes{TensorShape({10, 19}), - TensorShape({19, 10})}; - XlaCompiler::CompilationResult compilation_result; - - Status s = CompileSerializedMlirToXlaHlo( - mlir_module, arg_shapes, "XLA_CPU_JIT", - /*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result); - TF_ASSERT_OK(s); - - const xla::HloModuleConfig module_config( - compilation_result.computation->GetProgramShape().ValueOrDie()); - auto status_or_hlo_module = xla::HloModule::CreateFromProto( - compilation_result.computation->proto(), module_config); - TF_ASSERT_OK(status_or_hlo_module.status()); - constexpr char expected_hlo_module_string[] = R"(HloModule main.6 - -ENTRY %main.6 (arg_tuple.1: (f32[10,19], f32[19,10])) -> (f32[10,19]) { - %arg_tuple.1 = (f32[10,19]{1,0}, f32[19,10]{1,0}) parameter(0), parameter_replication={false,true} - %get-tuple-element.2 = f32[10,19]{1,0} get-tuple-element((f32[10,19]{1,0}, f32[19,10]{1,0}) %arg_tuple.1), index=0 - %get-tuple-element.3 = f32[19,10]{1,0} get-tuple-element((f32[10,19]{1,0}, f32[19,10]{1,0}) %arg_tuple.1), index=1 - %reshape.4 = f32[10,19]{1,0} reshape(f32[19,10]{1,0} %get-tuple-element.3) - ROOT %tuple.5 = (f32[10,19]{1,0}) tuple(f32[10,19]{1,0} %reshape.4) -} - -)"; - EXPECT_EQ(expected_hlo_module_string, - status_or_hlo_module.ValueOrDie()->ToString()); -} - -TEST(CompileSerializedMlirToXlaHloTest, ShapeInference) { - constexpr char mlir_module[] = R"( - module attributes {tf.versions = {producer = 179 : i32}} { - func @main(%arg0: tensor<*xf32>, %arg1: tensor) -> tensor { - %0 = "tf.MatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", transpose_a = false, transpose_b = false} : (tensor<*xf32>, tensor) -> tensor - return %0 : tensor - } - } - )"; - - std::vector arg_shapes{TensorShape({10, 17}), - TensorShape({17, 19})}; - XlaCompiler::CompilationResult compilation_result; - - Status s = CompileSerializedMlirToXlaHlo( - mlir_module, arg_shapes, "XLA_CPU_JIT", - /*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result); - TF_ASSERT_OK(s); - - const xla::HloModuleConfig module_config( - compilation_result.computation->GetProgramShape().ValueOrDie()); - auto status_or_hlo_module = xla::HloModule::CreateFromProto( - compilation_result.computation->proto(), module_config); - TF_ASSERT_OK(status_or_hlo_module.status()); - - constexpr char expected_signature[] = - R"((arg_tuple.1: (f32[10,17], f32[17,19])) -> (f32[10,19]))"; - EXPECT_THAT(status_or_hlo_module.ValueOrDie()->ToString(), - ::testing::HasSubstr(expected_signature)); -} - -TEST(CompileSerializedMlirToXlaHloTest, ShapeInferenceAfterLegalization) { - constexpr char mlir_module[] = R"( - module attributes {tf.versions = {producer = 179 : i32}} { - func @main(%arg0: tensor<8x16x16x64xbf16>, %arg1: tensor<64xf32>) -> (tensor<8x16x16x64xbf16>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<*xf32>) { - %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg1, %arg1, %arg1) {data_format = "NHWC", device = "", epsilon = 9.99999974E-5 : f32, exponential_avg_factor = 1.000000e+00 : f32, is_training = false} : (tensor<8x16x16x64xbf16>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) -> (tensor<8x16x16x64xbf16>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<*xf32>) - return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5 : tensor<8x16x16x64xbf16>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<*xf32> - } - } - )"; - - std::vector arg_shapes{TensorShape({8, 16, 16, 64}), - TensorShape({64})}; - XlaCompiler::CompilationResult compilation_result; - - Status s = CompileSerializedMlirToXlaHlo( - mlir_module, arg_shapes, "XLA_CPU_JIT", - /*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result); - TF_ASSERT_OK(s); - - const xla::HloModuleConfig module_config( - compilation_result.computation->GetProgramShape().ValueOrDie()); - auto status_or_hlo_module = xla::HloModule::CreateFromProto( - compilation_result.computation->proto(), module_config); - TF_ASSERT_OK(status_or_hlo_module.status()); - - constexpr char expected_signature[] = - R"(-> (bf16[8,16,16,64], f32[64], f32[64], f32[64], f32[64], f32[0]))"; - EXPECT_THAT(status_or_hlo_module.ValueOrDie()->ToString(), - ::testing::HasSubstr(expected_signature)); -} - -TEST(CompileSerializedMlirToXlaHloTest, ConstantFoldHook) { - constexpr char mlir_module[] = R"( -module attributes {tf.versions = {producer = 179 : i32}} { - func @main() -> (tensor<0xi32>, tensor<0xi32>) { - %0 = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32> - %r0, %r1 = "tf.BroadcastGradientArgs"(%0, %0) {T = i32} : (tensor<0xi32>, tensor<0xi32>) -> (tensor<0xi32>, tensor<0xi32>) - return %r0, %r1 : tensor<0xi32>, tensor<0xi32> - } -} -)"; - - std::vector arg_shapes(2, TensorShape()); - XlaCompiler::CompilationResult compilation_result; - - Status s = CompileSerializedMlirToXlaHlo( - mlir_module, arg_shapes, "XLA_CPU_JIT", - /*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result); - TF_ASSERT_OK(s); - - const xla::HloModuleConfig module_config( - compilation_result.computation->GetProgramShape().ValueOrDie()); - auto status_or_hlo_module = xla::HloModule::CreateFromProto( - compilation_result.computation->proto(), module_config); - TF_ASSERT_OK(status_or_hlo_module.status()); - constexpr char expected_hlo_module_string[] = R"(HloModule main.4 - -ENTRY %main.4 (arg_tuple.1: ()) -> (s32[0], s32[0]) { - %arg_tuple.1 = () parameter(0) - %constant.2 = s32[0]{0} constant({}) - ROOT %tuple.3 = (s32[0]{0}, s32[0]{0}) tuple(s32[0]{0} %constant.2, s32[0]{0} %constant.2) -} - -)"; - EXPECT_EQ(expected_hlo_module_string, - status_or_hlo_module.ValueOrDie()->ToString()); -} - -// The following xla::OpSharding protos are used: -// Serialized string: -// "\08\03\1A\02\01\02\22\02\00\01" -// Proto debug string: -// type: OTHER -// tile_assignment_dimensions: 1 -// tile_assignment_dimensions: 2 -// tile_assignment_devices: 0 -// tile_assignment_devices: 1 -// -// Serialized string: -// "\08\01\1A\01\01\22\01\00" -// Proto debug string: -// type: MAXIMAL -// tile_assignment_dimensions: 1 -// tile_assignment_devices: 0 -// -// Serialized string: -// "" -// Proto debug string (empty but would equivalent to): -// type: REPLICATED -TEST(CompileSerializedMlirToXlaHloTest, ArgumentSharding) { - constexpr char mlir_module[] = R"( -module attributes {tf.versions = {producer = 179 : i32}} { - func @main(%arg0: tensor<128x10xf32> {mhlo.sharding = "\08\03\1A\02\01\02\22\02\00\01"}, %arg1: tensor<10x1024xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg2: tensor<128x1024xf32> {mhlo.sharding = ""}) { - return - } -} -)"; - - std::vector arg_shapes{TensorShape({128, 10}), - TensorShape({10, 1024}), - TensorShape({128, 1024})}; - XlaCompiler::CompilationResult compilation_result; - - Status s = CompileSerializedMlirToXlaHlo( - mlir_module, arg_shapes, "XLA_CPU_JIT", - /*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result); - TF_ASSERT_OK(s); - - const xla::HloModuleConfig module_config( - compilation_result.computation->GetProgramShape().ValueOrDie()); - auto status_or_hlo_module = xla::HloModule::CreateFromProto( - compilation_result.computation->proto(), module_config); - TF_ASSERT_OK(status_or_hlo_module.status()); - constexpr char expected_hlo_module_string[] = R"(HloModule main.6 - -ENTRY %main.6 (arg_tuple.1: (f32[128,10], f32[10,1024], f32[128,1024])) -> () { - %arg_tuple.1 = (f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) parameter(0), sharding={{devices=[1,2]0,1}, {maximal device=0}, {replicated}} - %get-tuple-element.2 = f32[128,10]{1,0} get-tuple-element((f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) %arg_tuple.1), index=0 - %get-tuple-element.3 = f32[10,1024]{1,0} get-tuple-element((f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) %arg_tuple.1), index=1 - %get-tuple-element.4 = f32[128,1024]{1,0} get-tuple-element((f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) %arg_tuple.1), index=2 - ROOT %tuple.5 = () tuple() -} - -)"; - EXPECT_EQ(expected_hlo_module_string, - status_or_hlo_module.ValueOrDie()->ToString()); -} - -TEST(CompileSerializedMlirToXlaHloTest, BadArgumentSharding) { - constexpr char mlir_module[] = R"( -module attributes {tf.versions = {producer = 179 : i32}} { - func @main(%arg0: tensor<128x10xf32> {mhlo.sharding = "bad_sharding"}) { - return - } -} -)"; - - std::vector arg_shapes{TensorShape({128, 10})}; - XlaCompiler::CompilationResult compilation_result; - - Status s = CompileSerializedMlirToXlaHlo( - mlir_module, arg_shapes, "XLA_CPU_JIT", - /*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result); - ASSERT_FALSE(s.ok()); - EXPECT_EQ(s.error_message(), - "failed to parse argument sharding 0 'bad_sharding'"); -} - -TEST(CompileSerializedMlirToXlaHloTest, ResultSharding) { - constexpr char mlir_module[] = R"( -module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 351 : i32}} { - func @main(%arg0: tensor<128x10xf32>, %arg1: tensor<10x1024xf32>, %arg2: tensor<128x1024xf32>) -> (tensor<128x10xf32> {mhlo.sharding = "\08\03\1A\02\01\02\22\02\00\01"}, tensor<10x1024xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<128x1024xf32> {mhlo.sharding = ""}) { - return %arg0, %arg1, %arg2 : tensor<128x10xf32>, tensor<10x1024xf32>, tensor<128x1024xf32> - } -} -)"; - - std::vector arg_shapes{TensorShape({128, 10}), - TensorShape({10, 1024}), - TensorShape({128, 1024})}; - XlaCompiler::CompilationResult compilation_result; - - Status s = CompileSerializedMlirToXlaHlo( - mlir_module, arg_shapes, "XLA_CPU_JIT", - /*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result); - TF_ASSERT_OK(s); - - const xla::HloModuleConfig module_config( - compilation_result.computation->GetProgramShape().ValueOrDie()); - auto status_or_hlo_module = xla::HloModule::CreateFromProto( - compilation_result.computation->proto(), module_config); - TF_ASSERT_OK(status_or_hlo_module.status()); - constexpr char expected_hlo_module_string[] = R"(HloModule main.9 - -ENTRY %main.9 (arg_tuple.1: (f32[128,10], f32[10,1024], f32[128,1024])) -> (f32[128,10], f32[10,1024], f32[128,1024]) { - %arg_tuple.1 = (f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) parameter(0) - %get-tuple-element.2 = f32[128,10]{1,0} get-tuple-element((f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) %arg_tuple.1), index=0 - %reshape.5 = f32[128,10]{1,0} reshape(f32[128,10]{1,0} %get-tuple-element.2) - %get-tuple-element.3 = f32[10,1024]{1,0} get-tuple-element((f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) %arg_tuple.1), index=1 - %reshape.6 = f32[10,1024]{1,0} reshape(f32[10,1024]{1,0} %get-tuple-element.3) - %get-tuple-element.4 = f32[128,1024]{1,0} get-tuple-element((f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) %arg_tuple.1), index=2 - %reshape.7 = f32[128,1024]{1,0} reshape(f32[128,1024]{1,0} %get-tuple-element.4) - ROOT %tuple.8 = (f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) tuple(f32[128,10]{1,0} %reshape.5, f32[10,1024]{1,0} %reshape.6, f32[128,1024]{1,0} %reshape.7), sharding={{devices=[1,2]0,1}, {maximal device=0}, {replicated}} -} - -)"; - EXPECT_EQ(expected_hlo_module_string, - status_or_hlo_module.ValueOrDie()->ToString()); -} - -// Verify that conversion from Graph to MLIR and empty shape representation -// function is successful. -TEST(CompileGraphToXlaHlo, Basic) { - FunctionLibraryDefinition flib_def(OpRegistry::Global(), {}); - Graph graph(OpRegistry::Global()); - - Node* arg = test::graph::Arg(&graph, 0, DT_FLOAT); - test::graph::Retval(&graph, 0, arg); - - XlaCompiler::CompilationResult result; - XlaCompiler::Argument compiler_arg; - compiler_arg.kind = XlaCompiler::Argument::kParameter; - compiler_arg.shape = TensorShape(); - - TF_ASSERT_OK( - CompileGraphToXlaHlo(graph, /*args=*/{compiler_arg}, "XLA_CPU_JIT", - /*use_tuple_args=*/false, flib_def, GraphDebugInfo(), - /*shape_representation_fn=*/nullptr, &result)); - - const xla::HloModuleConfig module_config( - result.computation->GetProgramShape().ValueOrDie()); - auto status_or_hlo_module = xla::HloModule::CreateFromProto( - result.computation->proto(), module_config); - ASSERT_TRUE(status_or_hlo_module.ok()); - - constexpr char expected_hlo_module_string[] = R"(HloModule main.3 - -ENTRY %main.3 (Arg_0.1: f32[]) -> (f32[]) { - %Arg_0.1 = f32[] parameter(0) - ROOT %tuple.2 = (f32[]) tuple(f32[] %Arg_0.1) -} - -)"; - - EXPECT_EQ(expected_hlo_module_string, - status_or_hlo_module.ValueOrDie()->ToString()); -} - -// Tests a conversion from Graph to MLIR with resource arguments. -TEST(CompileGraphToXlaHlo, Resources) { - FunctionLibraryDefinition flib_def(OpRegistry::Global(), {}); - Graph graph(OpRegistry::Global()); - - Scope scope = Scope::NewRootScope().ExitOnError(); - auto val = ops::_Arg(scope.WithOpName("arg0"), DT_FLOAT, 0); - auto var = ops::_Arg(scope.WithOpName("arg1"), DT_RESOURCE, 1); - auto assign = - ops::AssignVariableOp(scope.WithOpName("assign_variable"), var, val); - TF_ASSERT_OK(scope.ToGraph(&graph)); - - XlaCompiler::CompilationResult result; - XlaCompiler::Argument arg0; - arg0.kind = XlaCompiler::Argument::kParameter; - arg0.shape = TensorShape({2}); - XlaCompiler::Argument arg1; - arg1.kind = XlaCompiler::Argument::kResource; - arg1.shape = TensorShape({2}); - arg1.type = DT_FLOAT; - - TF_ASSERT_OK( - CompileGraphToXlaHlo(graph, /*args=*/{arg0, arg1}, "XLA_CPU_JIT", - /*use_tuple_args=*/false, flib_def, GraphDebugInfo(), - /*shape_representation_fn=*/nullptr, &result)); - - EXPECT_EQ(result.outputs.size(), 0); - ASSERT_EQ(result.resource_updates.size(), 1); - const auto& resource_update = result.resource_updates[0]; - EXPECT_EQ(resource_update.input_index, 1); - EXPECT_EQ(resource_update.modified, true); - EXPECT_EQ(resource_update.shape, TensorShape({2})); - EXPECT_EQ(resource_update.type, DT_FLOAT); - - const xla::HloModuleConfig module_config( - result.computation->GetProgramShape().ValueOrDie()); - auto status_or_hlo_module = xla::HloModule::CreateFromProto( - result.computation->proto(), module_config); - ASSERT_TRUE(status_or_hlo_module.ok()); - - constexpr char expected_hlo_module_string[] = - R"(HloModule main.4, input_output_alias={ {0}: (1, {}, may-alias) } - -ENTRY %main.4 (Arg_0.1: f32[2], Arg_1.2: f32[2]) -> (f32[2]) { - %Arg_1.2 = f32[2]{0} parameter(1) - %Arg_0.1 = f32[2]{0} parameter(0) - ROOT %tuple.3 = (f32[2]{0}) tuple(f32[2]{0} %Arg_0.1) -} - -)"; - - EXPECT_EQ(expected_hlo_module_string, - status_or_hlo_module.ValueOrDie()->ToString()); -} - -} // namespace -} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_attr.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_attr.cc new file mode 100644 index 00000000000..98bfbbe608a --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_attr.cc @@ -0,0 +1,113 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tensorflow/utils/convert_attr.h" + +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/platform/errors.h" + +namespace tensorflow { + +// Converts non func AttrValue proto into an MLIR attribute. Func attribute is +// exclused in this function because the function might be renamed when the +// function definition is imported. +StatusOr ConvertNonFuncAttributeValue(const AttrValue& value, + mlir::Builder* builder) { + switch (value.value_case()) { + case AttrValue::kI: + return builder->getI64IntegerAttr(value.i()); + case AttrValue::kS: + return builder->getStringAttr(value.s()); + case AttrValue::kF: + return builder->getFloatAttr(builder->getF32Type(), value.f()); + case AttrValue::kB: + return builder->getBoolAttr(value.b()); + case AttrValue::kType: { + mlir::Type type; + TF_RETURN_IF_ERROR(ConvertDataType(value.type(), *builder, &type)); + return mlir::TypeAttr::get(type); + } + case AttrValue::kShape: + return ConvertTensorShapeProto(value.shape(), builder->getContext()); + case AttrValue::kTensor: + return ConvertTensorProto(value.tensor(), builder); + case AttrValue::kList: { + absl::InlinedVector attrs; + for (const auto& item : value.list().i()) + attrs.push_back(builder->getI64IntegerAttr(item)); + for (const auto& item : value.list().s()) + attrs.push_back(builder->getStringAttr(item)); + for (const auto& item : value.list().f()) + attrs.push_back(builder->getFloatAttr(builder->getF32Type(), item)); + for (const auto& item : value.list().b()) + attrs.push_back(builder->getBoolAttr(item)); + for (const auto& item : value.list().type()) { + mlir::Type type; + TF_RETURN_IF_ERROR(ConvertDataType(DataType(item), *builder, &type)); + attrs.push_back(mlir::TypeAttr::get(type)); + } + for (const auto& item : value.list().shape()) { + TF_ASSIGN_OR_RETURN( + auto attr, ConvertTensorShapeProto(item, builder->getContext())); + attrs.push_back(attr); + } + for (const auto& item : value.list().tensor()) { + TF_ASSIGN_OR_RETURN(auto attr, ConvertTensorProto(item, builder)); + attrs.push_back(attr); + } + if (!value.list().func().empty()) { + return tensorflow::errors::Unimplemented( + absl::StrCat("Attribute ", value.DebugString())); + } + return builder->getArrayAttr( + llvm::makeArrayRef(attrs.begin(), attrs.end())); + } + case AttrValue::VALUE_NOT_SET: + return builder->getUnitAttr(); + // kPlaceholder is not implemented. + default: + return tensorflow::errors::Unimplemented( + absl::StrCat("Attribute ", value.DebugString())); + } +} + +StatusOr ConvertAttributeValue(const AttrValue& value, + mlir::Builder* builder) { + switch (value.value_case()) { + case AttrValue::kFunc: { + // TODO(b/156546237): Unify kFunc/NameAttrList attribute representation. + // Currently kFunc/NameAttrList attributes in a kList/repeated AttrValue + // will not use this representation. + mlir::NamedAttrList attrs; + for (const auto& func_attr : value.func().attr()) { + TF_ASSIGN_OR_RETURN(auto attr, + ConvertAttributeValue(func_attr.second, builder)); + attrs.push_back(builder->getNamedAttr(func_attr.first, attr)); + } + auto func_attrs = builder->getDictionaryAttr(attrs); + return mlir::TF::FuncAttr::get(builder->getContext(), value.func().name(), + func_attrs); + } + default: + return ConvertNonFuncAttributeValue(value, builder); + } +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_attr.h b/tensorflow/compiler/mlir/tensorflow/utils/convert_attr.h new file mode 100644 index 00000000000..c95ed60273d --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_attr.h @@ -0,0 +1,39 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_CONVERT_ATTR_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_CONVERT_ATTR_H_ + +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/stream_executor/lib/statusor.h" + +namespace tensorflow { + +using stream_executor::port::StatusOr; + +// Converts non func AttrValue proto into an MLIR attribute. Func attribute is +// exclused in this function because the function might be renamed when the +// function definition is imported. +StatusOr ConvertNonFuncAttributeValue(const AttrValue& value, + mlir::Builder* builder); + +// Converts all kinds of AttrValue proto into an MLIR attribute. +StatusOr ConvertAttributeValue(const AttrValue& value, + mlir::Builder* builder); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_CONVERT_ATTR_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc index 05e1f059029..98328212c88 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc @@ -214,6 +214,20 @@ mlir::TF::ShapeAttr ConvertTypeToTensorShapeAttr(const mlir::Type& type) { return mlir::TF::ShapeAttr::get(type.getContext(), ArrayRef()); } +// Converts the tensor shape proto into an MLIR shape attribute. +StatusOr ConvertTensorShapeProto(const TensorShapeProto& shape, + mlir::MLIRContext* context) { + if (shape.unknown_rank()) + return mlir::TF::ShapeAttr::get(context, llvm::None); + + llvm::SmallVector dims; + dims.reserve(shape.dim().size()); + for (const auto& dim : shape.dim()) { + dims.push_back(dim.size()); + } + return mlir::TF::ShapeAttr::get(context, llvm::makeArrayRef(dims)); +} + // Converts an MLIR dense string elements attribute to a TensorFlow tensor // proto. void ConvertStringElementsAttr( diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h index e7cde4db936..294453ebcfd 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h @@ -48,6 +48,10 @@ PartialTensorShape ConvertTypeToTensorShape(const mlir::Type& type); // Converts an MLIR shaped type to a TensorFlow shape attribute. mlir::TF::ShapeAttr ConvertTypeToTensorShapeAttr(const mlir::Type& type); +// Converts a TensorFlow shape attribute to an MLIR shape attribute. +StatusOr ConvertTensorShapeProto(const TensorShapeProto& shape, + mlir::MLIRContext* context); + // Converts an MLIR elements attribute to a TensorFlow tensor proto. Status ConvertToTensorProto(mlir::ElementsAttr attr, TensorProto* output_tensor); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc index febf2bc096d..6c1cab435d3 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc @@ -182,4 +182,55 @@ std::string DumpRawStringToFile(llvm::StringRef name, llvm::StringRef content, return filepath; } +void SetCrashReproducer(mlir::PassManager& pm, llvm::StringRef dir_path) { + std::string path = dir_path.str(); + if (path.empty()) { + if (getenv("MLIR_CRASH_REPRODUCER_DIRECTORY")) + path = getenv("MLIR_CRASH_REPRODUCER_DIRECTORY"); + else if (getenv("TEST_UNDECLARED_OUTPUTS_DIR")) + path = "sponge"; + } + if (path.empty()) { + LOG_FIRST_N(INFO, 1) << "disabling MLIR crash reproducer, set env var " + "`MLIR_CRASH_REPRODUCER_DIRECTORY` to enable."; + return; + } + + // Output dirs "sponge" (case-insensitive) have a special meaning: Dump into + // the directory specified by the environment variable + // TEST_UNDECLARED_OUTPUTS_DIR. + string lower_path = absl::AsciiStrToLower(path); + if (lower_path == "sponge") { + if (!tensorflow::io::GetTestUndeclaredOutputsDir(&path)) { + LOG(ERROR) << "MLIR crash reproducer is set to '" << dir_path.str() + << "', but environment variable TEST_UNDECLARED_OUTPUTS_DIR " + "is not set, so cannot dump anywhere."; + return; + } + } + + auto* env = tensorflow::Env::Default(); + auto status = env->RecursivelyCreateDir(path); + if (!status.ok()) { + LOG(WARNING) << "cannot create directory '" + path + + "': " + status.error_message(); + return; + } + + path += "/mlir_reproducer_"; + + if (!tensorflow::Env::Default()->CreateUniqueFileName(&path, ".mlir")) { + LOG(WARNING) + << "cannot create unique filename, won't enable MLIR crash reproducer."; + return; + } + pm.enableCrashReproducerGeneration(path, /*genLocalReproducer=*/false); +} + +void applyTensorflowAndCLOptions(mlir::PassManager& pm, + llvm::StringRef dir_path) { + mlir::applyPassManagerCLOptions(pm); + SetCrashReproducer(pm, dir_path); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h index 726eed8974e..133285864f6 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h @@ -20,6 +20,7 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project #include "tensorflow/core/platform/status.h" namespace tensorflow { @@ -64,6 +65,22 @@ std::string GetDumpDirFromEnvVar(); std::string DumpRawStringToFile(llvm::StringRef name, llvm::StringRef content, llvm::StringRef dirname = ""); +// Enable the crash reproducer on the provided PassManager to the provided +// directory path. If the provided path is empty, it is retrieved from the +// environment variable `MLIR_CRASH_REPRODUCER_DIRECTORY`. If the provided path +// is the string "sponge", the file will be included in the sponge "Output +// Files" by looking up the environment to infer the directory path. +void SetCrashReproducer(mlir::PassManager& pm, llvm::StringRef dir_path = ""); + +// This applies both the PassManagerCLOptions provided by MLIR along with any +// tensorflow specific options. +// +// Note that this function should be in a more appropriate file, but it is +// unclear what a proper file would be as no other functions would currently be +// in the file also. +void applyTensorflowAndCLOptions(mlir::PassManager& pm, + llvm::StringRef dir_path = ""); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_DUMP_MLIR_UTIL_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc index 67c2aebf121..cad5f2bae98 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc @@ -127,11 +127,12 @@ Status ConvertAttribute(const mlir::FlatSymbolRefAttr& attr, AttrValue* value) { return Status::OK(); } -Status ConvertAttribute(const mlir::TF::FuncAttr& attr, AttrValue* value) { +Status ConvertAttribute(const mlir::TF::FuncAttr& attr, bool remove_ref_type, + AttrValue* value) { TF_RETURN_IF_ERROR( ConvertAttribute(attr.GetName().cast(), value)); TF_RETURN_IF_ERROR(ConvertAttributes(attr.GetAttrs().getValue(), - /*attrs_to_ignore=*/{}, + /*attrs_to_ignore=*/{}, remove_ref_type, value->mutable_func()->mutable_attr())); return Status::OK(); } @@ -159,15 +160,18 @@ Status ConvertAttribute(const mlir::StringAttr& attr, AttrValue* value) { return Status::OK(); } -Status ConvertAttribute(mlir::Type type, AttrValue* value) { +Status ConvertAttribute(mlir::Type type, bool remove_ref_type, + AttrValue* value) { DataType dtype; TF_RETURN_IF_ERROR(ConvertToDataType(type, &dtype)); + if (tensorflow::IsRefType(dtype)) dtype = tensorflow::RemoveRefType(dtype); value->set_type(dtype); return Status::OK(); } -Status ConvertAttribute(const mlir::TypeAttr& type, AttrValue* value) { - return ConvertAttribute(type.getValue(), value); +Status ConvertAttribute(const mlir::TypeAttr& type, bool remove_ref_type, + AttrValue* value) { + return ConvertAttribute(type.getValue(), remove_ref_type, value); } Status ConvertAttribute(const mlir::UnitAttr& attr, AttrValue* value) { @@ -175,7 +179,8 @@ Status ConvertAttribute(const mlir::UnitAttr& attr, AttrValue* value) { return Status::OK(); } -Status ConvertAttribute(const mlir::ArrayAttr& attr, AttrValue* value) { +Status ConvertAttribute(const mlir::ArrayAttr& attr, bool remove_ref_type, + AttrValue* value) { auto* list = value->mutable_list(); for (mlir::Attribute a : attr.getValue()) { if (auto attr = a.dyn_cast()) { @@ -215,7 +220,8 @@ Status ConvertAttribute(const mlir::ArrayAttr& attr, AttrValue* value) { if (auto shaped_type = elt_type.dyn_cast()) { elt_type = shaped_type.getElementType(); } - TF_RETURN_IF_ERROR(ConvertAttribute(elt_type, &attr_val)); + TF_RETURN_IF_ERROR( + ConvertAttribute(elt_type, remove_ref_type, &attr_val)); list->add_type(attr_val.type()); } else if (auto attr = a.dyn_cast()) { AttrValue attr_val; @@ -228,18 +234,6 @@ Status ConvertAttribute(const mlir::ArrayAttr& attr, AttrValue* value) { return Status::OK(); } -// Updates NodeDef constructed out of an MLIR Case/IfW/While op to map it to -// either TensorFlow StatelessX or X op depending on the additional attribute. -void UpdateCompositeOp(NodeDef* node_def) { - auto it = node_def->mutable_attr()->find("is_stateless"); - if (it != node_def->attr().end()) { - if (it->second.b()) { - *node_def->mutable_op() = "Stateless" + node_def->op(); - } - node_def->mutable_attr()->erase(it); - } -} - // Returns true if the executor/control dialect op should map to Ref node in // TensorFlow Graph. For control dialect NextIteration it uses the 1st operand // type. For executor dialect NextIteration it uses the 2nd operand type. For @@ -291,7 +285,6 @@ StatusOr GetTensorFlowOpName(llvm::StringRef op_name) { } StatusOr> GetOperationNodeDef( - const absl::flat_hash_set& attrs_to_ignore, mlir::Operation* inst, llvm::StringRef name) { auto node_def = absl::make_unique(); // Note: we do not use NodeBuilder or NodeDefBuilder as that would require @@ -321,6 +314,14 @@ StatusOr> GetOperationNodeDef( node_def->set_name(name.str()); node_def->set_op(std::string(op_name.str())); + // Update NodeDef constructed out of an MLIR Case/If/While op to map it to + // either TensorFlow StatelessX or X op depending on the additional attribute. + if (llvm::isa(inst)) { + auto stateless = inst->getAttrOfType("is_stateless"); + if (stateless && stateless.getValue()) + *node_def->mutable_op() = "Stateless" + node_def->op(); + } + // Add inputs to the NodeDef based on the number of operands. This is required // as later when edges are added to the Node using Graph::AddEdge the // associated NodeDef is not updated. @@ -331,27 +332,17 @@ StatusOr> GetOperationNodeDef( node_def->set_device(std::string(attr.getValue())); } - // Add the node attributes. - TF_RETURN_WITH_CONTEXT_IF_ERROR( - ConvertAttributes(inst->getAttrs(), attrs_to_ignore, - node_def->mutable_attr()), - "while converting attributes for node: ", name.str()); - // Add the node debug info. TF_RETURN_IF_ERROR(ConvertLocation( inst->getLoc(), node_def->mutable_experimental_debug_info())); - if (node_def->op() == "Case") UpdateCompositeOp(node_def.get()); - if (node_def->op() == "If") UpdateCompositeOp(node_def.get()); - if (node_def->op() == "While") UpdateCompositeOp(node_def.get()); - return node_def; } Status ConvertAttributes( const llvm::ArrayRef attrs, const absl::flat_hash_set& attrs_to_ignore, - AttrValueMap* values) { + bool remove_ref_type, AttrValueMap* values) { AttrValueMap func_call_attrs; for (const mlir::NamedAttribute& named_attr : attrs) { auto name_strref = named_attr.first.str(); @@ -376,7 +367,7 @@ Status ConvertAttributes( continue; } if (auto func_attr = attr.dyn_cast()) { - TF_RETURN_IF_ERROR(ConvertAttribute(func_attr, &value)); + TF_RETURN_IF_ERROR(ConvertAttribute(func_attr, remove_ref_type, &value)); func_call_attrs[string(name)] = value; continue; } @@ -388,11 +379,13 @@ Status ConvertAttributes( TF_RETURN_IF_ERROR( llvm::TypeSwitch(attr) .Case( - [&](auto derived_attr) { - return ConvertAttribute(derived_attr, &value); - }) + mlir::StringAttr, mlir::ElementsAttr, mlir::UnitAttr, + mlir::TF::ShapeAttr>([&](auto derived_attr) { + return ConvertAttribute(derived_attr, &value); + }) + .Case([&](auto derived_attr) { + return ConvertAttribute(derived_attr, remove_ref_type, &value); + }) .Default([&](mlir::Attribute) { return errors::Unimplemented( "Unhandled attribute kind for attribute '", name_strref, @@ -419,28 +412,6 @@ Status ConvertAttributes( return Status::OK(); } -// Sets type attribute with the given name. If the attribute already exists with -// a different value, returns an error. -Status SetTypeAttribute(absl::string_view name, mlir::Type type, - AttrValueMap* values) { - DataType dtype; - TF_RETURN_IF_ERROR(ConvertScalarTypeToDataType(type, &dtype)); - if (tensorflow::IsRefType(dtype)) dtype = tensorflow::RemoveRefType(dtype); - AttrValue value; - value.set_type(dtype); - - auto result = values->insert({string(name), value}); - if (!result.second) { - DataType actual_dtype = result.first->second.type(); - if (actual_dtype != dtype) { - return errors::InvalidArgument("Expected ", DataType_Name(dtype), " '", - name, "' attribute but found ", - DataType_Name(actual_dtype)); - } - } - return Status::OK(); -} - Status SetShapeAttribute(absl::string_view name, mlir::ShapedType shaped_type, AttrValueMap* values) { tensorflow::TensorShapeProto tshape; @@ -469,26 +440,6 @@ Status SetShapeAttribute(absl::string_view name, mlir::ShapedType shaped_type, return Status::OK(); } -Status SetSizeAttribute(absl::string_view name, size_t size, - AttrValueMap* values) { - AttrValue value; - value.set_i(size); - - auto result = values->insert({string(name), value}); - if (!result.second) { - // This should be extremely rare as it means we are adding the same - // attribute multiple times/have some redundancy in representing this - // attribute. - size_t actual_size = result.first->second.i(); - // Just check via string output as we shouldn't get here and if we do they - // should be trivially the same, else fail. - if (actual_size != size) - return errors::InvalidArgument("Expected '", name, "' attribute to be ", - size, " but found ", actual_size); - } - return Status::OK(); -} - bool IsLegacyCallInstruction(mlir::Operation* inst) { return llvm::dyn_cast(inst); } diff --git a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h index 58fe39fa4e8..d1e0fd12f26 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h @@ -50,12 +50,8 @@ Status AddTensorFlowOpPrefix(std::string); StatusOr GetTensorFlowOpName(llvm::StringRef); // Converts an MLIR operation to TensorFlow NodeDef with given node name. This -// name should be unique to the graph it is being inserted into. `op_name_func` -// is to map the op name of `inst` to its op name in TensorFlow. "name" and -// "device" attributes are ignored by default. Use attrs_to_ignore to specify -// any other attributes that should be ignored. +// name should be unique to the graph it is being inserted into. StatusOr> GetOperationNodeDef( - const absl::flat_hash_set& attrs_to_ignore, mlir::Operation* inst, llvm::StringRef name); // Converts MLIR attributes with values to their tensorflow equivalent. @@ -64,23 +60,13 @@ StatusOr> GetOperationNodeDef( Status ConvertAttributes( const llvm::ArrayRef attrs, const absl::flat_hash_set& attrs_to_ignore, - AttrValueMap* values); - -// Sets type attribute with the given name. If the attribute already exists with -// a different value, returns an error. -Status SetTypeAttribute(absl::string_view name, mlir::Type type, - AttrValueMap* values); + bool remove_ref_type, AttrValueMap* values); // Sets shape attribute with the given name. If the attribute already exists // with a different value, returns an error. Status SetShapeAttribute(absl::string_view name, mlir::ShapedType shape, AttrValueMap* values); -// Sets the given size_t value as an integer attribute with the given name. -// If the attribute already exists with a different value, returns an error. -Status SetSizeAttribute(absl::string_view name, size_t size, - AttrValueMap* values); - // Returns true if the given instruction is an mlir::TF::LegacyCallOp or the // result of such an operation transformed by the // ExecutorToControlDialectConversion pass. diff --git a/tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.cc new file mode 100644 index 00000000000..8e9495c0454 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.cc @@ -0,0 +1,56 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h" + +#include "llvm/Support/raw_ostream.h" +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/Parser.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/platform/errors.h" + +namespace tensorflow { + +std::string SerializeMlirModule(mlir::ModuleOp module_op) { + std::string serialized_mlir_module; + llvm::raw_string_ostream os(serialized_mlir_module); + mlir::OpPrintingFlags print_flags; + print_flags.enableDebugInfo(); + module_op.print(os, print_flags); + return std::move(os.str()); +} + +Status DeserializeMlirModule(llvm::StringRef serialized_mlir_module, + mlir::MLIRContext* mlir_context, + mlir::OwningModuleRef* mlir_module) { + TF_RET_CHECK(!serialized_mlir_module.empty()) + << "unexpected empty serialized MLIR module string"; + TF_RET_CHECK(mlir_module) << "unexpected null MLIR module pointer"; + + // Make sure we catch any error reported by MLIR and forward it to the TF + // error reporting system. + mlir::StatusScopedDiagnosticHandler error_handler(mlir_context); + + // Parse the module. + *mlir_module = mlir::parseSourceString(serialized_mlir_module, mlir_context); + if (!*mlir_module) + return error_handler.Combine( + errors::InvalidArgument("could not parse MLIR module")); + + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h b/tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h new file mode 100644 index 00000000000..12d1c39132e --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h @@ -0,0 +1,39 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_SERIALIZE_MLIR_MODULE_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_SERIALIZE_MLIR_MODULE_UTILS_H_ + +#include + +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { + +// Prints a MLIR module `module_op` and returns it as a string. +std::string SerializeMlirModule(mlir::ModuleOp module_op); + +// Parses a MLIR module from `mlir_module_string` into `mlir_module` with +// context `mlir_context`. +Status DeserializeMlirModule(llvm::StringRef serialized_mlir_module, + mlir::MLIRContext* mlir_context, + mlir::OwningModuleRef* mlir_module); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_SERIALIZE_MLIR_MODULE_UTILS_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.cc new file mode 100644 index 00000000000..d82d61ecf9e --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.cc @@ -0,0 +1,414 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.h" + +#include +#include +#include +#include + +#include "llvm/ADT/None.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/Sequence.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/FormatVariadic.h" +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/Region.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project +#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/export_utils.h" +#include "tensorflow/compiler/mlir/utils/array_container_utils.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_def_builder.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/types.h" + +#define DEBUG_TYPE "tf-shape-inference-utils" + +using ::tensorflow::int64; +using tensorflow::shape_inference::DimensionHandle; +using tensorflow::shape_inference::InferenceContext; +using tensorflow::shape_inference::ShapeHandle; + +namespace mlir { +namespace TF { + +namespace { + +// Extracts attributes from a MLIR operation, including derived attributes, into +// one NamedAttrList. +NamedAttrList GetAllAttributesFromOperation(Operation* op) { + NamedAttrList attr_list; + attr_list.append(op->getAttrDictionary().getValue()); + + if (auto derived = dyn_cast(op)) { + auto materialized = derived.materializeDerivedAttributes(); + attr_list.append(materialized.getValue()); + } + + return attr_list; +} + +// Extracts a PartialTensorShape from the MLIR type. +Optional GetShapeFromMlirType(Type t) { + if (auto ranked_type = t.dyn_cast()) { + // Convert the MLIR shape indices (int64_t) to TensorFlow indices + // (int64). + ArrayRef shape = ranked_type.getShape(); + SmallVector tf_shape(shape.begin(), shape.end()); + return tensorflow::PartialTensorShape( + MutableArrayRefToSpan(tf_shape)); + } + return None; +} + +// Gets the subtype's shape and data type for `type`. Templated to support both +// ResourceType and VariantType. +template +std::unique_ptr>> +GetSubtypesHelper(Type type) { + auto type_with_subtypes = + type.cast().getElementType().dyn_cast(); + if (!type_with_subtypes || type_with_subtypes.getSubtypes().empty()) { + return nullptr; + } + auto shapes_and_types = std::make_unique>>(); + for (auto subtype : type_with_subtypes.getSubtypes()) { + auto shape = GetShapeFromMlirType(subtype); + // handle_shapes_and_types requires all shapes to be known. So if any + // subtype is unknown, clear the vector. + if (!shape) { + shapes_and_types = nullptr; + break; + } + tensorflow::DataType dtype; + auto status = + tensorflow::ConvertToDataType(subtype.getElementType(), &dtype); + assert(status.ok() && "Unknown element type"); + shapes_and_types->emplace_back(*shape, dtype); + } + return shapes_and_types; +} + +// Gets the subtype's shape and data type for `type`. +std::unique_ptr>> +GetSubtypes(Type type) { + auto subclasses = GetSubtypesHelper(type); + if (subclasses) return subclasses; + return GetSubtypesHelper(type); +} + +// Returns a shape inference function call failure at `location`. +LogicalResult EmitErrorFromShapeFunction(Optional location, + StringRef op_name, + StringRef error_message) { + LLVM_DEBUG(llvm::dbgs() << "Shape inference error for '" << op_name + << "': " << error_message << "\n"); + return emitOptionalError( + location, + llvm::formatv( + "TensorFlow shape inference function errored for op '{0}': {1}", + op_name, error_message) + .str()); +} + +// Extracts shape from a shape handle and inference context. +Optional> GetShapeFromHandle(InferenceContext& context, + const ShapeHandle& sh) { + if (!context.RankKnown(sh)) return None; + SmallVector shape; + for (int dim : llvm::seq(0, context.Rank(sh))) + shape.push_back(context.Value(context.Dim(sh, dim))); + return shape; +} + +// Creates a tensor type from a shape handle and element type. +TensorType CreateTensorType(InferenceContext& context, const ShapeHandle& sh, + Type element_type) { + auto shape = GetShapeFromHandle(context, sh); + if (shape.hasValue()) + return RankedTensorType::get(shape.getValue(), element_type); + return UnrankedTensorType::get(element_type); +} + +// Creates a ShapedTypeComponent from a shape handle and element type. +ShapedTypeComponents CreateShapedTypeComponents(InferenceContext& context, + const ShapeHandle& sh, + Type element_type) { + auto shape = GetShapeFromHandle(context, sh); + if (shape.hasValue()) + return ShapedTypeComponents(shape.getValue(), element_type); + return ShapedTypeComponents(element_type); +} + +} // namespace + +LogicalResult InferReturnTypeComponentsForTFOp( + Optional location, Operation* op, int64_t graph_version, + OperandAsConstantFn operand_as_constant_fn, + OpResultAsShapeFn op_result_as_shape_fn, + ResultElementTypeFn result_element_type_fn, + SmallVectorImpl& inferred_return_shapes) { + assert(op->getName().getDialect() == + TensorFlowDialect::getDialectNamespace()); + + auto op_name_or = + tensorflow::GetTensorFlowOpName(op->getName().getStringRef()); + if (!op_name_or.ok()) { + LLVM_DEBUG(llvm::dbgs() << "Skipping inference for unregistered op '" + << op->getName().getStringRef() << "'.\n"); + return emitOptionalError(location, "op is unregistered"); + } + llvm::StringRef op_name = op_name_or.ConsumeValueOrDie(); + + // Get information from the registry and check if we have a shape function for + // this op. + const tensorflow::OpRegistrationData* op_reg_data = + tensorflow::OpRegistry::Global()->LookUp(op_name.str()); + if (!op_reg_data) { + LLVM_DEBUG(llvm::dbgs() << "Skipping inference for unregistered op '" + << op_name << "'.\n"); + return emitOptionalError(location, "op is unregistered"); + } + if (!op_reg_data->shape_inference_fn) { + LLVM_DEBUG(llvm::dbgs() + << "Skipping inference for op without shape function '" + << op_name << "'.\n"); + return emitOptionalError(location, "missing shape function"); + } + + // Convert the operation attributes to be able to use the InferenceContext + // and the TensorFlow shape function. + tensorflow::AttrValueMap attributes; + auto attr_status = tensorflow::GetAttrValuesFromOperation( + op, op_name, op_reg_data, /*ignore_unregistered_attrs=*/true, + &attributes); + if (!attr_status.ok()) { + LLVM_DEBUG(llvm::dbgs() << "Error creating attribute map for '" << op_name + << "': " << attr_status.error_message() << "\n"); + return emitOptionalError(location, attr_status.error_message()); + } + + // Collect an array with input values for constant operands and input shapes + // for all the operands. + const int num_operands = op->getNumOperands(); + std::vector input_tensors(num_operands); + std::vector input_shapes(num_operands); + std::vector tensors(num_operands); + std::vector>>> + handle_shapes_and_types(num_operands); + for (auto it : llvm::enumerate(op->getOperands())) { + Value operand = it.value(); + size_t index = it.index(); + + // If the operand is constant, then convert it to Tensor. + if (auto attr = operand_as_constant_fn(operand)) { + tensorflow::Tensor* input_tensor = &tensors[index]; + auto status = + tensorflow::ConvertToTensor(attr.cast(), input_tensor); + if (status.ok()) { + input_tensors[index] = input_tensor; + } else { + LLVM_DEBUG(llvm::dbgs() << "Error converting input " << index + << " of op '" << op_name << "' to Tensor: " + << status.error_message() << "\n"); + } + } + + Type operand_type = operand.getType(); + if (auto shape = GetShapeFromMlirType(operand_type)) { + input_shapes[index] = *shape; + } + // Collect the handle shapes and types for a resource/variant. + handle_shapes_and_types[index] = GetSubtypes(operand_type); + } + + // Perform the shape inference using an InferenceContext with the input + // shapes. This object is abstracting the information that the ShapeInference + // function operates on. + InferenceContext c(graph_version, tensorflow::AttrSlice(&attributes), + op_reg_data->op_def, input_shapes, input_tensors, + /*input_tensors_as_shapes=*/{}, handle_shapes_and_types); + auto status = c.Run(op_reg_data->shape_inference_fn); + if (!status.ok()) + return EmitErrorFromShapeFunction(location, op_name, + status.error_message()); + + // Determine if, during shape computation, the shape functions attempted to + // query an input operand as shape where the input was not known/constant. + bool requires_inputs = + any_of(llvm::seq(0, c.num_inputs()), [&](int input) { + return c.requested_input_tensor_as_partial_shape(input) && + !input_tensors[input]; + }); + if (requires_inputs) { + LLVM_DEBUG(llvm::dbgs() << "\trequired input\n"); + std::vector input_tensors_as_shapes; + for (int input : llvm::seq(0, c.num_inputs())) { + if (c.requested_input_tensor_as_partial_shape(input) && + !input_tensors[input]) { + LLVM_DEBUG(llvm::dbgs() << "Requesting " << input << " as shape\n"); + auto op_result = op->getOperand(input).dyn_cast(); + if (!op_result) continue; + // Resize on first valid shape computed. + input_tensors_as_shapes.resize(c.num_inputs()); + auto handle = op_result_as_shape_fn(c, op_result); + LLVM_DEBUG(llvm::dbgs() << "Requested " << input << " as shape " + << (handle.Handle() ? "found" : "not found")); + if (handle.Handle()) input_tensors_as_shapes[input] = handle; + } + } + + // Attempt to compute the unknown operands as shapes. + // Note: in the case where no partial outputs could be computed, this + // would be empty. + if (!input_tensors_as_shapes.empty()) { + c.set_input_tensors_as_shapes(input_tensors_as_shapes); + auto status = c.Run(op_reg_data->shape_inference_fn); + if (!status.ok()) + return EmitErrorFromShapeFunction(location, op_name, + status.error_message()); + } + } + + // Update the shape for each of the operation result if the InferenceContext + // has more precise shapes recorded. + for (int output : llvm::seq(0, c.num_outputs())) { + ShapeHandle shape_handle = c.output(output); + LLVM_DEBUG(llvm::dbgs() << "Inferred output " << output << " : " + << c.DebugString(shape_handle) << "\n"); + + Type new_element_type = result_element_type_fn(output); + // Populate the handle shapes for a resource/variant. + if (new_element_type && + new_element_type.isa()) { + auto handle_shapes_types = c.output_handle_shapes_and_types(output); + if (handle_shapes_types) { + SmallVector subtypes; + Builder b(op->getContext()); + for (const auto& shape_n_type : *handle_shapes_types) { + Type element_type; + auto status = + tensorflow::ConvertDataType(shape_n_type.dtype, b, &element_type); + assert(status.ok() && "Unknown element type"); + subtypes.push_back( + CreateTensorType(c, shape_n_type.shape, element_type)); + } + if (new_element_type.isa()) { + new_element_type = TF::ResourceType::get(subtypes, op->getContext()); + } else { + new_element_type = TF::VariantType::get(subtypes, op->getContext()); + } + } + } + inferred_return_shapes.push_back( + CreateShapedTypeComponents(c, shape_handle, new_element_type)); + } + + return success(); +} + +LogicalResult InferReturnTypeComponentsForTFOp( + Optional location, Operation* op, int64_t graph_version, + SmallVectorImpl& inferred_return_shapes) { + if (auto type_op = dyn_cast(op)) { + auto attributes = GetAllAttributesFromOperation(op); + SmallVector inferred_return_types; + auto result = type_op.inferReturnTypes( + op->getContext(), location, op->getOperands(), + DictionaryAttr::get(attributes, op->getContext()), op->getRegions(), + inferred_return_types); + if (failed(result)) return failure(); + + inferred_return_shapes.resize(inferred_return_types.size()); + for (auto inferred_return_type : llvm::enumerate(inferred_return_types)) { + if (auto shaped_type = + inferred_return_type.value().dyn_cast()) { + if (shaped_type.hasRank()) { + inferred_return_shapes[inferred_return_type.index()] = + ShapedTypeComponents(shaped_type.getShape(), + shaped_type.getElementType()); + } else { + inferred_return_shapes[inferred_return_type.index()] = + ShapedTypeComponents(shaped_type.getElementType()); + } + } + } + + return success(); + } + + if (auto shape_type_op = dyn_cast(op)) { + auto attributes = GetAllAttributesFromOperation(op); + return shape_type_op.inferReturnTypeComponents( + op->getContext(), location, op->getOperands(), + DictionaryAttr::get(attributes, op->getContext()), op->getRegions(), + inferred_return_shapes); + } + + auto operand_as_constant_fn = [](Value operand) -> Attribute { + Attribute attr; + if (matchPattern(operand, m_Constant(&attr))) return attr; + return nullptr; + }; + + auto op_result_as_shape_fn = [](InferenceContext& ic, + OpResult op_result) -> ShapeHandle { + auto rt = op_result.getType().dyn_cast(); + if (!rt || rt.getRank() != 1 || !rt.hasStaticShape()) return {}; + + std::vector dims(rt.getDimSize(0), ic.UnknownDim()); + Attribute attr; + if (matchPattern(op_result, m_Constant(&attr))) { + auto elements = attr.dyn_cast(); + if (elements) + for (auto element : llvm::enumerate(elements.getIntValues())) + dims[element.index()] = ic.MakeDim(element.value().getSExtValue()); + } + return ic.MakeShape(dims); + }; + + auto result_element_type_fn = [](int) -> Type { return nullptr; }; + + return InferReturnTypeComponentsForTFOp( + location, op, graph_version, operand_as_constant_fn, + op_result_as_shape_fn, result_element_type_fn, inferred_return_shapes); +} + +} // namespace TF +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.h b/tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.h new file mode 100644 index 00000000000..eda2bc49514 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.h @@ -0,0 +1,76 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_SHAPE_INFERENCE_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_SHAPE_INFERENCE_UTILS_H_ + +#include + +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/core/framework/shape_inference.h" + +namespace mlir { +namespace TF { + +// Function that takes in a value and extracts a constant from it, if available. +// If the value cannot be resolved as a constant, a nullptr will be returned. +// Certain shape functions require constant values as arguments. +using OperandAsConstantFn = llvm::function_ref; + +// Function that takes in an operation result and computes a shape (can be +// partial) value. Certain shape functions require shape values as arguments. +using OpResultAsShapeFn = + llvm::function_ref; + +// Function that takes a result index and returns the element type. Element +// types are necessary for handle types (resource, variant). +using ResultElementTypeFn = llvm::function_ref; + +// Runs TensorFlow shape inference associated to the op type registered in the +// TensorFlow op registry based on the Graph version, operands, and attributes. +// Invoking this shape function will create conversions of parameters to the +// TensorFlow Graph equivalent data structures and back to MLIR equivalent data +// structures. This does not use a natively implemented shape inference in MLIR, +// and instead is temporary until shape functions are reimplemented/migrated to +// being in MLIR instead of the TensorFlow op registry. +LogicalResult InferReturnTypeComponentsForTFOp( + Optional location, Operation* op, int64_t graph_version, + OperandAsConstantFn operand_as_constant_fn, + OpResultAsShapeFn op_result_as_shape_fn, + ResultElementTypeFn result_element_type_fn, + SmallVectorImpl& inferred_return_shapes); + +// Runs TensorFlow shape inference for an operation for a given Graph version. +// If an operation implements the `InferTypeOpInterface` or +// `InferShapedTypeOpInterface` interfaces, those are used instead but with +// derived attributes populated. Otherwise the above function is used but with +// default `operand_as_constant_fn` and `op_result_as_shape_fn` that only +// extracts a value if the operands are constant (no partial evaluation, and an +// empty `result_element_type_fn`. Element types with subtypes (DT_RESOURCE, +// DT_VARIANT) are not supported. +LogicalResult InferReturnTypeComponentsForTFOp( + Optional location, Operation* op, int64_t graph_version, + SmallVectorImpl& inferred_return_shapes); + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_SHAPE_INFERENCE_UTILS_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tf_xla_mlir_translate.cc b/tensorflow/compiler/mlir/tensorflow/utils/tf_xla_mlir_translate.cc new file mode 100644 index 00000000000..bcc3fe62f99 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/utils/tf_xla_mlir_translate.cc @@ -0,0 +1,334 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/Parser.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Translation.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h" +#include "tensorflow/compiler/mlir/utils/string_container_utils.h" +#include "tensorflow/compiler/mlir/xla/xla_mlir_translate_cl.h" +#include "tensorflow/compiler/tf2xla/xla_argument.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" + +// NOLINTNEXTLINE +llvm::cl::opt input_types( + "tf-xla-input-types", + llvm::cl::desc("XLA input argument types (kinds), separated by ','. " + "Supported types include ['parameter', 'resource']. If " + "empty, all arguments are assumed to be parameters."), + llvm::cl::init("")); + +namespace tensorflow { + +namespace { + +mlir::LogicalResult PrintHloModuleText( + const XlaCompilationResult& compilation_result, llvm::raw_ostream& output) { + const xla::HloModuleConfig module_config( + compilation_result.computation->GetProgramShape().ValueOrDie()); + auto status_or_hlo_module = xla::HloModule::CreateFromProto( + compilation_result.computation->proto(), module_config); + if (!status_or_hlo_module.ok()) { + LOG(ERROR) << "Conversion to HLO module failed: " + << status_or_hlo_module.status().ToString(); + return mlir::failure(); + } + + xla::HloModule* hlo_module = status_or_hlo_module.ValueOrDie().get(); + + output << hlo_module->ToString(); + + if (!compilation_result.input_mapping.empty()) + output << "// InputMapping {" + << absl::StrJoin(compilation_result.input_mapping, ", ") << "}\n"; + + for (const auto& xla_input_shape : compilation_result.xla_input_shapes) + output << "// XlaInputShape " << xla_input_shape.ToString() << '\n'; + + output << "// XlaOutputShape " + << compilation_result.xla_output_shape.ToString() << '\n'; + + for (const auto& xla_output_description : compilation_result.outputs) { + output << "// XlaOutputDescription type=" + << DataTypeString(xla_output_description.type) << " shape=(" + << absl::StrJoin(xla_output_description.shape.dim_sizes(), ", ") + << ')'; + if (xla_output_description.input_index >= 0) + output << " input_index=" << xla_output_description.input_index; + if (xla_output_description.is_constant) output << " constant"; + if (xla_output_description.is_tensor_list) output << " tensor_list"; + output << '\n'; + } + + for (const auto& resource_update : compilation_result.resource_updates) { + output << "// ResourceUpdate input_index=" << resource_update.input_index + << " type=" << DataTypeString(resource_update.type) << " shape=(" + << absl::StrJoin(resource_update.shape.dim_sizes(), " ") << ')'; + if (resource_update.modified) output << " modified"; + output << '\n'; + } + + return mlir::success(); +} + +Status ParseArgumentShapes( + absl::string_view input_shapes_str, + llvm::SmallVectorImpl& arg_shapes) { + arg_shapes.clear(); + std::vector> input_shapes_vector; + TF_RETURN_IF_ERROR(ParseNodeShapes(input_shapes_str, input_shapes_vector)); + arg_shapes.resize(input_shapes_vector.size()); + for (const auto& shape : llvm::enumerate(input_shapes_vector)) + TF_RETURN_IF_ERROR(TensorShapeUtils::MakeShape( + shape.value(), &arg_shapes[shape.index()].shape)); + + return Status::OK(); +} + +Status ParseDataTypes(absl::string_view data_types_str, + llvm::SmallVectorImpl& data_types) { + data_types.clear(); + std::vector input_dtypes_vector; + TF_RETURN_IF_ERROR(ParseNodeDataTypes(data_types_str, input_dtypes_vector)); + data_types.resize(input_dtypes_vector.size(), DT_INVALID); + for (auto data_type : llvm::enumerate(input_dtypes_vector)) { + if (!DataType_Parse(data_type.value(), &data_types[data_type.index()])) + return errors::InvalidArgument("Invalid dtype at index ", + data_type.index(), ": ", + data_type.value()); + const auto& resolved_dtype = data_types[data_type.index()]; + if (resolved_dtype == DT_INVALID || resolved_dtype == DT_STRING || + resolved_dtype == DT_RESOURCE || resolved_dtype == DT_VARIANT || + IsRefType(resolved_dtype)) + return errors::InvalidArgument("Unsupported dtype at index ", + data_type.index(), ": ", + data_type.value()); + } + + return Status::OK(); +} + +Status ParseArgumentKinds( + absl::string_view input_types_str, + llvm::SmallVectorImpl& argument_kinds) { + argument_kinds.clear(); + if (input_types_str.empty()) return Status::OK(); + + std::vector argument_kind_strs = + absl::StrSplit(input_types_str, ','); + argument_kinds.reserve(argument_kind_strs.size()); + for (const auto& argument_kind_str : llvm::enumerate(argument_kind_strs)) { + const auto& value = argument_kind_str.value(); + if (value == "parameter") { + argument_kinds.push_back(XlaArgument::Kind::kParameter); + } else if (value == "resource") { + argument_kinds.push_back(XlaArgument::Kind::kResource); + } else { + return errors::InvalidArgument( + "Unsupported TF/XLA argument kind at index ", + argument_kind_str.index(), ": ", value); + } + } + + return Status::OK(); +} + +Status ParseXlaArguments(absl::string_view input_shapes_str, + absl::string_view input_dtypes_str, + absl::string_view arg_kinds_str, + llvm::SmallVectorImpl& xla_arguments) { + xla_arguments.clear(); + std::vector> input_shapes_vector; + TF_RETURN_IF_ERROR( + tensorflow::ParseNodeShapes(input_shapes_str, input_shapes_vector)); + llvm::SmallVector dtypes_vector; + TF_RETURN_IF_ERROR(ParseDataTypes(input_dtypes_str, dtypes_vector)); + llvm::SmallVector arg_kinds_vector; + TF_RETURN_IF_ERROR(ParseArgumentKinds(arg_kinds_str, arg_kinds_vector)); + + if (input_shapes_vector.empty()) + input_shapes_vector.resize(dtypes_vector.size()); + + if (arg_kinds_vector.empty()) + arg_kinds_vector.resize(input_shapes_vector.size(), + XlaArgument::Kind::kParameter); + + if (input_shapes_vector.size() != dtypes_vector.size() || + input_shapes_vector.size() != arg_kinds_vector.size()) + return errors::InvalidArgument( + "Input shapes, dtypes, and types/kinds must be of the same " + "length, but got ", + input_shapes_vector.size(), ", ", dtypes_vector.size(), ", and ", + arg_kinds_vector.size(), " respectively"); + + xla_arguments.resize(input_shapes_vector.size()); + for (const auto& arg_components : + llvm::zip(xla_arguments, input_shapes_vector, dtypes_vector, + arg_kinds_vector)) { + XlaArgument& arg = std::get<0>(arg_components); + TensorShape shape; + TF_RETURN_IF_ERROR( + TensorShapeUtils::MakeShape(std::get<1>(arg_components), &shape)); + arg.shape = std::move(shape); + arg.type = std::get<2>(arg_components); + arg.kind = std::get<3>(arg_components); + } + + return Status::OK(); +} + +} // anonymous namespace + +static mlir::LogicalResult MlirTfToHloTextTranslateFunction( + mlir::ModuleOp module_op, llvm::raw_ostream& output) { + if (!module_op) return mlir::failure(); + + llvm::SmallVector arg_shapes; + auto args_status = + ParseArgumentShapes(mlir::StringRefToView(input_shapes), arg_shapes); + if (!args_status.ok()) { + LOG(ERROR) << args_status.ToString(); + return mlir::failure(); + } + + XlaCompilationResult compilation_result; + auto compilation_status = CompileMlirToXlaHlo( + module_op, arg_shapes, /*device_type=*/"XLA_CPU_JIT", emit_use_tuple_arg, + emit_return_tuple, IdentityShapeRepresentationFn(), &compilation_result, + /*custom_legalization_passes=*/{}); + if (!compilation_status.ok()) { + LOG(ERROR) << "TF/XLA compilation failed: " + << compilation_status.ToString(); + return mlir::failure(); + } + + return PrintHloModuleText(compilation_result, output); +} + +static mlir::LogicalResult MlirTfGraphToHloTextTranslateFunction( + mlir::ModuleOp module_op, llvm::raw_ostream& output) { + if (!module_op) return mlir::failure(); + + llvm::SmallVector xla_arguments; + auto args_status = ParseXlaArguments( + mlir::StringRefToView(input_shapes), mlir::StringRefToView(input_dtypes), + mlir::StringRefToView(input_types), xla_arguments); + if (!args_status.ok()) { + LOG(ERROR) << args_status.ToString(); + return mlir::failure(); + } + + XlaCompilationResult compilation_result; + auto compilation_status = CompileGraphToXlaHlo( + module_op, xla_arguments, /*device_type=*/"XLA_CPU_JIT", + emit_use_tuple_arg, emit_return_tuple, IdentityShapeRepresentationFn(), + &compilation_result, /*custom_legalization_passes=*/{}); + if (!compilation_status.ok()) { + LOG(ERROR) << "TF/XLA compilation failed: " + << compilation_status.ToString(); + return mlir::failure(); + } + + return PrintHloModuleText(compilation_result, output); +} + +static void RegisterMlirInputDialects(mlir::DialectRegistry& registry) { + registry.insert(); +} + +static void RegisterGraphInputDialects(mlir::DialectRegistry& registry) { + RegisterMlirInputDialects(registry); + registry.insert(); +} + +static mlir::OwningModuleRef SerializedMlirStringAttrToMlirModuleTranslate( + llvm::StringRef input, mlir::MLIRContext* context) { + mlir::Attribute attr = mlir::parseAttribute(input, context); + if (!attr || !attr.isa()) { + LOG(ERROR) << "Input is not parsable as a MLIR StringAttr."; + return nullptr; + } + auto str_attr = attr.cast(); + + RegisterMlirInputDialects(context->getDialectRegistry()); + mlir::OwningModuleRef module_ref; + auto status = + DeserializeMlirModule(str_attr.getValue().str(), context, &module_ref); + if (!status.ok()) { + LOG(ERROR) << status.ToString(); + return nullptr; + } + + return module_ref; +} + +static mlir::LogicalResult MlirModuleToSerializedMlirStringAttrTranslate( + mlir::ModuleOp module_op, llvm::raw_ostream& output) { + output << "\""; + std::string serialized_module = SerializeMlirModule(module_op); + llvm::printEscapedString(serialized_module, output); + output << "\""; + return mlir::success(); +} + +} // namespace tensorflow + +static mlir::TranslateFromMLIRRegistration MlirTfToHloTextTranslate( + "mlir-tf-to-hlo-text", tensorflow::MlirTfToHloTextTranslateFunction, + tensorflow::RegisterMlirInputDialects); + +static mlir::TranslateFromMLIRRegistration MlirTfGraphToHloTextTranslate( + "mlir-tf-graph-to-hlo-text", + tensorflow::MlirTfGraphToHloTextTranslateFunction, + tensorflow::RegisterGraphInputDialects); + +static mlir::TranslateToMLIRRegistration SerializedMlirStringAttrToMlirModule( + "mlir-tf-str-attr-to-mlir", + tensorflow::SerializedMlirStringAttrToMlirModuleTranslate); + +static mlir::TranslateFromMLIRRegistration MlirModuleToSerializedMlirStringAttr( + "mlir-tf-mlir-to-str-attr", + tensorflow::MlirModuleToSerializedMlirStringAttrTranslate, + tensorflow::RegisterMlirInputDialects); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc index 843d491c330..3516e3a65d9 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc @@ -374,9 +374,8 @@ GetGeneralTPUExecutionDeviceAssignment( return (x + bound_x * (y + bound_y * z)) * bound_core + core; }; - std::vector used_device_ids( - location_to_id(bound_x - 1, bound_y - 1, bound_z - 1, bound_core - 1), - false); + std::vector used_device_ids(bound_x * bound_y * bound_z * bound_core, + false); TPUDevicesAndHosts devices_and_hosts( num_replicas, llvm::SmallVector( num_cores_per_replica, TPUDeviceAndHost())); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc index 19eb5b2c476..8cf06259142 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc @@ -760,8 +760,8 @@ TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceTPUReplicate) { devices; auto replicate = builder.create( mlir::UnknownLoc::get(&context), /*num_replicas=*/2, devices, - llvm::ArrayRef, mlir::Type>>{}, - llvm::ArrayRef{}, llvm::ArrayRef{}); + llvm::ArrayRef>{}, + mlir::ValueRange{}, mlir::TypeRange{}); builder.setInsertionPoint(&replicate.body().front(), replicate.body().front().begin()); diff --git a/tensorflow/compiler/mlir/tf_mlir_opt_main.cc b/tensorflow/compiler/mlir/tf_mlir_opt_main.cc index 144e22750ca..e5408cef828 100644 --- a/tensorflow/compiler/mlir/tf_mlir_opt_main.cc +++ b/tensorflow/compiler/mlir/tf_mlir_opt_main.cc @@ -13,40 +13,33 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "llvm/Support/CommandLine.h" -#include "llvm/Support/InitLLVM.h" -#include "llvm/Support/SourceMgr.h" -#include "llvm/Support/ToolOutputFile.h" #include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project -#include "mlir/IR/AsmState.h" // from @llvm-project #include "mlir/InitAllDialects.h" // from @llvm-project #include "mlir/InitAllPasses.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Pass/PassManager.h" // from @llvm-project -#include "mlir/Support/FileUtilities.h" // from @llvm-project #include "mlir/Support/MlirOptMain.h" // from @llvm-project +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/register.h" +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/register_passes.h" #include "tensorflow/compiler/mlir/init_mlir.h" #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" +#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" +#include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h" #include "tensorflow/core/platform/init_main.h" -#include "tensorflow/core/platform/logging.h" int main(int argc, char **argv) { tensorflow::InitMlir y(&argc, &argv); mlir::registerAllPasses(); + mlir::mhlo::registerAllMhloPasses(); + mlir::lmhlo::registerAllLmhloPasses(); + mlir::mhlo::registerAllMhloPasses(); mlir::DialectRegistry registry; mlir::registerAllDialects(registry); + mlir::RegisterAllTensorFlowDialects(registry); + mlir::mhlo::registerAllMhloDialects(registry); registry.insert(); - registry.insert(); registry.insert(); - registry.insert(); - registry.insert(); - registry.insert(); + registry.insert(); return failed( mlir::MlirOptMain(argc, argv, "TensorFlow pass driver\n", registry)); } diff --git a/tensorflow/compiler/mlir/tf_mlir_translate_main.cc b/tensorflow/compiler/mlir/tf_mlir_translate_main.cc index 9b0b3aaa82b..3ea92a70ec7 100644 --- a/tensorflow/compiler/mlir/tf_mlir_translate_main.cc +++ b/tensorflow/compiler/mlir/tf_mlir_translate_main.cc @@ -111,8 +111,6 @@ int main(int argc, char** argv) { if (import_saved_model_object_graph) { mlir::MLIRContext context; - context.loadAllGloballyRegisteredDialects(); - auto module_or = tensorflow::SavedModelObjectGraphToMlirImport( input_filename, tags, exported_names, &context); if (!module_or.status().ok()) return 1; @@ -120,8 +118,6 @@ int main(int argc, char** argv) { module_or.ConsumeValueOrDie()->print(output->os()); } else if (import_saved_model_signature_defs) { mlir::MLIRContext context; - context.loadAllGloballyRegisteredDialects(); - auto module_or = tensorflow::SavedModelSignatureDefsToMlirImport( input_filename, tags, exported_names, &context, upgrade_legacy); if (!module_or.status().ok()) return 1; @@ -141,7 +137,6 @@ int main(int argc, char** argv) { llvm::SourceMgr sourceMgr; sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), llvm::SMLoc()); mlir::MLIRContext context; - context.loadAllGloballyRegisteredDialects(); mlir::SourceMgrDiagnosticHandler diagnostic_handler(sourceMgr, &context); return (*requested_translation)(sourceMgr, os, &context); }; diff --git a/tensorflow/compiler/mlir/tfjs/BUILD b/tensorflow/compiler/mlir/tfjs/BUILD index 4db960085ec..34686cc0f68 100644 --- a/tensorflow/compiler/mlir/tfjs/BUILD +++ b/tensorflow/compiler/mlir/tfjs/BUILD @@ -1,3 +1,9 @@ +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "filegroup") + +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "get_compatible_with_cloud") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//third_party/mlir:tblgen.bzl", "gentbl") load("//tensorflow:tensorflow.bzl", "tf_cc_binary") @@ -16,6 +22,7 @@ filegroup( gentbl( name = "tfjs_inc_gen", + compatible_with = get_compatible_with_cloud(), tbl_outs = [ ( "-gen-op-decls", @@ -70,6 +77,7 @@ cc_library( gentbl( name = "tfjs_optimize_inc_gen", + compatible_with = get_compatible_with_cloud(), tbl_outs = [ ( "-gen-rewriters", @@ -117,7 +125,6 @@ cc_library( ":tfjs_optimize", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:decode_constant_pass", - "//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration", "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow:tf_graph_optimization_pass", "//tensorflow/compiler/mlir/tensorflow:tf_saved_model_passes", @@ -141,7 +148,6 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:convert_graphdef", "//tensorflow/compiler/mlir/tensorflow:export_utils", "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", - "//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration", "//tensorflow/compiler/xla:statusor", "//tensorflow/core:framework", "//tensorflow/core:graph", @@ -179,7 +185,7 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", - "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", "@llvm-project//mlir:Pass", diff --git a/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.cc b/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.cc index 331bed09dce..5ea3f51b475 100644 --- a/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.cc +++ b/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.cc @@ -15,12 +15,12 @@ limitations under the License. #include "tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.h" -namespace mlir { -namespace tfjs { - #define GET_OP_CLASSES #include "tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.cc.inc" +namespace mlir { +namespace tfjs { + //===----------------------------------------------------------------------===// // TFJSDialect //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.h b/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.h index 9c98c9b0e19..bc52e3a0c7a 100644 --- a/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.h +++ b/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.h @@ -29,15 +29,9 @@ limitations under the License. #include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project -namespace mlir { -namespace tfjs { - #include "tensorflow/compiler/mlir/tfjs/ir/tfjs_dialect.h.inc" #define GET_OP_CLASSES #include "tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.h.inc" -} // namespace tfjs -} // namespace mlir - #endif // TENSORFLOW_COMPILER_MLIR_TFJS_IR_TFJS_OPS_H_ diff --git a/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.td b/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.td index 134aa010d8c..e2539c2f6d8 100644 --- a/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.td +++ b/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.td @@ -39,7 +39,7 @@ def TFJSDialect : Dialect { TF graphs to be deployed on TFJS. }]; - let cppNamespace = "tfjs"; + let cppNamespace = "::mlir::tfjs"; } //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tfjs/tests/BUILD b/tensorflow/compiler/mlir/tfjs/tests/BUILD index 5789480c3ba..979a9b773f2 100644 --- a/tensorflow/compiler/mlir/tfjs/tests/BUILD +++ b/tensorflow/compiler/mlir/tfjs/tests/BUILD @@ -1,3 +1,4 @@ +load("//tensorflow:tensorflow.bzl", "filegroup") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") package(licenses = ["notice"]) diff --git a/tensorflow/compiler/mlir/tfjs/tests/e2e/BUILD b/tensorflow/compiler/mlir/tfjs/tests/e2e/BUILD index 5c8d37da2f0..1fc3d51cb24 100644 --- a/tensorflow/compiler/mlir/tfjs/tests/e2e/BUILD +++ b/tensorflow/compiler/mlir/tfjs/tests/e2e/BUILD @@ -1,3 +1,4 @@ +load("//tensorflow:tensorflow.bzl", "filegroup") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") licenses(["notice"]) diff --git a/tensorflow/compiler/mlir/tfjs/transforms/optimize.cc b/tensorflow/compiler/mlir/tfjs/transforms/optimize.cc index a3678f7d154..5d3ee121577 100644 --- a/tensorflow/compiler/mlir/tfjs/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/tfjs/transforms/optimize.cc @@ -50,7 +50,7 @@ void Optimize::runOnFunction() { auto *ctx = &getContext(); auto func = getFunction(); - populateWithGenerated(ctx, &patterns); + populateWithGenerated(ctx, patterns); applyPatternsAndFoldGreedily(func, patterns); } } // namespace diff --git a/tensorflow/compiler/mlir/tfr/BUILD b/tensorflow/compiler/mlir/tfr/BUILD new file mode 100644 index 00000000000..2861dd92d5d --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/BUILD @@ -0,0 +1,362 @@ +load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test", "tf_py_test") +load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension") +load("//tensorflow/compiler/mlir/tfr:build_defs.bzl", "gen_op_libraries") +load( + "//third_party/mlir:tblgen.bzl", + "gentbl", +) +load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") + +package( + default_visibility = [ + ":friends", + ], + licenses = ["notice"], # Apache 2.0 +) + +package_group( + name = "friends", + includes = ["//third_party/mlir:subpackages"], + packages = [ + "//learning/brain/experimental/mlir/tfr/...", + "//tensorflow/c/...", + "//tensorflow/compiler/...", + ], +) + +filegroup( + name = "tfr_ops_td_files", + srcs = [ + "ir/tfr_ops.td", + "//tensorflow/compiler/mlir/tensorflow:ir/tf_op_base.td", + "//tensorflow/compiler/mlir/tensorflow:ir/tf_op_interfaces.td", + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:include/mlir/Dialect/Shape/IR/ShapeBase.td", + "@llvm-project//mlir:include/mlir/Dialect/Shape/IR/ShapeOps.td", + "@llvm-project//mlir:include/mlir/IR/SymbolInterfaces.td", + "@llvm-project//mlir:include/mlir/Interfaces/CallInterfaces.td", + "@llvm-project//mlir:include/mlir/Interfaces/ControlFlowInterfaces.td", + "@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td", + ], +) + +gentbl( + name = "tfr_ops_inc_gen", + tbl_outs = [ + ( + "-gen-op-decls", + "ir/tfr_ops.h.inc", + ), + ( + "-gen-op-defs", + "ir/tfr_ops.cc.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "ir/tfr_ops.td", + td_srcs = [ + ":tfr_ops_td_files", + ], +) + +cc_library( + name = "tfr", + srcs = [ + "ir/tfr_ops.cc", + "ir/tfr_ops.cc.inc", + ], + hdrs = [ + "ir/tfr_ops.h", + "ir/tfr_ops.h.inc", + "ir/tfr_types.h", + ], + deps = [ + ":tfr_ops_inc_gen", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_attributes", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ControlFlowInterfaces", + "@llvm-project//mlir:Dialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Shape", + "@llvm-project//mlir:SideEffects", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + ], +) + +cc_library( + name = "utils", + srcs = [ + "utils/utils.cc", + ], + hdrs = [ + "utils/utils.h", + ], + deps = [ + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Support", + ], +) + +cc_library( + name = "passes", + srcs = [ + "passes/canonicalize.cc", + "passes/decompose.cc", + "passes/raise_to_tf.cc", + ], + hdrs = [ + "passes/passes.h", + ], + deps = [ + ":tfr", + ":utils", + "//tensorflow/compiler/mlir/tensorflow", + "@com_google_absl//absl/memory", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:SCFToStandard", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + ], + alwayslink = 1, +) + +tf_cc_binary( + name = "tfr-opt", + srcs = ["passes/tfr_opt.cc"], + deps = [ + ":passes", + ":tfr", + "//tensorflow/compiler/mlir:init_mlir", + "//tensorflow/compiler/mlir:passes", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", + "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", + "@llvm-project//mlir:MlirOptLib", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:Shape", + "@llvm-project//mlir:StandardOps", + ], +) + +glob_lit_tests( + data = [ + ":test_utilities", + "@llvm-project//mlir:run_lit.sh", + ], + driver = "//tensorflow/compiler/mlir:run_lit.sh", + test_file_exts = ["mlir"], +) + +# Bundle together all of the test utilities that are used by tests. +filegroup( + name = "test_utilities", + testonly = True, + data = [ + "//tensorflow/compiler/mlir/tfr:tfr-opt", + "@llvm-project//llvm:FileCheck", + "@llvm-project//llvm:not", + ], +) + +cc_library( + name = "tfr_decompose_ctx", + srcs = ["integration/tfr_decompose_ctx.cc"], + hdrs = ["integration/tfr_decompose_ctx.h"], + deps = [ + ":passes", + ":tfr", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:convert_attr", + "//tensorflow/compiler/mlir/tensorflow:convert_graphdef", + "//tensorflow/compiler/mlir/tensorflow:convert_type", + "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/stream_executor/lib", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:Shape", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:TransformUtils", + ], +) + +tf_cc_test( + name = "tfr_decompose_ctx_test", + srcs = ["integration/tfr_decompose_ctx_test.cc"], + deps = [ + ":tfr_decompose_ctx", + "//tensorflow/compiler/xla:test", + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/stream_executor/lib", + "@com_google_absl//absl/types:span", + "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", + "@llvm-project//mlir:IR", + ], +) + +cc_library( + name = "graph_decompose_pass", + srcs = ["integration/graph_decompose_pass.cc"], + hdrs = ["integration/graph_decompose_pass.h"], + deps = [ + ":tfr_decompose_ctx", + "//tensorflow/compiler/mlir:mlir_graph_optimization_pass", + "//tensorflow/stream_executor/lib", + "@llvm-project//mlir:IR", + ], + alwayslink = 1, +) + +tf_py_test( + name = "graph_decompose_test", + size = "small", + srcs = ["integration/graph_decompose_test.py"], + data = ["//tensorflow/compiler/mlir/tfr/resources:decomposition_lib"], + python_version = "PY3", + tags = [ + "no_pip", + "no_windows", # TODO(b/170752141) + "nomac", # TODO(b/170752141) + ], + deps = [ + "//tensorflow/compiler/mlir/tfr/resources:composite_ops", + "//tensorflow/python/eager:def_function", + ], +) + +cc_library( + name = "node_expansion_pass", + srcs = ["integration/node_expansion_pass.cc"], + hdrs = ["integration/node_expansion_pass.h"], + deps = [ + ":tfr_decompose_ctx", + "//tensorflow/core/common_runtime/eager:core", + "//tensorflow/core/common_runtime/eager:eager_op_rewrite_registry", + "//tensorflow/stream_executor/lib", + "@com_google_absl//absl/strings", + ], + alwayslink = 1, +) + +tf_py_test( + name = "node_expansion_test", + size = "small", + srcs = ["integration/node_expansion_test.py"], + data = ["//tensorflow/compiler/mlir/tfr/resources:decomposition_lib"], + python_version = "PY3", + tags = [ + "no_pip", + "no_windows", # TODO(b/170752141) + "nomac", # TODO(b/170752141) + ], + deps = [ + "//tensorflow/compiler/mlir/tfr/resources:composite_ops", + ], +) + +tf_python_pybind_extension( + name = "tfr_wrapper", + srcs = ["python/tfr_wrapper.cc"], + module_name = "tfr_wrapper", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tfr", + "//tensorflow/python:pybind11_lib", + "//tensorflow/python:pybind11_status", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:Shape", + "@llvm-project//mlir:StandardOps", + "@pybind11", + ], +) + +py_library( + name = "composite", + srcs = ["python/composite.py"], + srcs_version = "PY2AND3", +) + +py_library( + name = "tfr_gen", + srcs = ["python/tfr_gen.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow:tensorflow_py", + "//tensorflow/compiler/mlir/tfr:tfr_wrapper", + ], +) + +tf_py_test( + name = "tfr_gen_test", + size = "small", + srcs = ["python/tfr_gen_test.py"], + python_version = "PY3", + tags = ["no_pip"], + deps = [ + ":composite", + ":tfr_gen", + "//tensorflow/compiler/mlir/python/mlir_wrapper:filecheck_wrapper", + "//tensorflow/compiler/mlir/tfr/resources:test_ops", + "//tensorflow/python:array_ops", + "//tensorflow/python:math_ops", + ], +) + +py_library( + name = "op_reg_gen", + srcs = ["python/op_reg_gen.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow:tensorflow_py", + ], +) + +py_test( + name = "op_reg_gen_test", + size = "small", + srcs = ["python/op_reg_gen_test.py"], + python_version = "PY3", + srcs_version = "PY2AND3", + deps = [ + ":composite", + ":op_reg_gen", + "//tensorflow/compiler/mlir/python/mlir_wrapper:filecheck_wrapper", + ], +) + +py_library( + name = "test_utils", + srcs = ["python/test_utils.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow:tensorflow_py", + ], +) + +gen_op_libraries( + name = "one_op", + src = "define_op_template.py", +) diff --git a/tensorflow/compiler/mlir/tfr/build_defs.bzl b/tensorflow/compiler/mlir/tfr/build_defs.bzl new file mode 100644 index 00000000000..2b92d8a652a --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/build_defs.bzl @@ -0,0 +1,116 @@ +"""BUILD extension for TF composition project.""" + +load("//tensorflow:tensorflow.bzl", "py_binary", "tf_custom_op_library", "tf_gen_op_wrapper_py") +load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") + +def gen_op_libraries( + name, + src, + deps = [], + tags = [], + test = False): + """gen_op_libraries() generates all cc and py libraries for composite op source. + + Args: + name: used as the name component of all the generated libraries. + src: File contains the composite ops. + deps: Libraries the 'src' depends on. + tags: + test: + """ + if not src.endswith(".py") or name == src[:-3]: + fail("'src' %s conflicts with op Python wrapper. Rename it to be different from 'name'." % src) + + gen_op_lib_exec = src[:-3] # Strip off the .py + py_binary( + name = gen_op_lib_exec, + srcs = [src], + srcs_version = "PY2AND3", + python_version = "PY3", + deps = [ + "//tensorflow/compiler/mlir/tfr:op_reg_gen", + "//tensorflow/compiler/mlir/tfr:tfr_gen", + "//tensorflow/compiler/mlir/tfr:composite", + ] + deps, + ) + + registed_op = "registed_" + name + native.genrule( + name = registed_op, + srcs = [], + outs = [name + ".inc.cc"], + cmd = "$(location %s) --output=$@ --gen_register_op=true" % gen_op_lib_exec, + exec_tools = [":" + gen_op_lib_exec], + tags = tags, + ) + + native.cc_library( + name = name + "_cc", + testonly = test, + srcs = [":" + registed_op], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], + alwayslink = 1, + ) + + tf_custom_op_library( + name = name + ".so", + srcs = [":" + registed_op], + ) + + tf_gen_op_wrapper_py( + name = "gen_" + name, + out = "gen_" + name + ".py", + deps = [ + ":%s_cc" % name, + ], + ) + + tf_custom_op_py_library( + name = name, + dso = [":%s.so" % name], + kernels = [":%s_cc" % name], + srcs_version = "PY2AND3", + deps = [ + ":gen_%s" % name, + ], + ) + + # Link the register op and rebuild the binary + gen_tfr_lib_exec = gen_op_lib_exec + "_with_op_library" + py_binary( + name = gen_tfr_lib_exec, + main = src, + srcs = [src], + srcs_version = "PY2AND3", + python_version = "PY3", + deps = [ + "//tensorflow/compiler/mlir/tfr:op_reg_gen", + "//tensorflow/compiler/mlir/tfr:tfr_gen", + "//tensorflow/compiler/mlir/tfr:composite", + ":%s" % name, + ] + deps, + ) + + native.genrule( + name = name + "_mlir", + srcs = [], + outs = [name + ".mlir"], + cmd = "$(location %s) --output=$@ --gen_register_op=false" % gen_tfr_lib_exec, + exec_tools = [":" + gen_tfr_lib_exec], + tags = tags, + ) + + native.py_library( + name = name + "_py", + srcs = [src], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/compiler/mlir/tfr:op_reg_gen", + "//tensorflow/compiler/mlir/tfr:tfr_gen", + "//tensorflow/compiler/mlir/tfr:composite", + ] + deps, + ) diff --git a/tensorflow/compiler/mlir/tfr/define_op_template.py b/tensorflow/compiler/mlir/tfr/define_op_template.py new file mode 100644 index 00000000000..c0db2981d2d --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/define_op_template.py @@ -0,0 +1,64 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A template to define composite ops.""" + +# pylint: disable=g-direct-tensorflow-import + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import sys + +from tensorflow.compiler.mlir.tfr.python.composite import Composite +from tensorflow.compiler.mlir.tfr.python.op_reg_gen import gen_register_op +from tensorflow.compiler.mlir.tfr.python.tfr_gen import tfr_gen_from_module +from tensorflow.python.platform import app +from tensorflow.python.platform import flags + +FLAGS = flags.FLAGS + +flags.DEFINE_string( + 'output', None, + 'Path to write the genereated register op file and MLIR file.') + +flags.DEFINE_bool('gen_register_op', True, + 'Generate register op cc file or tfr mlir file.') + +flags.mark_flag_as_required('output') + + +@Composite('TestRandom', derived_attrs=['T: numbertype'], outputs=['o: T']) +def _composite_random_op(): + pass + + +def main(_): + if FLAGS.gen_register_op: + assert FLAGS.output.endswith('.cc') + generated_code = gen_register_op(sys.modules[__name__], '_composite_') + else: + assert FLAGS.output.endswith('.mlir') + generated_code = tfr_gen_from_module(sys.modules[__name__], '_composite_') + + dirname = os.path.dirname(FLAGS.output) + if not os.path.exists(dirname): + os.makedirs(dirname) + with open(FLAGS.output, 'w') as f: + f.write(generated_code) + + +if __name__ == '__main__': + app.run(main=main) diff --git a/tensorflow/compiler/mlir/tfr/examples/mnist/BUILD b/tensorflow/compiler/mlir/tfr/examples/mnist/BUILD new file mode 100644 index 00000000000..eeaee926c87 --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/examples/mnist/BUILD @@ -0,0 +1,60 @@ +load("//tensorflow:tensorflow.bzl", "py_binary") +load("//tensorflow:tensorflow.bzl", "tf_py_test") +load("//tensorflow/compiler/mlir/tfr:build_defs.bzl", "gen_op_libraries") + +package( + default_visibility = [ + ":friends", + ], + licenses = ["notice"], # Apache 2.0 +) + +package_group( + name = "friends", + includes = ["//third_party/mlir:subpackages"], + packages = [ + "//tensorflow/compiler/mlir/tfr/...", + ], +) + +gen_op_libraries( + name = "mnist_ops", + src = "ops_defs.py", + deps = [ + "//tensorflow:tensorflow_py", + ], +) + +tf_py_test( + name = "mnist_ops_test", + size = "small", + srcs = ["mnist_ops_test.py"], + data = [":mnist_ops_mlir"], + python_version = "PY3", + srcs_version = "PY2AND3", + tags = [ + "no_pip", + "no_windows", # TODO(b/170752141) + "nomac", # TODO(b/170752141) + ], + deps = [ + ":mnist_ops", + ":mnist_ops_py", + "//tensorflow:tensorflow_py", + "//tensorflow/compiler/mlir/tfr:test_utils", + ], +) + +py_binary( + name = "mnist_train", + srcs = ["mnist_train.py"], + data = [":mnist_ops_mlir"], + python_version = "PY3", + deps = [ + ":mnist_ops", + ":mnist_ops_py", + "//tensorflow:tensorflow_py", + "@absl_py//absl:app", + "@absl_py//absl/flags", + ], +) diff --git a/tensorflow/compiler/mlir/tfr/examples/mnist/mnist_ops_test.py b/tensorflow/compiler/mlir/tfr/examples/mnist/mnist_ops_test.py new file mode 100644 index 00000000000..d25b424279f --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/examples/mnist/mnist_ops_test.py @@ -0,0 +1,126 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for tensorflow.compiler.mlir.tfr.examples.mnist.ops_defs.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import tensorflow as tf + +from tensorflow.compiler.mlir.tfr.examples.mnist import gen_mnist_ops +from tensorflow.compiler.mlir.tfr.examples.mnist import ops_defs +from tensorflow.compiler.mlir.tfr.python import test_utils +from tensorflow.python.framework import load_library +from tensorflow.python.platform import test + +_lib_dir = os.path.dirname(gen_mnist_ops.__file__) +_lib_name = os.path.basename(gen_mnist_ops.__file__)[4:].replace('.py', '.so') +load_library.load_op_library(os.path.join(_lib_dir, _lib_name)) + + +class MnistOpsDefsTest(test_utils.OpsDefsTest): + + def test_new_conv2d_relu(self): + input_ = tf.random.uniform([1, 4, 4, 1]) + filter_ = tf.random.uniform([2, 2, 1, 8]) + bias = tf.zeros([8]) + kwargs = { + 'input_': input_, + 'filter_': filter_, + 'bias': bias, + 'stride_w': 2, + 'stride_h': 2, + 'dilation_w': 1, + 'dilation_h': 1, + 'padding': 'SAME', + 'act': 'RELU' + } + + self._assertOpAndComposite([input_, filter_, bias], + tf.function(gen_mnist_ops.new_conv2d), + ops_defs._composite_conv_add_relu, kwargs) + + def test_new_conv2d_relu6(self): + input_ = tf.random.uniform([1, 4, 4, 1]) + filter_ = tf.random.uniform([2, 2, 1, 8]) + bias = tf.zeros([8]) + kwargs = { + 'input_': input_, + 'filter_': filter_, + 'bias': bias, + 'stride_w': 2, + 'stride_h': 2, + 'dilation_w': 1, + 'dilation_h': 1, + 'padding': 'SAME', + 'act': 'RELU6' + } + + self._assertOpAndComposite([input_, filter_, bias], + tf.function(gen_mnist_ops.new_conv2d), + ops_defs._composite_conv_add_relu, kwargs) + + def test_new_conv2d_tanh(self): + self.skipTest('Fix tanh gradients') + input_ = tf.random.uniform([1, 4, 4, 1]) + filter_ = tf.random.uniform([2, 2, 1, 8]) + bias = tf.zeros([8]) + kwargs = { + 'input_': input_, + 'filter_': filter_, + 'bias': bias, + 'stride_w': 2, + 'stride_h': 2, + 'dilation_w': 1, + 'dilation_h': 1, + 'padding': 'SAME', + 'act': 'TANH' + } + + self._assertOpAndComposite([input_, filter_, bias], + tf.function(gen_mnist_ops.new_conv2d), + ops_defs._composite_conv_add_relu, kwargs) + + def test_new_fully_connected(self): + input_ = tf.random.uniform([2, 4]) + filter_ = tf.random.uniform([3, 4]) + bias = tf.zeros([3]) + kwargs = {'input_': input_, 'filter_': filter_, 'bias': bias, 'act': 'RELU'} + + self._assertOpAndComposite([input_, filter_, bias], + tf.function(gen_mnist_ops.new_fully_connected), + ops_defs._composite_fully_connected, kwargs) + + def test_new_max_pool(self): + input_ = tf.random.uniform([8, 4, 4, 1]) + kwargs = { + 'input_': input_, + 'stride_w': 2, + 'stride_h': 2, + 'filter_width': 1, + 'filter_height': 1, + 'padding': 'SAME', + } + + self._assertOpAndComposite([input_], + tf.function(gen_mnist_ops.new_max_pool), + ops_defs._composite_max_pool, kwargs) + + +if __name__ == '__main__': + os.environ[ + 'TF_MLIR_TFR_LIB_DIR'] = 'tensorflow/compiler/mlir/tfr/examples/mnist' + test.main() diff --git a/tensorflow/compiler/mlir/tfr/examples/mnist/mnist_train.py b/tensorflow/compiler/mlir/tfr/examples/mnist/mnist_train.py new file mode 100644 index 00000000000..a4adcf86d5b --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/examples/mnist/mnist_train.py @@ -0,0 +1,179 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MNIST model float training script with TensorFlow graph execution.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +from absl import app +from absl import flags + +import tensorflow as tf +import tensorflow_datasets as tfds +from tensorflow.compiler.mlir.tfr.examples.mnist import gen_mnist_ops +from tensorflow.compiler.mlir.tfr.examples.mnist import ops_defs # pylint: disable=unused-import +from tensorflow.python.framework import load_library + +flags.DEFINE_integer('train_steps', 200, 'Number of steps in training.') + +_lib_dir = os.path.dirname(gen_mnist_ops.__file__) +_lib_name = os.path.basename(gen_mnist_ops.__file__)[4:].replace('.py', '.so') +load_library.load_op_library(os.path.join(_lib_dir, _lib_name)) + +# MNIST dataset parameters. +num_classes = 10 # total classes (0-9 digits). +num_features = 784 # data features (img shape: 28*28). +num_channels = 1 + +# Training parameters. +learning_rate = 0.01 +display_step = 10 +batch_size = 128 + +# Network parameters. +n_hidden_1 = 32 # 1st conv layer number of neurons. +n_hidden_2 = 64 # 2nd conv layer number of neurons. +n_hidden_3 = 1024 # 1st fully connected layer of neurons. +flatten_size = num_features // 16 * n_hidden_2 + +seed = 66478 + +weights = { + 'f1': + tf.Variable( + tf.random.truncated_normal([5, 5, num_channels, n_hidden_1], + stddev=0.1, + seed=seed)), + 'f2': + tf.Variable( + tf.random.truncated_normal([5, 5, n_hidden_1, n_hidden_2], + stddev=0.1, + seed=seed)), + 'f3': + tf.Variable( + tf.random.truncated_normal([n_hidden_3, flatten_size], + stddev=0.1, + seed=seed)), + 'f4': + tf.Variable( + tf.random.truncated_normal([num_classes, n_hidden_3], + stddev=0.1, + seed=seed)), +} + +biases = { + 'b1': tf.Variable(tf.zeros([n_hidden_1])), + 'b2': tf.Variable(tf.zeros([n_hidden_2])), + 'b3': tf.Variable(tf.zeros([n_hidden_3])), + 'b4': tf.Variable(tf.zeros([num_classes])), +} + + +class FloatModel(tf.Module): + """Float inference for mnist model.""" + + @tf.function + def __call__(self, data): + """The Model definition.""" + x = tf.reshape(data, [-1, 28, 28, 1]) + + # 2D convolution, with 'SAME' padding (i.e. the output feature map has + # the same size as the input). + + # NOTE: The data/x/input is always specified in floating point precision. + # output shape: [-1, 28, 28, 32] + conv1 = gen_mnist_ops.new_conv2d(x, weights['f1'], biases['b1'], 1, 1, 1, 1, + 'SAME', 'RELU') + + # Max pooling. The kernel size spec {ksize} also follows the layout of + # the data. Here we have a pooling window of 2, and a stride of 2. + # output shape: [-1, 14, 14, 32] + max_pool1 = gen_mnist_ops.new_max_pool(conv1, 2, 2, 2, 2, 'SAME') + + # output shape: [-1, 14, 14, 64] + conv2 = gen_mnist_ops.new_conv2d(max_pool1, weights['f2'], biases['b2'], 1, + 1, 1, 1, 'SAME', 'RELU') + + # output shape: [-1, 7, 7, 64] + max_pool2 = gen_mnist_ops.new_max_pool(conv2, 2, 2, 2, 2, 'SAME') + + # Reshape the feature map cuboid into a 2D matrix to feed it to the + # fully connected layers. + # output shape: [-1, 7*7*64] + reshape = tf.reshape(max_pool2, [-1, flatten_size]) + + # output shape: [-1, 1024] + fc1 = gen_mnist_ops.new_fully_connected(reshape, weights['f3'], + biases['b3'], 'RELU') + # output shape: [-1, 10] + return gen_mnist_ops.new_fully_connected(fc1, weights['f4'], biases['b4']) + + +def grad(model, inputs, labels, trainable_variables): + with tf.GradientTape() as tape: + logits = model(inputs) + loss_value = tf.reduce_mean( + tf.nn.softmax_cross_entropy_with_logits(labels, logits)) + grads = tape.gradient(loss_value, trainable_variables) + correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(labels, 1)) + accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) + return accuracy, loss_value, grads + + +def training_step(model, inputs, labels, optimizer, step): + trainable_variables = list(weights.values()) + list(biases.values()) + accuracy, loss_value, grads = grad(model, inputs, labels, trainable_variables) + if step % display_step == 0: + print('Step %d:' % step) + print(' Loss = %f' % loss_value) + print(' Batch accuracy: %f' % accuracy) + optimizer.apply_gradients(zip(grads, trainable_variables)) + + +def get_next_batch(iter_): + features = next(iter_) + images, labels = features['image'], features['label'] + return (mnist_preprocess(images), tf.one_hot(labels, num_classes)) + + +def mnist_preprocess(x): + x_float = tf.cast(x, tf.float32) + return x_float / 255.0 + + +def train(model, dataset, optimizer): + iter_ = iter(dataset) + for step in range(flags.FLAGS.train_steps): + inputs, labels = get_next_batch(iter_) + training_step(model, inputs, labels, optimizer, step) + + +def main(_): + # TODO(fengliuai): put this in some automatically generated code. + os.environ[ + 'TF_MLIR_TFR_LIB_DIR'] = 'tensorflow/compiler/mlir/tfr/examples/mnist' + # Create an mnist float model with the specified float state. + model = FloatModel() + optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate) + + ds_train = tfds.load('mnist', split='train', shuffle_files=True) + ds_train = ds_train.shuffle(1024).batch(batch_size).prefetch(64) + + train(model, ds_train, optimizer) + + +if __name__ == '__main__': + app.run(main) diff --git a/tensorflow/compiler/mlir/tfr/examples/mnist/ops_defs.py b/tensorflow/compiler/mlir/tfr/examples/mnist/ops_defs.py new file mode 100644 index 00000000000..0cf4678892e --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/examples/mnist/ops_defs.py @@ -0,0 +1,217 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Defines all the new composite ops used in the mnist example.""" + +# pylint: disable=g-direct-tensorflow-import +# pylint: disable=missing-function-docstring + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import sys + +import tensorflow as tf + +from tensorflow.compiler.mlir.tfr.python import composite +from tensorflow.compiler.mlir.tfr.python.op_reg_gen import gen_register_op +from tensorflow.compiler.mlir.tfr.python.tfr_gen import tfr_gen_from_module +from tensorflow.python.ops import gen_math_ops +from tensorflow.python.ops import gen_nn_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import app +from tensorflow.python.platform import flags + +Composite = composite.Composite +FLAGS = flags.FLAGS + +flags.DEFINE_string( + 'output', None, + 'Path to write the genereated register op file and MLIR file.') + +flags.DEFINE_bool('gen_register_op', True, + 'Generate register op cc file or tfr mlir file.') + + +@Composite( + 'NewConv2D', + inputs=['input_: T', 'filter_: T', 'bias: T'], + attrs=[ + 'stride_w: int', 'stride_h: int', 'dilation_w: int', 'dilation_h: int', + 'padding: {"SAME", "VALID"}', 'act: {"", "RELU", "RELU6", "TANH"} = ""' + ], + derived_attrs=['T: {float, int8}'], + outputs=['o: T']) +def _composite_conv_add_relu(input_, filter_, bias, stride_w, stride_h, + dilation_w, dilation_h, padding, act): + res = tf.raw_ops.Conv2D( + input=input_, + filter=filter_, + strides=[1, stride_w, stride_h, 1], + dilations=[1, dilation_w, dilation_h, 1], + padding=padding) + res = tf.raw_ops.Add(x=res, y=bias) + if act == 'RELU': + return tf.raw_ops.Relu(features=res) + elif act == 'RELU6': + return tf.raw_ops.Relu6(features=res) + elif act == 'TANH': + return tf.raw_ops.Tanh(x=res) + else: + return res + + +@tf.RegisterGradient('NewConv2D') +def _conv_add_relu_grad(op, grad): + act = op.get_attr('act') + y = op.outputs[0] + if act == 'RELU': + grad = gen_nn_ops.relu_grad(grad, y) + elif act == 'RELU6': + grad = gen_nn_ops.relu6_grad(grad, y) + elif act == 'TANH': + y = math_ops.conj(y) + grad = gen_math_ops.tanh_grad(y, grad) + + broadcast_shape = tf.shape(y) + input_value_shape = tf.shape(op.inputs[2]) + _, reduction_axes = tf.raw_ops.BroadcastGradientArgs( + s0=broadcast_shape, s1=input_value_shape) + updates_grad_reshaped = tf.reduce_sum( + grad, axis=reduction_axes, keepdims=True) + bias_grad = tf.reshape(updates_grad_reshaped, input_value_shape) + + dilations = [1, op.get_attr('dilation_w'), op.get_attr('dilation_h'), 1] + strides = [1, op.get_attr('stride_w'), op.get_attr('stride_h'), 1] + padding = op.get_attr('padding') + shape_0, shape_1 = tf.shape_n([op.inputs[0], op.inputs[1]]) + return [ + tf.compat.v1.nn.conv2d_backprop_input( + shape_0, + op.inputs[1], + grad, + strides=strides, + padding=padding, + dilations=dilations, + data_format='NHWC'), + tf.compat.v1.nn.conv2d_backprop_filter( + op.inputs[0], + shape_1, + grad, + strides=strides, + padding=padding, + dilations=dilations, + data_format='NHWC'), bias_grad + ] + + +@Composite( + 'NewFullyConnected', + inputs=['input_: T', 'filter_: T', 'bias: T'], + attrs=['act: {"", "RELU", "RELU6", "TANH"} = ""'], + derived_attrs=['T: {float, int8}'], + outputs=['o: T']) +def _composite_fully_connected(input_, filter_, bias, act): + res = tf.raw_ops.MatMul( + a=input_, b=filter_, transpose_a=False, transpose_b=True) + res = tf.raw_ops.Add(x=res, y=bias) + if act == 'RELU': + return tf.raw_ops.Relu(features=res) + elif act == 'RELU6': + return tf.raw_ops.Relu6(features=res) + elif act == 'TANH': + return tf.raw_ops.Tanh(x=res) + else: + return res + + +@tf.RegisterGradient('NewFullyConnected') +def _fully_connected_grad(op, grad): + act = op.get_attr('act') + y = op.outputs[0] + if act == 'RELU': + grad = gen_nn_ops.relu_grad(grad, y) + elif act == 'RELU6': + grad = gen_nn_ops.relu6_grad(grad, y) + elif act == 'TANH': + y = math_ops.conj(y) + grad = gen_math_ops.tanh_grad(y, grad) + + broadcast_shape = tf.shape(y) + input_value_shape = tf.shape(op.inputs[2]) + _, reduction_axes = tf.raw_ops.BroadcastGradientArgs( + s0=broadcast_shape, s1=input_value_shape) + updates_grad_reshaped = tf.reduce_sum( + grad, axis=reduction_axes, keepdims=True) + bias_grad = tf.reshape(updates_grad_reshaped, input_value_shape) + + a = math_ops.conj(op.inputs[0]) + b = math_ops.conj(op.inputs[1]) + grad_a = gen_math_ops.mat_mul(grad, b) + grad_b = gen_math_ops.mat_mul(grad, a, transpose_a=True) + return [grad_a, grad_b, bias_grad] + + +@Composite( + 'NewMaxPool', + inputs=['input_: T'], + attrs=[ + 'stride_w: int', 'stride_h: int', 'filter_width: int', + 'filter_height: int', 'padding: {"SAME", "VALID"}' + ], + derived_attrs=['T: {float, int8}'], + outputs=['o: T']) +def _composite_max_pool(input_, stride_w, stride_h, filter_width, filter_height, + padding): + ksize = [1, filter_width, filter_height, 1] + strides = [1, stride_w, stride_h, 1] + return tf.raw_ops.MaxPool( + input=input_, ksize=ksize, strides=strides, padding=padding) + + +@tf.RegisterGradient('NewMaxPool') +def _max_pool_grad(op, grad): + filter_width = op.get_attr('filter_width') + filter_height = op.get_attr('filter_height') + stride_w = op.get_attr('stride_w') + stride_h = op.get_attr('stride_h') + padding = op.get_attr('padding') + return tf.raw_ops.MaxPoolGrad( + orig_input=op.inputs[0], + orig_output=op.outputs[0], + grad=grad, + ksize=[1, filter_width, filter_height, 1], + strides=[1, stride_w, stride_h, 1], + padding=padding, + data_format='NHWC') + + +def main(_): + if FLAGS.gen_register_op: + assert FLAGS.output.endswith('.cc') + generated_code = gen_register_op(sys.modules[__name__], '_composite_') + else: + assert FLAGS.output.endswith('.mlir') + generated_code = tfr_gen_from_module(sys.modules[__name__], '_composite_',) + + dirname = os.path.dirname(FLAGS.output) + if not os.path.exists(dirname): + os.makedirs(dirname) + with open(FLAGS.output, 'w') as f: + f.write(generated_code) + + +if __name__ == '__main__': + app.run(main=main) diff --git a/tensorflow/compiler/mlir/tfr/examples/pad/BUILD b/tensorflow/compiler/mlir/tfr/examples/pad/BUILD new file mode 100644 index 00000000000..ef08caff939 --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/examples/pad/BUILD @@ -0,0 +1,45 @@ +load("//tensorflow:tensorflow.bzl", "tf_py_test") +load("//tensorflow/compiler/mlir/tfr:build_defs.bzl", "gen_op_libraries") + +package( + default_visibility = [ + ":friends", + ], + licenses = ["notice"], # Apache 2.0 +) + +package_group( + name = "friends", + includes = ["//third_party/mlir:subpackages"], + packages = [ + "//tensorflow/compiler/mlir/tfr/...", + ], +) + +gen_op_libraries( + name = "pad_ops", + src = "ops_defs.py", + deps = [ + "//tensorflow:tensorflow_py", + ], +) + +tf_py_test( + name = "pad_ops_test", + size = "small", + srcs = ["pad_ops_test.py"], + data = [":pad_ops_mlir"], + python_version = "PY3", + srcs_version = "PY2AND3", + tags = [ + "no_pip", + "no_windows", # TODO(b/170752141) + "nomac", # TODO(b/170752141) + ], + deps = [ + ":pad_ops", + ":pad_ops_py", + "//tensorflow:tensorflow_py", + "//tensorflow/compiler/mlir/tfr:test_utils", + ], +) diff --git a/tensorflow/compiler/mlir/tfr/examples/pad/ops_defs.py b/tensorflow/compiler/mlir/tfr/examples/pad/ops_defs.py new file mode 100644 index 00000000000..4b072a58f08 --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/examples/pad/ops_defs.py @@ -0,0 +1,168 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Defines the mirror pad and mirror pad grad.""" + +# pylint: disable=g-direct-tensorflow-import +# pylint: disable=missing-function-docstring + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import sys + +import tensorflow as tf + +from tensorflow.compiler.mlir.tfr.python import composite +from tensorflow.compiler.mlir.tfr.python.op_reg_gen import gen_register_op +from tensorflow.compiler.mlir.tfr.python.tfr_gen import tfr_gen_from_module +from tensorflow.python.ops import gen_array_ops +from tensorflow.python.platform import app +from tensorflow.python.platform import flags + +Composite = composite.Composite +FLAGS = flags.FLAGS + +flags.DEFINE_string( + 'output', None, + 'Path to write the genereated register op file and MLIR file.') + +flags.DEFINE_bool('gen_register_op', True, + 'Generate register op cc file or tfr mlir file.') + + +@Composite( + 'NewMirrorPad', + inputs=['input_: T', 'paddings: Tpaddings'], + attrs=['mode: {"REFLECT", "SYMMETRIC"}'], + derived_attrs=['T: type', 'Tpaddings: {int32, int64} = DT_INT32'], + outputs=['output: T']) +def _composite_mirror_pad(input_, paddings, mode): + shape = input_.shape.as_list() + for i in range(len(shape)): + rdims = tf.raw_ops.OneHot( + indices=i, depth=len(shape), on_value=True, off_value=False, axis=-1) + rarray = tf.raw_ops.Reverse(tensor=input_, dims=rdims) + + left_padding_size = tf.raw_ops.GatherNd(params=paddings, indices=[i, 0]) + right_padding_size = tf.raw_ops.GatherNd(params=paddings, indices=[i, 1]) + + if mode == 'REFLECT': + left_padding, _ = tf.raw_ops.SplitV( + value=rarray, + size_splits=[left_padding_size, -1], + axis=i, + num_split=2) + _, right_padding = tf.raw_ops.SplitV( + value=rarray, + size_splits=[-1, right_padding_size], + axis=i, + num_split=2) + else: + _, left_padding = tf.raw_ops.SplitV( + value=rarray, + size_splits=[-1, left_padding_size], + axis=i, + num_split=2) + right_padding, _ = tf.raw_ops.SplitV( + value=rarray, + size_splits=[right_padding_size, -1], + axis=i, + num_split=2) + + input_ = tf.raw_ops.Concat( + concat_dim=i, values=[left_padding, input_, right_padding]) + return input_ + + +@tf.RegisterGradient('NewMirrorPad') +def _mirror_pad_grad(op, grad): + mode = op.get_attr('mode') + return [gen_array_ops.mirror_pad_grad(grad, op.inputs[1], mode=mode), None] + + +@Composite( + 'NewMirrorPadGrad', + inputs=['input_: T', 'paddings: Tpaddings'], + attrs=['mode: {"REFLECT", "SYMMETRIC"}'], + derived_attrs=['T: type', 'Tpaddings: {int32, int64} = DT_INT32'], + outputs=['output: T']) +def _composite_mirror_pad_grad(input_, paddings, mode): + shape = input_.shape.as_list() + for i in range(len(shape)): + rdims = tf.raw_ops.OneHot( + indices=i, depth=len(shape), on_value=True, off_value=False, axis=-1) + left_padding_size = tf.raw_ops.GatherNd(params=paddings, indices=[i, 0]) + right_padding_size = tf.raw_ops.GatherNd(params=paddings, indices=[i, 1]) + + left_padding, core, right_padding = tf.raw_ops.SplitV( + value=input_, + size_splits=[left_padding_size, -1, right_padding_size], + axis=i, + num_split=3) + reversed_left_padding = tf.raw_ops.Reverse(tensor=left_padding, dims=rdims) + reversed_right_padding = tf.raw_ops.Reverse( + tensor=right_padding, dims=rdims) + zero_like = tf.raw_ops.ZerosLike(x=core) + left_offset, _ = tf.raw_ops.SplitV( + value=zero_like, + size_splits=[-1, left_padding_size], + axis=i, + num_split=2) + right_offset, _ = tf.raw_ops.SplitV( + value=zero_like, + size_splits=[-1, right_padding_size], + axis=i, + num_split=2) + + if mode == 'REFLECT': + from_left_padding = tf.raw_ops.Concat( + concat_dim=i, values=[left_offset, reversed_left_padding]) + from_right_padding = tf.raw_ops.Concat( + concat_dim=i, values=[reversed_right_padding, right_offset]) + else: + from_left_padding = tf.raw_ops.Concat( + concat_dim=i, values=[reversed_left_padding, left_offset]) + from_right_padding = tf.raw_ops.Concat( + concat_dim=i, values=[right_offset, reversed_right_padding]) + input_ = tf.raw_ops.AddN( + inputs=[from_left_padding, core, from_right_padding]) + + return input_ + + +@tf.RegisterGradient('NewMirrorPadGrad') +def _mirror_pad_grad_grad(op, grad): + mode = op.get_attr('mode') + return [gen_array_ops.mirror_pad(grad, op.inputs[1], mode=mode), None] + + +def main(_): + if FLAGS.gen_register_op: + assert FLAGS.output.endswith('.cc') + generated_code = gen_register_op(sys.modules[__name__], '_composite_') + else: + assert FLAGS.output.endswith('.mlir') + generated_code = tfr_gen_from_module(sys.modules[__name__], '_composite_') + + dirname = os.path.dirname(FLAGS.output) + if not os.path.exists(dirname): + os.makedirs(dirname) + with open(FLAGS.output, 'w') as f: + f.write(generated_code) + + +if __name__ == '__main__': + app.run(main=main) diff --git a/tensorflow/compiler/mlir/tfr/examples/pad/pad_ops_test.py b/tensorflow/compiler/mlir/tfr/examples/pad/pad_ops_test.py new file mode 100644 index 00000000000..11f6e0acbf2 --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/examples/pad/pad_ops_test.py @@ -0,0 +1,96 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for tensorflow.compiler.mlir.tfr.examples.pad.ops_defs.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +from absl.testing import parameterized +import tensorflow as tf + +from tensorflow.compiler.mlir.tfr.examples.pad import gen_pad_ops +from tensorflow.compiler.mlir.tfr.examples.pad import ops_defs +from tensorflow.compiler.mlir.tfr.python import test_utils +from tensorflow.python.framework import load_library +from tensorflow.python.platform import test + +_lib_dir = os.path.dirname(gen_pad_ops.__file__) +_lib_name = os.path.basename(gen_pad_ops.__file__)[4:].replace('.py', '.so') +load_library.load_op_library(os.path.join(_lib_dir, _lib_name)) + + +class PadOpsDefsTest(test_utils.OpsDefsTest, parameterized.TestCase): + + @parameterized.named_parameters(('ReflectMode', 'REFLECT'), + ('SymmetricMode', 'SYMMETRIC')) + def test_mirror_pad(self, mode): + input_ = tf.constant([[1, 2, 3], [4, 5, 6]], dtype=tf.float32) + paddings = tf.constant([[ + 1, + 1, + ], [2, 2]]) + kwargs = { + 'input': input_, + 'paddings': paddings, + 'mode': mode, + } + kwargs_ = { + 'input_': input_, + 'paddings': paddings, + 'mode': mode, + } + # Make sure the composition python function is correct + self._assertOpAndComposite([input_], tf.raw_ops.MirrorPad, + ops_defs._composite_mirror_pad, kwargs_, kwargs) + # Make sure the translation and decomposition is correct + self._assertOpAndComposite([input_], + tf.function(gen_pad_ops.new_mirror_pad), + ops_defs._composite_mirror_pad, kwargs_) + + @parameterized.named_parameters(('ReflectMode', 'REFLECT'), + ('SymmetricMode', 'SYMMETRIC')) + def test_mirror_pad_grad(self, mode): + input_ = tf.constant([[2, 1, 1, 2, 3, 3, 2], [2, 1, 1, 2, 3, 3, 2], + [5, 4, 4, 5, 6, 6, 5], [5, 4, 4, 5, 6, 6, 5]], + dtype=tf.float32) + paddings = tf.constant([[ + 1, + 1, + ], [2, 2]]) + kwargs = { + 'input': input_, + 'paddings': paddings, + 'mode': mode, + } + kwargs_ = { + 'input_': input_, + 'paddings': paddings, + 'mode': mode, + } + # Make sure the composition python function is correct + self._assertOpAndComposite([input_], tf.raw_ops.MirrorPadGrad, + ops_defs._composite_mirror_pad_grad, kwargs_, + kwargs) + # Make sure the translation and decomposition is correct + self._assertOpAndComposite([input_], + tf.function(gen_pad_ops.new_mirror_pad_grad), + ops_defs._composite_mirror_pad_grad, kwargs_) + + +if __name__ == '__main__': + os.environ[ + 'TF_MLIR_TFR_LIB_DIR'] = 'tensorflow/compiler/mlir/tfr/examples/pad' + test.main() diff --git a/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.cc b/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.cc new file mode 100644 index 00000000000..99890e9f621 --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.cc @@ -0,0 +1,55 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.h" + +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h" +#include "tensorflow/stream_executor/lib/statusor.h" + +namespace tensorflow { +namespace tfr { + +bool GraphDecomposePass::IsEnabled(const ConfigProto& config_proto) const { + const char* tfr_lib_env_val = getenv(std::string(kTFRLibEnv).c_str()); + return tfr_lib_env_val != nullptr; +} + +Status GraphDecomposePass::Run(const ConfigProto& config_proto, + mlir::ModuleOp module) { + if (!IsEnabled(config_proto)) { + LOG_FIRST_N(INFO, 1) << "Skipping Graph Decomposition Pass, decompositin " + "library was not found"; + return Status::OK(); + } + + LOG_FIRST_N(INFO, 1) << "Run Graph Decomposition Passes"; + + TF_RETURN_IF_ERROR(DecomposeGraph(module)); + + LOG_FIRST_N(INFO, 1) << "Finish Graph Decomposition Passes"; + + return Status::OK(); +} + +namespace { +constexpr int kMlirGraphDecomposePassPriority = -1; + +static mlir_pass_registration::MlirOptimizationPassRegistration + register_mlir_graph_decompose_pass(kMlirGraphDecomposePassPriority, + std::make_unique()); +} // namespace + +} // namespace tfr +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.h b/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.h new file mode 100644 index 00000000000..dd93e99f04b --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.h @@ -0,0 +1,46 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TFR_INTEGRATION_GRAPH_DECOMPOSE_PASS_H_ +#define TENSORFLOW_COMPILER_MLIR_TFR_INTEGRATION_GRAPH_DECOMPOSE_PASS_H_ + +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "tensorflow/compiler/mlir/mlir_graph_optimization_pass.h" +#include "tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h" +#include "tensorflow/stream_executor/lib/statusor.h" + +namespace tensorflow { +namespace tfr { + +// An optimization pass that decompose the composite ops in a module according +// to the decomposition library. Currently the decomposition library is loaded +// each time the pass runs. A special environment variable is set to locate the +// decomposition library. +class GraphDecomposePass : public MlirOptimizationPass { + public: + llvm::StringRef name() const override { return "tfr"; } + + // Whether to run this pass. If this is enabled, the GraphDef will be imported + // to MLIR even no tf composition file is found. + bool IsEnabled(const ConfigProto& config_proto) const override; + + // This should be used as a thin mapper around mlir::ModulePass::runOnModule + // API integrated with the Tensorflow runtime. + Status Run(const ConfigProto& config_proto, mlir::ModuleOp module) override; +}; + +} // namespace tfr +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFR_INTEGRATION_GRAPH_DECOMPOSE_PASS_H_ diff --git a/tensorflow/compiler/mlir/tfr/integration/graph_decompose_test.py b/tensorflow/compiler/mlir/tfr/integration/graph_decompose_test.py new file mode 100644 index 00000000000..d573b8e7195 --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/integration/graph_decompose_test.py @@ -0,0 +1,83 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for tensorflow.compiler.mlir.tfr.integrattion.graph_decompose.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.compiler.mlir.tfr.resources import gen_composite_ops +from tensorflow.python.eager import def_function +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import load_library +from tensorflow.python.framework import ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.platform import test + +_lib_dir = os.path.dirname(gen_composite_ops.__file__) +_lib_name = os.path.basename(gen_composite_ops.__file__)[4:].replace( + '.py', '.so') +load_library.load_op_library(os.path.join(_lib_dir, _lib_name)) + + +class GraphDecomposeTest(test.TestCase): + + def testAddN(self): + add = def_function.function(gen_composite_ops.my_add_n) + t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + t2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + t3 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + sq1 = add([t1]) + sq2 = add([t1, t2]) + sq3 = add([t1, t2, t3]) + self.assertAllEqual(sq1.numpy().reshape(-1), [1, 2, 3, 4]) + self.assertAllEqual(sq2.numpy().reshape(-1), [2, 4, 6, 8]) + self.assertAllEqual(sq3.numpy().reshape(-1), [3, 6, 9, 12]) + + def testBiasedDense(self): + biased_dense = def_function.function(gen_composite_ops.my_biased_dense) + t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + t2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + t3 = constant_op.constant([[-10.0, -10.0], [-10.0, -10.0]]) + sq = biased_dense(t1, t2, t3) + self.assertAllEqual(sq.numpy().reshape(-1), [-3, 0, 5, 12]) + + def testBiasedDenseRelu(self): + biased_dense = def_function.function(gen_composite_ops.my_biased_dense) + t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + t2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + t3 = constant_op.constant([[-10.0, -10.0], [-10.0, -10.0]]) + sq = biased_dense(t1, t2, t3, act='relu') + self.assertAllEqual(sq.numpy().reshape(-1), [0, 0, 5, 12]) + + def testWithKnownKernel(self): + + @def_function.function + def biasd_dense_elu(x, y, z): + dot = gen_composite_ops.my_biased_dense(x, y, z) + return nn_ops.elu(dot) # with known kernel, should not expand. + + t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + t2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + t3 = constant_op.constant([[-10.0, -10.0], [-10.0, -10.0]]) + sq = biasd_dense_elu(t1, t2, t3) + self.assertAllClose(sq.numpy().reshape(-1), [-0.950213, 0, 5, 12]) + + +if __name__ == '__main__': + os.environ['TF_MLIR_TFR_LIB_DIR'] = 'tensorflow/compiler/mlir/tfr/resources' + ops.enable_eager_execution() + test.main() diff --git a/tensorflow/compiler/mlir/tfr/integration/node_expansion_pass.cc b/tensorflow/compiler/mlir/tfr/integration/node_expansion_pass.cc new file mode 100644 index 00000000000..61c4d1c8953 --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/integration/node_expansion_pass.cc @@ -0,0 +1,69 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/tfr/integration/node_expansion_pass.h" + +#include + +#include "absl/strings/str_cat.h" +#include "tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h" +#include "tensorflow/stream_executor/lib/statusor.h" + +namespace tensorflow { +namespace tfr { + +Status CompositeOpExpansion::Run(EagerOperation* orig_op, + std::unique_ptr* out_op) { + if (!IsEnabled()) return Status::OK(); + if (orig_op->Device() != kVariantDeviceNull) return Status::OK(); + + LOG_FIRST_N(INFO, 1) << "Run Node Expansion Passes"; + + // Get the FunctionDef and insert that into the context + const NodeDef& ndef = orig_op->MutableAttrs()->BuildNodeDef(); + auto& ctx = orig_op->EagerContext(); + Fprint128 cache_key = + orig_op->MutableAttrs()->CacheKey(orig_op->DeviceName()); + // Include soft placement policy in cache key since the placement strategy + // can change and thus affect which kernel is picked. + auto x = FingerprintCat64(cache_key.high64, cache_key.low64); + std::string fname = + absl::StrCat("_expanded_", ndef.name(), "_", std::to_string(x)); + if (!ctx.FindFunctionByName(fname)) { + TF_ASSIGN_OR_RETURN(auto func, ExpandNode(ndef, fname)); + TF_RETURN_IF_ERROR(ctx.AddFunctionDef(func)); + } + + // Rewrite the out_op to be the call op. This essentially a deep copy of the + // orig_op, except the op name. + auto* new_op = new EagerOperation(&ctx); + TF_RETURN_IF_ERROR( + new_op->Reset(fname.c_str(), orig_op->DeviceName().c_str())); + for (auto input : orig_op->GetInputs()) { + TF_RETURN_IF_ERROR(new_op->AddInput(input)); + } + new_op->MutableAttrs()->CopyAttributes(orig_op->Attrs()); + out_op->reset(new_op); + + LOG_FIRST_N(INFO, 1) + << "Finish Node Expansion Passes. Rewrite the op to call function: " + << fname; + + return Status::OK(); +} + +REGISTER_REWRITE(EagerOpRewriteRegistry::POST_PLACEMENT, CompositeOpExpansion); + +} // namespace tfr +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfr/integration/node_expansion_pass.h b/tensorflow/compiler/mlir/tfr/integration/node_expansion_pass.h new file mode 100644 index 00000000000..b1e4911b541 --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/integration/node_expansion_pass.h @@ -0,0 +1,49 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TFR_INTEGRATION_NODE_EXPANSION_PASS_H_ +#define TENSORFLOW_COMPILER_MLIR_TFR_INTEGRATION_NODE_EXPANSION_PASS_H_ + +#include "tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h" +#include "tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.h" +#include "tensorflow/stream_executor/lib/statusor.h" + +namespace tensorflow { +namespace tfr { + +// An optimization pass that decompose the composite ops in a module according +// to the decomposition library. Currently the decomposition library is loaded +// each time the pass runs. A special environment variable is set to locate the +// decomposition library. +class CompositeOpExpansion : public EagerOpRewrite { + public: + CompositeOpExpansion(string name, string file, string line) + : EagerOpRewrite(name, file, line) {} + + Status Run(EagerOperation* orig_op, + std::unique_ptr* out_op) override; + + private: + // Whether to run this pass. If this is enabled, the NodeDef will be imported + // to MLIR even no tf composition file is found. + bool IsEnabled() { + const char* tfr_lib_env_val = getenv(string(kTFRLibEnv).c_str()); + return tfr_lib_env_val != nullptr; + } +}; + +} // namespace tfr +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFR_INTEGRATION_NODE_EXPANSION_PASS_H_ diff --git a/tensorflow/compiler/mlir/tfr/integration/node_expansion_test.py b/tensorflow/compiler/mlir/tfr/integration/node_expansion_test.py new file mode 100644 index 00000000000..f99b52fe65a --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/integration/node_expansion_test.py @@ -0,0 +1,78 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for tensorflow.compiler.mlir.tfr.integrattion.node_expansion.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.compiler.mlir.tfr.resources import gen_composite_ops +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import load_library +from tensorflow.python.framework import ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.platform import test + +_lib_dir = os.path.dirname(gen_composite_ops.__file__) +_lib_name = os.path.basename(gen_composite_ops.__file__)[4:].replace( + '.py', '.so') +load_library.load_op_library(os.path.join(_lib_dir, _lib_name)) + + +class NodeExpansionTest(test.TestCase): + + def testAddN(self): + t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + t2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + t3 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + sq1 = gen_composite_ops.my_add_n([t1]) + sq2 = gen_composite_ops.my_add_n([t1, t2]) + sq3 = gen_composite_ops.my_add_n([t1, t2, t3]) + self.assertAllEqual(sq1.numpy().reshape(-1), [1, 2, 3, 4]) + self.assertAllEqual(sq2.numpy().reshape(-1), [2, 4, 6, 8]) + self.assertAllEqual(sq3.numpy().reshape(-1), [3, 6, 9, 12]) + + def testBiasedDense(self): + t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + t2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + t3 = constant_op.constant([[-10.0, -10.0], [-10.0, -10.0]]) + sq = gen_composite_ops.my_biased_dense(t1, t2, t3) + self.assertAllEqual(sq.numpy().reshape(-1), [-3, 0, 5, 12]) + + def testBiasedDenseRelu(self): + t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + t2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + t3 = constant_op.constant([[-10.0, -10.0], [-10.0, -10.0]]) + sq = gen_composite_ops.my_biased_dense(t1, t2, t3, act='relu') + self.assertAllEqual(sq.numpy().reshape(-1), [0, 0, 5, 12]) + + def testWithKnownKernel(self): + + def biasd_dense_elu(x, y, z): + dot = gen_composite_ops.my_biased_dense(x, y, z) + return nn_ops.elu(dot) # with known kernel, should not expand. + + t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + t2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + t3 = constant_op.constant([[-10.0, -10.0], [-10.0, -10.0]]) + sq = biasd_dense_elu(t1, t2, t3) + self.assertAllClose(sq.numpy().reshape(-1), [-0.950213, 0, 5, 12]) + + +if __name__ == '__main__': + os.environ['TF_MLIR_TFR_LIB_DIR'] = 'tensorflow/compiler/mlir/tfr/resources' + ops.enable_eager_execution() + test.main() diff --git a/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.cc b/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.cc new file mode 100644 index 00000000000..61e96548579 --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.cc @@ -0,0 +1,222 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h" + +#include +#include + +#include "absl/strings/str_cat.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/SMLoc.h" +#include "llvm/Support/SourceMgr.h" +#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Identifier.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Verifier.h" // from @llvm-project +#include "mlir/Parser.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Transforms/Passes.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/convert_attr.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" +#include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h" +#include "tensorflow/compiler/mlir/tfr/passes/passes.h" +#include "tensorflow/core/platform/path.h" +#include "tensorflow/core/platform/stringpiece.h" +#include "tensorflow/core/util/env_var.h" +#include "tensorflow/stream_executor/lib/statusor.h" + +namespace tensorflow { +namespace tfr { + +const char* const kTFRLibEnv = "TF_MLIR_TFR_LIB_DIR"; + +StatusOr> TFRDecomposeContext::Get( + mlir::MLIRContext* mlir_ctx) { + Env* env = Env::Default(); + std::string tfr_lib_dir; + TF_RETURN_IF_ERROR(ReadStringFromEnvVar( + kTFRLibEnv, "tensorflow/compiler/mlir/tfr/resources", &tfr_lib_dir)); + string composite_mlir_dir = io::JoinPath(env->GetRunfilesDir(), tfr_lib_dir); + std::vector files; + TF_RETURN_IF_ERROR(env->GetChildren(composite_mlir_dir, &files)); + if (files.empty()) { + return errors::Internal(absl::StrCat( + "Failed to find the decomposition lib from path ", composite_mlir_dir)); + } + std::string tfr_raw_text; + for (const auto& file : files) { + string fullpath = io::JoinPath(composite_mlir_dir, file); + if (env->MatchPath(fullpath, io::JoinPath(composite_mlir_dir, "*.mlir"))) { + std::string text; + TF_RETURN_IF_ERROR(ReadFileToString(env, fullpath, &text)); + tfr_raw_text.append(text); + } + } + + auto ctx = TFRDecomposeContext::GetFromText(tfr_raw_text, mlir_ctx); + if (!ctx) { + return errors::Internal(absl::StrCat( + "Failed to load the imported decomposition lib: ", tfr_raw_text)); + } + return ctx; +} + +std::unique_ptr TFRDecomposeContext::GetFromText( + StringPiece tfr_raw_text, mlir::MLIRContext* mlir_ctx) { + mlir_ctx->allowUnregisteredDialects(/*allow=*/true); + // Load dialects involved in the conversion + mlir::DialectRegistry& registry = mlir_ctx->getDialectRegistry(); + // clang-format off + registry.insert(); + // clang-format on + + // Load the TFR functions in a mlir::ModuleOp + auto memory_buffer = llvm::MemoryBuffer::getMemBuffer( + llvm::StringRef(tfr_raw_text.data(), tfr_raw_text.size())); + llvm::SourceMgr source_mgr; + source_mgr.AddNewSourceBuffer(std::move(memory_buffer), llvm::SMLoc()); + mlir::OwningModuleRef module = mlir::parseSourceFile(source_mgr, mlir_ctx); + // The MLIRContext owns the module + auto module_op = module.release(); + + // Create the context + return absl::make_unique(module_op); +} + +StatusOr TFRDecomposeContext::ExpandNode(const NodeDef& node_def, + StringPiece func_name) { + const OpDef* op_def; + TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node_def.op(), &op_def)); + DataTypeVector input_dtys, output_dtys; + TF_RETURN_IF_ERROR(InputTypesForNode(node_def, *op_def, &input_dtys)); + TF_RETURN_IF_ERROR(OutputTypesForNode(node_def, *op_def, &output_dtys)); + + mlir::MLIRContext* context = tfr_module_.getContext(); + llvm::SmallVector input_tys, output_tys; + mlir::Builder builder(context); + for (auto ty : input_dtys) { + mlir::Type elt_ty; + TF_RETURN_IF_ERROR(ConvertDataType(ty, builder, &elt_ty)); + mlir::TensorType mlir_ty = mlir::UnrankedTensorType::get(elt_ty); + input_tys.push_back(mlir_ty); + } + for (auto ty : output_dtys) { + mlir::Type elt_ty; + TF_RETURN_IF_ERROR(ConvertDataType(ty, builder, &elt_ty)); + mlir::TensorType mlir_ty = mlir::UnrankedTensorType::get(elt_ty); + output_tys.push_back(mlir_ty); + } + llvm::SmallVector attrs; + for (const auto& attr : node_def.attr()) { + TF_ASSIGN_OR_RETURN(auto mlir_attr, + ConvertAttributeValue(attr.second, &builder)); + attrs.push_back({mlir::Identifier::get(attr.first, context), mlir_attr}); + } + + mlir::Location loc = mlir::UnknownLoc::get(context); + mlir::ModuleOp module = mlir::ModuleOp::create(loc); + mlir::FunctionType func_type = + mlir::FunctionType::get(input_tys, output_tys, context); + llvm::StringRef func_name_str(func_name.data(), func_name.size()); + auto func = mlir::FuncOp::create(loc, func_name_str, func_type, {}); + module.push_back(func); + func.addEntryBlock(); + mlir::OpBuilder op_builder(func.getBody()); + + // Create the TF op + const std::string tf_op_full_name = absl::StrCat("tf.", node_def.op()); + mlir::OperationState op_state(loc, tf_op_full_name); + op_state.addOperands(func.getArguments()); + op_state.addTypes(output_tys); + op_state.addAttributes(attrs); + mlir::Operation* tf_op = op_builder.createOperation(op_state); + op_builder.create(loc, tf_op->getResults()); + + // Run the decompose passes on the module + TF_RETURN_IF_ERROR(DecomposeGraph(module)); + + // Export the result as a FunctionDef. + FunctionDef func_def; + TF_RETURN_IF_ERROR( + ConvertMlirFunctionToFunctionLibraryDef(func, export_confs_, &func_def)); + module.erase(); + return func_def; +} + +Status TFRDecomposeContext::DecomposeGraph(mlir::ModuleOp user_module) { + // Call the decompose passes by using the external symbol table. + if (failed(pm_.run(user_module))) { + return errors::Internal("Failed to run the decompose passes."); + } + return Status::OK(); +} + +// Constructor of the decompose context. +TFRDecomposeContext::TFRDecomposeContext(mlir::ModuleOp tfr_module) + : tfr_module_(tfr_module), pm_(tfr_module_.getContext()) { + mlir::OpPassManager& func_pm = pm_.nest(); + + // Prepare the imported graph. + func_pm.addPass(mlir::CreateExecutorDialectToFunctionalConversionPass()); + + // Run TFR lowering, inlining and raising to tf. + func_pm.addPass(mlir::TFR::CreateDecomposeTFOpsPass(tfr_module_)); + func_pm.addPass(mlir::TFR::CreateRaiseToTFOpsPass( + tfr_module_, /*materialize_derived_attrs=*/true)); + + // Prepare to be exported. + func_pm.addPass(mlir::CreateFunctionalToExecutorDialectConversionPass()); + pm_.addPass(mlir::CreateBreakUpIslandsPass()); +} + +void TFRDecomposeContext::Destroy() { tfr_module_.erase(); } + +StatusOr ExpandNode(const NodeDef& node_def, + StringPiece func_name) { + mlir::MLIRContext mlir_ctx; + TF_ASSIGN_OR_RETURN(auto ctx, TFRDecomposeContext::Get(&mlir_ctx)); + return ctx->ExpandNode(node_def, func_name); +} + +Status DecomposeGraph(mlir::ModuleOp user_module) { + mlir::MLIRContext* mlir_ctx = user_module.getContext(); + TF_ASSIGN_OR_RETURN(auto ctx, TFRDecomposeContext::Get(mlir_ctx)); + return ctx->DecomposeGraph(user_module); +} + +} // namespace tfr +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h b/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h new file mode 100644 index 00000000000..6e33bbf0b0c --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h @@ -0,0 +1,81 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TFR_INTEGRATION_TFR_DECOMPOSE_CTX_H_ +#define TENSORFLOW_COMPILER_MLIR_TFR_INTEGRATION_TFR_DECOMPOSE_CTX_H_ + +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/stream_executor/lib/statusor.h" + +namespace tensorflow { +namespace tfr { + +extern const char* const kTFRLibEnv; + +using stream_executor::port::StatusOr; + +// An wrapper for all the objects used to decompose a module (graph mode) and +// node_def (eager mode). Note that this class owns the decomposition library. +class TFRDecomposeContext { + public: + // The entry function to get a decompose context. All the required passes have + // been initialized. + static StatusOr> Get( + mlir::MLIRContext* mlir_ctx); + + // Constructor of the decompose context. To share the decompose library, the + // whole decompose TFR function library is loaded. + explicit TFRDecomposeContext(mlir::ModuleOp tfr_module); + + // Constructs the decompose context from the tfr text module and the mlir + // context. The tfr text module is added to the mlir context. + static std::unique_ptr GetFromText( + StringPiece tfr_raw_text, mlir::MLIRContext* mlir_ctx); + + // Decomposes the op in the NodeDef to a set of primitive ops according to the + // decompose library in the context. Wrap the decomposed result in a + // FunctionDef. + StatusOr ExpandNode(const NodeDef& node_def, + StringPiece func_name); + + // Runs the decompose passes on the user_module. + Status DecomposeGraph(mlir::ModuleOp user_module); + + // Erases the tfr_module created. + void Destroy(); + + private: + mlir::ModuleOp tfr_module_; + mlir::PassManager pm_; + + GraphExportConfig export_confs_; +}; + +// Decomposes the NodeDef to a set of primitive ops according to the decompose +// library loaded. Wrap the decomposed result in a FunctionDef. +StatusOr ExpandNode(const NodeDef& node_def, + StringPiece func_name); + +// Decomposes the ops in the ModuleOp to a set of primitive ops according to +// decompose library in the context. +Status DecomposeGraph(mlir::ModuleOp user_module); + +} // namespace tfr +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFR_INTEGRATION_TFR_DECOMPOSE_CTX_H_ diff --git a/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx_test.cc b/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx_test.cc new file mode 100644 index 00000000000..3d83b8d5535 --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx_test.cc @@ -0,0 +1,162 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h" + +#include +#include + +#include "absl/types/span.h" +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/stream_executor/lib/statusor.h" + +using testing::ElementsAreArray; +using testing::Test; +using NodeAndType = std::pair; + +namespace tensorflow { + +REGISTER_OP("MyAddN") + .Input("inputs: N * T") + .Output("sum: T") + .Attr("N: int >= 1") + .Attr("T: {numbertype, variant}") + .SetIsCommutative() + .SetIsAggregate() + .SetShapeFn(shape_inference::UnchangedShape); + +REGISTER_OP("RiscAdd") + .Input("x: T") + .Input("y: T") + .Output("z: T") + .Attr( + "T: {bfloat16, half, float, double, uint8, int8, int16, int32, int64, " + "complex64, complex128, string}") + .SetShapeFn(shape_inference::UnchangedShape); + +namespace { + +constexpr char tfr_raw_text[] = R"( + +tfr.func @tf__my_add_n(%values: !tfr.tensor_list, + %n: i64 {tfr.name="N"}) -> !tfr.tensor { + %index = constant 0 : index + %cst = constant 1 : i64 + %eq = cmpi "eq", %n, %cst : i64 + %v1 = tfr.get_element %values[%index] : (!tfr.tensor_list, index) -> !tfr.tensor + %res = scf.if %eq -> !tfr.tensor { + scf.yield %v1 : !tfr.tensor + } else { + %step = index_cast %cst : i64 to index + %end = index_cast %n : i64 to index + %reduce = scf.for %i = %step to %end step %step iter_args(%reduce_iter=%v1) -> !tfr.tensor { + %v = tfr.get_element %values[%i] : (!tfr.tensor_list, index) -> !tfr.tensor + %reduce_next = tfr.call @tf__risc_add(%reduce_iter, %v) : (!tfr.tensor, !tfr.tensor) -> !tfr.tensor + scf.yield %reduce_next : !tfr.tensor + } + scf.yield %reduce : !tfr.tensor + } + tfr.return %res : !tfr.tensor +} + +tfr.func @tf__risc_add_(!tfr.tensor, !tfr.tensor) -> !tfr.tensor attributes{T} +)"; + +class TFRDecomposeContextTest : public Test { + protected: + void SetUp() override { + test_ctx_ = tfr::TFRDecomposeContext::GetFromText(tfr_raw_text, &ctx_); + } + + void TearDown() override { test_ctx_->Destroy(); } + + mlir::MLIRContext ctx_; + std::unique_ptr test_ctx_; +}; + +std::vector NodesSequenceOf(const FunctionDef& graph) { + std::vector nodes; + for (auto& node : graph.node_def()) { + nodes.push_back({node.op(), node.attr().at("T").type()}); + } + return nodes; +} + +TEST_F(TFRDecomposeContextTest, FLOAT_1_ins) { + std::vector src_list; + src_list.emplace_back("input", 0, DT_FLOAT); + NodeDef test_node; + auto status = NodeDefBuilder("float_add", "MyAddN") + .Input(src_list) + .Finalize(&test_node); + EXPECT_TRUE(status.ok()); + auto decomposed = test_ctx_->ExpandNode(test_node, "test"); + EXPECT_TRUE(decomposed.ok()); + std::vector expected_results{{"Identity", DT_FLOAT}}; + EXPECT_THAT(NodesSequenceOf(decomposed.ValueOrDie()), + ElementsAreArray(expected_results)); +} + +TEST_F(TFRDecomposeContextTest, FLOAT_3_ins) { + std::vector src_list; + src_list.emplace_back("in0", 0, DT_FLOAT); + src_list.emplace_back("in1", 0, DT_FLOAT); + src_list.emplace_back("in2", 0, DT_FLOAT); + NodeDef test_node; + auto status = NodeDefBuilder("float_add_3", "MyAddN") + .Input(src_list) + .Finalize(&test_node); + EXPECT_TRUE(status.ok()); + auto decomposed = test_ctx_->ExpandNode(test_node, "test"); + EXPECT_TRUE(decomposed.ok()); + + std::vector expected_results{{"RiscAdd", DT_FLOAT}, + {"RiscAdd", DT_FLOAT}}; + EXPECT_THAT(NodesSequenceOf(decomposed.ValueOrDie()), + ElementsAreArray(expected_results)); +} + +TEST_F(TFRDecomposeContextTest, INT32_3_ins) { + std::vector src_list; + src_list.emplace_back("in0", 0, DT_INT32); + src_list.emplace_back("in1", 0, DT_INT32); + src_list.emplace_back("in2", 0, DT_INT32); + NodeDef test_node; + auto status = + NodeDefBuilder("int_add", "MyAddN").Input(src_list).Finalize(&test_node); + EXPECT_TRUE(status.ok()); + auto decomposed = test_ctx_->ExpandNode(test_node, "test"); + EXPECT_TRUE(decomposed.ok()); + + std::vector expected_results{{"RiscAdd", DT_INT32}, + {"RiscAdd", DT_INT32}}; + EXPECT_THAT(NodesSequenceOf(decomposed.ValueOrDie()), + ElementsAreArray(expected_results)); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc b/tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc new file mode 100644 index 00000000000..c0ef5c3b387 --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc @@ -0,0 +1,590 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h" + +#include +#include + +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringSet.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/DialectImplementation.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/FunctionImplementation.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/OpImplementation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/InliningUtils.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/mlir/tfr/ir/tfr_types.h" + +namespace mlir { + +namespace TFR { + +//===----------------------------------------------------------------------===// +// InlinerInterface +//===----------------------------------------------------------------------===// + +namespace { +/// This class defines the interface for inlining within the TFR dialect. +struct TFRInlinerInterface : public DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + + // Returns true if the given region 'src' can be inlined into the region + // 'dest' that is attached to an operation registered to the current dialect. + bool isLegalToInline(Region *dest, Region *src, + BlockAndValueMapping &) const final { + return true; + } + + // Returns true if the given operation 'op', that is registered to this + // dialect, can be inlined into the region 'dest' that is attached to an + // operation registered to the current dialect. + bool isLegalToInline(Operation *op, Region *dest, + BlockAndValueMapping &) const final { + return true; + } + + // Handle the given inlined terminator by replacing it with a new operation + // as necessary. Required when the region has only one block. + void handleTerminator(Operation *op, + ArrayRef valuesToRepl) const final { + auto retValOp = dyn_cast(op); + if (!retValOp) return; + + for (auto ret_value : llvm::zip(valuesToRepl, retValOp.operands())) { + std::get<0>(ret_value).replaceAllUsesWith(std::get<1>(ret_value)); + } + } + + // Attempts to materialize a conversion for a type mismatch between a call + // from this dialect, and a callable region. This method should generate an + // operation that takes 'input' as the only operand, and produces a single + // result of 'resultType'. If a conversion can not be generated, nullptr + // should be returned. + Operation *materializeCallConversion(OpBuilder &builder, Value input, + Type result_type, + Location conversion_loc) const final { + if (!result_type.isa()) return nullptr; + return builder.create(conversion_loc, result_type, input); + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// TFR Dialect +//===----------------------------------------------------------------------===// + +TFRDialect::TFRDialect(MLIRContext *context) + : Dialect(/*name=*/"tfr", context, TypeID::get()) { + addTypes(); + addOperations< +#define GET_OP_LIST +#include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc.inc" + >(); + + addInterfaces(); +} + +bool TFRType::classof(Type type) { + return llvm::isa(type.getDialect()); +} + +//===----------------------------------------------------------------------===// +// Custom op methods +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(ConstantTensorOp op) { + auto input_type = op.arg().getType(); + auto output_type = op.out().getType(); + + if (auto output_tensor_type = output_type.dyn_cast()) { + return success(); + } + + auto output_tensor_type = output_type.dyn_cast(); + if (!output_tensor_type || !output_tensor_type.hasStaticShape()) { + op.emitError("output type should be static and ranked."); + return failure(); + } + + if (output_tensor_type.getRank() == 0) { + bool same_scalar = output_tensor_type.getElementType() == input_type; + if (!same_scalar) { + op.emitError("input and output should have the same scalar types."); + } + return success(same_scalar); + } + + if (auto input_vector_type = input_type.dyn_cast()) { + bool same_element_type = output_tensor_type.getElementType() == + input_vector_type.getElementType(); + bool same_shape = + output_tensor_type.getShape() == input_vector_type.getShape(); + if (!same_element_type || !same_shape) { + op.emitError("input and output should have same shape and element type."); + } + return success(same_element_type && same_shape); + } + + op.emitError("input can not be converted to an output tensor."); + return failure(); +} + +static LogicalResult Verify(TFRFuncOp func) { + // Collect all attribute names used by the tensor and tensor list arguments + // and returns. Also, collect the names of all the attribute arguments as the + // defined list. Later on, the used attribute names will be verified to be in + // the defined list. + llvm::SmallVector used_attrs; + + // While scanning the arguments, record the start/end indices of each argument + // type, so the order can be verified as well. + // TODO(fengliuai): the attribute arguments with default values need to be + // at the end? + int first_tensor = -1, last_tensor = -1, first_tensor_list = -1, + last_tensor_list = -1, first_attr = -1; + + for (auto arg : llvm::enumerate(func.getType().getInputs())) { + Type arg_type = arg.value(); + + if (auto tensor = arg_type.dyn_cast()) { + if (first_tensor == -1) { + first_tensor = arg.index(); + } + last_tensor = arg.index(); + auto used = tensor.getAttrKeys(); + used_attrs.append(used.begin(), used.end()); + continue; + } + + if (auto tensor_list = arg_type.dyn_cast()) { + if (first_tensor_list == -1) { + first_tensor_list = arg.index(); + } + last_tensor_list = arg.index(); + auto used = tensor_list.getAttrKeys(); + used_attrs.append(used.begin(), used.end()); + continue; + } + + if (!arg_type.isa()) { + if (first_attr == -1) { + first_attr = arg.index(); + } + auto name = + func.getArgAttrOfType(arg.index(), kAttrArgumentNameAttr); + if (!name) { + func.emitError( + llvm::Twine(arg.index()) + + " attribute argument doesn't have a tfr.name attribute."); + return failure(); + } + continue; + } + + func.emitError("Builtin TensorType isn't allowed as the argument."); + return failure(); + } + + // Verify the argument order: tensors, tensor list, attributes; and also + // verify there is at most one tensor list argument. + if (first_tensor_list != -1 && first_tensor_list < last_tensor) { + func.emitError( + "tfr.tensor argument should be before tfr.tensor_list argument."); + return failure(); + } + if (first_attr != -1 && first_attr < last_tensor_list) { + func.emitError( + "tfr.tensor_list argument should be before non tensor arguments."); + return failure(); + } + if (first_tensor_list != last_tensor_list) { + func.emitError("More than one tfr.tensor_list argument isn't allowed."); + return failure(); + } + + // Verify the result order: tensor, tensor list, and also verify at most one + // tensor list result. + bool seen_tensor_list = false; + for (auto result_type : func.getType().getResults()) { + if (auto tensor = result_type.dyn_cast()) { + if (seen_tensor_list) { + func.emitError( + "tfr.tensor result should be before tfr.tensor_list result."); + return failure(); + } + auto used = tensor.getAttrKeys(); + used_attrs.append(used.begin(), used.end()); + continue; + } + + if (auto tensor_list = result_type.dyn_cast()) { + if (seen_tensor_list) { + func.emitError("More than one tfr.tensor_list result isn't allowed."); + return failure(); + } + seen_tensor_list = true; + auto used = tensor_list.getAttrKeys(); + used_attrs.append(used.begin(), used.end()); + continue; + } + + func.emitError( + "None tfr.tensor/tfr.tensor_list results aren't allowed as a " + "result."); + return failure(); + } + + // Verify that all the used attributes are in the attribute arguments. + llvm::SmallVector undefined_attrs; + for (auto attr : used_attrs) { + if (!func.getAttr(attr.getValue())) { + undefined_attrs.push_back(attr); + } + } + if (!undefined_attrs.empty()) { + llvm::SmallVector attr_names(undefined_attrs.size()); + std::transform(undefined_attrs.begin(), undefined_attrs.end(), + attr_names.begin(), + [](StringAttr attr) { return attr.getValue().str(); }); + func.emitError(llvm::Twine("Undefined attributes are used: ", + llvm::join(attr_names, ","))); + return failure(); + } + + return success(); +} + +static ParseResult ParseFuncOp(OpAsmParser &parser, OperationState *result) { + auto build_func_type = [](Builder &builder, ArrayRef arg_types, + ArrayRef results, impl::VariadicFlag, + std::string &) { + return builder.getFunctionType(arg_types, results); + }; + return impl::parseFunctionLikeOp(parser, *result, /*allowVariadic=*/false, + build_func_type); +} + +static void PrintFuncOp(OpAsmPrinter &p, TFRFuncOp op) { + FunctionType fn_type = op.getType(); + impl::printFunctionLikeOp(p, op, fn_type.getInputs(), /*isVariadic=*/false, + fn_type.getResults()); +} + +} // namespace TFR +} // namespace mlir + +//===----------------------------------------------------------------------===// +// TableGen'd op method definitions +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc.inc" + +namespace mlir { +namespace TFR { +struct ConvertConstToTensorConst : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ConstantTensorOp cst_tensor_op, + PatternRewriter &rewriter) const override { + Location loc = cst_tensor_op.getLoc(); + Type out_type = cst_tensor_op.getType(); + Operation *new_cst = nullptr; + + ArrayAttr array; + if (matchPattern(cst_tensor_op.arg(), m_Constant(&array))) { + llvm::DenseSet all_types; + for (auto it : array) { + all_types.insert(it.getType()); + } + if (all_types.size() != 1) return failure(); + ShapedType new_out_type = RankedTensorType::get( + {static_cast(array.size())}, *all_types.begin()); + DenseElementsAttr attr = + DenseElementsAttr::get(new_out_type, array.getValue()); + new_cst = rewriter.create(loc, new_out_type, attr); + if (out_type.isa()) { + new_cst = rewriter.create(loc, out_type, new_cst->getResult(0)); + } + rewriter.replaceOp(cst_tensor_op, new_cst->getResult(0)); + return success(); + } + + Attribute scalar; + if (matchPattern(cst_tensor_op.arg(), m_Constant(&scalar))) { + Type new_out_type = RankedTensorType::get({}, scalar.getType()); + new_cst = rewriter.create(loc, new_out_type, scalar); + if (out_type.isa()) { + new_cst = rewriter.create(loc, out_type, new_cst->getResult(0)); + } + rewriter.replaceOp(cst_tensor_op, new_cst->getResult(0)); + return success(); + } + return failure(); + } +}; + +struct RemoveRedundantCast : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(CastOp cast_op, + PatternRewriter &rewriter) const override { + auto preceding_cast = + llvm::dyn_cast_or_null(cast_op.arg().getDefiningOp()); + if (!preceding_cast) { + return failure(); + } + Value input = preceding_cast.arg(); + Type input_type = input.getType(); + Type output_type = cast_op.getType(); + + // If the two types are the same, the back-to-back tfr.cast ops can be + // removed. + if (input_type == output_type || output_type.isa()) { + rewriter.replaceOp(cast_op, {input}); + return success(); + } + + // If the rank of the input tensor isn't ranked, we replace the pair + // with tf.EnsureShape op so it can be removed after shape inference or + // confirmed at runtime. + if (input_type.isa() && output_type.isa()) { + auto shape = output_type.cast().getShape(); + auto shape_attr = TF::ShapeAttr::get(rewriter.getContext(), shape); + rewriter.replaceOpWithNewOp(cast_op, output_type, + input, shape_attr); + } + + return success(); + } +}; + +struct GetTensorShape : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GetShapeOp shape_op, + PatternRewriter &rewriter) const override { + Operation *preceding_op = shape_op.arg().getDefiningOp(); + if (auto cast_op = llvm::dyn_cast_or_null(preceding_op)) { + // replace this pair by shape.shape_of, so the folding works. + rewriter.replaceOpWithNewOp(shape_op, cast_op.arg()); + return success(); + } + return failure(); + } +}; + +struct RemoveRedundantGetElement : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GetElementOp ge_op, + PatternRewriter &rewriter) const override { + IntegerAttr index; + if (!matchPattern(ge_op.index(), m_Constant(&index))) { + return failure(); + } + auto preceding_build_list = llvm::dyn_cast_or_null( + ge_op.tensor_list().getDefiningOp()); + if (!preceding_build_list || + preceding_build_list.getNumOperands() <= index.getInt()) { + return failure(); + } + Value input = preceding_build_list.getOperand(index.getInt()); + Type output_type = ge_op.getType(); + if (input.getType() != output_type && + !output_type.isa()) { + return failure(); + } + rewriter.replaceOp(ge_op, {input}); + return success(); + } +}; + +struct BuildConstantListAsAttr : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(BuildListOp bl_op, + PatternRewriter &rewriter) const override { + SmallVector array_list; + array_list.reserve(bl_op.getNumOperands()); + for (const auto &operand : bl_op.getOperands()) { + Attribute array_elt; + if (!matchPattern(operand, m_Constant(&array_elt))) { + return failure(); + } + array_list.push_back(array_elt); + } + auto array_attr = rewriter.getArrayAttr(array_list); + rewriter.replaceOpWithNewOp(bl_op, array_attr); + return success(); + } +}; + +void ConstantTensorOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + +void CastOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + +void GetShapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + +void GetElementOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + +void BuildListOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + +OpFoldResult TFR::EqualOp::fold(ArrayRef operands) { + assert(operands.size() == 2 && "equal op has two operands"); + auto ctx = getContext(); + if (operands[0] == operands[1]) return BoolAttr::get(/*value=*/true, ctx); + return BoolAttr::get(/*value=*/false, ctx); +} + +OpFoldResult ConstOp::fold(ArrayRef operands) { + assert(operands.empty() && "constant has no operands"); + + // Return the held attribute value. + return value(); +} + +// CallableOpInterface +Region *TFRFuncOp::getCallableRegion() { + return isExternal() ? nullptr : &body().front(); +} + +// CallableOpInterface +ArrayRef TFRFuncOp::getCallableResults() { + return getType().getResults(); +} + +//===----------------------------------------------------------------------===// +// Dialect type definitions +//===----------------------------------------------------------------------===// + +// Parses a TFR type. +// tfr_type ::= tensor_type | tensor_list_type | attr_type +// string_list ::= `[` string-literal (, string-literal)+ `]` +// tensor_type ::= `tensor` +// | `tensor<` (string-literal | string_list) '>' +// tensor_list_type ::= `tensor_list` +// | `tensor_list<` (string-literal | string_list) '>' +// attr_type ::= `attr` +Type TFRDialect::parseType(DialectAsmParser &parser) const { + Location loc = parser.getEncodedSourceLoc(parser.getNameLoc()); + MLIRContext *ctx = loc.getContext(); + + StringRef typeNameSpelling; + if (failed(parser.parseKeyword(&typeNameSpelling))) return {}; + llvm::SmallVector attrs; + if (succeeded(parser.parseOptionalLess())) { + bool l_square_parsed = false; + if (succeeded(parser.parseOptionalLSquare())) { + l_square_parsed = true; + } + + do { + StringRef attr; + if (failed(parser.parseKeyword(&attr))) return {}; + attrs.push_back(StringAttr::get(attr, ctx)); + } while (succeeded(parser.parseOptionalComma())); + + if (l_square_parsed && failed(parser.parseRSquare())) { + parser.emitError(parser.getNameLoc(), "expected ']'"); + } + + if (failed(parser.parseGreater())) { + parser.emitError(parser.getNameLoc(), "expected '>'"); + } + } + + if (typeNameSpelling == "tensor") { + return TFRTensorType::getChecked(attrs, loc); + } else if (typeNameSpelling == "tensor_list") { + return TFRTensorListType::getChecked(attrs, loc); + } else if (typeNameSpelling == "attr") { + return TFRAttrType::getChecked(loc); + } else { + parser.emitError(parser.getNameLoc(), "unknown type " + typeNameSpelling); + return {}; + } +} + +void TFRDialect::printType(Type type, DialectAsmPrinter &os) const { + llvm::ArrayRef attrs; + + if (type.isa()) { + os << "attr"; + return; + } + if (auto tensor_ty = type.dyn_cast()) { + attrs = tensor_ty.getAttrKeys(); + os << "tensor"; + } else if (auto tensor_list_ty = type.dyn_cast()) { + attrs = tensor_list_ty.getAttrKeys(); + os << "tensor_list"; + } else { + llvm_unreachable("Unhandled tfr type"); + } + + if (attrs.empty()) return; + os << "<"; + + if (attrs.size() > 1) { + os << "["; + } + + llvm::interleaveComma(attrs, os, + [&](StringAttr attr) { os << attr.getValue(); }); + + if (attrs.size() > 1) { + os << "]"; + } + os << ">"; +} + +} // namespace TFR +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tfr/ir/tfr_ops.h b/tensorflow/compiler/mlir/tfr/ir/tfr_ops.h new file mode 100644 index 00000000000..cb36ee28351 --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/ir/tfr_ops.h @@ -0,0 +1,55 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_OPS_H_ +#define TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_OPS_H_ + +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/DialectImplementation.h" // from @llvm-project +#include "mlir/IR/FunctionSupport.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project +#include "mlir/Interfaces/ControlFlowInterfaces.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project + +namespace mlir { +namespace TFR { + +constexpr char kAttrArgumentNameAttr[] = "tfr.name"; +constexpr char kAttrArgumentDefaultAttr[] = "tfr.default"; + +class TFRDialect : public Dialect { + public: + explicit TFRDialect(MLIRContext *context); + + static StringRef getDialectNamespace() { return "tfr"; } + + // Parse a type registered to this dialect. + Type parseType(DialectAsmParser &parser) const override; + + // Prints a type registered to this dialect. + void printType(Type ty, DialectAsmPrinter &os) const override; +}; + +} // namespace TFR +} // namespace mlir + +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h.inc" + +#endif // TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_OPS_H_ diff --git a/tensorflow/compiler/mlir/tfr/ir/tfr_ops.td b/tensorflow/compiler/mlir/tfr/ir/tfr_ops.td new file mode 100644 index 00000000000..562b3f79955 --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/ir/tfr_ops.td @@ -0,0 +1,435 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This is the operation definition file for TFR + +#ifndef DIALECT_TFR_OPS_ +#define DIALECT_TFR_OPS_ + +include "mlir/Dialect/Shape/IR/ShapeBase.td" +include "mlir/IR/OpBase.td" +include "mlir/IR/SymbolInterfaces.td" +include "mlir/Interfaces/CallInterfaces.td" +include "mlir/Interfaces/ControlFlowInterfaces.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td" + +//===----------------------------------------------------------------------===// +// Dialect +//===----------------------------------------------------------------------===// + +def TFR_Dialect : Dialect { + let name = "tfr"; + + let description = [{ + The TensorFlow Composition dialect. + }]; + + let cppNamespace = "::mlir::TFR"; +} + +//===----------------------------------------------------------------------===// +// Type classes +//===----------------------------------------------------------------------===// + +// tensor argument types +class TFR_Type : DialectType()">, + "TFR " # name #" type">, + BuildableType<"$_builder.getType()">; +def TFR_TensorType : TFR_Type<"TFRTensor">; +def TFR_TensorListType : TFR_Type<"TFRTensorList">; +def TFR_AllTensorTypes : Type, "all tensor related types">; + +// attribute argument types +def TFR_AttrType : TFR_Type<"TFRAttr">; +def TFR_AttrScalarType: TypeAlias; +def TFR_AttrVectorType : VectorOf<[TF_ElementType, TFR_AttrType]>; +def TFR_AllAttrTypes : Type, "all attribute related types">; + +// all allowed arguments types +def TFR_allowedArgType : Type, "allowed tfr.call operand types">; + +def TFR_allowedConstValues : Attr, "allowed tfr.constant value"> { + let storageType = "Attribute"; + let returnType = "Attribute"; + let convertFromStorage = "$_self"; + let constBuilderCall = "$0"; +} + +// all allowed result types +def TFR_allowedResultType : TypeAlias; + +// standard tensor type and tfr.tensor types can be casted to each other. +def TFR_singleTensorType : Type, "single tensor or tfr.tensor type">; + +// all allowed build list input types +def TFR_allowedBuiltListType : Type, "single tfr.tensor or tensor element type">; + +// all allowed build list result types +def TFR_allowedListResultType : Type, "tfr.tensor_list or tfr.attr type">; + +//===----------------------------------------------------------------------===// +// Op classes +//===----------------------------------------------------------------------===// + +class TFR_Op traits> : + Op; + +def TFR_CallOp : TFR_Op<"call", [CallOpInterface]> { + let description = [{ + The `call` operation represents a direct call to a function that is within + the same symbol scope as the callee. The operands and result types of the + call must match the specified function type. The callee is encoded as a + symbol reference attribute named "callee". + + Example: + + ```mlir + %2 = tfr.call @my_add(%0, %1) : (tfr.tensor, f32) -> tfr.tensor_list + ``` + + Note that the operands of the `call` operation can only be with tfr.tensor, + tfr.tensor_list, tfr.attr and mlir float and integer types. The results of + the `call` operation can only be with tfr.tensor and tfr.tensor_list types. + }]; + + let arguments = (ins + FlatSymbolRefAttr:$callee, + Variadic:$args); + + let results = (outs + Variadic:$outs); + + let extraClassDeclaration = [{ + StringRef getCallee() { return callee(); } + + // Get the argument operands to the called function. + operand_range getArgOperands() { return args(); } + + // Return the callee of this operation. + CallInterfaceCallable getCallableForCallee() { return calleeAttr(); } + }]; + + let assemblyFormat = [{ + $callee `(` $args `)` attr-dict `:` functional-type($args, results) + }]; +} + +def TFR_CastOp : TFR_Op<"cast", [NoSideEffect]> { + let description = [{ + The `cast` operation converts the operand with built-in tensor type to + tfr.tensor type, or vice versa. + + Example: + + ```mlir + %1 = tfr.cast(%0) : tensor -> !tfr.tensor + %3 = tfr.cast(%1) : !tfr.tensor -> tensor + ``` + }]; + + let arguments = (ins TFR_singleTensorType:$arg); + + let results = (outs TFR_singleTensorType:$out); + + let extraClassDeclaration = [{ + // Return element type of the input tensor type. Only available when the + // input is a MLIR built-in tensor type. + Attribute getInputElementType() { + if (auto ty = arg().getType().dyn_cast()) { + return TypeAttr::get(ty.getElementType()); + } + return {}; + } + }]; + + let hasCanonicalizer = 1; +} + +def TFR_GetShapeOp : TFR_Op<"get_shape", [NoSideEffect]> { + let description = [{ + The `get_shape` operation gets the shape of a tfr.tensor and returns + !shape.shape type. + + Example: + + ```mlir + %1 = "tfr.get_shape"(%0) : !tfr.tensor -> !shape.shape + %1 = tfr.get_shape %0 -> !shape.shape + ``` + }]; + + let arguments = (ins TFR_TensorType:$arg); + + let results = (outs Shape_ShapeType:$out); + + let assemblyFormat = "$arg attr-dict `->` type($out)"; + + let hasCanonicalizer = 1; +} + +def TFR_GetElementTypeOp : TFR_Op<"get_element_type", [NoSideEffect]> { + let description = [{ + The `get_element_type` operation gets the element type of a tfr.tensor and + returns !tfr.attr. + + Example: + + ```mlir + %1 = "tfr.get_element_type"(%0) : !tfr.tensor -> !tfr.attr + %1 = tfr.get_element_type %0 -> !tfr.attr + ``` + }]; + + let arguments = (ins TFR_TensorType:$arg); + + let results = (outs TFR_AttrType:$out); + + let assemblyFormat = "$arg attr-dict `->` type($out)"; +} + +def TFR_EqualOp : TFR_Op<"equal", [NoSideEffect, SameTypeOperands]> { + let description = [{ + The `equal` operation compares the values of the tfr.attr type arguments. + The operation returns an i1 boolean indicating if the two values are the + same. + Example: + + ```mlir + %x = tfr.equal %lhs, %rhs -> i1 + %x = "tfr.equal"(%lhs, %rhs) : (!tfr.attr, !tfr.attr) -> i1 + ``` + }]; + + let arguments = (ins + TFR_AttrType:$lhs, + TFR_AttrType:$rhs + ); + let results = (outs BoolLike:$result); + + let hasFolder = 1; + + let assemblyFormat = "$lhs `,` $rhs attr-dict `->` type($result)"; +} + +def TFR_ConstOp : TFR_Op<"constant", [ConstantLike, NoSideEffect]> { + let description = [{ + The `attr` operation stores TF op's attribute, which doesn't support + arithmetic operations. + + Example: + + ```mlir + %1 = "tfr.constant"() { value: i32 } : () -> !tfr.attr + %2 = "tfr.constant"() { value: [i32, f32] } : () -> !tfr.attr + %3 = tfr.constant [i32, f32] -> !tfr.attr + %4 = tfr.constant f32 -> !tfr.attr + ``` + }]; + + let arguments = (ins TFR_allowedConstValues:$value); + + let results = (outs TFR_AttrType:$out); + + let hasFolder = 1; + + let builders = [OpBuilder<"Attribute value", + [{ + auto* ctx = value.getContext(); + $_state.addAttribute("value", value); + $_state.addTypes(TFRAttrType::get(ctx)); + }]> + ]; + + let assemblyFormat = [{ + $value attr-dict `->` type($out) + }]; +} + +def TFR_ConstantTensorOp : TFR_Op<"constant_tensor", [NoSideEffect]> { + let description = [{ + The `constant_tensor` operation converts the operand with non-built-in + tensor type to built-in tensor type or tfr.tensor type. If it is built-in + tensor type, the shape shouldn't be changed during the conversion. + + Example: + + ```mlir + %1 = tfr.contant_tensor(%0) : f32 -> tensor + %3 = tfr.contant_tensor(%2) : vector<1xf32> -> tensor<1xf32> + ``` + }]; + + let arguments = (ins TFR_AllAttrTypes:$arg); + + let results = (outs TFR_singleTensorType:$out); + + let hasCanonicalizer = 1; + + let verifier = [{ return Verify(*this); }]; +} + +def TFR_GetElementOp : TFR_Op<"get_element", [NoSideEffect]> { + let description = [{ + The `get_element` operation extracts one tfr.tensor element from a + tfr.tensor_list. + + Example: + + ```mlir + %2 = tfr.get_element %1[%0] : (tfr.tensor, index) -> tfr.tensor + ``` + }]; + + let arguments = (ins + TFR_TensorListType:$tensor_list, + Index:$index); + + let results = (outs TFR_TensorType:$out); + + let hasCanonicalizer = 1; + + let assemblyFormat = [{ + $tensor_list `[` $index `]` attr-dict `:` + `(` type($tensor_list) `,` type($index) `)` `->` type($out) + }]; +} + +def TFR_BuildListOp : TFR_Op<"build_list", [NoSideEffect]> { + let description = [{ + The `build_list` operation builds a tensor list from a list of tensors, or + an tfr.attr from a list of scalars. + + Example: + + ```mlir + %3 = tfr.build_list(%2, %1, %0) : + (tfr.tensor, tfr.tensor, tfr.tensor) -> tfr.tensor_list + %3 = tfr.build_list(%2, %1, %0) : (i32, i32, i32) -> tfr.attr + ``` + }]; + + let arguments = (ins Variadic:$tensors); + + let results = (outs TFR_allowedListResultType:$out); + + let hasCanonicalizer = 1; +} + +//===----------------------------------------------------------------------===// +// Function related classes +//===----------------------------------------------------------------------===// + +def TFR_TFRFuncOp : TFR_Op<"func", [HasParent<"ModuleOp">, + DeclareOpInterfaceMethods, + FunctionLike, + IsolatedFromAbove, Symbol]> { + let summary = "TFR Function defines a composition of other ops"; + + let description = [{ + Defines a function that can be used to decompose an TF function call to + the invocation of a set of other TF ops. + + Syntax: + + ``` + op ::= `tfr.func` symbol-ref-id `(` argument-list `)` (`->` + function-result-list)? function-attributes? region + ``` + + Example: + + ```mlir + tfr.func @foo(%arg0: !tfr.tensor, %arg1: !tfr.tensor_list, + %arg2: int {tfr.name="T", tfr.default=1}) + attributes {qux: "quux"} { + tfr.return + } + ``` + + Note the arguments are ordered by the following rule: + tfr.tensor > tfr.tensor_list > tfr.attr/i32/..., + and only one trfr.tensor_list argument is allowed. + }]; + + let arguments = (ins + TypeAttr:$type, + StrAttr:$sym_name + ); + + let results = (outs); + + // When the regions is empty, the tfr.func is an external function and used + // to model the element type constraints of the tf op. Otherwise, there is one + // region containing the composition. + let regions = (region VariadicRegion:$body); + + let skipDefaultBuilders = 1; + + let builders = [ + OpBuilder<"StringRef name, FunctionType type, " + "ArrayRef attrs = {}"> + ]; + + let extraClassDeclaration = [{ + // FunctionLike trait needs access to the functions below. + friend class OpTrait::FunctionLike; + + // Hooks for the input/output type enumeration in FunctionLike . + unsigned getNumFuncArguments() { return getType().getNumInputs(); } + unsigned getNumFuncResults() { return getType().getNumResults(); } + }]; + + let verifier = [{ return Verify(*this); }]; + let parser = [{ return ParseFuncOp(parser, &result); }]; + let printer = [{ PrintFuncOp(p, *this); }]; +} + +def TFR_TFRReturnOp : TFR_Op<"return", [HasParent<"TFRFuncOp">, NoSideEffect, + ReturnLike, Terminator]> { + let description = [{ + A terminator operation for regions that appear in the body of `tfr.func` + functions. The operands to the `tfr.return` are the result values returned + by an invocation of the `tfr.func`. + + Note that only the tfr.tensor and tfr.tensor_list can be returned. + }]; + + let arguments = (ins Variadic:$operands); + + let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; +} + +#endif // DIALECT_TFR_OPS_ diff --git a/tensorflow/compiler/mlir/tfr/ir/tfr_types.h b/tensorflow/compiler/mlir/tfr/ir/tfr_types.h new file mode 100644 index 00000000000..4bda8f34658 --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/ir/tfr_types.h @@ -0,0 +1,115 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_TYPES_H_ +#define TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_TYPES_H_ + +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/TypeSupport.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project + +namespace mlir { +namespace TFR { + +class TFRType : public Type { + public: + using Type::Type; + + static bool classof(Type type); +}; + +namespace detail { + +struct TFRTypeStorage final + : public TypeStorage, + public llvm::TrailingObjects { + using KeyTy = ArrayRef; + + explicit TFRTypeStorage(unsigned num_attrs) : num_attrs(num_attrs) {} + + static TFRTypeStorage* construct(TypeStorageAllocator& allocator, KeyTy key) { + // Allocate a new storage instance. + auto byteSize = TFRTypeStorage::totalSizeToAlloc(key.size()); + auto rawMem = allocator.allocate(byteSize, alignof(TFRTypeStorage)); + auto result = ::new (rawMem) TFRTypeStorage(key.size()); + + // Copy in the string attributes into the trailing storage. + std::uninitialized_copy(key.begin(), key.end(), + result->getTrailingObjects()); + return result; + } + + bool operator==(const KeyTy& attrs) const { return attrs == GetAttrs(); } + + KeyTy GetAttrs() const { + return {getTrailingObjects(), num_attrs}; + } + + unsigned num_attrs; +}; + +template +class TFRTypeImpl : public Type::TypeBase { + public: + using Base = Type::TypeBase; + using TFRBase = TFRTypeImpl; + using Base::Base; + + static Derived get(ArrayRef attrs, MLIRContext* context) { + return Base::get(context, attrs); + } + + static Derived getChecked(ArrayRef attrs, Location loc) { + return Base::getChecked(loc, attrs); + } + + static Derived get(MLIRContext* context) { return get({}, context); } + + // TODO(fengliuai): fix the implementation + static LogicalResult verifyConstructionInvariants( + Location loc, ArrayRef attrs) { + return success(); + } + + ArrayRef getAttrKeys() { return Base::getImpl()->GetAttrs(); } +}; +} // namespace detail + +class TFRTensorType : public detail::TFRTypeImpl { + public: + using TFRBase::TFRBase; + static std::string getTypeName() { return "TFRTensorType"; } +}; + +class TFRTensorListType : public detail::TFRTypeImpl { + public: + using TFRBase::TFRBase; + static std::string getTypeName() { return "TFRTensorListType"; } +}; + +class TFRAttrType : public Type::TypeBase { + public: + using Base::Base; + static std::string getTypeName() { return "TFRAttrType"; } +}; + +} // namespace TFR +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_TYPES_H_ diff --git a/tensorflow/compiler/mlir/tfr/passes/canonicalize.cc b/tensorflow/compiler/mlir/tfr/passes/canonicalize.cc new file mode 100644 index 00000000000..d399a10a35e --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/passes/canonicalize.cc @@ -0,0 +1,160 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "llvm/Support/raw_ostream.h" +#include "mlir/Conversion/SCFToStandard/SCFToStandard.h" // from @llvm-project +#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Region.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/InliningUtils.h" // from @llvm-project +#include "mlir/Transforms/LoopUtils.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h" +#include "tensorflow/compiler/mlir/tfr/passes/passes.h" + +//===----------------------------------------------------------------------===// +// Canonicalization patterns for the scf.for and scf.if ops. They are used to +// optimize the control flow in the tfr function. Technically, both patterns +// should be upstreamed to be part of the op definition. +// TODO(fengliuai): sync with the llvm upstream for both patterns. +// +namespace mlir { +namespace TFR { + +namespace { + +struct UnrollSCFForOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(scf::ForOp for_op, + PatternRewriter &rewriter) const override { + Location loc = for_op.getLoc(); + APInt lower_bound, upper_bound, step; + if (!matchPattern(for_op.lowerBound(), m_ConstantInt(&lower_bound)) || + !matchPattern(for_op.upperBound(), m_ConstantInt(&upper_bound)) || + !matchPattern(for_op.step(), m_ConstantInt(&step))) { + return failure(); + } + uint64_t trip_count = (upper_bound - lower_bound).sdiv(step).getZExtValue(); + if (trip_count <= 0) return failure(); + + // TODO(fengliuai): use loopUnrollByFactor once the iter_arg is supported + + Block *single_block = for_op.getBody(); + BlockAndValueMapping mapping; + Value iv = for_op.getInductionVar(); + for (auto iter_op : + llvm::zip(for_op.getRegionIterArgs(), for_op.initArgs())) { + mapping.map(std::get<0>(iter_op), std::get<1>(iter_op)); + } + mapping.map(iv, for_op.lowerBound()); + for (auto i = 0; i < trip_count; ++i) { + if (!iv.use_empty()) { + // iv' = iv + step * i; + Value iter = rewriter.create(loc, i); + Value step_cst = + rewriter.create(loc, step.getSExtValue()); + Value stride = rewriter.create(loc, step_cst, iter); + Value iv_unroll = + rewriter.create(loc, mapping.lookup(iv), stride); + mapping.map(iv, iv_unroll); + } + + Operation *terminator_op; + for (auto it = single_block->begin(); it != single_block->end(); ++it) { + terminator_op = rewriter.clone(*it, mapping); + } + // Map the block arguments to the yield results. + for (auto iter_op : llvm::zip(for_op.getRegionIterArgs(), + terminator_op->getOperands())) { + mapping.map(std::get<0>(iter_op), std::get<1>(iter_op)); + } + rewriter.eraseOp(terminator_op); + } + SmallVector returned; + for (Value arg : for_op.getRegionIterArgs()) { + returned.push_back(mapping.lookup(arg)); + } + rewriter.replaceOp(for_op, returned); + return success(); + } +}; + +// TODO(fengliuai): up stream this pattern. +struct SimplifySCFIfOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(scf::IfOp if_op, + PatternRewriter &rewriter) const override { + // Then branch + if (matchPattern(if_op.condition(), m_NonZero())) { + return InlineRegion(if_op.getLoc(), rewriter, if_op, &if_op.thenRegion()); + } + + // Else branch + if (matchPattern(if_op.condition(), m_Zero())) { + if (if_op.elseRegion().empty()) { + // Remove the op + rewriter.eraseOp(if_op); + return success(); + } else { + return InlineRegion(if_op.getLoc(), rewriter, if_op, + &if_op.elseRegion()); + } + } + + // Not a constant condition + return failure(); + } + + private: + LogicalResult InlineRegion(Location loc, PatternRewriter &rewriter, + Operation *inline_point, Region *region) const; +}; + +LogicalResult SimplifySCFIfOp::InlineRegion(Location loc, + PatternRewriter &rewriter, + Operation *inline_point, + Region *region) const { + InlinerInterface interface(loc.getContext()); + if (failed(inlineRegion(interface, region, inline_point, {}, + inline_point->getResults(), loc, + /*shouldCloneInlinedRegion=*/true))) { + return failure(); + } + + // If the inlining was successful then erase the scf.if op. + rewriter.eraseOp(inline_point); + return success(); +} + +} // namespace + +void populateSCFOpsCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + +} // namespace TFR +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tfr/passes/decompose.cc b/tensorflow/compiler/mlir/tfr/passes/decompose.cc new file mode 100644 index 00000000000..9265437cca9 --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/passes/decompose.cc @@ -0,0 +1,280 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/None.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project +#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "mlir/Transforms/InliningUtils.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h" +#include "tensorflow/compiler/mlir/tfr/ir/tfr_types.h" +#include "tensorflow/compiler/mlir/tfr/passes/passes.h" +#include "tensorflow/compiler/mlir/tfr/utils/utils.h" + +//===----------------------------------------------------------------------===// +// The pass to decompose unregistered TF ops with the TFR compose function. +// +namespace mlir { +namespace TFR { + +namespace { + +// Decompose the TF ops with the registered composition library. +struct DecomposeTFOpsPass + : public PassWrapper { + + explicit DecomposeTFOpsPass(llvm::Optional external_tfr_module) + : external_tfr_module(external_tfr_module) {} + + void runOnFunction() override; + + private: + // Apply canonicalization, mainly constant folding, on the function. + void ApplyCanonicalization(); + + // Rewrite unregistered TF ops to TFR func call ops. Return failure if all the + // ops are registered or the compose function doesn't exist. + LogicalResult RewriteUnregisteredTFOps(); + + // Inline the TFR func call ops. + LogicalResult InlineTFRFuncCalls(); + + // Optional external symbol table to look up the TFR function. + llvm::Optional external_tfr_module; +}; + +void DecomposeTFOpsPass::ApplyCanonicalization() { + OwningRewritePatternList patterns; + + auto* context = &getContext(); + for (auto* op : context->getRegisteredOperations()) { + op->getCanonicalizationPatterns(patterns, context); + } + populateSCFOpsCanonicalizationPatterns(patterns, context); + + applyPatternsAndFoldGreedily(getFunction(), patterns); +} + +LogicalResult DecomposeTFOpsPass::RewriteUnregisteredTFOps() { + FuncOp func = getFunction(); + SymbolTable table(external_tfr_module.hasValue() + ? *external_tfr_module + : func.getParentOfType()); + OpBuilder builder(func); + bool changed = false; + func.walk([&table, &builder, &changed](Operation* op) { + // Only the un-registered ops requires decomposition. The remaining ones + // either will be constant folded or lowered by the rules defined in the + // bridge. + if (op->isRegistered()) { + return; + } + + // Find out the compose function + auto compose_func_name = GetComposeFuncName(op->getName().getStringRef()); + auto compose_func = table.lookup(compose_func_name); + if (!compose_func || compose_func.isExternal()) { + // There are no decomposition methods defined for this op, skip. + return; + } + + auto compose_func_type = compose_func.getType(); + builder.setInsertionPoint(op); + TFRTensorType unconstrainted_tensor_type = builder.getType(); + + // Create the new operands. This is mapping the operands from the target + // TF ops to the TFR function arguments. If the TFR function argument is + // a tensor_list, a "tfr.build_list" op is used to concat the available + // TF op operands. If the TFR function argument isn't a tensor/tensor_list, + // a constant is created by using the attribute stored in the TF op or the + // default value in the argument attribute. + llvm::SmallVector new_operands; + for (auto arg : llvm::enumerate(compose_func_type.getInputs())) { + if (auto tensor_type = arg.value().dyn_cast()) { + auto casted = builder.create(op->getLoc(), tensor_type, + op->getOperand(arg.index())); + new_operands.push_back(casted); + } else if (auto list_type = arg.value().dyn_cast()) { + llvm::SmallVector variadic_operands; + for (int i = arg.index(); i < op->getNumOperands(); i++) { + auto casted = builder.create( + op->getLoc(), unconstrainted_tensor_type, op->getOperand(i)); + variadic_operands.push_back(casted); + } + auto build_list_op = builder.create( + op->getLoc(), list_type, variadic_operands); + new_operands.push_back(build_list_op.out()); + } else { + auto attr_name = compose_func.getArgAttrOfType( + arg.index(), kAttrArgumentNameAttr); + auto attribute = op->getAttr(attr_name.getValue()); + if (!attribute) { + attribute = + compose_func.getArgAttr(arg.index(), kAttrArgumentDefaultAttr); + } + Value attr_cst; + // Wrap these special attributes as a special TFR constant, so the SSA + // value has a valid type to be used as TFR function argument. These + // attributes are not expected to be manipulated by the lowering passes. + if (attribute.isa() || attribute.isa() || + attribute.isa() || attribute.isa()) { + TFRAttrType output_type = TFRAttrType::get(builder.getContext()); + attr_cst = + builder.create(op->getLoc(), output_type, attribute); + } else { + attr_cst = builder.create(op->getLoc(), attribute); + } + new_operands.push_back(attr_cst); + } + } + + // Create the TFR call op + auto new_op = builder.create( + op->getLoc(), compose_func_type.getResults(), + builder.getSymbolRefAttr(compose_func.getName()), new_operands); + + // Replace the use of the old op. This is mapping the results from the + // target TF ops to the TFR function returns. If the TFR function return is + // a tensor_list, "tfr.get_element" op is used to extract the required TF + // op result. + llvm::SmallVector new_results; + for (auto res : llvm::enumerate(compose_func_type.getResults())) { + if (res.value().dyn_cast()) { + new_results.push_back(new_op.getResult(res.index())); + } else if (auto list_type = res.value().dyn_cast()) { + for (int i = res.index(), j = 0; i < op->getNumResults(); i++, j++) { + auto index = + builder.create(op->getLoc(), builder.getIndexAttr(j)); + auto element_op = builder.create( + op->getLoc(), unconstrainted_tensor_type, + new_op.getResult(res.index()), index.getResult()); + new_results.push_back(element_op.out()); + } + } + } + for (auto res : llvm::zip(op->getResults(), new_results)) { + auto casted = builder.create( + op->getLoc(), std::get<0>(res).getType(), std::get<1>(res)); + std::get<0>(res).replaceAllUsesWith(casted.out()); + } + op->erase(); + changed |= true; + }); + + // If `changed` is false, it is considered as a failure, so the recursive + // rewrite will stop. + return success(changed); +} + +LogicalResult DecomposeTFOpsPass::InlineTFRFuncCalls() { + // The Inliner will automatically use the registered dialect inliner. + InlinerInterface inliner(&getContext()); + FuncOp func = getFunction(); + SymbolTable table(external_tfr_module.hasValue() + ? *external_tfr_module + : func.getParentOfType()); + + // The inliner only inlines the TFR call op. + bool changed = false; + auto walk_result = func.walk([&](CallOp call_op) { + auto callee = table.lookup(call_op.callee()); + if (!callee || callee.isExternal()) return WalkResult::advance(); + if (failed(inlineCall(inliner, + cast(call_op.getOperation()), + cast(callee.getOperation()), + callee.getCallableRegion(), + /**shouldCloneInLinedRegion=*/true))) { + // This failure is usually because the decompose function is not defined. + // This call will be raised to TF ops. + return WalkResult::interrupt(); + } + call_op.erase(); + changed |= true; + return WalkResult::advance(); + }); + + if (walk_result.wasInterrupted()) { + signalPassFailure(); + return failure(); + } + + // If `changed` is false, it is considered as a failure, so the recursive + // rewrite will stop. + return success(changed); +} + +void DecomposeTFOpsPass::runOnFunction() { + // Set a maximum iteration threshold in case there are infinite loops in the + // call stack. + int max_iterators = 10; + do { + // canonicalization + ApplyCanonicalization(); + + // rewrite unregistered tf ops. Failed either because no ops can be + // decomposed or the compose function isn't defined. + auto rewrite_status = RewriteUnregisteredTFOps(); + // inline the tfr call op until there are no tfr.call op can be inlined. + auto inline_status = InlineTFRFuncCalls(); + + if (failed(rewrite_status) && failed(inline_status)) { + break; + } + } while (max_iterators-- >= 0); +} + +} // namespace + +// Creates an instance of the pass to decompose the TF ops. +std::unique_ptr> CreateDecomposeTFOpsPass( + llvm::Optional tfr_module) { + return std::make_unique(tfr_module); +} + +static PassRegistration pass( + "tfr-decompose", + "Decompose TF ops with the registered composition library.", + [] { return CreateDecomposeTFOpsPass(); }); + +} // namespace TFR +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tfr/passes/passes.h b/tensorflow/compiler/mlir/tfr/passes/passes.h new file mode 100644 index 00000000000..5c27d81ace8 --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/passes/passes.h @@ -0,0 +1,44 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_PASSES_H_ + +#include "llvm/ADT/None.h" +#include "llvm/ADT/Optional.h" +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project + +namespace mlir { +namespace TFR { + +void populateSCFOpsCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context); + +// Decompose ops. +std::unique_ptr> CreateDecomposeTFOpsPass( + llvm::Optional tfr_module = llvm::None); + +// Raise to TF ops. +std::unique_ptr> CreateRaiseToTFOpsPass( + llvm::Optional tfr_module = llvm::None, + bool materialize_derived_attrs = false); + +} // namespace TFR +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_PASSES_H_ diff --git a/tensorflow/compiler/mlir/tfr/passes/raise_to_tf.cc b/tensorflow/compiler/mlir/tfr/passes/raise_to_tf.cc new file mode 100644 index 00000000000..f3fe9618c62 --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/passes/raise_to_tf.cc @@ -0,0 +1,474 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project +#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "mlir/Transforms/InliningUtils.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h" +#include "tensorflow/compiler/mlir/tfr/ir/tfr_types.h" +#include "tensorflow/compiler/mlir/tfr/passes/passes.h" +#include "tensorflow/compiler/mlir/tfr/utils/utils.h" + +//===----------------------------------------------------------------------===// +// The pass to rewrite the TFR function call ops by TF ops. The callee of the +// TFR function call defines the signatures of the TF ops. +// +namespace mlir { +namespace TFR { + +namespace { + +// This pattern is to rewrite the "tfr.call" op and the "tfr.cast" ops on the +// operands by a TF op with "tfr.cast" ops on the results. The result type of +// the new TF op is an unranked tensor with element type derived. +class RewriteTFRCallOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + public: + explicit RewriteTFRCallOp(MLIRContext* context, const SymbolTable& table, + bool materialize_derived_attrs) + : OpRewritePattern(context), + symbol_table_(table), + materialize_derived_attrs_(materialize_derived_attrs) {} + + LogicalResult matchAndRewrite(CallOp call_op, + PatternRewriter& rewriter) const override; + + private: + // Derives the attribute values for the attributes attached to the + // `input_tfr_type`. These attributes are only for the element type of the + // inputs, and these type information has been collected in the `input_types`. + // The result is stored in `derived_attrs` as the named attributes. Returns + // failure if the attributes stored in the `input_tfr_type` violates the + // assumptions. + LogicalResult AddDerivedAttrs( + PatternRewriter& rewriter, Type input_tfr_type, + ArrayRef input_types, + llvm::StringMap* derived_attrs) const; + + // Collects the operands and attributes for the TF op. At the same time, it + // collects all the derived attribute values to derive the output types of the + // TF op. + LogicalResult CollectInputsAndAttributes( + PatternRewriter& rewriter, TFRFuncOp signature, CallOp call_op, + SmallVectorImpl* inputs, NamedAttrList* arg_attrs, + llvm::StringMap* derived_attrs) const; + + // Uses the collected attribute values to derive all the output types. + LogicalResult DeriveOutputTypes(FunctionType signature, + const llvm::StringMap& attrs, + SmallVectorImpl* output_types) const; + + // Creates the TF op and also the necessary tfr.cast ops to replace the + // original TFR call op. + LogicalResult CreateAndReplaceOp( + PatternRewriter& rewriter, CallOp call_op, + const SmallVectorImpl& output_types, + const SmallVectorImpl& inputs, const NamedAttrList& attr_list, + const llvm::StringMap& derived_attrs) const; + + // Adds a tf.Cast op if the tfr.tensor attribute indicated a fixed element + // type. + // TODO(fengliuai): This method is required when the operand types are not set + // by the frontend correctly. + Value CastToNonDerivedType(PatternRewriter& rewriter, Location loc, + CastOp cast_op, Type input_tfr_type) const { + auto tensor_type = input_tfr_type.dyn_cast(); + if (!tensor_type) return cast_op.arg(); + + auto attr_names = tensor_type.getAttrKeys(); + if (attr_names.empty() || attr_names.size() > 1) return cast_op.arg(); + StringRef tfr_type_attr = attr_names[0].getValue(); + if (!fixed_elt_type_attrs_.contains(tfr_type_attr)) return cast_op.arg(); + + Type result_elt_type; + if (tfr_type_attr == "i32_") { + result_elt_type = rewriter.getI32Type(); + } else if (tfr_type_attr == "i64_") { + result_elt_type = rewriter.getI64Type(); + } else if (tfr_type_attr == "f32_") { + result_elt_type = rewriter.getF32Type(); + } else if (tfr_type_attr == "i1_") { + result_elt_type = rewriter.getI1Type(); + } else { + return cast_op.arg(); + } + + Type original_input_type = + cast_op.getInputElementType().cast().getValue(); + if (result_elt_type != original_input_type) { + UnrankedTensorType result_type = UnrankedTensorType::get(result_elt_type); + return rewriter.create(loc, result_type, cast_op.arg()); + } + return cast_op.arg(); + } + + // For variadic operands, we have to enforce them to use the same types. + // TODO(fengliuai): This method is required when the operand types are not set + // by the frontend correctly. + void CastValuesToSameType(PatternRewriter& rewriter, Location loc, + const llvm::SmallVectorImpl& input_types, + llvm::SmallVectorImpl& input_values) const { + if (input_types.size() <= 1) return; + + Type target_input_type = input_types[0].cast().getValue(); + auto result_type = UnrankedTensorType::get(target_input_type); + for (auto i = 1; i < input_types.size(); ++i) { + Type current_input_type = input_types[i].cast().getValue(); + if (current_input_type != target_input_type) { + input_values[i] = + rewriter.create(loc, result_type, input_values[i]); + } + } + } + + const SymbolTable& symbol_table_; + const bool materialize_derived_attrs_; + const llvm::SmallDenseSet fixed_elt_type_attrs_{"i32_", "i64_", + "f32_", "i1_"}; +}; + +LogicalResult RewriteTFRCallOp::AddDerivedAttrs( + PatternRewriter& rewriter, Type input_tfr_type, + ArrayRef input_types, + llvm::StringMap* derived_attrs) const { + // If there is an attribute associated to the input in the signature, we + // store it as an derived attribute. + if (auto tensor_type = input_tfr_type.dyn_cast()) { + auto attr_names = tensor_type.getAttrKeys(); + if (attr_names.empty()) return success(); + + if (attr_names.size() == 1) { + derived_attrs->insert({attr_names[0].getValue(), input_types[0]}); + return success(); + } + } + + // If there is an attribute associated to the input in the signature, + // we store it as an derived attribute. + if (auto list_type = input_tfr_type.dyn_cast()) { + auto attr_names = list_type.getAttrKeys(); + if (attr_names.empty()) return success(); + + // N*T case + if (attr_names.size() == 2) { + derived_attrs->insert({attr_names[0].getValue(), + rewriter.getI32IntegerAttr(input_types.size())}); + // Note that this uses the first element of the list to infer the T value. + // A tf.Cast is required to cast the other inputs to the same type. + derived_attrs->insert({attr_names[1].getValue(), input_types[0]}); + return success(); + } + + // list(dtype) case + if (attr_names.size() == 1) { + derived_attrs->insert( + {attr_names[0].getValue(), rewriter.getArrayAttr(input_types)}); + return success(); + } + } + + return failure(); +} + +LogicalResult RewriteTFRCallOp::CollectInputsAndAttributes( + PatternRewriter& rewriter, TFRFuncOp signature, CallOp call_op, + SmallVectorImpl* inputs, NamedAttrList* arg_attrs, + llvm::StringMap* derived_attrs) const { + for (const auto& operand : llvm::enumerate(signature.getType().getInputs())) { + // If the index is larger than the operand number of the call_op, the + // default value of the operand needs to be used. + if (operand.index() >= call_op.getNumOperands()) { + auto attr_name = signature.getArgAttrOfType( + operand.index(), kAttrArgumentNameAttr); + auto attr_value = + signature.getArgAttr(operand.index(), kAttrArgumentDefaultAttr); + arg_attrs->push_back( + rewriter.getNamedAttr(attr_name.getValue(), attr_value)); + continue; + } + + // The index is valid for the call_op. + Value input = call_op.getOperand(operand.index()); + Operation* input_op = input.getDefiningOp(); + auto input_tfr_type = signature.getType().getInputs()[operand.index()]; + + // There are three cases for the preceding input_op: + + // 1. The preceding op can be a tfr.cast op, which will be fused to the + // current op, so the result op has input with tensor type. + if (auto cast_op = dyn_cast_or_null(input_op)) { + Value input_to_cast = CastToNonDerivedType(rewriter, call_op.getLoc(), + cast_op, input_tfr_type); + inputs->push_back(input_to_cast); + if (failed(AddDerivedAttrs(rewriter, input_tfr_type, + {cast_op.getInputElementType()}, + derived_attrs))) { + return failure(); + } + continue; + } + + // 2. The preceding op is a tfr.build_list op, which collects multiple + // values with tensor types via the tfr.cast ops. These ops will be fused + // to the current op as well, so all the tfr.cast op inputs will be inputs + // to the result op. + if (auto list_op = dyn_cast_or_null(input_op)) { + // Find out all the inputs to the build list op + // TODO(fengliuai): make build_list op only take tensor argument + llvm::SmallVector list_input_types; + llvm::SmallVector list_inputs; + for (auto list_input : list_op.getOperands()) { + auto cast_op = dyn_cast_or_null(list_input.getDefiningOp()); + if (!cast_op) return failure(); + list_inputs.push_back(cast_op.arg()); + list_input_types.push_back(cast_op.getInputElementType()); + } + CastValuesToSameType(rewriter, call_op.getLoc(), list_input_types, + list_inputs); + inputs->append(list_inputs.begin(), list_inputs.end()); + if (failed(AddDerivedAttrs(rewriter, input_tfr_type, list_input_types, + derived_attrs))) { + return failure(); + } + continue; + } + + // 3. The preceding op is a constant, thus the value of this constant is + // used to create an attribute of the result op, according to the signature. + Attribute arg_value; + // A failure indicates the argument isn't a constant value, so we should + // not use it as an attribute. + if (!matchPattern(input, m_Constant(&arg_value))) { + return failure(); + } + auto attr_name = signature.getArgAttrOfType( + operand.index(), kAttrArgumentNameAttr); + arg_attrs->push_back( + rewriter.getNamedAttr(attr_name.getValue(), arg_value)); + } + return success(); +} + +// For each output, uses the attribute name associated to the tfr types to find +// out the attribute value from the collected `attrs` and create the output type +// of the result op by using the attribute value as the element type. +LogicalResult RewriteTFRCallOp::DeriveOutputTypes( + FunctionType signature, const llvm::StringMap& attrs, + SmallVectorImpl* output_types) const { + for (auto res : llvm::enumerate(signature.getResults())) { + if (auto tensor_type = res.value().dyn_cast()) { + // tfr.tensor should only have one attribute attached. + auto attr_key = tensor_type.getAttrKeys().front(); + output_types->push_back(UnrankedTensorType::get( + attrs.lookup(attr_key.getValue()).cast().getValue())); + continue; + } + + if (auto list_type = res.value().dyn_cast()) { + // There are two cases: N*T or list(dtype) + auto attr_keys = list_type.getAttrKeys(); + // N*T case + if (attr_keys.size() == 2) { + // The first one is N, and the second one is T + int list_size = + attrs.lookup(attr_keys[0].getValue()).cast().getInt(); + Type list_type = + attrs.lookup(attr_keys[1].getValue()).cast().getValue(); + for (int i = 0; i < list_size; ++i) { + output_types->push_back(UnrankedTensorType::get(list_type)); + } + continue; + } + // TODO(fengliuai): list(dtype) case + } + return failure(); + } + return success(); +} + +LogicalResult RewriteTFRCallOp::CreateAndReplaceOp( + PatternRewriter& rewriter, CallOp call_op, + const SmallVectorImpl& output_types, + const SmallVectorImpl& inputs, const NamedAttrList& attr_list, + const llvm::StringMap& derived_attrs) const { + // Create the new op + Location loc = call_op.getLoc(); + rewriter.setInsertionPointAfter(call_op); + std::string tf_op_name = GetTFOpName(call_op.callee()); + OperationState new_state(loc, tf_op_name, inputs, output_types, attr_list); + Operation* new_op = rewriter.createOperation(new_state); + if (materialize_derived_attrs_) { + for (const auto& attr : derived_attrs) { + // Add or update the derived attribute with the value. Skip the fixed + // element type attributes, in case they are present in the NodeDef. + if (!fixed_elt_type_attrs_.contains(attr.first())) { + new_op->setAttr(attr.first(), attr.second); + } + } + } + + // Create the tfr.cast ops on the results and replace the uses of the + // original call op. + TFRTensorType unconstrainted_type = rewriter.getType(); + SmallVector new_results; + for (auto res : llvm::enumerate(call_op.getResultTypes())) { + Type res_type = res.value(); + if (res_type.dyn_cast()) { + Value new_res = new_op->getResult(res.index()); + auto casted = rewriter.create(loc, res_type, new_res); + new_results.push_back(casted.out()); + } else if (auto list_type = res.value().dyn_cast()) { + SmallVector tensor_list; + for (int i = res.index(); i < new_op->getNumResults(); i++) { + Value new_res = new_op->getResult(i); + auto casted = + rewriter.create(loc, unconstrainted_type, new_res); + tensor_list.push_back(casted.out()); + } + auto list_op = rewriter.create(loc, res_type, tensor_list); + new_results.push_back(list_op.out()); + } + } + rewriter.replaceOp(call_op, new_results); + return success(); +} + +LogicalResult RewriteTFRCallOp::matchAndRewrite( + CallOp call_op, PatternRewriter& rewriter) const { + // Get the func op and verify that it is external. The type of this external + // func op is used as the signature of the corresponding TF ops. All the + // external func ops have the trailing underscore. + std::string external_callee_name = call_op.callee().str().append("_"); + TFRFuncOp func = symbol_table_.lookup(external_callee_name); + if (!func || !func.isExternal()) return failure(); + // Get the inputs and attributes. The attributes include these from the + // argument list and also these derived from the inputs. + SmallVector inputs; + NamedAttrList argument_attrs; + llvm::StringMap derived_attrs; + if (failed(CollectInputsAndAttributes(rewriter, func, call_op, &inputs, + &argument_attrs, &derived_attrs))) { + return failure(); + } + + // Derive the output types. The result type is derived by using the + // attributes attched to the result type of the signature. The attribute + // value should be either in the attribute argument list or the derived + // attribute from the input tensors. All the result type + // are unranked, and shape inference should be applied afterwards. + SmallVector output_types; + + // Merge the attributes from the argument list to the derived ones. + for (auto& attr : argument_attrs) { + derived_attrs.insert({attr.first, attr.second}); + } + + // Derive the output types by using the attributes attached to the tfr + // types. + if (failed(DeriveOutputTypes(func.getType(), derived_attrs, &output_types))) { + return failure(); + } + + // Create the new op and replace the old TFR call op. + return CreateAndReplaceOp(rewriter, call_op, output_types, inputs, + argument_attrs, derived_attrs); +} + +// Raise TFR call ops to the TF ops. +struct RaiseToTFOpsPass : public PassWrapper { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + + explicit RaiseToTFOpsPass(llvm::Optional tfr_module, + bool materialize_derived_attrs) + : external_tfr_module(tfr_module), + materialize_derived_attrs(materialize_derived_attrs) {} + + void runOnFunction() override; + + private: + llvm::Optional external_tfr_module; + const bool materialize_derived_attrs; +}; + +void RaiseToTFOpsPass::runOnFunction() { + FuncOp func = getFunction(); + MLIRContext* ctx = &getContext(); + SymbolTable table(external_tfr_module.hasValue() + ? *external_tfr_module + : func.getParentOfType()); + + OwningRewritePatternList patterns; + patterns.insert(ctx, table, materialize_derived_attrs); + for (auto* op : ctx->getRegisteredOperations()) { + op->getCanonicalizationPatterns(patterns, ctx); + } + + applyPatternsAndFoldGreedily(func, patterns); +} +} // namespace + +// Creates an instance of the pass to raise TFR call ops to the TF ops. +std::unique_ptr> CreateRaiseToTFOpsPass( + llvm::Optional tfr_module, bool materialize_derived_attrs) { + return std::make_unique(tfr_module, + materialize_derived_attrs); +} + +static PassRegistration pass( + "tfr-raise-to-tf", "Raise all the TFR call ops to TF ops.", + [] { return CreateRaiseToTFOpsPass(); }); + +} // namespace TFR +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tfr/passes/tfr_opt.cc b/tensorflow/compiler/mlir/tfr/passes/tfr_opt.cc new file mode 100644 index 00000000000..8f06f278369 --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/passes/tfr_opt.cc @@ -0,0 +1,37 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/InitAllDialects.h" // from @llvm-project +#include "mlir/InitAllPasses.h" // from @llvm-project +#include "mlir/Support/MlirOptMain.h" // from @llvm-project +#include "tensorflow/compiler/mlir/init_mlir.h" +#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h" + +int main(int argc, char **argv) { + tensorflow::InitMlir y(&argc, &argv); + + mlir::registerAllPasses(); + + mlir::DialectRegistry registry; + registry.insert(); + return failed(mlir::MlirOptMain(argc, argv, "TFR Pass Driver\n", registry)); +} diff --git a/tensorflow/compiler/mlir/tfr/python/composite.py b/tensorflow/compiler/mlir/tfr/python/composite.py new file mode 100644 index 00000000000..7f558ce2fe7 --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/python/composite.py @@ -0,0 +1,56 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Op composition registration.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +# TODO(fengliuai): add the tf_export decrator +class Composite(object): + """A decorator to register a function as a composition for an TF operator. + + The argument to the decorator must be the name of a TF raw operator the + function composites for. Decorated function must take positional arguments + which corresponds to the input and attributes in OpDef of the TF operation. + # TODO(fengliuai): more documents here. + + Example: + @composite.Composite('AddN') + def _compose_add_n(inputs, N): + if N == 1: + .... + """ + + # TODO(fengliuai): support input_binding and output_binding so the arguments + # are not positional. + def __init__(self, + op_name, + inputs=None, + attrs=None, + derived_attrs=None, + outputs=None): + self._op_name = op_name + self._inputs = inputs + self._attrs = attrs + self._derived_attrs = derived_attrs + self._outputs = outputs + + def __call__(self, compose_fn): + # TODO(fengliuai): more sanity check of the input function and make sure + # the bounded arguments of the function matches the 'inputs' and 'attrs'. + setattr(compose_fn, '_tfr_op_name', self._op_name) + return compose_fn diff --git a/tensorflow/compiler/mlir/tfr/python/op_reg_gen.py b/tensorflow/compiler/mlir/tfr/python/op_reg_gen.py new file mode 100644 index 00000000000..99b2dfdedc4 --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/python/op_reg_gen.py @@ -0,0 +1,147 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""op_reg_gen: Generate op registration code from composite op code.""" + +# pylint: disable=invalid-name +# pylint: disable=missing-function-docstring +# pylint: disable=g-direct-tensorflow-import + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gast as ast + +from tensorflow.python.autograph.pyct import transformer +from tensorflow.python.autograph.pyct import transpiler +from tensorflow.python.framework import op_def_registry +from tensorflow.python.util import tf_inspect + +_COMPOSITE_ARG_LIST = ['op_name', 'inputs', 'attrs', 'derived_attrs', 'outputs'] + + +class OpRegGenImpl(transformer.CodeGenerator): + """Visit the AST and generate C++ op registration functions.""" + + def __init__(self, ctx): + super(OpRegGenImpl, self).__init__(ctx) + self.ctx = ctx + + def visit_Name(self, node): + return node.id + + def visit_Constant(self, node): + return node.value + + def visit_keyword(self, node): + return node.arg, self.visit(node.value) + + def visit_List(self, node): + return [self.visit(cst) for cst in node.elts] + + def visit_arguments(self, node): + return [self.visit(arg) for arg in node.args] + + def visit_FunctionDef(self, node): + # TODO(fengliuai): create one utility method to match different apis and + # shared it with the tfr_gen.py module. + compose_dec = [] + for dec in node.decorator_list: + if isinstance(dec, ast.Call): + if isinstance(dec.func, ast.Attribute) and dec.func.attr == 'Composite': + compose_dec.append(dec) + if isinstance(dec.func, ast.Name) and dec.func.id == 'Composite': + compose_dec.append(dec) + + if not compose_dec: + # skip a non-composition function + return + elif len(compose_dec) > 1: + raise KeyError('More than one TF ops decomposes for.') + + all_dec_args = {} + for arg_name, arg_value in zip(_COMPOSITE_ARG_LIST, compose_dec[0].args): + all_dec_args[arg_name] = self.visit(arg_value) + + kw_dec_args = dict([self.visit(kw) for kw in compose_dec[0].keywords]) + + if all_dec_args.keys() & kw_dec_args.keys(): + raise KeyError('More arguments than expected.') + + all_dec_args.update(kw_dec_args) + + op_name = all_dec_args['op_name'] + op_def = op_def_registry.get(op_name) + if op_def: + if len(all_dec_args) > 1: + # Op has been registered, so it is a user error to specify op def. + raise ValueError('Op has been registered: ' + op_name) + else: + # Op has been registered, then we don't need to generate register code. + return + + # Validates the function inputs match what are in the decorator. + inputs = all_dec_args.get('inputs', []) + attrs = all_dec_args.get('attrs', []) + expected_args = [arg.split(':')[0] for arg in inputs + attrs] + all_func_args = self.visit(node.args) + + if len(expected_args) != len(all_func_args): + raise KeyError('Composition arguments do not match the registration.') + + cxx_reg_code = '\nREGISTER_OP("{0}")'.format(op_name) + for input_ in inputs: + cxx_reg_code += '\n .Input("{0}")'.format(input_) + for attr in attrs: + py_str = attr.replace('"', '\'') + cxx_reg_code += '\n .Attr("{0}")'.format(py_str) + for attr in all_dec_args.get('derived_attrs', []): + py_str = attr.replace('"', '\'') + cxx_reg_code += '\n .Attr("{0}")'.format(py_str) + for output_ in all_dec_args.get('outputs', []): + cxx_reg_code += '\n .Output("{0}")'.format(output_) + cxx_reg_code += ';\n' + self.emit(cxx_reg_code) + + +class OpRegGen(transpiler.GenericTranspiler): + """Transforms Python objects into TFR MLIR source code.""" + + def transform_ast(self, node, ctx): + gen = OpRegGenImpl(ctx) + gen.visit(node) + return gen.code_buffer + + +def op_reg_gen(func): + """Parse a function and emit the TFR functions.""" + op_reg_code, _ = OpRegGen().transform(func, None) + return op_reg_code + + +def gen_register_op(source, method_prefix=None): + """Parse a python code and emit the TFR functions from a target class.""" + mlir_funcs = [ + op_reg_gen(func) + for name, func in tf_inspect.getmembers(source, tf_inspect.isfunction) + if not method_prefix or name.startswith(method_prefix) + ] + headers = r""" +#include "tensorflow/core/framework/op.h" + +namespace tensorflow { + """ + code = '\n'.join(mlir_funcs) + return headers + code + '} // namespace tensorflow\n' diff --git a/tensorflow/compiler/mlir/tfr/python/op_reg_gen_test.py b/tensorflow/compiler/mlir/tfr/python/op_reg_gen_test.py new file mode 100644 index 00000000000..6392015ba4d --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/python/op_reg_gen_test.py @@ -0,0 +1,81 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for `op_reg_gen` module.""" + +# pylint: disable=missing-function-docstring +# pylint: disable=invalid-name +# pylint: disable=g-direct-tensorflow-import + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys + +from tensorflow.compiler.mlir.python.mlir_wrapper import filecheck_wrapper as fw +from tensorflow.compiler.mlir.tfr.python import composite +from tensorflow.compiler.mlir.tfr.python.op_reg_gen import gen_register_op +from tensorflow.python.platform import test + + +Composite = composite.Composite + + +@composite.Composite( + 'TestNoOp', derived_attrs=['T: numbertype'], outputs=['o1: T']) +def _composite_no_op(): + pass + + +@Composite( + 'TestCompositeOp', + inputs=['x: T', 'y: T'], + attrs=['act: {"", "relu"}', 'trans: bool = true'], + derived_attrs=['T: numbertype'], + outputs=['o1: T', 'o2: T']) +def _composite_op(x, y, act, trans): + return x + act, y + trans + + +class TFRGenTensorTest(test.TestCase): + """MLIR Generation Tests for MLIR TFR Program.""" + + def test_op_reg_gen(self): + cxx_code = gen_register_op(sys.modules[__name__]) + cxx_code_exp = r""" + CHECK-NEXT: #include "third_party/tensorflow/core/framework/op.h" + CHECK-EMPTY + CHECK-LABEL: namespace tensorflow { + CHECK-EMPTY + CHECK-LABEL: REGISTER_OP("TestNoOp") + CHECK-NEXT: .Attr("T: numbertype") + CHECK-NEXT: .Output("o1: T"); + CHECK-EMPTY + CHECK-LABEL: REGISTER_OP("TestCompositeOp") + CHECK-NEXT: .Input("x: T") + CHECK-NEXT: .Input("y: T") + CHECK-NEXT: .Attr("act: {'', 'relu'}") + CHECK-NEXT: .Attr("trans: bool = true") + CHECK-NEXT: .Attr("T: numbertype") + CHECK-NEXT: .Output("o1: T") + CHECK-NEXT: .Output("o2: T"); + CHECK-EMPTY + CHECK-LABEL: } // namespace tensorflow + """ + self.assertTrue(fw.check(str(cxx_code), cxx_code_exp), str(cxx_code)) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/compiler/mlir/tfr/python/test_utils.py b/tensorflow/compiler/mlir/tfr/python/test_utils.py new file mode 100644 index 00000000000..62aa3e39105 --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/python/test_utils.py @@ -0,0 +1,48 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test utils for composite op definition.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.eager import backprop +from tensorflow.python.platform import test + + +class OpsDefsTest(test.TestCase): + """Test utils.""" + + def _assertOpAndComposite(self, vars_, compute_op, compute_composite, kwargs, + op_kwargs=None): + if op_kwargs is None: + op_kwargs = kwargs + + # compute with op. + with backprop.GradientTape() as gt: + for var_ in vars_: + gt.watch(var_) + y = compute_op(**op_kwargs) # uses op and decomposites by the graph pass. + grads = gt.gradient(y, vars_) # uses registered gradient function. + + # compute with composition + with backprop.GradientTape() as gt: + for var_ in vars_: + gt.watch(var_) + re_y = compute_composite(**kwargs) # uses composite function. + re_grads = gt.gradient(re_y, vars_) # uses gradients compposite function. + + for v, re_v in zip(y, re_y): + self.assertAllClose(v, re_v) + for g, re_g in zip(grads, re_grads): + self.assertAllClose(g, re_g) diff --git a/tensorflow/compiler/mlir/tfr/python/tfr_gen.py b/tensorflow/compiler/mlir/tfr/python/tfr_gen.py new file mode 100644 index 00000000000..3bf89c7a2d5 --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/python/tfr_gen.py @@ -0,0 +1,1377 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""tfr_gen: Generate mlir tfr decomposition function from python code.""" + +# pylint: disable=invalid-name +# pylint: disable=missing-function-docstring +# pylint: disable=g-direct-tensorflow-import + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import enum +import os +import re +import types +from typing import List, Tuple +import gast as ast + +from tensorflow.compiler.mlir.tfr import tfr_wrapper as tfr +from tensorflow.core.framework import types_pb2 +from tensorflow.python.autograph.converters import control_flow +from tensorflow.python.autograph.converters import return_statements +from tensorflow.python.autograph.impl import api +from tensorflow.python.autograph.pyct import anno +from tensorflow.python.autograph.pyct import cfg +from tensorflow.python.autograph.pyct import qual_names +from tensorflow.python.autograph.pyct import transformer +from tensorflow.python.autograph.pyct import transpiler +from tensorflow.python.autograph.pyct.static_analysis import activity +from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions +from tensorflow.python.autograph.pyct.static_analysis import reaching_fndefs +from tensorflow.python.autograph.pyct.static_analysis import type_inference +from tensorflow.python.framework import load_library +from tensorflow.python.framework import op_def_registry +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util import tf_inspect + + +class TFRTypes(enum.Enum): + """All the supported types. + + 1-3: tfr types + 4-99: mlir built-in types + 100-199: TF related translator internal types + 200- : Python related translator internal types + """ + TENSOR = 1 + TENSOR_LIST = 2 + ATTR = 3 + NONE = 4 + SHAPE = 5 # shape -> !shape.shape + I1 = 21 + I32 = 22 + I64 = 23 + F32 = 24 + INDEX = 25 + AG_UNDEFINED_VAL = 100 + AG_BUILTIN_FUNC = 101 + TF_RAW_OP = 102 + TF_REGION = 103 + TF_TENSOR_SHAPE_FUNC = 104 # shape.as_list + TF_TENSOR_SHAPE_LIST = 105 # shape.as_list() + PY_BUILTIN_FUNC = 200 + + # As these are not real types, __getattribute__ helps them appear more like + # actual types (i.e. class definitions). + def __getattribute__(self, name): + if name == 'shape' and object.__getattribute__(self, 'value') == 1: + return TFRTypes.SHAPE + if name == 'as_list' and object.__getattribute__(self, 'value') == 5: + return TFRTypes.TF_TENSOR_SHAPE_FUNC + return object.__getattribute__(self, name) + + def __str__(self): + if self.value < 4: # pylint: disable=comparison-with-callable + return '!tfr.' + self.name.lower() + elif self.value < 10: # pylint: disable=comparison-with-callable + return '!shape.' + self.name.lower() + else: + return self.name.lower() + + +_attribute_types = [ + TFRTypes.I1, TFRTypes.I32, TFRTypes.I64, TFRTypes.F32, TFRTypes.INDEX, + TFRTypes.ATTR +] + + +def _get_type_from_proto(arg_def=None, attr_def=None): + if not arg_def: + if attr_def.type == 'bool': + return TFRTypes.I1 + elif attr_def.type == 'int32': + return TFRTypes.I32 + elif attr_def.type == 'int' or attr_def.type == 'int64': + return TFRTypes.I64 + elif attr_def.type == 'float': + return TFRTypes.F32 + else: + return TFRTypes.ATTR + + if arg_def.number_attr or arg_def.type_list_attr: + return TFRTypes.TENSOR_LIST + else: + return TFRTypes.TENSOR + + +def _get_type_info_from_proto(arg_def=None, attr_def=None): + attr_type = _get_type_from_proto(arg_def, attr_def) + if not arg_def: + return '{}{{tfr.name="{}"}}'.format(attr_type, attr_def.name) + else: + attr_names = [] + if arg_def.number_attr: + attr_names.append(arg_def.number_attr) + if arg_def.type_attr: + attr_names.append(arg_def.type_attr) + if arg_def.type_list_attr: + attr_names.append(arg_def.type_list_attr) + + # TODO(fengliuai): currently we don't support backward type inference, so we + # have to store these non-derivable type in the signatures, and then they + # can be used to cast the values when raising to tf ops. + if arg_def.type == types_pb2.DT_FLOAT: + attr_names.append('f32_') + elif arg_def.type == types_pb2.DT_INT32: + attr_names.append('i32_') + elif arg_def.type == types_pb2.DT_INT64: + attr_names.append('i64_') + elif arg_def.type == types_pb2.DT_BOOL: + attr_names.append('i1_') + + if not attr_names: + return str(attr_type) + else: + return '{}<{}>'.format(attr_type, ','.join(attr_names)) + + +def _get_val_from_proto(attr_type, attr_val): + if attr_type == TFRTypes.I1: + return 'true' if attr_val.b else 'false' + elif attr_type == TFRTypes.I32 or attr_type == TFRTypes.I64: + return attr_val.i + elif attr_type == TFRTypes.F32: + return attr_val.f + elif attr_type == TFRTypes.ATTR: + # string + if attr_val.HasField('s'): + return '"{}"'.format(attr_val.s.decode()) + # type + if attr_val.HasField('type'): + if attr_val.type == types_pb2.DT_FLOAT: + return 'f32' + elif attr_val.type == types_pb2.DT_INT32: + return 'i32' + elif attr_val.type == types_pb2.DT_INT64: + return 'i64' + elif attr_val.type == types_pb2.DT_BOOL: + return 'i1' + # list + if attr_val.HasField('list'): + if attr_val.list.f: + elt_ty = TFRTypes.F32 + values = attr_val.list.f + elif attr_val.list.i: + elt_ty = TFRTypes.I64 + values = attr_val.list.i + else: + elt_ty = TFRTypes.NONE + values = [] + array_attr_elts = ['{}:{}'.format(val, elt_ty) for val in values] + return '[{}]'.format(','.join(array_attr_elts)) + raise NotImplementedError( + 'Proto AttrValue not recoganized. type: {}, value: {}'.format( + attr_type, attr_val)) + + +def _collect_derived_attrs_from_proto(op_def): + derived_attrs = set() + for arg in op_def.input_arg: + if arg.type_attr: + derived_attrs.add(arg.type_attr) + if arg.number_attr: + derived_attrs.add(arg.number_attr) + if arg.type_list_attr: + derived_attrs.add(arg.type_list_attr) + + # TODO(fengliuai): currently we don't support backward type inference, so we + # have to store these non-derivable type in the signatures, and then they + # can be used to cast the values when raising to tf ops. + if arg.type == types_pb2.DT_FLOAT: + derived_attrs.add('f32_') + elif arg.type == types_pb2.DT_INT32: + derived_attrs.add('i32_') + elif arg.type == types_pb2.DT_INT64: + derived_attrs.add('i64_') + elif arg.type == types_pb2.DT_BOOL: + derived_attrs.add('i1_') + return derived_attrs + + +def _require_tensor_list(arg_def): + return arg_def.type_list_attr or arg_def.number_attr + + +def _camel_to_snake(name): + s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name) + return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower() + + +class OpDefCache(object): + """A Dict to cache the OpDef for the Python function name.""" + + def __init__(self): + self._op_defs = {} + + def lookup(self, f_name, func_def=None, optional=False): + if f_name in self._op_defs.keys(): + return self._op_defs[f_name] + + if isinstance(func_def, types.FunctionType): + if not hasattr(func_def, '_tfr_op_name'): + # skip a non-composition function + if optional: + return (None, None) + else: + raise KeyError('OpDef does not exist: ' + f_name) + op_name = getattr(func_def, '_tfr_op_name') + elif not func_def: + op_name = f_name + else: + # TODO(fengliuai): create one utility method to match different apis. + compose_dec = [] + for dec in func_def.decorator_list: + if isinstance(dec, ast.Call): + if isinstance(dec.func, + ast.Attribute) and dec.func.attr == 'Composite': + compose_dec.append(dec) + if isinstance(dec.func, ast.Name) and dec.func.id == 'Composite': + compose_dec.append(dec) + + if not compose_dec: + # skip a non-composition function + if optional: + return (None, None) + else: + raise KeyError('OpDef does not exist: ' + f_name) + elif len(compose_dec) > 1: + raise KeyError('More than one TF ops decomposes for.') + else: + op_name = compose_dec[0].args[0].value + + op_def = op_def_registry.get(op_name) + if not op_def: + raise ValueError('Not a registered op: ' + op_name) + derived_attrs = _collect_derived_attrs_from_proto(op_def) + self._op_defs[f_name] = (op_def, derived_attrs) + return (op_def, derived_attrs) + + def mlir_external_funcs(self): + tfr_funcs = [] + for _, (op_def, derived_attrs) in sorted(self._op_defs.items()): + tfr_func = '\ntfr.func @tf__{}_('.format(_camel_to_snake(op_def.name)) + + # tensor inputs + inputs = [ + _get_type_info_from_proto(arg_def) for arg_def in op_def.input_arg + ] + + # attribute inputs. The attribute with default values are moved backwards. + non_derived_attrs = [ + attr for attr in op_def.attr if attr.name not in derived_attrs + ] + attrs_no_default = [ + attr for attr in non_derived_attrs + if not attr.HasField('default_value') + ] + attrs_with_default = [ + attr for attr in non_derived_attrs if attr.HasField('default_value') + ] + attr_names = set() + for attr_def in attrs_no_default + attrs_with_default: + inputs.append(_get_type_info_from_proto(None, attr_def)) + attr_names.add(attr_def.name) + + # tensor outputs + outputs = [ + _get_type_info_from_proto(arg_def) for arg_def in op_def.output_arg + ] + + inputs = ','.join(inputs) + outputs = ','.join(outputs) + attrs = ','.join(sorted(derived_attrs.union(attr_names))) + tfr_funcs.append('{}{}) -> ({}) attributes {{{}}}'.format( + tfr_func, inputs, outputs, attrs)) + return tfr_funcs + + +_PY_TYPE_TO_TFR = { + bool: TFRTypes.I1, + int: TFRTypes.I64, + float: TFRTypes.F32, +} + +_AG_FIXED_RETURN_TYPE = { + 'for_stmt': type(None), + 'if_stmt': type(None), + 'Undefined': TFRTypes.AG_UNDEFINED_VAL, +} + +QN = qual_names.QN + +# TODO(mdan): Fix this with an importable module. +AG_MODULE = api._TRANSPILER._extra_locals['ag__'] # pylint:disable=protected-access + + +class TFRTypeResolver(type_inference.Resolver): + """Resolve types for the external names, calls and arguments.""" + + def __init__(self, op_defs): + super(TFRTypeResolver, self).__init__() + self._op_defs = op_defs + + # This pattern matching mechanism works with the functional form generated + # by autograph: + # + # for i in data: + # print(i) + # + # generates: + # + # def loop_body(itr): + # i = itr + # print(i) + # ag__.for_stmt(target) + # + # The mechanism lets us infer the type of the itr argument based on that of + # target. + self._for_loop_target_types = {} # Maps body function name to iterated. + self._for_loop_body_fns = {} # Used only to avoid collisions. + + def res_name(self, ns, types_ns, name): + name_str = str(name) + if name_str in ns: + ns_val = ns[name_str] + return {type(ns_val)}, ns_val + if name_str in __builtins__: + return {TFRTypes.PY_BUILTIN_FUNC}, __builtins__[name_str] + # This name is not in the namespace because the autograph transformation + # is not backloaded into Python. + if name_str == 'ag__': + return {type(AG_MODULE)}, AG_MODULE + + return None, None + + def res_value(self, ns, value): + if value is None: + return {TFRTypes.NONE} + if value in (TFRTypes.SHAPE, TFRTypes.TF_TENSOR_SHAPE_FUNC): + # See TFRTypes.__getattrbute__. + # TODO(mdan): Replacing the enum with classes would avoid this overlap. + return {value} + # TODO(mdan): Index more efficiently. Could do a name check instead. + if any(v is value for v in AG_MODULE.__dict__.values()): + return {TFRTypes.AG_BUILTIN_FUNC} + if getattr(value, '__name__', None) == 'tensorflow.raw_ops': + return {types.ModuleType} + if hasattr(value, '__module__'): + # All the imported operations, which are not autograph built-ins, are + # considered to be TF raw ops. + # TODO(fengliuai): refine the condition so we only matche tensorflow + # ops here. + return {TFRTypes.TF_RAW_OP} + # TODO(mdan): Is ATTR equivalent to string? + return {_PY_TYPE_TO_TFR.get(type(value), TFRTypes.ATTR)} + + def res_call(self, ns, types_ns, node, f_type, args, keywords): + name = anno.Basic.QN.of(node.func) + if f_type == (TFRTypes.AG_BUILTIN_FUNC,): + + if name == QN(QN('ag__'), attr='if_stmt'): + nouts = node.args[6].value + # TODO(mdan): Look at the actual types out of if_body. + side_effects = { + qual_names.QN(n.value): {TFRTypes.TENSOR} + for n in node.args[5].elts[:nouts] + } + return {type(None)}, side_effects + + if name == QN(QN('ag__'), attr='for_stmt'): + assert isinstance(node.args[2], ast.Name) + body_fn_name = str(anno.Basic.QN.of(node.args[2])) + assert body_fn_name not in self._for_loop_body_fns, ( + 'Previously used here: {}. Are you reusing the Resolver across ' + 'transformations?').format(self._for_loop_body_fns[body_fn_name]) + self._for_loop_body_fns[body_fn_name] = anno.Basic.ORIGIN.of(node) + + iterated_type = args[0] + assert iterated_type & { + TFRTypes.TENSOR_LIST, TFRTypes.TENSOR, List[int] + }, ( + iterated_type) + self._for_loop_target_types[body_fn_name] = iterated_type + + return {type(None)}, None + + # TODO(mdan): Actually resolve the type here instead. + ret_type = _AG_FIXED_RETURN_TYPE.get(name.qn[1], None) + if ret_type is not None: + return {ret_type}, None + raise NotImplementedError('return type of {}'.format(name)) + + elif f_type == (TFRTypes.TF_RAW_OP,): + op_name = name.qn[1] + op_def, _ = self._op_defs.lookup(op_name) + if len(op_def.output_arg) == 1: + return {_get_type_from_proto(op_def.output_arg[0])}, None + return ({tuple(_get_type_from_proto(arg) for arg in op_def.output_arg)}, + None) + + elif f_type == (TFRTypes.PY_BUILTIN_FUNC,): + assert name.is_simple() + if name == QN('range'): + return {List[int]}, None + + if name == QN('len'): + return {TFRTypes.INDEX}, None + + elif f_type == (TFRTypes.TF_TENSOR_SHAPE_FUNC,): + return {TFRTypes.TF_TENSOR_SHAPE_LIST}, None + + raise NotImplementedError('Function:', name, f_type) + + def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local): + if f_is_local: + f_name_str = str(f_name) + if f_name_str in self._for_loop_target_types: + # See autograph/converters/control_flow.py - the function has a single + # argument, the iterate before any expansion. + assert self._for_loop_target_types[f_name_str] & {List[int]} + # Assume all loops are TF loops. Then the iterates are autoboxed into + # Tensors. + return {TFRTypes.INDEX} + else: + return None + + func = ns[f_name] + + op_def, derived_attrs = self._op_defs.lookup(f_name, func) + if op_def is None: + return None + pos = tf_inspect.getfullargspec(func).args.index(str(name)) + + if pos < len(op_def.input_arg): + arg_def = op_def.input_arg[pos] + return {_get_type_from_proto(arg_def)} + elif pos < len(op_def.input_arg) + len(op_def.attr) - len(derived_attrs): + non_derived_attr_pos = pos - len(op_def.input_arg) + for attr_def in op_def.attr: + # derived attribute, skip this one and continue to the next one. + if attr_def.name in derived_attrs: + continue + if non_derived_attr_pos == 0: + return {_get_type_from_proto(None, attr_def)} + non_derived_attr_pos -= 1 + + raise ValueError('Argument is not defined in OpDef: ' + str(name)) + + def res_subscript(self, ns, types_ns, node_or_slice, value, slice_): + assert len(value) == 1 + value, = tuple(value) + if value == TFRTypes.TF_TENSOR_SHAPE_LIST: + # TODO(mdan): This is not entirely correct for multi-element slices. + return {int} + elif value in (TFRTypes.TENSOR_LIST, TFRTypes.TENSOR): + # TODO(mdan): This is not entirely correct for multi-element slices. + return {TFRTypes.TENSOR} + raise NotImplementedError('slice of {}'.format(value)) + + def res_compare(self, ns, types_ns, node, left, right): + # TODO(fengliuai): make sure left and right are compatible + return {TFRTypes.I1} + + def res_binop(self, ns, types_ns, node, left, right): + # TODO(fengliuai): make sure left and right are compatible + return left + + +class SymbolTable(object): + """Symbol Table for python code.""" + + def __init__(self): + self.symbols = [] + self.enter_scope() + self.scf_scope = 0 + # reserved key words + self.insert_symbol('len', 'len', TFRTypes.PY_BUILTIN_FUNC) + + def enter_scope(self, scf_scope=False): + """Enter a new scope - at function level.""" + self.symbols.append({'types': {}, 'symbols': {}}) + self.curr_table = self.symbols[len(self.symbols) - 1] + if scf_scope: + self.scf_scope += 1 + + def insert_symbol(self, name, value, type_): + self.curr_table['symbols'][name] = (value, type_) + # TODO(mdan): Use the inferred type rather than tracking it here. + # The following field is decrepcated. + self.curr_table['types'][name] = type_ + return value + + def exit_scope(self): + self.symbols.pop() + self.curr_table = self.symbols[len(self.symbols) - 1] + if self.scf_scope > 0: + self.scf_scope -= 1 + + def in_scf_scope(self): + return self.scf_scope > 0 + + def lookup(self, name): + curr_idx = len(self.symbols) - 1 + while curr_idx >= 0 and (name not in self.symbols[curr_idx]['symbols']): + curr_idx -= 1 + if curr_idx < 0: + return None + return self.symbols[curr_idx]['symbols'][name] + + +class TFRGen(transformer.CodeGenerator): + """Visit the AST and generate MLIR TFR functions.""" + + def __init__(self, ctx, op_defs): + super(TFRGen, self).__init__(ctx) + self.ctx = ctx + self.symbol_table = SymbolTable() + self._op_defs = op_defs + + def _create_mlir_loc(self, loc): + """Creates mlir location from autograph ORIGIN value. + + Args: + loc: OriginInfo + + Returns: + A serialized mlir location string. + """ + if loc is not None and loc.loc.filename: + file_name = os.path.basename(loc.loc.filename) + return 'loc("{}":{}:{})'.format(file_name, loc.loc.lineno, + loc.loc.col_offset) + else: + return 'loc(unknown)' + + def _emit_with_loc(self, op_str, node=None): + """Emit the mlir operation with the location associated with the node. + + Args: + op_str: The mlir operation string to be emitted. + node: The node of the AST tree, the mlir operation translated from. + """ + loc = '' + if node: + loc = self._create_mlir_loc( + anno.getanno(node, anno.Basic.ORIGIN, default=None)) + self.emit(op_str + ' ' + loc) + + def _get_inferred_type(self, node, default=None): + types_ = anno.getanno(node, anno.Static.TYPES, None) + if not types_: + print('WARN: no Static.TYPES annotation. Fix the type inference pass: ') + self.debug_print(node) + return default + if types_ and len(types_) > 1: + raise ValueError('ambiguous inferred type for "{}": {}'.format( + node, types_)) + + type_, = types_ + # TODO(fengliuai): Tuple is added here to make return tuple work. + if type_ is list or type_ is Tuple: + # TODO(fengliuai): Seems like we need to move the followed list handling + # to the type inference and we shouldn't just put 'list' there. Otherwise + # we couldn't find out the right type for the Name node. + if not isinstance(node, ast.List): + return default + all_types = [ + anno.getanno(elt, anno.Static.TYPES, None) for elt in node.elts + ] + if (TFRTypes.TENSOR,) in all_types: + # For the elt which is not tfr.tensor, tfr.constant_tensor needs to be + # use to cast it to a tfr.tensor. + return TFRTypes.TENSOR_LIST + else: + return TFRTypes.ATTR + + if default is not None and type_ != default: + print('WARN: type annotation {}({}) does not match {}({})'.format( + type_, type(type_), default, type(default))) + self.debug_print(node) + + return type_ + + def _pack_tensor_list(self, value): + # This is packing a list of tensors, then the axis is 0. + axis = self._ssa_name('zero') + self._emit_with_loc('\n{} = constant 0 : i64'.format(axis)) + casted = self._ssa_name('pack') + self.emit('\n{} = tfr.call @tf__pack({}, {})'.format(casted, value, axis)) + self._emit_with_loc(' : (!tfr.tensor_list, i64) -> !tfr.tensor') + # load the op def of tf.Pack + self._op_defs.lookup('Pack') + return casted, TFRTypes.TENSOR + + def _index_to_I64(self, value, ty): + if ty == TFRTypes.INDEX: + casted = self._ssa_name('casted') + self._emit_with_loc('\n{} = index_cast {} : index to i64'.format( + casted, value)) + return casted, TFRTypes.I64 + else: + return value, ty + + def _value_to_tensor(self, value, ty, node): + value, ty = self._index_to_I64(value, ty) + cst_tensor = self._ssa_name('cst') + self.emit('\n{} = "tfr.constant_tensor"({})'.format(cst_tensor, value)) + self._emit_with_loc(' : ({}) -> !tfr.tensor'.format(ty), node) + return cst_tensor, TFRTypes.TENSOR + + def _ssa_name(self, prefix): + if isinstance(prefix, qual_names.QN): + assert prefix.is_simple(), 'ANF transform should have cleaned this up' + prefix = prefix.ssf() + return '%' + self.ctx.namer.new_symbol(prefix, set()) + + def _op_def(self, op_name): + return op_def_registry.get(op_name) + + def visit_block(self, block): + return [self.visit(item) for item in block] + + def visit_Pass(self, node): + if self.symbol_table.in_scf_scope(): + self._emit_with_loc('\nscf.yield', node) + else: + self._emit_with_loc('\ntfr.return', node) + + def visit_Attribute(self, node): + node_type = self._get_inferred_type(node, None) + if isinstance(node.value, ast.Name): + if node.value.id == 'ag__': + # some variables are assigned with 'ag__.xxx' method, we should handle + # them following the autograph convensions. + return (node.attr, TFRTypes.AG_BUILTIN_FUNC) + + if node_type == TFRTypes.TF_RAW_OP: + # This branch is used when it is inside tensorflow + return (node.attr, TFRTypes.TF_RAW_OP) + + value, _ = self.visit(node.value) + tensor_type = self._get_inferred_type(node.value, None) + # TODO(fengliuai): use node_type once it + if node_type == TFRTypes.SHAPE: + print('TODO: use "node_type"') + if node.attr == 'shape' and tensor_type == TFRTypes.TENSOR: + ssa_value = self._ssa_name('shape') + self._emit_with_loc( + '\n{} = tfr.get_shape {} -> !shape.shape'.format(ssa_value, value), + node) + return (ssa_value, TFRTypes.SHAPE) + + if isinstance(node.value, ast.Attribute): + if isinstance(node.value.value, ast.Name): + if node.value.value.id == 'tf' and node.value.attr == 'raw_ops': + # This branch is used when it is outside tensorflow + return (node.attr, TFRTypes.TF_RAW_OP) + + value, ty = self.visit(node.value) + # TODO(fengliuai): use node_type once it + if node_type == TFRTypes.TF_TENSOR_SHAPE_FUNC: + print('TODO: use "node_type"') + if ty == TFRTypes.SHAPE and node.attr == 'as_list': + return (value, TFRTypes.TF_TENSOR_SHAPE_FUNC) + + raise NotImplementedError('Attribute kind not recoganized.') + + def visit_Assign(self, node): + values = self.visit(node.value) + if isinstance(node.targets[0], ast.Tuple): + targets = [elt.id for elt in node.targets[0].elts] + elif isinstance(node.targets[0], ast.Name): + targets = [node.targets[0].id] + else: + raise NotImplementedError('Assignment target type not recoganized.') + + if isinstance(values, list): + if len(targets) == len(values): + for key, value in zip(targets, values): + ssa_value, ty_ = value + ty = self._get_inferred_type(node.value, ty_) + self.symbol_table.insert_symbol(key, ssa_value, ty) + elif len(values) == 1: + n, ty = values[0] + assert ty == TFRTypes.TENSOR_LIST + # assign a tensor_list to multiple variables + for idx, key in enumerate(targets): + idx_name = self._ssa_name('idx') + self._emit_with_loc( + '\n{} = constant {} : index'.format(idx_name, idx), node) + elt_name = self._ssa_name('elt') + self.emit('\n{} = tfr.get_element {}[{}]'.format( + elt_name, n, idx_name)) + self._emit_with_loc(' : (!tfr.tensor_list, index) -> !tfr.tensor', + node) + self.symbol_table.insert_symbol(key, elt_name, TFRTypes.TENSOR) + elif len(targets) == 1: + ssa_names = [n for n, _ in values] + tys = [t for _, t in values] + self.symbol_table.insert_symbol(targets[0], ssa_names, tys) + else: + self.symbol_table.insert_symbol(targets[0], values[0], values[1]) + + def _emit_binary_op(self, op, lhs, lhs_ty, rhs, rhs_ty): + assert lhs_ty, rhs_ty + if isinstance(op, ast.Sub): + code = 'sub' + elif isinstance(op, ast.Add): + code = 'add' + else: + raise NotImplementedError('BinOp operator not recognized' + op) + + if lhs_ty == TFRTypes.I64: + suffix = 'i' + elif lhs_ty == TFRTypes.F32: + suffix = 'f' + else: + raise NotImplementedError('BinOp operand type not recognized' + op) + + ret = self._ssa_name(code) + self._emit_with_loc( + '\n{} = {}{} {}, {} : {}'.format(ret, code, suffix, lhs, rhs, lhs_ty), + op) + return ret, lhs_ty + + def visit_AugAssign(self, node): + lhs, lhs_ty = self.visit(node.target) + rhs, rhs_ty = self.visit(node.value) + ret, ret_ty = self._emit_binary_op(node.op, lhs, lhs_ty, rhs, rhs_ty) + self.symbol_table.insert_symbol(node.target.id, ret, ret_ty) + + def visit_BinOp(self, node): + lhs, lhs_ty = self.visit(node.left) + rhs, rhs_ty = self.visit(node.right) + return self._emit_binary_op(node.op, lhs, lhs_ty, rhs, rhs_ty) + + def visit_BoolOp(self, node): + values = [self.visit(value) for value in node.values] + # TODO(fengliuai): Handle more ast node types. + if isinstance(node.op, ast.Or): + raise NotImplementedError('Or operator not recognized') + elif isinstance(node.op, ast.And): + raise NotImplementedError('And operator not recognized') + + def visit_Call(self, node): + func_name, func_type = self.visit(node.func) + _ = self._get_inferred_type(node.func, func_type) + if func_type == TFRTypes.AG_BUILTIN_FUNC: + if func_name == 'if_stmt': + cond, _ = self.visit(node.args[0]) + body, _ = self.visit(node.args[1]) + orelse, _ = self.visit(node.args[2]) + get_state, _ = self.visit(node.args[3]) + nouts = int(node.args[6].value) + out_symbols = [] + # The out symbols are just a Tuple of names + for out in node.args[5].elts[:nouts]: + val, ty = self.symbol_table.lookup(out.value) + if ty != TFRTypes.AG_UNDEFINED_VAL: + raise ValueError('if stmt out symbol is not defined.') + out_symbols.append(out.value) + return self._visit_if_stmt(cond, body, orelse, get_state, out_symbols, + node) + elif func_name == 'for_stmt': + range_ = self._visit_iter(node.args[0]) + body, _ = self.visit(node.args[2]) + get_state, _ = self.visit(node.args[3]) + loop_carried = [out.value for out in node.args[5].elts] + # TODO(fengliuai): opt is not used here. + return self._visit_for_stmt(range_, body, get_state, loop_carried, node) + elif func_name == 'Undefined': + val = self._ssa_name(node.args[0].value) + return (val, TFRTypes.AG_UNDEFINED_VAL) + elif func_name == 'UndefinedReturnValue': + val = self._ssa_name('return_val') + return (val, TFRTypes.AG_UNDEFINED_VAL) + + if func_type == TFRTypes.TF_RAW_OP: + return self._visit_tf_op(func_name, node.args, node.keywords, node) + + if func_type == TFRTypes.TF_TENSOR_SHAPE_FUNC: + return (func_name, TFRTypes.TF_TENSOR_SHAPE_LIST) + + if func_type == TFRTypes.PY_BUILTIN_FUNC: + if func_name == 'len': + arg, ty = self.visit(node.args[0]) + ty = self._get_inferred_type(node.args[0], ty) + assert ty == TFRTypes.TF_TENSOR_SHAPE_LIST, ty + len_value = self._ssa_name('len') + self._emit_with_loc( + '\n{} = shape.rank {} : !shape.shape -> !shape.size'.format( + len_value, arg), node) + size_value = self._ssa_name('len_size') + self._emit_with_loc( + '\n{} = shape.size_to_index {} : !shape.size'.format( + size_value, len_value), node) + return (size_value, TFRTypes.INDEX) + + raise NotImplementedError('call operator not recognized: {} {}'.format( + func_name, func_type)) + + def visit_Compare(self, node): + lhs, lhs_ty = self.visit(node.left) + for op, right in zip(node.ops, node.comparators): + rhs, _ = self.visit(right) + if isinstance(op, ast.Eq): + pred = 'eq' + elif isinstance(op, ast.Lt): + pred = 'ult' + elif isinstance(op, ast.LtE): + pred = 'ule' + elif isinstance(op, ast.Gt): + pred = 'ugt' + elif isinstance(op, ast.GtE): + pred = 'uge' + elif isinstance(op, ast.NotEq): + pred = 'ne' + else: + raise NotImplementedError('Compare operator not recognized') + + ret = self._ssa_name(pred) + if lhs_ty == TFRTypes.ATTR: + self._emit_with_loc( + '\n{} = tfr.equal {}, {} -> i1'.format(ret, lhs, rhs), node) + else: + if lhs_ty == TFRTypes.I64: + code = 'cmpi' + elif lhs_ty == TFRTypes.F32: + code = 'cmpf' + else: + raise NotImplementedError('Compare operand type not recognized') + self._emit_with_loc( + '\n{} = {} "{}", {}, {} : {}'.format(ret, code, pred, lhs, rhs, + lhs_ty), node) + + return ret, TFRTypes.I1 + + def visit_Constant(self, node): + cst_name = self._ssa_name('cst') + if node.value is None: + cst_ty = TFRTypes.NONE + elif isinstance(node.value, bool): + cst_ty = self._get_inferred_type(node) + cst_val = str(node.value).lower() + self._emit_with_loc('\n{} = constant {}'.format(cst_name, cst_val), node) + else: + cst_ty = self._get_inferred_type(node) + cst_val = node.value + if cst_ty == TFRTypes.ATTR: + self._emit_with_loc( + '\n{} = tfr.constant "{}" -> {}'.format(cst_name, cst_val, cst_ty), + node) + else: + self._emit_with_loc( + '\n{} = constant {} : {}'.format(cst_name, cst_val, cst_ty), node) + return cst_name, cst_ty + + def visit_FunctionDef(self, node): + op_def, derived_attrs = self._op_defs.lookup(node.name, node, True) + if op_def is None: + # Nested function. Insert it to symbol table for looking up later. + self.symbol_table.insert_symbol(node.name, node, None) + return + op_name = op_def.name + if self.symbol_table.lookup(op_name): + raise LookupError('Composition has not been registered for op: ' + + op_name) + else: + self.symbol_table.insert_symbol(node.name, None, None) + + self.symbol_table.enter_scope() + self.emit('\ntfr.func @tf__{0}('.format(_camel_to_snake(op_name))) + + arg_list = [] + idx = 0 + max_idx = len(op_def.input_arg) + len(op_def.attr) + for arg in node.args.args: + arg_name = self._ssa_name(anno.getanno(arg, anno.Basic.QN)) + arg_type = anno.getanno(arg, anno.Static.TYPES)[0] + + arg_attr = '' + if idx >= len(op_def.input_arg): + attr_def = op_def.attr[idx - len(op_def.input_arg)] + # skip the derived attributes + while attr_def.name in derived_attrs and (idx + 1) < max_idx: + idx += 1 + attr_def = op_def.attr[idx - len(op_def.input_arg)] + if idx >= max_idx: + raise ValueError('Argument is not defined in OpDef: ' + arg_name) + + arg_attr += '{{tfr.name="{}"'.format(attr_def.name) + if attr_def.HasField('default_value'): + default_val = _get_val_from_proto(arg_type, attr_def.default_value) + arg_attr += ',tfr.default={}'.format(default_val) + arg_attr += '}' + + idx += 1 + arg_str = '{}: {}{}'.format(arg_name, arg_type, arg_attr) + arg_list.append(arg_str) + self.symbol_table.insert_symbol(arg.id, arg_name, arg_type) + + ret_type_list = [] + for ret_def in op_def.output_arg: + if ret_def.number_attr or ret_def.type_list_attr: + ret_type_list.append(str(TFRTypes.TENSOR_LIST)) + else: + ret_type_list.append(str(TFRTypes.TENSOR)) + + self.emit('{}) -> ({}) {{'.format(', '.join(arg_list), + ', '.join(ret_type_list))) + self.visit_block(node.body) + self._emit_with_loc('\n}', node) + self.symbol_table.exit_scope() + + def visit_arguments(self, node): + # TODO(fengliuai): return ordered the types and names. + # We need to order the arguments to match the assumption in the TFR dialect. + raise NotImplementedError('arguments not supported.') + + def visit_Lambda(self, node): + raise NotImplementedError('Lambda not supported.') + + def _get_mlir_ssa_values(self, name_prefix, out_types): + """Create MLIR convention SSA values.""" + out_ssa_values = [] + if not out_types: + return '', out_ssa_values + + out_name = self._ssa_name(name_prefix) + if len(out_types) == 1: + out_name_suffix = '' + out_ssa_values.append(out_name) + else: + # For multiple returns, MLIR uses '%s:i' when they are defined and + # '%s#i' when they are used. + out_name_suffix = ':{}'.format(len(out_types)) + for idx, _ in enumerate(out_types): + out_ssa_values.append('{}#{}'.format(out_name, idx)) + + return '{}{}'.format(out_name, out_name_suffix), out_ssa_values + + def _visit_if_stmt(self, cond, body_def, orelse_def, get_state, out_symbols, + node): + self.emit('\n') + ret_str, ret_ssa_values = self._get_mlir_ssa_values( + 'if_stmt', [TFRTypes.TENSOR] * len(out_symbols)) + if ret_ssa_values: + self.emit(ret_str + ' = ') + + # add ssa values to the symbol table + out_types = [] + for symbol, ssa_value in zip(out_symbols, ret_ssa_values): + self.symbol_table.insert_symbol(symbol, ssa_value, TFRTypes.TENSOR) + out_types.append(str(TFRTypes.TENSOR)) + + self.emit('scf.if {} -> ({}) {{'.format(cond, ', '.join(out_types))) + # Create a new scope in case the local variables are leaked. + self.symbol_table.enter_scope(scf_scope=True) + self.visit_block(body_def.body) + self.visit_block(get_state.body) + self.symbol_table.exit_scope() + + self.emit('\n} else {') + + # Create a new scope in case the local variables are leaked. + self.symbol_table.enter_scope(scf_scope=True) + self.visit_block(orelse_def.body) + self.visit_block(get_state.body) + self.symbol_table.exit_scope() + + self._emit_with_loc('\n}', node) + return list(zip(ret_ssa_values, out_types)) + + def _visit_iter(self, node): + if isinstance(node, ast.Call): + f_name = anno.getanno(node.func, anno.Basic.QN) + if f_name == QN('range'): + args = [self.visit(arg) for arg in node.args] + begin = None + step = None + end = None + if len(args) == 1: + end, end_ty = args[0] + elif len(args) == 2: + begin, begin_ty = args[0] + end, end_ty = args[1] + elif len(args) == 3: + begin, begin_ty = args[0] + end, end_ty = args[1] + step, step_ty = args[2] + + if begin is None: + begin = self._ssa_name('begin') + self._emit_with_loc('\n{} = constant 0 : index'.format(begin), node) + elif begin_ty != TFRTypes.INDEX: + begin_ = self._ssa_name('begin') + self._emit_with_loc( + '\n{} = index_cast {} : {} to index'.format( + begin_, begin, begin_ty), node) + begin = begin_ + + if end_ty != TFRTypes.INDEX: + end_ = self._ssa_name('end') + self._emit_with_loc( + '\n{} = index_cast {} : {} to index'.format(end_, end, end_ty), + node) + end = end_ + + if step is None: + step = self._ssa_name('step') + self._emit_with_loc('\n{} = constant 1 : index'.format(step), node) + elif step_ty != TFRTypes.INDEX: + step_ = self._ssa_name('step') + self._emit_with_loc( + '\n{} = index_cast {} : {} to index'.format(step_, step, step_ty), + node) + step = step_ + + return begin, end, step + + raise NotImplementedError('Iterator entity not supported.' + node) + + def _visit_for_stmt(self, range_, body_def, get_state, loop_carried, node): + self.emit('\n') + ret_str, ret_ssa_values = self._get_mlir_ssa_values( + 'for_stmt', [TFRTypes.TENSOR] * len(loop_carried)) + if ret_ssa_values: + self.emit(ret_str + ' = ') + + # Before enter the loop, we use the original ssa values as the initial + # values to the loop iteration arguments. We also create new ssa values as + # the returns of the scf for statements. The symbol table needs to be + # updated to these new ssa values before it enters the scope of the loop. + out_types = [] + init_values = [] + for symbol, ssa_value in zip(loop_carried, ret_ssa_values): + init, ty = self.symbol_table.lookup(symbol) + self.symbol_table.insert_symbol(symbol, ssa_value, ty) + out_types.append(str(ty)) + init_values.append((init, ty)) + + # Create a new scope in case the local variables are leaked. + self.symbol_table.enter_scope(scf_scope=True) + + # Create the iteration variable with index type + assert len(body_def.args.args) == 1 + it_name = body_def.args.args[0].id + it = self._ssa_name(it_name) + self.symbol_table.insert_symbol(it_name, it, TFRTypes.INDEX) + + self.emit('scf.for {} = {} to {} step {} '.format(it, range_[0], range_[1], + range_[2])) + if loop_carried: + iter_args = [] + for symbol, init in zip(loop_carried, init_values): + # create new ssa values for the loop carried variables + it_arg = self._ssa_name('it_arg') + self.symbol_table.insert_symbol(symbol, it_arg, init[1]) + iter_args.append('{} = {}'.format(it_arg, init[0])) + self.emit('iter_args({}) '.format(', '.join(iter_args))) + self.emit('-> ({}) {{'.format(', '.join(out_types))) + else: + self.emit(' {') + self.visit_block(body_def.body) + self.visit_block(get_state.body) + self.symbol_table.exit_scope() + self._emit_with_loc('\n}', node) + return list(zip(ret_ssa_values, out_types)) + + def _emit_default_constant_from_proto(self, attr_def): + """emit mlir constant statement from default value of the ArgDef proto.""" + name = self._ssa_name('cst') + cst_ty = _get_type_from_proto(None, attr_def) + cst_val = _get_val_from_proto(cst_ty, attr_def.default_value) + if cst_ty == TFRTypes.ATTR: + self._emit_with_loc('\n{} = tfr.constant {} -> {}'.format( + name, cst_val, cst_ty)) + elif cst_ty == TFRTypes.I1: + self._emit_with_loc('\n{} = constant {}'.format(name, cst_val)) + else: + self._emit_with_loc('\n{} = constant {} : {}'.format( + name, cst_val, cst_ty)) + return name, cst_ty + + def visit_keyword(self, node): + return node.arg, self.visit(node.value) + + def _visit_tf_op(self, op_name, args, keywords, node): + op_def, derived_attrs = self._op_defs.lookup(op_name) + ret_tys = [_get_type_from_proto(arg) for arg in op_def.output_arg] + + ret_str, ret_ssa_values = self._get_mlir_ssa_values(op_name, ret_tys) + + arg_strs = [] + ty_strs = [] + for arg in args: + value, ty = self.visit(arg) + arg_strs.append(value) + ty_strs.append(str(ty)) + + input_args = [arg for arg in op_def.input_arg] + attrs_no_default = [ + attr for attr in op_def.attr + if not attr.HasField('default_value') and attr.name not in derived_attrs + ] + attrs_with_default = [ + attr for attr in op_def.attr + if attr.HasField('default_value') and attr.name not in derived_attrs + ] + + kw_args = {} + for arg in keywords: + value, (ssa_name, ty) = self.visit(arg) + ty = self._get_inferred_type(arg.value, ty) + + # TODO(fengliuai): implement the "rename_to" for the customization in + # tensorflow/core/api_def/base_api/* + if value == 'axis': + value = 'split_dim' + + kw_args[value] = (ssa_name, ty) + + # tensor arguments and attribute arguments + ordered_args = input_args + attrs_no_default + attrs_with_default + for attr_def in ordered_args[len(args):]: + if attr_def.name in kw_args: + value, ty = kw_args[attr_def.name] + if attr_def in input_args: + if ty in _attribute_types: + # the argument shouldn't be used as tf op calls directly. + value, ty = self._value_to_tensor(value, ty, node) + if ty is TFRTypes.TENSOR_LIST and not _require_tensor_list(attr_def): + value, ty = self._pack_tensor_list(value) + else: + value, ty = self._emit_default_constant_from_proto(attr_def) + arg_strs.append(value) + ty_strs.append(str(ty)) + + if ret_ssa_values: + self.emit('\n{} = '.format(ret_str)) + + self.emit('tfr.call @tf__{}('.format(_camel_to_snake(op_name))) + arg_str = ', '.join(arg_strs) + arg_ty_str = ', '.join(ty_strs) + ret_ty_str = ', '.join([str(ty) for ty in ret_tys]) + self._emit_with_loc( + '{}) : ({}) -> ({})'.format(arg_str, arg_ty_str, ret_ty_str), node) + return list(zip(ret_ssa_values, ret_tys)) + + def visit_If(self, node): + raise NotImplementedError('If not supported.') + + def visit_Name(self, node): + val, lookup_type = self.symbol_table.lookup(node.id) + type_ = self._get_inferred_type(node, lookup_type) + return val, type_ + + def visit_Return(self, node): + values = self.visit(node.value) + if self.symbol_table.in_scf_scope(): + self.emit('\nscf.yield ') + else: + self.emit('\ntfr.return ') + if not values: + return + + if isinstance(values, list): + vals, tys = zip(*values) + else: + vals = values[0] + tys = values[1] + + if isinstance(tys, list) or isinstance(tys, tuple): + tys = [str(t) for t in tys] + self._emit_with_loc('{} : {}'.format(', '.join(vals), ', '.join(tys)), + node) + elif tys != TFRTypes.NONE: + # TODO(fengliuai): scf region yield uses this branch. Fix it. + self._emit_with_loc('{} : {}'.format(vals, tys), node) + + def visit_Subscript(self, node): + val, ty = self.visit(node.value) + type_ = self._get_inferred_type(node.value, ty) + + # TODO(fengliuai): Here we hardcode the node.slice here to get the index + # type. Use the visit method once the type inference is done. + # slice_val, slice_ty = self.visit(node.slice) + if isinstance(node.slice, ast.Index): + if isinstance(node.slice.value, ast.Constant): + # TODO(fengliuai): promote to an assignment + idx_val = self._ssa_name('cst') + self._emit_with_loc( + '\n{} = constant {} : index'.format(idx_val, + node.slice.value.value), node) + else: + idx_val, _ = self.visit(node.slice.value) + else: + raise NotImplementedError('non-index slice not supported.') + + elt = self._ssa_name('elt') + if type_ == TFRTypes.TENSOR_LIST: + self.emit('\n{} = tfr.get_element {}[{}] '.format(elt, val, idx_val)) + self._emit_with_loc(': (!tfr.tensor_list, index) -> !tfr.tensor', node) + return (elt, TFRTypes.TENSOR) + elif type_ == TFRTypes.TF_TENSOR_SHAPE_LIST: + size_ = self._ssa_name('size') + self.emit('\n{} = shape.get_extent {}, {}'.format(size_, val, idx_val)) + self._emit_with_loc(': !shape.shape, index -> !shape.size', node) + self._emit_with_loc( + '\n{} = shape.size_to_index {} : !shape.size'.format(elt, size_), + node) + return (elt, TFRTypes.INDEX) + + def visit_List(self, node): + out_type = self._get_inferred_type(node) + vals = [] + tys = [] + for elt in node.elts: + val, ty = self.visit(elt) + if ty in _attribute_types and out_type == TFRTypes.TENSOR_LIST: + # This list is a tensor list, then cast all the input values to tensors. + val, ty = self._value_to_tensor(val, ty, node) + else: + # We shouldn't use index type to build the list because list will be use + # as attribute. + val, ty = self._index_to_I64(val, ty) + vals.append(val) + tys.append(str(ty)) + + list_val = self._ssa_name('list') + self.emit('\n{} = "tfr.build_list"({})'.format(list_val, ', '.join(vals))) + self._emit_with_loc(' : ({}) -> {}'.format(', '.join(tys), out_type), node) + return (list_val, out_type) + + def visit_Tuple(self, node): + return [self.visit(elt) for elt in node.elts] + + def visit_UnaryOp(self, node): + value, ty = self.visit(node.operand) + if isinstance(node.op, ast.USub): + zero_value = self._ssa_name('zero') + self._emit_with_loc('\n{} = constant 0 : {}'.format(zero_value, ty), node) + ssa_value = self._ssa_name('cst') + if ty == TFRTypes.I32 or ty == TFRTypes.I64: + self._emit_with_loc( + '\n{} = subi {}, {} : {}'.format(ssa_value, zero_value, value, ty), + node) + elif ty == TFRTypes.F32: + self._emit_with_loc( + '\n{} = subf {}, {} : {}'.format(ssa_value, zero_value, value, ty), + node) + else: + raise NotImplementedError('USub type not recognized: ' + str(ty)) + return ssa_value, ty + raise NotImplementedError('USub operator not recognized') + + def visit_For(self, node): + raise NotImplementedError('For operator not recognized') + + def visit_While(self, node): + raise NotImplementedError('While operator not recognized') + + def visit_Try(self, node): + # Only handles the body of the try statement. + self.visit_block(node.body) + + +def _apply_py_to_tf_passes(node, ctx): + """Apply transformations from PyToTF to match tf.function tracing.""" + # TODO(fengliuai): we don't know which passes are required, thus we evalute + # each one when the corresponding node is handled. + # copied from PyToTF.transform_ast + node = return_statements.transform(node, ctx, False) + node = control_flow.transform(node, ctx) + return node + + +class TfrGen(transpiler.GenericTranspiler): + """Transforms Python objects into TFR MLIR source code.""" + + def __init__(self, op_defs): + self._op_defs = op_defs + + def transform_ast(self, node, ctx): + node = _apply_py_to_tf_passes(node, ctx) + # TODO(mdan): Enable this. + # node = anf.transform(node, ctx) + + graphs = cfg.build(node) + node = qual_names.resolve(node) + node = activity.resolve(node, ctx) + node = reaching_definitions.resolve(node, ctx, graphs) + node = reaching_fndefs.resolve(node, ctx, graphs) + node = type_inference.resolve(node, ctx, graphs, + TFRTypeResolver(self._op_defs)) + + mlir_generator = TFRGen(ctx, self._op_defs) + mlir_generator.visit(node) + return mlir_generator.code_buffer + + +def tfr_gen(func, op_defs): + """Parse a function and emit the TFR functions.""" + mlir_code, _ = TfrGen(op_defs).transform(func, None) + assert tfr.verify(mlir_code), 'mlir code not verified: {}'.format(mlir_code) + return mlir_code + + +def tfr_gen_from_module(source, method_prefix=None, op_libraries=None): + """Parse the input source module and emit the TFR functions.""" + op_defs = OpDefCache() + + # Load the op library so the op is added to the op registry. This is + # required when the op cc_library couldn't be statically linked in open + # source. + # This is a no op if the op shared library couldn't be found in the same + # directory of the op Python API. + # TODO(fengliuai): make the .so file path configurable. + if op_libraries: + prefix_len = len('gen_') + for m in op_libraries: + lib_dir = os.path.dirname(m.__file__) + lib_name = os.path.basename(m.__file__)[prefix_len:].replace('.py', '.so') + lib_path = os.path.join(lib_dir, lib_name) + if os.path.exists(lib_path): + logging.info('load file: ' + lib_path) + load_library.load_op_library(lib_path) + else: + # The op library is generated from the source module, then we load all the + # .so file in the directory + lib_dir = os.path.dirname(source.__file__) + for lib_name in os.listdir(lib_dir): + if lib_name.endswith('.so'): + lib_path = os.path.join(lib_dir, lib_name) + logging.info('load file: ' + lib_path) + load_library.load_op_library(lib_path) + + mlir_funcs = [ + tfr_gen(func, op_defs) + for name, func in tf_inspect.getmembers(source, tf_inspect.isfunction) + if not method_prefix or name.startswith(method_prefix) + ] + + return '\n'.join(mlir_funcs + op_defs.mlir_external_funcs()) diff --git a/tensorflow/compiler/mlir/tfr/python/tfr_gen_test.py b/tensorflow/compiler/mlir/tfr/python/tfr_gen_test.py new file mode 100644 index 00000000000..88696490c4a --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/python/tfr_gen_test.py @@ -0,0 +1,563 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for `tfr_gen` module.""" + +# pylint: disable=missing-function-docstring +# pylint: disable=invalid-name +# pylint: disable=g-direct-tensorflow-import + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys + +from tensorflow.compiler.mlir.python.mlir_wrapper import filecheck_wrapper as fw +from tensorflow.compiler.mlir.tfr.python import composite +from tensorflow.compiler.mlir.tfr.python.tfr_gen import tfr_gen_from_module as tfr_gen +from tensorflow.compiler.mlir.tfr.resources import gen_test_ops as test_ops +from tensorflow.python.ops import gen_array_ops as array_ops +from tensorflow.python.ops import gen_math_ops as math_ops +from tensorflow.python.platform import test + + +Composite = composite.Composite + +#--- test fn for mlir location --- + + +@Composite('TestInputNOp') +def _tfr_loc_test(x): + n = 10 + x_sum = x[0] + for i in range(1, n): + x_sum = math_ops.Add(x_sum, x[i]) + return x_sum + + +#--- test fn for tfr tensors --- + + +@composite.Composite('TestNoOp') +def _tfr_tensor_empty_arg(): + pass + + +@composite.Composite('TestIdentityOp') +def _tfr_tensor_tensor(x): + return x + + +@composite.Composite('TestIdentityNOp') +def _tfr_tensor_tensor_list(x): + return x + + +@composite.Composite('TestInputNOp') +def _tfr_tensor_tensor_list_get_elt(x): + return x[1] + + +@composite.Composite('TestOutputNOp') +def _tfr_tensor_tensor_list_output(x): + return [x, x] + + +@composite.Composite('TestTwoInputsOp') +def _tfr_tensor_tensor_list_split(x, y, pred): + z, _ = array_ops.Split(axis=0, value=x, num_split=2) + (y, pred) # pylint: disable=pointless-statement + return z + + +@composite.Composite('TestNumAttrsOp') +def _tfr_tensor_tensor_with_cst(x1, y1, x2, y2): + x = array_ops.OneHot( + indices=[0, 2, -1, x1], depth=y1, on_value=True, off_value=False) + (x, x2, y2) # pylint: disable=pointless-statement + return + + +@composite.Composite('TestTwoOutputsOp') +def _tfr_tensor_two_output(x): + z = array_ops.Split(axis=0, value=x, num_split=2) + return z[0], z[1] + + +#--- test fn for scf control flow --- + + +@composite.Composite('TestTwoInputsOp') +def _tfr_control_flow_if(x, y, pred): + if pred: + return x + else: + return y + + +@composite.Composite('TestThreeInputsOp') +def _tfr_control_flow_nested_if(x, y, z, select): + if select == 'x': + return x + elif select == 'y': + return y + else: + return z + + +@composite.Composite('TestInputNOp') +def _tfr_control_flow_range_for(x): + # TODO(fengliuai): use len(x) instead + n = 10 + x_sum = x[0] + for i in range(1, n): + x_sum = math_ops.Add(x_sum, x[i]) + return x_sum + + +#--- test fn for tf ops --- + + +@composite.Composite('TestComplexTFOp') +def _tfr_tf_ops_complex(lhs, rhs): + left_padding, _ = array_ops.SplitV( + value=lhs, size_splits=[rhs, -1], axis=0, num_split=2) + _, right_padding = array_ops.SplitV( + value=lhs, size_splits=[rhs, rhs], axis=1, num_split=2) + return [left_padding, right_padding] + + +@composite.Composite('TestIdentityOp') +def _tfr_tf_ops_tensor(x): + return array_ops.Identity(x) + + +@composite.Composite('TestTwoInputsOp') +def _tfr_tf_ops_tensors(x, y, pred): + if pred: + return math_ops.Add(x, y) + else: + return array_ops.Concat(0, [x, y]) + + +@composite.Composite('TestInputNOp') +def _tfr_tf_ops_with_defaults(ins): + return test_ops.TestTwoInputsOp(ins[0], ins[1]) + + +#--- test fn for tfr attributes --- + + +@composite.Composite('TestNumAttrsOp') +def _tfr_attrs_num_type(x, y, x1, y1): + # int + z0 = [x, y] + z1 = x == y + z2 = x < y + z3 = x <= y + z4 = x > y + z5 = x >= y + z6 = x != y + z7 = x + y + z8 = x - y + z8 += x + z8 += 1 + (z0, z1, z2, z3, z4, z5, z6, z7, z8) # pylint: disable=pointless-statement + + # float + z9 = x1 > y1 + z10 = x1 + y1 + z11 = [x1, y1] + (z9, z10, z11) # pylint: disable=pointless-statement + return + + +@composite.Composite('TestNonNumAttrsOp') +def _tfr_attrs_tfr_type(x, y, z): + z1 = x == y + z2 = x == 'test' + z3 = y == z + (z1, z2, z3) # pylint: disable=pointless-statement + return + + +#--- test fn for shapes --- + + +@composite.Composite('TestIdentityOp') +def _tfr_shapes(x): + s1 = x.shape + s3 = x.shape.as_list() + + for i in range(len(s3)): + s3[i] # pylint: disable=pointless-statement + + for i in range(1, len(s3), 2): + s3[i] # pylint: disable=pointless-statement + + s5 = array_ops.Shape(x) + (s1, s3, s5) # pylint: disable=pointless-statement + return x + + +class TFRGenTestBase(test.TestCase): + + def _check_code(self, tfr_code, exp_tfr_code): + return self.assertTrue(fw.check(str(tfr_code), exp_tfr_code), str(tfr_code)) + + +class TFRGenTensorTest(TFRGenTestBase): + """MLIR Generation Tests for MLIR TFR Program.""" + + def test_tfr_loc(self): + mlir_code = tfr_gen(sys.modules[__name__], '_tfr_loc', [test_ops]) + mlir_code_exp = r""" + CHECK-LABEL: tfr.func @tf__test_input_n_op(%x: !tfr.tensor_list) -> (!tfr.tensor) { + CHECK-NEXT: %[[n:.*]] = constant 10 : i64 + CHECK-SAME loc("tfr_gen_test.py":%{{.*}}:6) + CHECK-NEXT: %[[cst:.*]] = constant 0 : index + CHECK-SAME loc("tfr_gen_test.py":%[[sum_line:.*]]:10) + CHECK-NEXT: %[[elt:.*]] = tfr.get_element %x[%[[cst]]] : (!tfr.tensor_list, index) -> !tfr.tensor + CHECK-SAME loc("tfr_gen_test.py":%[[sum_line]]:10) + CHECK-NEXT: %[[cst_1:.*]] = constant 1 : i64 + CHECK-SAME loc("tfr_gen_test.py":%[[for_line:.*]]:2) + CHECK-NEXT: %[[begin:.*]] = index_cast %[[cst_1]] : i64 to index + CHECK-SAME loc("tfr_gen_test.py":%[[for_line]]:2) + CHECK-NEXT: %[[end:.*]] = index_cast %[[n]] : i64 to index + CHECK-SAME loc("tfr_gen_test.py":%[[for_line]]:2) + CHECK-NEXT: %[[step:.*]] = constant 1 : index + CHECK-SAME loc("tfr_gen_test.py":%[[for_line]]:2) + CHECK-NEXT: %[[for_stmt:.*]] = scf.for %[[itr_1:.*]] = %[[begin]] to %[[end]] step %[[step]] + CHECK-SAME: iter_args(%[[it_arg:.*]] = %[[elt]]) -> (!tfr.tensor) { + CHECK-NEXT: %[[elt_1:.*]] = tfr.get_element %x[%itr_1] : (!tfr.tensor_list, index) -> !tfr.tensor + CHECK-SAME loc("tfr_gen_test.py":%[[add_line:.*]]:34) + CHECK-NEXT: %[[Add:.*]] = tfr.call @tf__add(%[[it_arg]], %[[elt_1]]) : (!tfr.tensor, !tfr.tensor) -> (!tfr.tensor) + CHECK-SAME loc("tfr_gen_test.py":%[[add_line]]:12) + CHECK-NEXT: scf.yield %[[Add]] : !tfr.tensor + CHECK-SAME loc(unknown) + CHECK-NEXT: } + CHECK-SAME loc("tfr_gen_test.py":%[[for_line]]:2) + CHECK-NEXT: %{{.*}} = constant true + CHECK-SAME loc(unknown) + CHECK-NEXT: tfr.return %[[for_stmt]] : !tfr.tensor + CHECK-SAME loc(unknown) + CHECK-NEXT: } + CHECK-SAME loc("tfr_gen_test.py":%{{def_line:.*}}:0) + """ + self._check_code(mlir_code, mlir_code_exp) + + def test_tfr_tensors(self): + mlir_code = tfr_gen(sys.modules[__name__], '_tfr_tensor', [test_ops]) + mlir_code_exp = r""" + CHECK-LABEL: tfr.func @tf__test_no_op() -> () { + CHECK-NEXT: tfr.return + CHECK-NEXT: } + + CHECK-LABEL: tfr.func @tf__test_identity_op(%x: !tfr.tensor) -> (!tfr.tensor) { + CHECK-NEXT: constant true + CHECK-NEXT: tfr.return %x : !tfr.tensor + CHECK-NEXT: } + + CHECK-LABEL: tfr.func @tf__test_identity_n_op(%x: !tfr.tensor_list) -> (!tfr.tensor_list) { + CHECK-NEXT: constant true + CHECK-NEXT: tfr.return %x : !tfr.tensor_list + CHECK-NEXT: } + + CHECK-LABEL: tfr.func @tf__test_input_n_op(%x: !tfr.tensor_list) -> (!tfr.tensor) { + CHECK-NEXT: constant true + CHECK-NEXT: %[[index:.*]] = constant 1 : index + CHECK-NEXT: %[[sub:.*]] = tfr.get_element %x[%cst_1] : (!tfr.tensor_list, index) -> !tfr.tensor + CHECK-NEXT: tfr.return %[[sub]] : !tfr.tensor + CHECK-NEXT: } + + CHECK-LABEL: tfr.func @tf__test_output_n_op(%x: !tfr.tensor) -> (!tfr.tensor_list) { + CHECK-NEXT: constant true + CHECK-NEXT: %[[list:.*]] = "tfr.build_list"(%x, %x) : (!tfr.tensor, !tfr.tensor) -> !tfr.tensor_list + CHECK-NEXT: tfr.return %[[list]] : !tfr.tensor_list + CHECK-NEXT: } + + CHECK-LABEL: tfr.func @tf__test_two_inputs_op(%x: !tfr.tensor, %y: !tfr.tensor, %pred: i1{tfr.name="pred",tfr.default=false}) -> (!tfr.tensor) { + CHECK-NEXT: %[[cst:.*]] = constant 0 : i64 + CHECK-NEXT: %[[cst_1:.*]] = constant 2 : i64 + CHECK-NEXT: %[[cst_2:.*]] = "tfr.constant_tensor"(%[[cst]]) : (i64) -> !tfr.tensor + CHECK-NEXT: %[[Split:.*]] = tfr.call @tf__split(%[[cst_2]], %x, %[[cst_1]]) : (!tfr.tensor, !tfr.tensor, i64) -> (!tfr.tensor_list) + CHECK-NEXT: %[[cst_4:.*]] = constant 0 : index + CHECK-NEXT: %[[elt:.*]] = tfr.get_element %[[Split]][%idx] : (!tfr.tensor_list, index) -> !tfr.tensor + CHECK-NEXT: %[[cst_5:.*]] = constant 1 : index + CHECK-NEXT: %[[elt_1:.*]] = tfr.get_element %[[Split]][%idx_1] : (!tfr.tensor_list, index) -> !tfr.tensor + CHECK-NEXT: constant true + CHECK-NEXT: tfr.return %[[elt]] : !tfr.tensor + CHECK-NEXT: } + + CHECK-LABEL: tfr.func @tf__test_num_attrs_op(%x1: i64{tfr.name="x1",tfr.default=-10}, %y1: i64{tfr.name="y1",tfr.default=1}, %x2: f32{tfr.name="x2",tfr.default=0.0}, %y2: f32{tfr.name="y2",tfr.default=-3.0}) -> () { + CHECK-NEXT: %[[cst:.*]] = constant 0 : i64 + CHECK-NEXT: %[[cst_1:.*]] = constant 2 : i64 + CHECK-NEXT: %[[cst_2:.*]] = constant 1 : i64 + CHECK-NEXT: %[[zero:.*]] = constant 0 : i64 + CHECK-NEXT: %[[cst_3:.*]] = subi %zero, %cst_2 : i64 + CHECK-NEXT: %[[list:.*]] = "tfr.build_list"(%[[cst]], %[[cst_1]], %[[cst_3]], %x1) : (i64, i64, i64, i64) -> !tfr.attr + CHECK-NEXT: %[[cst_4:.*]] = constant true + CHECK-NEXT: %[[cst_5:.*]] = constant false + CHECK-NEXT: %[[cst_6:.*]] = "tfr.constant_tensor"(%[[list]]) : (!tfr.attr) -> !tfr.tensor + CHECK-NEXT: %[[cst_7:.*]] = "tfr.constant_tensor"(%y1) : (i64) -> !tfr.tensor + CHECK-NEXT: %[[cst_8:.*]] = "tfr.constant_tensor"(%[[cst_4]]) : (i1) -> !tfr.tensor + CHECK-NEXT: %[[cst_9:.*]] = "tfr.constant_tensor"(%[[cst_5]]) : (i1) -> !tfr.tensor + CHECK-NEXT: %[[cst_10:.*]] = constant -1 : i64 + CHECK-NEXT: %[[OneHot:.*]] = tfr.call @tf__one_hot(%[[cst_6]], %[[cst_7]], %[[cst_8]], %[[cst_9]], %[[cst_10]]) + CHECK-SAME: (!tfr.tensor, !tfr.tensor, !tfr.tensor, !tfr.tensor, i64) -> (!tfr.tensor) + CHECK-NEXT: constant true + CHECK-NEXT: tfr.return + CHECK-NEXT: } + + CHECK-LABEL: tfr.func @tf__test_two_outputs_op(%x: !tfr.tensor) -> (!tfr.tensor, !tfr.tensor) { + CHECK-NEXT: %[[cst:.*]] = constant 0 : i64 + CHECK-NEXT: %[[cst_1:.*]] = constant 2 : i64 + CHECK-NEXT: %[[cst_2:.*]] = "tfr.constant_tensor"(%[[cst]]) : (i64) -> !tfr.tensor + CHECK-NEXT: %[[Split:.*]] = tfr.call @tf__split(%[[cst_2]], %x, %[[cst_1]]) : (!tfr.tensor, !tfr.tensor, i64) -> (!tfr.tensor_list) + CHECK-NEXT: constant true + CHECK-NEXT: %[[cst_4:.*]] = constant 0 : index + CHECK-NEXT: %[[elt:.*]] = tfr.get_element %[[Split]][%cst_4] : (!tfr.tensor_list, index) -> !tfr.tensor + CHECK-NEXT: %[[cst_5:.*]] = constant 1 : index + CHECK-NEXT: %[[elt_1:.*]] = tfr.get_element %[[Split]][%cst_5] : (!tfr.tensor_list, index) -> !tfr.tensor + CHECK-NEXT: tfr.return %[[elt]], %[[elt_1]] : !tfr.tensor, !tfr.tensor + CHECK-NEXT: } + """ + self._check_code(mlir_code, mlir_code_exp) + + def test_tfr_control_flow(self): + mlir_code = tfr_gen(sys.modules[__name__], '_tfr_control_flow', [test_ops]) + mlir_code_exp = r""" + CHECK-LABEL: tfr.func @tf__test_two_inputs_op(%x: !tfr.tensor, %y: !tfr.tensor, + CHECK-SAME: %pred: i1{tfr.name="pred",tfr.default=false}) -> (!tfr.tensor) { + CHECK-NEXT: %[[if:.*]] = scf.if %pred -> (!tfr.tensor) { + CHECK-NEXT: constant true + CHECK-NEXT: scf.yield %x : !tfr.tensor + CHECK-NEXT: } else { + CHECK-NEXT: constant true + CHECK-NEXT: scf.yield %y : !tfr.tensor + CHECK-NEXT: } + CHECK-NEXT: tfr.return %if_stmt : !tfr.tensor + CHECK-NEXT: } + + CHECK-LABEL: tfr.func @tf__test_three_inputs_op(%x: !tfr.tensor, %y: !tfr.tensor, %z: !tfr.tensor, + CHECK-SAME: %select: !tfr.attr{tfr.name="act",tfr.default="z"}) -> (!tfr.tensor) { + CHECK-NEXT: %[[cst:.*]] = tfr.constant "x" -> !tfr.attr + CHECK-NEXT: %[[eq:.*]] = tfr.equal %select, %[[cst]] -> i1 + CHECK-NEXT: %[[if_stmt:.*]] = scf.if %[[eq]] -> (!tfr.tensor) { + CHECK-NEXT: %[[cst_1:.*]] = constant true + CHECK-NEXT: scf.yield %x : !tfr.tensor + CHECK-NEXT: } else { + CHECK-NEXT: %[[cst_2:.*]] = tfr.constant "y" -> !tfr.attr + CHECK-NEXT: %[[eq_1:.*]] = tfr.equal %select, %[[cst_2]] -> i1 + CHECK-NEXT: %[[if_stmt1:.*]] = scf.if %[[eq_1]] -> (!tfr.tensor) { + CHECK-NEXT: %[[cst_3:.*]] = constant true + CHECK-NEXT: scf.yield %y : !tfr.tensor + CHECK-NEXT: } else { + CHECK-NEXT: %[[cst_4:.*]] = constant true + CHECK-NEXT: scf.yield %z : !tfr.tensor + CHECK-NEXT: } + CHECK-NEXT: scf.yield %[[if_stmt1]] : !tfr.tensor + CHECK-NEXT: } + CHECK-NEXT: tfr.return %[[if_stmt]] : !tfr.tensor + CHECK-NEXT: } + + CHECK-LABEL: tfr.func @tf__test_input_n_op(%x: !tfr.tensor_list) -> (!tfr.tensor) { + CHECK-NEXT: %[[n:.*]] = constant 10 : i64 + CHECK-NEXT: %[[cst:.*]] = constant 0 : index + CHECK-NEXT: %[[elt:.*]] = tfr.get_element %x[%[[cst]]] : (!tfr.tensor_list, index) -> !tfr.tensor + CHECK-NEXT: %[[cst_1:.*]] = constant 1 : i64 + CHECK-NEXT: %[[begin:.*]] = index_cast %[[cst_1]] : i64 to index + CHECK-NEXT: %[[end:.*]] = index_cast %[[n]] : i64 to index + CHECK-NEXT: %[[step:.*]] = constant 1 : index + CHECK-NEXT: %[[for_stmt:.*]] = scf.for %[[itr_1:.*]] = %[[begin]] to %[[end]] step %[[step]] + CHECK-SAME: iter_args(%[[it_arg:.*]] = %[[elt]]) -> (!tfr.tensor) { + CHECK-NEXT: %[[elt_1:.*]] = tfr.get_element %x[%itr_1] : (!tfr.tensor_list, index) -> !tfr.tensor + CHECK-NEXT: %[[Add:.*]] = tfr.call @tf__add(%[[it_arg]], %[[elt_1]]) : (!tfr.tensor, !tfr.tensor) -> (!tfr.tensor) + CHECK-NEXT: scf.yield %[[Add]] : !tfr.tensor + CHECK-NEXT: } + CHECK-NEXT: %{{.*}} = constant true + CHECK-NEXT: tfr.return %[[for_stmt]] : !tfr.tensor + CHECK-NEXT: } + """ + self._check_code(mlir_code, mlir_code_exp) + + def test_tfr_tf_ops(self): + mlir_code = tfr_gen(sys.modules[__name__], '_tfr_tf_ops', [test_ops]) + mlir_code_exp = r""" + CHECK-LABEL: tfr.func @tf__test_complex_tf_op(%lhs: !tfr.tensor, %rhs: !tfr.tensor) -> (!tfr.tensor_list) { + CHECK-NEXT: %[[cst:.*]] = constant 1 : i64 + CHECK-NEXT: %[[zero:.*]] = constant 0 : i64 + CHECK-NEXT: %[[cst_1:.*]] = subi %[[zero]], %cst : i64 + CHECK-NEXT: %[[cst_2:.*]] = "tfr.constant_tensor"(%[[cst_1]]) : (i64) -> !tfr.tensor + CHECK-NEXT: %[[list:.*]] = "tfr.build_list"(%rhs, %[[cst_2]]) : (!tfr.tensor, !tfr.tensor) -> !tfr.tensor_list + CHECK-NEXT: %[[cst_3:.*]] = constant 0 : i64 + CHECK-NEXT: %[[cst_4:.*]] = constant 2 : i64 + CHECK-NEXT: %[[zero_1:.*]] = constant 0 : i64 + CHECK-NEXT: %[[pack:.*]] = tfr.call @tf__pack(%[[list]], %[[zero_1]]) : (!tfr.tensor_list, i64) -> !tfr.tensor + CHECK-NEXT: %[[cst_5:.*]] = "tfr.constant_tensor"(%[[cst_3]]) : (i64) -> !tfr.tensor + CHECK-NEXT: %[[SplitV:.*]] = tfr.call @tf__split_v(%lhs, %[[pack]], %[[cst_5]], %[[cst_4]]) + CHECK-NEXT: %[[idx:.*]] = constant 0 : index + CHECK-NEXT: %[[elt:.*]] = tfr.get_element %SplitV[%idx] : (!tfr.tensor_list, index) -> !tfr.tensor + CHECK-NEXT: %[[idx_1:.*]] = constant 1 : index + CHECK-NEXT: %[[elt_1:.*]] = tfr.get_element %SplitV[%idx_1] : (!tfr.tensor_list, index) -> !tfr.tensor + CHECK-NEXT: %[[list_1:.*]] = "tfr.build_list"(%rhs, %rhs) : (!tfr.tensor, !tfr.tensor) -> !tfr.tensor_list + CHECK-NEXT: %[[cst_6:.*]] = constant 1 : i64 + CHECK-NEXT: %[[cst_7:.*]] = constant 2 : i64 + CHECK-NEXT: %[[zero_2:.*]] = constant 0 : i64 + CHECK-NEXT: %[[pack_1:.*]] = tfr.call @tf__pack(%[[list_1]], %[[zero_2]]) : (!tfr.tensor_list, i64) -> !tfr.tensor + CHECK-NEXT: %[[cst_8:.*]] = "tfr.constant_tensor"(%[[cst_6]]) : (i64) -> !tfr.tensor + CHECK-NEXT: %[[SplitV_1:.*]] = tfr.call @tf__split_v(%lhs, %[[pack_1]], %[[cst_8]], %[[cst_7]]) + CHECK-NEXT: %[[idx_2:.*]] = constant 0 : index + CHECK-NEXT: %[[elt_2:.*]] = tfr.get_element %SplitV_1[%idx_2] : (!tfr.tensor_list, index) -> !tfr.tensor + CHECK-NEXT: %[[idx_3:.*]] = constant 1 : index + CHECK-NEXT: %[[elt_3:.*]] = tfr.get_element %SplitV_1[%idx_3] : (!tfr.tensor_list, index) -> !tfr.tensor + CHECK-NEXT: %[[cst_9:.*]] = constant true + CHECK-NEXT: %[[list_2:.*]] = "tfr.build_list"(%[[elt]], %[[elt_3]]) : (!tfr.tensor, !tfr.tensor) -> !tfr.tensor_list + CHECK-NEXT: tfr.return %[[list_2]] : !tfr.tensor_list + CHECK-NEXT: } + + CHECK-LABEL: tfr.func @tf__test_identity_op(%x: !tfr.tensor) -> (!tfr.tensor) { + CHECK-NEXT: %cst = constant true + CHECK-NEXT: %[[Id:.*]] = tfr.call @tf__identity(%x) : (!tfr.tensor) -> (!tfr.tensor) + CHECK-NEXT: tfr.return %[[Id]] : !tfr.tensor + CHECK-NEXT: } + + CHECK-LABEL: tfr.func @tf__test_two_inputs_op(%x: !tfr.tensor, %y: !tfr.tensor, + CHECK-SAME: %pred: i1{tfr.name="pred",tfr.default=false}) -> (!tfr.tensor) { + CHECK-NEXT: %[[if_stmt:.*]] = scf.if %pred -> (!tfr.tensor) { + CHECK-NEXT: %cst = constant true + CHECK-NEXT: %[[Add:.*]] = tfr.call @tf__add(%x, %y) : (!tfr.tensor, !tfr.tensor) -> (!tfr.tensor) + CHECK-NEXT: scf.yield %[[Add]] : !tfr.tensor + CHECK-NEXT: } else { + CHECK-NEXT: %cst_1 = constant true + CHECK-NEXT: %[[cst_2:.*]] = constant 0 : i64 + CHECK-NEXT: %[[list:.*]] = "tfr.build_list"(%x, %y) : (!tfr.tensor, !tfr.tensor) -> !tfr.tensor_list + CHECK-NEXT: %[[Concat:.*]] = tfr.call @tf__concat(%[[cst_2]], %[[list]]) : (i64, !tfr.tensor_list) -> (!tfr.tensor) + CHECK-NEXT: scf.yield %[[Concat]] : !tfr.tensor + CHECK-NEXT: } + CHECK-NEXT: tfr.return %[[if_stmt]] : !tfr.tensor + CHECK-NEXT: } + + CHECK-LABEL: tfr.func @tf__test_input_n_op(%ins: !tfr.tensor_list) -> (!tfr.tensor) { + CHECK-NEXT: %cst = constant true + CHECK-NEXT: %[[cst_1:.*]] = constant 0 : index + CHECK-NEXT: %[[elt:.*]] = tfr.get_element %ins[%cst_1] : (!tfr.tensor_list, index) -> !tfr.tensor + CHECK-NEXT: %[[cst_2:.*]] = constant 1 : index + CHECK-NEXT: %[[elt_1:.*]] = tfr.get_element %ins[%cst_2] : (!tfr.tensor_list, index) -> !tfr.tensor + CHECK-NEXT: %[[cst_3:.*]] = constant false + CHECK-NEXT: %[[call:.*]] = tfr.call @tf__test_two_inputs_op( + CHECK-SAME: %[[elt]], %[[elt_1]], %[[cst_3]]) : (!tfr.tensor, !tfr.tensor, i1) -> (!tfr.tensor) + CHECK-NEXT: tfr.return %[[call]] : !tfr.tensor + CHECK-NEXT: } + + CHECK-LABEL: tfr.func @tf__add_(!tfr.tensor,!tfr.tensor) -> (!tfr.tensor) attributes {T} + + CHECK-LABEL: tfr.func @tf__concat_(!tfr.tensor,!tfr.tensor_list) -> (!tfr.tensor) attributes {N,T,i32_} + + CHECK-LABEL: tfr.func @tf__identity_(!tfr.tensor) -> (!tfr.tensor) attributes {T} + + CHECK-LABEL: tfr.func @tf__pack_(!tfr.tensor_list,i64{tfr.name="axis"}) -> (!tfr.tensor) attributes {N,T,axis} + + CHECK-LABEL: tfr.func @tf__split_v_(!tfr.tensor,!tfr.tensor,!tfr.tensor,i64{tfr.name="num_split"}) -> (!tfr.tensor_list) attributes {T,Tlen,i32_,num_split} + + CHECK-LABEL: tfr.func @tf__test_two_inputs_op_(!tfr.tensor,!tfr.tensor,i1{tfr.name="pred"}) -> (!tfr.tensor) attributes {T,pred} + + CHECK-LABEL: tfr.func @tf__test_complex_tf_op_(!tfr.tensor,!tfr.tensor,i64{tfr.name="N"}) -> (!tfr.tensor_list) attributes {N,T,Tlen} + + CHECK-LABEL: tfr.func @tf__test_identity_op_(!tfr.tensor) -> (!tfr.tensor) attributes {T} + + CHECK-LABEL: tfr.func @tf__test_two_inputs_op_(!tfr.tensor,!tfr.tensor,i1{tfr.name="pred"}) -> (!tfr.tensor) attributes {T,pred} + + CHECK-LABEL: tfr.func @tf__test_input_n_op_(!tfr.tensor_list) -> (!tfr.tensor) attributes {N,T} + """ + self._check_code(mlir_code, mlir_code_exp) + + def test_tfr_attrs(self): + mlir_code = tfr_gen(sys.modules[__name__], '_tfr_attrs', [test_ops]) + mlir_code_exp = r""" + CHECK-LABEL: tfr.func @tf__test_num_attrs_op( + CHECK-SAME: %x: i64{tfr.name="x1",tfr.default=-10}, + CHECK-SAME: %y: i64{tfr.name="y1",tfr.default=1}, + CHECK-SAME: %x1: f32{tfr.name="x2",tfr.default=0.0}, + CHECK-SAME: %y1: f32{tfr.name="y2",tfr.default=-3.0}) -> () { + CHECK-NEXT: %{{.*}} = "tfr.build_list"(%x, %y) : (i64, i64) -> !tfr.attr + CHECK-NEXT: %{{.*}} = cmpi "eq", %x, %y : i64 + CHECK-NEXT: %{{.*}} = cmpi "ult", %x, %y : i64 + CHECK-NEXT: %{{.*}} = cmpi "ule", %x, %y : i64 + CHECK-NEXT: %{{.*}} = cmpi "ugt", %x, %y : i64 + CHECK-NEXT: %{{.*}} = cmpi "uge", %x, %y : i64 + CHECK-NEXT: %{{.*}} = cmpi "ne", %x, %y : i64 + CHECK-NEXT: %{{.*}} = addi %x, %y : i64 + CHECK-NEXT: %{{.*}} = subi %x, %y : i64 + CHECK-NEXT: %[[add_1:.*]] = addi %sub, %x : i64 + CHECK-NEXT: %[[cst:.*]] = constant 1 : i64 + CHECK-NEXT: %{{.*}} = addi %[[add_1]], %[[cst]] : i64 + CHECK-NEXT: %{{.*}} = cmpf "ugt", %x1, %y1 : f32 + CHECK-NEXT: %{{.*}} = addf %x1, %y1 : f32 + CHECK-NEXT: %{{.*}} = "tfr.build_list"(%x1, %y1) : (f32, f32) -> !tfr.attr + CHECK-NEXT: %{{.*}} = constant true + CHECK-NEXT: tfr.return + CHECK-NEXT: } + + CHECK-LABEL: tfr.func @tf__test_non_num_attrs_op( + CHECK-SAME: %x: !tfr.attr{tfr.name="z"}, + CHECK-SAME: %y: !tfr.attr{tfr.name="x",tfr.default="hello"}, + CHECK-SAME: %z: !tfr.attr{tfr.name="y",tfr.default=f32}) -> () { + CHECK-NEXT: %{{.*}} = tfr.equal %x, %y -> i1 + CHECK-NEXT: %[[cst:.*]] = tfr.constant "test" -> !tfr.attr + CHECK-NEXT: %{{.*}} = tfr.equal %x, %[[cst]] -> i1 + CHECK-NEXT: %{{.*}} = tfr.equal %y, %z -> i1 + CHECK-NEXT: %{{.*}} = constant true + CHECK-NEXT: tfr.return + CHECK-NEXT: } + """ + self._check_code(mlir_code, mlir_code_exp) + + def test_tf_tensor_shape(self): + mlir_code = tfr_gen(sys.modules[__name__], '_tfr_shapes', [test_ops]) + mlir_code_exp = r""" + CHECK-LABEL: tfr.func @tf__test_identity_op(%x: !tfr.tensor) -> (!tfr.tensor) { + CHECK-NEXT: %[[shape:.*]] = tfr.get_shape %x -> !shape.shape + + CHECK-NEXT: %[[shape_1:.*]] = tfr.get_shape %x -> !shape.shape + CHECK-NEXT: %[[len:.*]] = shape.rank %[[shape_1]] : !shape.shape -> !shape.size + CHECK-NEXT: %[[index:.*]] = shape.size_to_index %[[len]] : !shape.size + CHECK-NEXT: %[[begin:.*]] = constant 0 : index + CHECK-NEXT: %[[step:.*]] = constant 1 : index + CHECK-NEXT: scf.for %[[itr_1:.*]] = %[[begin]] to %[[index]] step %[[step]] { + CHECK-NEXT: %[[size:.*]] = shape.get_extent %[[shape_1]], %[[itr_1]]: !shape.shape, index -> !shape.size + CHECK-NEXT: %[[elt:.*]] = shape.size_to_index %[[size]] : !shape.size + CHECK-NEXT: scf.yield + CHECK-NEXT: } + + CHECK-NEXT: %[[cst:.*]] = constant 1 : i64 + CHECK-NEXT: %[[len_1:.*]] = shape.rank %shape_1 : !shape.shape -> !shape.size + CHECK-NEXT: %[[len_size_1:.*]] = shape.size_to_index %[[len_1]] : !shape.size + CHECK-NEXT: %[[cst_1:.*]] = constant 2 : i64 + CHECK-NEXT: %[[begin_1:.*]] = index_cast %[[cst]] : i64 to index + CHECK-NEXT: %[[step_1:.*]] = index_cast %[[cst_1]] : i64 to index + CHECK-NEXT: scf.for %[[itr_3:.*]] = %[[begin_1]] to %[[len_size_1]] step %[[step_1]] + + CHECK: %[[cst:.*]] = tfr.constant i32 -> !tfr.attr + CHECK-NEXT: %[[Shape:.*]] = tfr.call @tf__shape(%x, %[[cst]]) : (!tfr.tensor, !tfr.attr) -> (!tfr.tensor) + CHECK-NEXT: %{{.*}} = constant true + CHECK-NEXT: tfr.return %x : !tfr.tensor + CHECK-NEXT: } + """ + self._check_code(mlir_code, mlir_code_exp) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/compiler/mlir/tfr/python/tfr_wrapper.cc b/tensorflow/compiler/mlir/tfr/python/tfr_wrapper.cc new file mode 100644 index 00000000000..b7372cffe2d --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/python/tfr_wrapper.cc @@ -0,0 +1,58 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/AsmState.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Verifier.h" // from @llvm-project +#include "mlir/Parser.h" // from @llvm-project +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" +#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h" +#include "tensorflow/python/lib/core/pybind11_lib.h" +#include "tensorflow/python/lib/core/pybind11_status.h" + +PYBIND11_MODULE(tfr_wrapper, m) { + m.def("verify", [](std::string input) { + mlir::MLIRContext ctx(/*loadAllDialects=*/true); + auto& registry = ctx.getDialectRegistry(); + registry.insert(); + ctx.getDialectRegistry().loadAll(&ctx); + + llvm::SourceMgr source_mgr = llvm::SourceMgr(); + source_mgr.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(input), + llvm::SMLoc()); + auto module = mlir::parseSourceFile(source_mgr, &ctx); + if (!module) { + return false; + } + + mlir::SourceMgrDiagnosticHandler sourceMgrHandler(source_mgr, &ctx); + if (failed(mlir::verify(*module))) { + module->emitError("Invalid MLIR module: failed verification."); + return false; + } + return true; + }); +} diff --git a/tensorflow/compiler/mlir/tfr/resources/BUILD b/tensorflow/compiler/mlir/tfr/resources/BUILD new file mode 100644 index 00000000000..62ca65c5b57 --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/resources/BUILD @@ -0,0 +1,97 @@ +load("//tensorflow:tensorflow.bzl", "tf_custom_op_library", "tf_gen_op_wrapper_py") +load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") + +package( + default_visibility = [ + ":friends", + ], + licenses = ["notice"], # Apache 2.0 +) + +package_group( + name = "friends", + includes = ["//third_party/mlir:subpackages"], + packages = [ + "//learning/brain/experimental/mlir/tfr/...", + "//tensorflow/compiler/mlir/...", + ], +) + +filegroup( + name = "decomposition_lib", + srcs = ["decomposition_lib.mlir"], +) + +cc_library( + name = "composite_ops_cc", + srcs = ["composite_ops.cc"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], + alwayslink = 1, +) + +tf_custom_op_library( + name = "composite_ops.so", + srcs = [ + "composite_ops.cc", + ], +) + +tf_gen_op_wrapper_py( + name = "gen_composite_ops", + out = "gen_composite_ops.py", + deps = [ + ":composite_ops_cc", + ], +) + +tf_custom_op_py_library( + name = "composite_ops", + dso = [":composite_ops.so"], + kernels = [":composite_ops_cc"], + visibility = ["//visibility:public"], + deps = [ + ":gen_composite_ops", + ], +) + +cc_library( + name = "test_ops_cc", + srcs = ["test_ops.cc"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], + alwayslink = 1, +) + +tf_custom_op_library( + name = "test_ops.so", + srcs = [ + "test_ops.cc", + ], +) + +tf_gen_op_wrapper_py( + name = "gen_test_ops", + out = "gen_test_ops.py", + deps = [ + ":test_ops_cc", + ], +) + +tf_custom_op_py_library( + name = "test_ops", + dso = ["test_ops.so"], + kernels = [ + ":test_ops_cc", + ], + srcs_version = "PY2AND3", + deps = [ + ":gen_test_ops", + ], +) diff --git a/tensorflow/c/eager/parallel_device/parallel_device_ops.cc b/tensorflow/compiler/mlir/tfr/resources/composite_ops.cc similarity index 58% rename from tensorflow/c/eager/parallel_device/parallel_device_ops.cc rename to tensorflow/compiler/mlir/tfr/resources/composite_ops.cc index 1decffca047..8120625bc89 100644 --- a/tensorflow/c/eager/parallel_device/parallel_device_ops.cc +++ b/tensorflow/compiler/mlir/tfr/resources/composite_ops.cc @@ -14,13 +14,26 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" -// TODO(allenl): Figure out if we need this op, and if so whether we should move -// it to core TF. Right now the eager C API does some checking of op -// registrations before calling into custom devices, but we may be able to avoid -// that. -REGISTER_OP("DeviceID") - .Output("device_id: int64") - .SetIsStateful() - .SetShapeFn(tensorflow::shape_inference::ScalarShape); +namespace tensorflow { + +REGISTER_OP("MyAddN") + .Input("inputs: N * T") + .Output("sum: T") + .Attr("N: int >= 1") + .Attr("T: {numbertype, variant}") + .SetIsCommutative() + .SetIsAggregate(); + +REGISTER_OP("MyBiasedDense") + .Input("input: T") + .Input("weight: T") + .Input("bias: T") + .Output("out: T") + .Attr("T: {float, int8}") + .Attr("act: {'', 'relu', 'relu6'} = ''"); + +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfr/resources/decomposition_lib.mlir b/tensorflow/compiler/mlir/tfr/resources/decomposition_lib.mlir new file mode 100644 index 00000000000..f67d24c9fec --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/resources/decomposition_lib.mlir @@ -0,0 +1,109 @@ +// A test resource file which contains some pre-defined internal tfr.functions +// for decomposition and external tfr.functions for raising the decomposition +// result to the ops in the TF dialect. +// +// All the tfr.func functions are supposed to be translated from the Python +// function with tf.composite annotation. +// All the external tfr.func functions modeles the op signature defined by +// OpDefs. + +tfr.func @tf__my_add_n(%values: !tfr.tensor_list, + %n: i64 {tfr.name="N"}) -> !tfr.tensor { + %index = constant 0 : index + %cst = constant 1 : i64 + %eq = cmpi "eq", %n, %cst : i64 + %v1 = tfr.get_element %values[%index] : (!tfr.tensor_list, index) -> !tfr.tensor + %res = scf.if %eq -> !tfr.tensor { + scf.yield %v1 : !tfr.tensor + } else { + %step = index_cast %cst : i64 to index + %end = index_cast %n : i64 to index + %reduce = scf.for %i = %step to %end step %step iter_args(%reduce_iter=%v1) -> !tfr.tensor { + %v = tfr.get_element %values[%i] : (!tfr.tensor_list, index) -> !tfr.tensor + %reduce_next = tfr.call @tf__add(%reduce_iter, %v) : (!tfr.tensor, !tfr.tensor) -> !tfr.tensor + scf.yield %reduce_next : !tfr.tensor + } + scf.yield %reduce : !tfr.tensor + } + tfr.return %res : !tfr.tensor +} + +// Translated from tf.compose Python function. +tfr.func @tf__my_biased_dense(%input: !tfr.tensor, %weight: !tfr.tensor, + %bias: !tfr.tensor, + %act: !tfr.attr{tfr.name="act", tfr.default=""}) -> !tfr.tensor { + %dot = tfr.call @tf__mat_mul(%input, %weight) : (!tfr.tensor, !tfr.tensor) -> !tfr.tensor + %add = tfr.call @tf__add(%dot, %bias) : (!tfr.tensor, !tfr.tensor) -> !tfr.tensor + + %relu = tfr.constant "relu" -> !tfr.attr + %relu6 = tfr.constant "relu6" -> !tfr.attr + + %is_relu = tfr.equal %act, %relu -> i1 + %res = scf.if %is_relu -> !tfr.tensor { + %applied_relu = tfr.call @tf__relu(%add) : (!tfr.tensor) -> !tfr.tensor + scf.yield %applied_relu : !tfr.tensor + } else { + %is_relu6 = tfr.equal %act, %relu6 -> i1 + %res1 = scf.if %is_relu6 -> !tfr.tensor { + %applied_relu6 = tfr.call @tf__relu6(%add) : (!tfr.tensor) -> !tfr.tensor + scf.yield %applied_relu6 : !tfr.tensor + } else { + scf.yield %add : !tfr.tensor + } + scf.yield %res1 : !tfr.tensor + } + tfr.return %res : !tfr.tensor +} + +// This is a wong decomposition and used to verify that tf.Elu isn't decomposed +// since its kernel has been registered. +tfr.func @tf__elu_(%input: !tfr.tensor) -> !tfr.tensor { + tfr.return %input : !tfr.tensor +} + +// Translated from: +// +// REGISTER_OP("Add") +// .Input("x: T") +// .Input("y: T") +// .Output("z: T") +// .Attr( +// "T: {bfloat16, half, float, double, uint8, int8, int16, int32, int64, " +// "complex64, complex128, string}") +tfr.func @tf__add_(!tfr.tensor, !tfr.tensor) + -> !tfr.tensor attributes{T} + +// Translated from: +// +// REGISTER_OP("MatMul") +// .Input("a: T") +// .Input("b: T") +// .Output("product: T") +// .Attr("transpose_a: bool = false") +// .Attr("transpose_b: bool = false") +// .Attr("T: {bfloat16, half, float, double, int32, int64, complex64, complex128}") +// T is a derived attribute. +// transpose_a and transpose_b is materialized attributes. +tfr.func @tf__mat_mul_(!tfr.tensor, !tfr.tensor, + i1 {tfr.name="transpose_a", tfr.default=false}, + i1 {tfr.name="transpose_b", tfr.default=false}) + -> !tfr.tensor attributes{T} + +// Translated from: +// +// REGISTER_OP("Relu") +// .Input("features: T") +// .Output("activations: T") +// .Attr("T: {realnumbertype, qint8}") +// T is a derived attribute. +tfr.func @tf__relu_(!tfr.tensor) -> !tfr.tensor attributes{T} + + +// Translated from: +// +// REGISTER_OP("Relu6") +// .Input("features: T") +// .Output("activations: T") +// .Attr("T: {realnumbertype}") +// T is a derived attribute. +tfr.func @tf__relu6_(!tfr.tensor) -> !tfr.tensor attributes{T} diff --git a/tensorflow/compiler/mlir/tfr/resources/test_ops.cc b/tensorflow/compiler/mlir/tfr/resources/test_ops.cc new file mode 100644 index 00000000000..3aaa0850805 --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/resources/test_ops.cc @@ -0,0 +1,86 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op.h" + +namespace tensorflow { + +REGISTER_OP("TestNoOp"); + +REGISTER_OP("TestIdentityOp") + .Input("input: T") + .Output("output: T") + .Attr("T: numbertype"); + +REGISTER_OP("TestIdentityNOp") + .Input("input: N * T") + .Output("output: N * T") + .Attr("N: int >= 1") + .Attr("T: numbertype"); + +REGISTER_OP("TestInputNOp") + .Input("input: N * T") + .Output("output: T") + .Attr("N: int >= 1") + .Attr("T: numbertype"); + +REGISTER_OP("TestOutputNOp") + .Input("input: T") + .Output("output: N * T") + .Attr("N: int >= 1") + .Attr("T: numbertype"); + +REGISTER_OP("TestTwoInputsOp") + .Input("lhs: T") + .Input("rhs: T") + .Output("output: T") + .Attr("T: numbertype") + .Attr("pred: bool = false"); + +REGISTER_OP("TestComplexTFOp") + .Input("lhs: T") + .Input("rhs: Tlen") + .Output("output: N * T") + .Attr("N: int >= 1") + .Attr("T: numbertype") + .Attr("Tlen: {int32, int64} = DT_INT64"); + +REGISTER_OP("TestNumAttrsOp") + .Attr("x1: int = -10") + .Attr("y1: int = 1") + .Attr("x2: float = 0.0") + .Attr("y2: float = -3.0"); + +REGISTER_OP("TestNonNumAttrsOp") + .Attr("z: shape") + .Attr("x: string = 'hello'") + .Attr("y: type = DT_FLOAT"); + +REGISTER_OP("TestThreeInputsOp") + .Input("x: T") + .Input("y: T") + .Input("z: T") + .Output("output: T") + .Attr("T: numbertype") + .Attr("act: {'x', 'y', 'z'} = 'z'"); + +REGISTER_OP("TestTwoOutputsOp") + .Input("input: T") + .Output("output1: T") + .Output("output2: T") + .Attr("T: numbertype"); + +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfr/tests/control_flow.mlir b/tensorflow/compiler/mlir/tfr/tests/control_flow.mlir new file mode 100644 index 00000000000..8dacd57653f --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/tests/control_flow.mlir @@ -0,0 +1,57 @@ +// RUN: tfr-opt %s -tfr-decompose -verify-diagnostics -split-input-file | FileCheck %s + +tfr.func @tf__my_pack(%values: !tfr.tensor_list, + %n: i32 {tfr.name="N"}, + %axis: i32 {tfr.name="axis"}) -> !tfr.tensor { + %index = constant 0 : index + %cst = constant 1 : i32 + %eq = cmpi "eq", %n, %cst : i32 + %v1 = tfr.get_element %values[%index] : (!tfr.tensor_list, index) -> !tfr.tensor + %temp = tfr.call @tf__expand_dims(%v1, %axis) : (!tfr.tensor, i32) -> !tfr.tensor + %res = scf.if %eq -> !tfr.tensor { + scf.yield %temp : !tfr.tensor + } else { + %step = index_cast %cst : i32 to index + %end = index_cast %n : i32 to index + %reduce = scf.for %i = %step to %end step %step iter_args(%reduce_iter=%temp) -> !tfr.tensor { + %v = tfr.get_element %values[%i] : (!tfr.tensor_list, index) -> !tfr.tensor + %temp1 = tfr.call @tf__expand_dims(%v, %axis) : (!tfr.tensor, i32) -> !tfr.tensor + %reduce_next = tfr.call @tf__risc_concat(%reduce_iter, %temp1, %axis) : (!tfr.tensor, !tfr.tensor, i32) -> !tfr.tensor + scf.yield %reduce_next : !tfr.tensor + } + scf.yield %reduce : !tfr.tensor + } + tfr.return %res : !tfr.tensor +} + +// CHECK-LABEL: pack_one +func @pack_one(%arg0: tensor<2x3xf32>) -> tensor<1x2x3xf32> { + %0 = "tf.MyPack"(%arg0) {N=1:i32, axis=0:i32} : (tensor<2x3xf32>) -> tensor<1x2x3xf32> + return %0 : tensor<1x2x3xf32> + +// CHECK-NEXT: %[[AXIS:.*]] = constant 0 : i32 +// CHECK-NEXT: %[[CAST:.*]] = "tfr.cast"(%arg0) : (tensor<2x3xf32>) -> !tfr.tensor +// CHECK-NEXT: %[[ED:.*]] = tfr.call @tf__expand_dims(%[[CAST]], %[[AXIS]]) : (!tfr.tensor, i32) -> !tfr.tensor +// CHECK-NEXT: %[[BACK:.*]] = "tfr.cast"(%[[ED]]) : (!tfr.tensor) -> tensor<1x2x3xf32> +// CHECK-NEXT: return %[[BACK]] : tensor<1x2x3xf32> +} + +// CHECK-LABEL: pack_multiple +func @pack_multiple(%arg0: tensor<2x3xf32>, + %arg1: tensor<2x3xf32>, + %arg2: tensor<2x3xf32>) -> tensor<3x2x3xf32> { + %0 = "tf.MyPack"(%arg0, %arg1, %arg2) {N=3:i32, axis=0:i32} : (tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<3x2x3xf32> + return %0 : tensor<3x2x3xf32> + +// CHECK-NEXT: %[[AXIS:.*]] = constant 0 : i32 +// CHECK-NEXT: %[[CAST0:.*]] = "tfr.cast"(%arg0) : (tensor<2x3xf32>) -> !tfr.tensor +// CHECK-NEXT: %[[CAST1:.*]] = "tfr.cast"(%arg1) : (tensor<2x3xf32>) -> !tfr.tensor +// CHECK-NEXT: %[[CAST2:.*]] = "tfr.cast"(%arg2) : (tensor<2x3xf32>) -> !tfr.tensor +// CHECK-NEXT: %[[EX0:.*]] = tfr.call @tf__expand_dims(%[[CAST0]], %[[AXIS]]) : (!tfr.tensor, i32) -> !tfr.tensor +// CHECK-NEXT: %[[EX1:.*]] = tfr.call @tf__expand_dims(%[[CAST1]], %[[AXIS]]) : (!tfr.tensor, i32) -> !tfr.tensor +// CHECK-NEXT: %[[CONCAT1:.*]] = tfr.call @tf__risc_concat(%[[EX0]], %[[EX1]], %c0_i32) : (!tfr.tensor, !tfr.tensor, i32) -> !tfr.tensor +// CHECK-NEXT: %[[EX2:.*]] = tfr.call @tf__expand_dims(%[[CAST2]], %[[AXIS]]) : (!tfr.tensor, i32) -> !tfr.tensor +// CHECK-NEXT: %[[CONCAT2:.*]] = tfr.call @tf__risc_concat(%[[CONCAT1]], %[[EX2]], %[[AXIS]]) : (!tfr.tensor, !tfr.tensor, i32) -> !tfr.tensor +// CHECK-NEXT: %[[BACK:.*]] = "tfr.cast"(%[[CONCAT2]]) : (!tfr.tensor) -> tensor<3x2x3xf32> +// CHECK-NEXT: return %[[BACK]] : tensor<3x2x3xf32> +} diff --git a/tensorflow/compiler/mlir/tfr/tests/decompose.mlir b/tensorflow/compiler/mlir/tfr/tests/decompose.mlir new file mode 100644 index 00000000000..97f12c9fedb --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/tests/decompose.mlir @@ -0,0 +1,84 @@ +// RUN: tfr-opt %s -tfr-decompose -verify-diagnostics -split-input-file | FileCheck %s + +// CHECK-LABEL: @tf__fake_no_op +tfr.func @tf__fake_no_op(%arg0: !tfr.tensor) -> !tfr.tensor { + tfr.return %arg0 : !tfr.tensor + +// CHECK-NEXT: tfr.return %arg0 : !tfr.tensor +} + +// CHECK-LABEL: @tf__intermediate +tfr.func @tf__intermediate(%arg0: !tfr.tensor) -> !tfr.tensor { + %0 = tfr.call @tf__risc(%arg0) : (!tfr.tensor) -> !tfr.tensor + tfr.return %0 : !tfr.tensor + +// CHECK-NEXT: %[[id:.*]] = tfr.call @tf__risc(%arg0) : (!tfr.tensor) -> !tfr.tensor +// CHECK-NEXT: tfr.return %[[id]] : !tfr.tensor +} + +// CHECK-LABEL: @tf__fused_n +tfr.func @tf__fused_n( + %arg0: !tfr.tensor, + %arg1: !tfr.tensor_list, + %arg2: index {tfr.name="A",tfr.default=1:index}) + -> !tfr.tensor_list { + %0 = tfr.call @tf__intermediate(%arg0) : (!tfr.tensor) -> !tfr.tensor + %1 = tfr.get_element %arg1[%arg2] : (!tfr.tensor_list, index) -> !tfr.tensor + %2 = tfr.call @tf__intermediate(%1) : (!tfr.tensor) -> !tfr.tensor + %3 = "tfr.build_list"(%0, %2) : (!tfr.tensor, !tfr.tensor) -> !tfr.tensor_list + tfr.return %3 : !tfr.tensor_list + +// CHECK-NEXT: %[[id1:.*]] = tfr.call @tf__intermediate(%arg0) : (!tfr.tensor) -> !tfr.tensor +// CHECK-NEXT: %[[ge:.*]] = tfr.get_element %arg1[%arg2] : (!tfr.tensor_list, index) -> !tfr.tensor +// CHECK-NEXT: %[[id2:.*]] = tfr.call @tf__intermediate(%[[ge]]) : (!tfr.tensor) -> !tfr.tensor +// CHECK-NEXT: %[[bl:.*]] = "tfr.build_list"(%[[id1]], %[[id2]]) : (!tfr.tensor, !tfr.tensor) -> !tfr.tensor_list +// CHECK-NEXT: tfr.return %[[bl]] : !tfr.tensor_list +} + +//------------------------ + +// CHECK-LABEL: decompose_tf_no_op +func @decompose_tf_no_op(%arg0: tensor<1x2x3x4x!tf.string>) -> tensor<1x2x3x4x!tf.string> { + %0 = "tf.FakeNoOp"(%arg0) : (tensor<1x2x3x4x!tf.string>) -> tensor<1x2x3x4x!tf.string> + return %0 : tensor<1x2x3x4x!tf.string> + +// CHECK-NEXT: return %arg0 +} + +// CHECK-LABEL: decompose_tf_intermediate +func @decompose_tf_intermediate(%arg0: tensor<1x2x3x4x!tf.string>) -> tensor<1x2x3x4x!tf.string> { + %0 = "tf.Intermediate"(%arg0) : (tensor<1x2x3x4x!tf.string>) -> tensor<1x2x3x4x!tf.string> + return %0 : tensor<1x2x3x4x!tf.string> + +// CHECK-NEXT: %[[casted:.*]] = "tfr.cast"(%arg0) : (tensor<1x2x3x4x!tf.string>) -> !tfr.tensor +// CHECK-NEXT: %[[id:.*]] = tfr.call @tf__risc(%[[casted]]) : (!tfr.tensor) -> !tfr.tensor +// CHECK-NEXT: %[[back:.*]] = "tfr.cast"(%[[id]]) : (!tfr.tensor) -> tensor<1x2x3x4x!tf.string> +// CHECK-NEXT: return %[[back]] +} + +// CHECK-LABEL: decompose_fused_n_default +func @decompose_fused_n_default(%arg0: tensor<1x2x3x4x!tf.string>, %arg1: tensor, %arg2: tensor) -> tensor { + %0:2 = "tf.FusedN"(%arg0, %arg1, %arg2) : (tensor<1x2x3x4x!tf.string>, tensor, tensor) -> (tensor<1x2x3x4x!tf.string>, tensor) + return %0#1 : tensor + +// CHECK-NEXT: %[[in0:.*]] = "tfr.cast"(%arg0) : (tensor<1x2x3x4x!tf.string>) -> !tfr.tensor +// CHECK-NEXT: %[[in2:.*]] = "tfr.cast"(%arg2) : (tensor) -> !tfr.tensor +// CHECK-NEXT: %[[id0:.*]] = tfr.call @tf__risc(%[[in0]]) : (!tfr.tensor) -> !tfr.tensor +// CHECK-NEXT: %[[id2:.*]] = tfr.call @tf__risc(%[[in2]]) : (!tfr.tensor) -> !tfr.tensor +// CHECK-NEXT: %[[back:.*]] = "tfr.cast"(%[[id2]]) : (!tfr.tensor) -> tensor +// CHECK-NEXT: return %[[back]] : tensor +} + +// CHECK-LABEL: decompose_fused_n +func @decompose_fused_n(%arg0: tensor<1x2x3x4x!tf.string>, %arg1: tensor, %arg2: tensor) -> tensor { + %0:2 = "tf.FusedN"(%arg0, %arg1, %arg2) {A=0:index} : (tensor<1x2x3x4x!tf.string>, tensor, tensor) -> (tensor<1x2x3x4x!tf.string>, tensor) + return %0#1 : tensor + +// CHECK-NEXT: %[[in0:.*]] = "tfr.cast"(%arg0) : (tensor<1x2x3x4x!tf.string>) -> !tfr.tensor +// CHECK-NEXT: %[[in1:.*]] = "tfr.cast"(%arg1) : (tensor) -> !tfr.tensor +// CHECK-NEXT: %[[id0:.*]] = tfr.call @tf__risc(%[[in0]]) : (!tfr.tensor) -> !tfr.tensor +// CHECK-NEXT: %[[id1:.*]] = tfr.call @tf__risc(%[[in1]]) : (!tfr.tensor) -> !tfr.tensor +// CHECK-NEXT: %[[back:.*]] = "tfr.cast"(%[[id1]]) : (!tfr.tensor) -> tensor +// CHECK-NEXT: return %[[back]] : tensor +} + diff --git a/tensorflow/compiler/mlir/tfr/tests/end2end.mlir b/tensorflow/compiler/mlir/tfr/tests/end2end.mlir new file mode 100644 index 00000000000..5738020ccdb --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/tests/end2end.mlir @@ -0,0 +1,235 @@ +// RUN: tfr-opt %s -tfr-decompose -tfr-raise-to-tf -canonicalize -verify-diagnostics -split-input-file | FileCheck %s + +//=================> User models, from GraphDef <==================== + +// CHECK-LABEL: my_identity +func @my_identity(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { + %0 = "tf.MyIdentity"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32> + return %0 : tensor<2x3xf32> + +// CHECK-NEXT: return %arg0 : tensor<2x3xf32> +} + +// CHECK-LABEL: my_rsqrt +func @my_rsqrt(%arg0: tensor<2x3xf32>) -> tensor<3x2x3xf32> { + %0 = "tf.MyRsqrt"(%arg0) : (tensor<2x3xf32>) -> tensor<3x2x3xf32> + return %0 : tensor<3x2x3xf32> + +// CHECK-NEXT: %[[RE:.*]] = "tf.RiscReciprocal"(%arg0) : (tensor<2x3xf32>) -> tensor<*xf32> +// CHECK-NEXT: %[[SQRT:.*]] = "tf.RiscSqrt"(%[[RE]]) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK-NEXT: %[[ES:.*]] = "tf.EnsureShape"(%[[SQRT]]) {shape = #tf.shape<3x2x3>} : (tensor<*xf32>) -> tensor<3x2x3xf32> +// CHECK-NEXT: return %[[ES]] : tensor<3x2x3xf32> +} + +// CHECK-LABEL: my_leaky_relu +func @my_leaky_relu(%arg0: tensor<2x3xf32>) -> tensor<3x2x3xf32> { + %0 = "tf.MyLeakyRelu"(%arg0) {alpha=3.0 : f32} : (tensor<2x3xf32>) -> tensor<3x2x3xf32> + return %0 : tensor<3x2x3xf32> + +// CHECK-NEXT: %[[ALPHA:.*]] = "tf.Const"() {value = dense<3.000000e+00> : tensor} : () -> tensor +// CHECK-NEXT: %[[SHAPE:.*]] = "tf.RiscShape"(%arg0) {T = i32} : (tensor<2x3xf32>) -> tensor<*xi32> +// CHECK-NEXT: %[[ALPHA1:.*]] = "tf.RiscBroadcast"(%[[ALPHA]], %[[SHAPE]]) : (tensor, tensor<*xi32>) -> tensor<*xf32> +// CHECK-NEXT: %[[MAX:.*]] = "tf.RiscMaximum"(%arg0, %[[ALPHA1]]) : (tensor<2x3xf32>, tensor<*xf32>) -> tensor<*xf32> +// CHECK-NEXT: %[[ES:.*]] = "tf.EnsureShape"(%[[MAX]]) {shape = #tf.shape<3x2x3>} : (tensor<*xf32>) -> tensor<3x2x3xf32> +// CHECK-NEXT: return %[[ES]] : tensor<3x2x3xf32> +} + +// CHECK-LABEL: my_leaky_relu_with_default +func @my_leaky_relu_with_default(%arg0: tensor<2x3xf32>) -> tensor<3x2x3xf32> { + %0 = "tf.MyLeakyRelu"(%arg0) : (tensor<2x3xf32>) -> tensor<3x2x3xf32> + return %0 : tensor<3x2x3xf32> + +// CHECK-NEXT: %[[ALPHA:.*]] = "tf.Const"() {value = dense<2.000000e-01> : tensor} : () -> tensor +// CHECK-NEXT: %[[SHAPE:.*]] = "tf.RiscShape"(%arg0) {T = i32} : (tensor<2x3xf32>) -> tensor<*xi32> +// CHECK-NEXT: %[[ALPHA1:.*]] = "tf.RiscBroadcast"(%[[ALPHA]], %[[SHAPE]]) : (tensor, tensor<*xi32>) -> tensor<*xf32> +// CHECK-NEXT: %[[MAX:.*]] = "tf.RiscMaximum"(%arg0, %[[ALPHA1]]) : (tensor<2x3xf32>, tensor<*xf32>) -> tensor<*xf32> +// CHECK-NEXT: %[[ES:.*]] = "tf.EnsureShape"(%[[MAX]]) {shape = #tf.shape<3x2x3>} : (tensor<*xf32>) -> tensor<3x2x3xf32> +// CHECK-NEXT: return %[[ES]] : tensor<3x2x3xf32> +} + +// CHECK-LABEL: my_cast +func @my_cast(%arg0: tensor<2x3xf32>) -> tensor<2x3xi32> { + %0 = "tf.MyCast"(%arg0) {Tout=i32} : (tensor<2x3xf32>) -> tensor<2x3xi32> + return %0 : tensor<2x3xi32> + +// CHECK-NEXT: %[[CAST:.*]] = "tf.RiscCast"(%arg0) {Tout = i32} : (tensor<2x3xf32>) -> tensor<*xi32> +// CHECK-NEXT: %[[ES:.*]] = "tf.EnsureShape"(%[[CAST]]) {shape = #tf.shape<2x3>} : (tensor<*xi32>) -> tensor<2x3xi32> +// CHECK-NEXT: return %[[ES]] : tensor<2x3xi32> +} + +// CHECK-LABEL: my_pack_single_input +func @my_pack_single_input(%arg0: tensor<2x3xf32>) -> tensor<3x2x3xf32> { + %0 = "tf.MyPack"(%arg0) {N=1:i32, axis=0:i32} : (tensor<2x3xf32>) -> tensor<3x2x3xf32> + return %0 : tensor<3x2x3xf32> + +// CHECK-NEXT: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor +// CHECK-NEXT: %[[ED:.*]] = "tf.ExpandDims"(%arg0, %[[AXIS]]) : (tensor<2x3xf32>, tensor) -> tensor<*xf32> +// CHECK-NEXT: %[[ES:.*]] = "tf.EnsureShape"(%[[ED]]) {shape = #tf.shape<3x2x3>} : (tensor<*xf32>) -> tensor<3x2x3xf32> +// CHECK-NEXT: return %[[ES]] : tensor<3x2x3xf32> +} + +// CHECK-LABEL: my_pack_multiple_inputs +func @my_pack_multiple_inputs(%arg0: tensor<2x3xf32>, %arg1: tensor<2x3xf32>, %arg2: tensor<2x3xf32>) -> tensor<3x2x3xf32> { + %0 = "tf.MyPack"(%arg0, %arg1, %arg2) {N=3:i32, axis=0:i32} : (tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<3x2x3xf32> + return %0 : tensor<3x2x3xf32> + +// CHECK-NEXT: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor +// CHECK-NEXT: %[[ED0:.*]] = "tf.ExpandDims"(%arg0, %[[AXIS]]) : (tensor<2x3xf32>, tensor) -> tensor<*xf32> +// CHECK-NEXT: %[[ED1:.*]] = "tf.ExpandDims"(%arg1, %[[AXIS]]) : (tensor<2x3xf32>, tensor) -> tensor<*xf32> +// CHECK-NEXT: %[[CC0:.*]] = "tf.RiscConcat"(%[[ED0]], %[[ED1]]) {axis = 0 : i32} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> +// CHECK-NEXT: %[[ED2:.*]] = "tf.ExpandDims"(%arg2, %[[AXIS]]) : (tensor<2x3xf32>, tensor) -> tensor<*xf32> +// CHECK-NEXT: %[[CC1:.*]] = "tf.RiscConcat"(%[[CC0]], %[[ED2]]) {axis = 0 : i32} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> +// CHECK-NEXT: %[[ES:.*]] = "tf.EnsureShape"(%[[CC1]]) {shape = #tf.shape<3x2x3>} : (tensor<*xf32>) -> tensor<3x2x3xf32> +// CHECK-NEXT: return %[[ES]] : tensor<3x2x3xf32> +} + +// CHECK-LABEL: my_add_n_single_input +func @my_add_n_single_input(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { + %0 = "tf.MyAddN"(%arg0) {N=1:i32} : (tensor<2x3xf32>) -> tensor<2x3xf32> + return %0 : tensor<2x3xf32> + +// CHECK-NEXT: return %arg0 : tensor<2x3xf32> +} + +// CHECK-LABEL: my_add_n_multiple_inputs +func @my_add_n_multiple_inputs(%arg0: tensor<2x3xf32>, %arg1: tensor<2x3xf32>, %arg2: tensor<2x3xf32>) -> tensor<2x3xf32> { + %0 = "tf.MyAddN"(%arg0, %arg1, %arg2) {N=3:i32} : (tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32> + return %0 : tensor<2x3xf32> + +// CHECK-NEXT: %[[ADD0:.*]] = "tf.RiscAdd"(%arg0, %arg1) : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<*xf32> +// CHECK-NEXT: %[[ADD1:.*]] = "tf.RiscAdd"(%[[ADD0]], %arg2) : (tensor<*xf32>, tensor<2x3xf32>) -> tensor<*xf32> +// CHECK-NEXT: %[[ES:.*]] = "tf.EnsureShape"(%[[ADD1]]) {shape = #tf.shape<2x3>} : (tensor<*xf32>) -> tensor<2x3xf32> +// CHECK-NEXT: return %[[ES]] : tensor<2x3xf32> +} + +// CHECK-LABEL: my_map_and_batch_dataset +func @my_map_and_batch_dataset(%input: tensor<*x!tf.variant>, + %other1: tensor<*xf32>, + %other2: tensor<*xi32>) -> tensor<*x!tf.variant> { + %0 = "tf.MyMapAndBatchDataset"(%input, %other1, %other2) + {batch_size=1000 : i64, num_parallel_calls = 8 : i64, drop_remainder = 0 : i1, + func = @"__some_func", output_types = [f32], output_shapes = [#tf.shape<>], preserve_cardinality = true} + : (tensor<*x!tf.variant>, tensor<*xf32>, tensor<*xi32>) -> tensor<*x!tf.variant> + return %0 : tensor<*x!tf.variant> + +// CHECK-NEXT: %[[BATCH:.*]] = "tf.Const"() {value = dense<1000> : tensor} : () -> tensor +// CHECK-NEXT: %[[PARAL:.*]] = "tf.Const"() {value = dense<8> : tensor} : () -> tensor +// CHECK-NEXT: %[[KEEP:.*]] = "tf.Const"() {value = dense : tensor} : () -> tensor +// CHECK-NEXT: %[[CAST:.*]] = "tf.Cast"(%arg2) {Truncate = false} : (tensor<*xi32>) -> tensor<*xf32> +// CHECK-NEXT: %[[RET:.*]] = "tf.MapAndBatchDatasetV0"(%arg0, %[[BATCH]], %[[PARAL]], %[[KEEP]], %arg1, %[[CAST]]) +// CHECK-SAME: {f = @__some_func, output_shapes = [#tf.shape<>], output_types = [f32], preserve_cardinality = true} : (tensor<*x!tf.variant>, tensor, tensor, tensor, tensor<*xf32>, tensor<*xf32>) -> tensor<*x!tf.variant> +// CHECK-NEXT: return %[[RET]] : tensor<*x!tf.variant> +} + +//=================> decomposition functions, translated from tf.compose api <==================== +tfr.func @tf__my_identity(%value: !tfr.tensor) -> !tfr.tensor { + tfr.return %value : !tfr.tensor +} + +tfr.func @tf__my_cast(%value: !tfr.tensor, %tout: !tfr.attr{tfr.name="Tout"}) -> !tfr.tensor { + %0 = tfr.call @tf__risc_cast(%value, %tout) : (!tfr.tensor, !tfr.attr) -> !tfr.tensor + tfr.return %0 : !tfr.tensor +} + +tfr.func @tf__my_rsqrt(%value: !tfr.tensor) -> !tfr.tensor { + %1 = tfr.call @tf__risc_reciprocal(%value) : (!tfr.tensor) -> !tfr.tensor + %2 = tfr.call @tf__risc_sqrt(%1) : (!tfr.tensor) -> !tfr.tensor + tfr.return %2 : !tfr.tensor +} + +tfr.func @tf__my_leaky_relu(%value: !tfr.tensor, %alpha: f32 {tfr.name="alpha", tfr.default=0.2:f32}) -> !tfr.tensor { + %1 = tfr.call @tf__risc_shape(%value) : (!tfr.tensor) -> !tfr.tensor + %2 = "tfr.constant_tensor"(%alpha) : (f32) -> tensor + %t = "tfr.cast"(%2) : (tensor) -> !tfr.tensor + %3 = tfr.call @tf__risc_broadcast(%t, %1) : (!tfr.tensor, !tfr.tensor) -> !tfr.tensor + %4 = tfr.call @tf__risc_maximum(%value, %3) : (!tfr.tensor, !tfr.tensor) -> !tfr.tensor + tfr.return %4 : !tfr.tensor +} + +// TODO(fengliuai): use shape dialect to manipulate the shape then this can be decomposed further. +tfr.func @tf__my_expand_dims(%value: !tfr.tensor, %axis: i32 {tfr.name="axis"}) -> !tfr.tensor { + %axis_cst = "tfr.constant_tensor"(%axis) : (i32) -> tensor + %dim = "tfr.cast"(%axis_cst) : (tensor) -> !tfr.tensor + %0 = tfr.call @tf__expand_dims(%value, %dim) : (!tfr.tensor, !tfr.tensor) -> !tfr.tensor + tfr.return %0 : !tfr.tensor +} + +tfr.func @tf__my_pack(%values: !tfr.tensor_list, + %n: i32 {tfr.name="N"}, + %axis: i32 {tfr.name="axis"}) -> !tfr.tensor { + %index = constant 0 : index + %cst = constant 1 : i32 + %eq = cmpi "eq", %n, %cst : i32 + %v1 = tfr.get_element %values[%index] : (!tfr.tensor_list, index) -> !tfr.tensor + %temp = tfr.call @tf__my_expand_dims(%v1, %axis) : (!tfr.tensor, i32) -> !tfr.tensor + %res = scf.if %eq -> !tfr.tensor { + scf.yield %temp : !tfr.tensor + } else { + %step = index_cast %cst : i32 to index + %end = index_cast %n : i32 to index + %reduce = scf.for %i = %step to %end step %step iter_args(%reduce_iter=%temp) -> !tfr.tensor { + %v = tfr.get_element %values[%i] : (!tfr.tensor_list, index) -> !tfr.tensor + %temp1 = tfr.call @tf__my_expand_dims(%v, %axis) : (!tfr.tensor, i32) -> !tfr.tensor + %reduce_next = tfr.call @tf__risc_concat(%reduce_iter, %temp1, %axis) : (!tfr.tensor, !tfr.tensor, i32) -> !tfr.tensor + scf.yield %reduce_next : !tfr.tensor + } + scf.yield %reduce : !tfr.tensor + } + tfr.return %res : !tfr.tensor +} + +tfr.func @tf__my_add_n(%values: !tfr.tensor_list, + %n: i32 {tfr.name="N"}) -> !tfr.tensor { + %index = constant 0 : index + %cst = constant 1 : i32 + %eq = cmpi "eq", %n, %cst : i32 + %v1 = tfr.get_element %values[%index] : (!tfr.tensor_list, index) -> !tfr.tensor + %res = scf.if %eq -> !tfr.tensor { + scf.yield %v1 : !tfr.tensor + } else { + %step = index_cast %cst : i32 to index + %end = index_cast %n : i32 to index + %reduce = scf.for %i = %step to %end step %step iter_args(%reduce_iter=%v1) -> !tfr.tensor { + %v = tfr.get_element %values[%i] : (!tfr.tensor_list, index) -> !tfr.tensor + %reduce_next = tfr.call @tf__risc_add(%reduce_iter, %v) : (!tfr.tensor, !tfr.tensor) -> !tfr.tensor + scf.yield %reduce_next : !tfr.tensor + } + scf.yield %reduce : !tfr.tensor + } + tfr.return %res : !tfr.tensor +} + +tfr.func @tf__my_map_and_batch_dataset( + %input_dataset: !tfr.tensor, + %other_arguments: !tfr.tensor_list, + %batch_size: i64 {tfr.name="batch_size"}, + %num_parallel_calls: i64 {tfr.name="num_parallel_calls"}, + %drop_remainder: i1 {tfr.name="drop_remainder"}, + %f: !tfr.attr {tfr.name="func"}, + %output_types: !tfr.attr {tfr.name="output_types"}, + %output_shapes: !tfr.attr {tfr.name="output_shapes"}, + %preserve_cardinality: i1 {tfr.name="preserve_cardinality", tfr.default=false}) -> !tfr.tensor { + %batch = "tfr.constant_tensor"(%batch_size) : (i64) -> tensor + %batch1 = "tfr.cast"(%batch) : (tensor) -> !tfr.tensor + %calls = "tfr.constant_tensor"(%num_parallel_calls) : (i64) -> tensor + %calls1 = "tfr.cast"(%calls) : (tensor) -> !tfr.tensor + %drop = "tfr.constant_tensor"(%drop_remainder) : (i1) -> tensor + %drop1 = "tfr.cast"(%drop) : (tensor) -> !tfr.tensor + %ret = tfr.call @tf__map_and_batch_dataset_v0(%input_dataset, %batch1, %calls1, %drop1, %other_arguments, %f, %output_types, %output_shapes, %preserve_cardinality) + : (!tfr.tensor, !tfr.tensor, !tfr.tensor, !tfr.tensor, !tfr.tensor_list, !tfr.attr, !tfr.attr, !tfr.attr, i1) -> !tfr.tensor + tfr.return %ret : !tfr.tensor +} + +//=================> signatures of the primitive ops with kernels, modeled as external TFR function <== +tfr.func @tf__risc_cast_(!tfr.tensor, !tfr.attr{tfr.name="Tout"}) -> !tfr.tensor attributes{Tout} +tfr.func @tf__risc_add_(!tfr.tensor, !tfr.tensor) -> !tfr.tensor attributes{T} +tfr.func @tf__risc_concat_(!tfr.tensor, !tfr.tensor, i32{tfr.name="axis"}) -> !tfr.tensor attributes{T} +tfr.func @tf__risc_broadcast_(!tfr.tensor, !tfr.tensor) -> !tfr.tensor attributes{T, Tidx} +tfr.func @tf__risc_reciprocal_(!tfr.tensor) -> !tfr.tensor attributes{T} +tfr.func @tf__risc_sqrt_(!tfr.tensor) -> !tfr.tensor attributes{T} +tfr.func @tf__risc_shape_(!tfr.tensor, !tfr.attr{tfr.name="T", tfr.default=i32}) -> !tfr.tensor attributes{T} +tfr.func @tf__risc_maximum_(!tfr.tensor, !tfr.tensor) -> !tfr.tensor attributes{T} +tfr.func @tf__expand_dims_(!tfr.tensor, !tfr.tensor) -> !tfr.tensor attributes{T, Tdim} +tfr.func @tf__map_and_batch_dataset_v0_(!tfr.tensor, !tfr.tensor, !tfr.tensor, !tfr.tensor, !tfr.tensor_list, + !tfr.attr{tfr.name="f"}, !tfr.attr{tfr.name="output_types"}, !tfr.attr{tfr.name="output_shapes"}, i1{tfr.name="preserve_cardinality"}) + -> !tfr.tensor attributes{T, Targuments} diff --git a/tensorflow/compiler/mlir/tfr/tests/ops.mlir b/tensorflow/compiler/mlir/tfr/tests/ops.mlir new file mode 100644 index 00000000000..b074985c591 --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/tests/ops.mlir @@ -0,0 +1,381 @@ +// RUN: tfr-opt %s -verify-diagnostics -split-input-file | tfr-opt | FileCheck %s +// RUN: tfr-opt %s -canonicalize -verify-diagnostics -split-input-file | FileCheck %s -check-prefix=CANON + +// Tests for types, ops with custom constraints, verifiers, printer or parser +// methods. + +// CHECK-LABEL: tensor_type_noconstraint +func @tensor_type_noconstraint() -> !tfr.tensor + +// ----- + +// CHECK-LABEL: tensor_type +func @tensor_type() -> !tfr.tensor + +// ----- + +// CHECK-LABEL: tensor_list_type_noconstraint +func @tensor_list_type_noconstraint() -> !tfr.tensor_list + +// ----- + +// CHECK-LABEL: tensor_list_type_array_like +func @tensor_list_type_array_like() -> !tfr.tensor_list<[N, T]> + +// ----- + +// CHECK-LABEL: tensor_list_type_tuple_like +func @tensor_list_type_tuple_like() -> !tfr.tensor_list + +// ----- + +// expected-error@+1 {{unbalanced '>' character in pretty dialect name}} +func @tensor_invalid_1() -> !tfr.tensor<[N, T> + +// ----- + +// expected-error@+1 {{unexpected nul or EOF in pretty dialect name}} +func @tensor_invalid_2() -> !tfr.tensor<[N, T] + +// ----- + +// CHECK-LABEL: call_op +func @call_op(%arg0: !tfr.tensor, %arg1: !tfr.tensor_list, %arg2: i32) -> !tfr.tensor { + %0 = tfr.call @Foo(%arg0, %arg1, %arg2) : (!tfr.tensor, !tfr.tensor_list, i32) -> !tfr.tensor + return %0 : !tfr.tensor +} + +// ----- + +// CHECK-LABEL: call_op_arg_attr(%arg0: i32) -> !tfr.tensor +func @call_op_arg_attr(%arg0: i32) -> !tfr.tensor { + %0 = tfr.call @Bar(%arg0) : (i32) -> !tfr.tensor + return %0 : !tfr.tensor +} + +// ----- + +func @call_op_invalid_1(%arg0: tensor) -> !tfr.tensor { + // expected-error@+1 {{got 'tensor'}} + %0 = tfr.call @Huu(%arg0) : (tensor) -> !tfr.tensor + return %0 : !tfr.tensor +} + +// ----- + +// CHECK-LABEL: get_shape +func @get_shape(%arg0: !tfr.tensor) -> (!shape.shape, !shape.shape) { + %0 = tfr.get_shape %arg0 -> !shape.shape + %1 = "tfr.get_shape"(%arg0) : (!tfr.tensor) -> !shape.shape + return %0, %1 : !shape.shape, !shape.shape +} + +// ----- + +// CHECK-LABEL: get_real_shape +// CANON-LABEL: get_real_shape +func @get_real_shape(%arg0: tensor<1x2xf32>) -> tensor<1xindex> { + %0 = "tfr.cast"(%arg0) : (tensor<1x2xf32>) -> !tfr.tensor + %1 = tfr.get_shape %0 -> !shape.shape + %2 = shape.to_extent_tensor %1 : !shape.shape -> tensor<1xindex> + return %2 : tensor<1xindex> + +// CANON-NEXT: %[[s:.*]] = shape.const_shape [1, 2] : tensor +// CANON-NEXT: %[[e:.*]] = shape.to_extent_tensor %[[s]] : tensor -> tensor<1xindex> +// CANON-NEXT: return %[[e]] : tensor<1xindex> +} + +// ----- + +func @get_element_type(%arg0: !tfr.tensor) -> (!tfr.attr, !tfr.attr) { + %0 = tfr.get_element_type %arg0 -> !tfr.attr + %1 = "tfr.get_element_type"(%arg0) : (!tfr.tensor) -> !tfr.attr + return %0, %1 : !tfr.attr, !tfr.attr +} + +// ----- + +// CHECK-LABEL: from_tf_tensor +func @from_tf_tensor(%arg0: tensor) -> !tfr.tensor { + %0 = "tfr.cast"(%arg0) : (tensor) -> !tfr.tensor + return %0 : !tfr.tensor +} + +// ----- + +// CHECK-LABEL: to_tf_tensor +func @to_tf_tensor(%arg0: !tfr.tensor) -> tensor { + %0 = "tfr.cast"(%arg0) : (!tfr.tensor) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: constant +func @constant() -> (!tfr.attr, !tfr.attr, !tfr.attr, !tfr.attr) { + %0 = tfr.constant f32 -> !tfr.attr + %1 = tfr.constant [f32, i32] -> !tfr.attr + %2 = "tfr.constant"() {value = f32} : () -> !tfr.attr + %3 = "tfr.constant"() {value = [f32, i32]} : () -> !tfr.attr + return %0, %1, %2, %3 : !tfr.attr, !tfr.attr, !tfr.attr, !tfr.attr +} + +// ----- + +// CHECK-LABEL: equal +// CANON-LABEL: equal +func @equal() -> (i1, i1, i1, i1) { + %0 = tfr.constant f32 -> !tfr.attr + %1 = tfr.constant f32 -> !tfr.attr + %2 = tfr.constant i32 -> !tfr.attr + %same_type = tfr.equal %0,%1 -> i1 + %diff_type = tfr.equal %0,%2 -> i1 + + %3 = tfr.constant "hello" -> !tfr.attr + %4 = tfr.constant "hello" -> !tfr.attr + %5 = tfr.constant "how are you" -> !tfr.attr + %same_str = tfr.equal %3,%4 -> i1 + %diff_str = tfr.equal %3,%5 -> i1 + return %same_type, %diff_type, %same_str, %diff_str : i1, i1, i1, i1 + +// CANON-NEXT: %true = constant true +// CANON-NEXT: %false = constant false +// CANON-NEXT: return %true, %false, %true, %false : i1, i1, i1, i1 +} + +// ----- + +// CHECK-LABEL: constant_tensor_scalar +func @constant_tensor_scalar(%arg0: i32) -> tensor { + %0 = "tfr.constant_tensor"(%arg0) : (i32) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: constant_tensor_vector +func @constant_tensor_vector(%arg0: vector<1x2xi32>) -> tensor<1x2xi32> { + %0 = "tfr.constant_tensor"(%arg0) : (vector<1x2xi32>) -> tensor<1x2xi32> + return %0 : tensor<1x2xi32> +} + +// ----- + +// CHECK-LABEL: constant_tensor_array +// CANON-LABEL: constant_tensor_array +func @constant_tensor_array() -> !tfr.tensor { + %0 = tfr.constant [1, -1, 3] -> !tfr.attr + %1 = "tfr.constant_tensor"(%0) : (!tfr.attr) -> !tfr.tensor + return %1 : !tfr.tensor + +// CANON-NEXT: "tf.Const"() {value = dense<[1, -1, 3]> : tensor<3xi64>} : () -> tensor<3xi64> +// CANON-NEXT: "tfr.cast"(%0) : (tensor<3xi64>) -> !tfr.tensor +// CANON-NEXT: return +} + +// ----- + +// CHECK-LABEL: constant_tensor_scalar +// CANON-LABEL: constant_tensor_scalar +func @constant_tensor_scalar() -> !tfr.tensor { + %0 = "std.constant"() {value = 42 : i32} : () -> i32 + %1 = "tfr.constant_tensor"(%0) : (i32) -> !tfr.tensor + return %1 : !tfr.tensor + +// CANON-NEXT: "tf.Const"() {value = dense<42> : tensor} : () -> tensor +// CANON-NEXT: "tfr.cast"(%0) : (tensor) -> !tfr.tensor +// CANON-NEXT: return +} + +// ----- + +func @constant_tensor_invalid_0(%arg0: i32) -> tensor { + // expected-error@+1 {{input and output should have the same scalar types.}} + %0 = "tfr.constant_tensor"(%arg0) : (i32) -> tensor + return %0 : tensor +} + +// ----- + +func @constant_tensor_invalid_1(%arg0: vector<1xi32>) -> tensor { + // expected-error@+1 {{output type should be static and ranked}} + %0 = "tfr.constant_tensor"(%arg0) : (vector<1xi32>) -> tensor + return %0 : tensor +} + +// ----- + +func @constant_tensor_invalid_2(%arg0: vector<1xi32>) -> tensor<1xf32> { + // expected-error@+1 {{input and output should have same shape and element type}} + %0 = "tfr.constant_tensor"(%arg0) : (vector<1xi32>) -> tensor<1xf32> + return %0 : tensor<1xf32> +} + +// ----- + +func @constant_tensor_invalid_3(%arg0: vector<1xi32>) -> tensor<1x1xi32> { + // expected-error@+1 {{input and output should have same shape and element type}} + %0 = "tfr.constant_tensor"(%arg0) : (vector<1xi32>) -> tensor<1x1xi32> + return %0 : tensor<1x1xi32> +} + +// ----- + +func @constant_tensor_invalid_4(%arg0: i32) -> tensor<1x1xi32> { + // expected-error@+1 {{input can not be converted to an output tensor}} + %0 = "tfr.constant_tensor"(%arg0) : (i32) -> tensor<1x1xi32> + return %0 : tensor<1x1xi32> +} + +// ----- + +// CHECK-LABEL: get_element +func @get_element(%arg0: !tfr.tensor_list) -> !tfr.tensor { + %cst = "std.constant"() {value = 1 : index} : () -> index + %0 = tfr.get_element %arg0[%cst] : (!tfr.tensor_list, index) -> !tfr.tensor + return %0 : !tfr.tensor +} + +// ----- + +// CHECK-LABEL: build_list +func @build_list(%arg0: !tfr.tensor, %arg1: !tfr.tensor) -> !tfr.tensor_list { + %0 = "tfr.build_list"(%arg0, %arg1) : (!tfr.tensor, !tfr.tensor) -> !tfr.tensor_list + return %0 : !tfr.tensor_list +} + +// ----- + +// CHECK-LABEL: build_const_list +// CANON-LABEL: build_const_list +func @build_const_list() -> !tfr.attr { + %0 = "std.constant"() {value = 42 : i32} : () -> i32 + %1 = "std.constant"() {value = 41 : i32} : () -> i32 + %2 = "tfr.build_list"(%0, %1) : (i32, i32) -> !tfr.attr + return %2 : !tfr.attr + +// CANON-NEXT: %[[c:.*]] = tfr.constant [42 : i32, 41 : i32] -> !tfr.attr +// CANON-NEXT: return %[[c]] : !tfr.attr +} + +// ----- + +// CHECK-LABEL: tfr.func +tfr.func @External(%arg0: !tfr.tensor, + %arg1: !tfr.tensor_list, + %arg2: i32 {tfr.name = "A"}, + %arg3: !tfr.attr {tfr.name = "T"}) + -> (!tfr.tensor, !tfr.tensor_list) + attributes {A, C} + +// ----- + +// CHECK-LABEL: tfr.func +tfr.func @Foo(%arg0: !tfr.tensor, + %arg1: !tfr.tensor_list, + %arg2: i32 {tfr.name = "A"}, + %arg3: vector<1xi32> {tfr.name = "C"}) + -> (!tfr.tensor, !tfr.tensor_list) + attributes {A, C} { + tfr.return %arg0, %arg1 : !tfr.tensor, !tfr.tensor_list +} + +// ----- + +// CHECK-LABEL: tfr.func +tfr.func @Bar(%arg0: !tfr.tensor, + %arg2: i32 {tfr.name = "B"}, + %arg3: vector<1xi32> {tfr.name = "C"}) + -> (!tfr.tensor, !tfr.tensor) + attributes {A} { + tfr.return %arg0, %arg0 : !tfr.tensor, !tfr.tensor +} + +// ----- + +// expected-error@+1 {{Undefined attributes are used: A}} +tfr.func @Foo_undefined_attr(%arg0: !tfr.tensor, + %arg1: !tfr.tensor_list, + %arg2: i32 {tfr.name = "A"}, + %arg3: vector<1xi32> {tfr.name = "C"}) -> + (!tfr.tensor, !tfr.tensor_list) { + tfr.return %arg0, %arg1 : !tfr.tensor, !tfr.tensor_list +} + +// ----- + +// expected-error@+1 {{3 attribute argument doesn't have a tfr.name attribute}} +tfr.func @Foo_unnamed_attr(%arg0: !tfr.tensor, + %arg1: !tfr.tensor_list, + %arg2: i32 {tfr.name = "A"}, + %arg3: vector<1xi32>) -> + (!tfr.tensor, !tfr.tensor_list) { + tfr.return %arg0, %arg1 : !tfr.tensor, !tfr.tensor_list +} + +// ----- + +// expected-error@+1 {{tfr.tensor_list argument should be before non tensor arguments}} +tfr.func @Foo_invalid_arg_order(%arg0: !tfr.tensor, + %arg2: i32 {tfr.name = "A"}, + %arg1: !tfr.tensor_list, + %arg3: vector<1xi32> {tfr.name = "C"}) -> + (!tfr.tensor, !tfr.tensor_list) { + tfr.return %arg0, %arg1 : !tfr.tensor, !tfr.tensor_list +} + +// ----- + +// expected-error@+1 {{tfr.tensor argument should be before tfr.tensor_list argument.}} +tfr.func @Foo_invalid_arg_order0( + %arg1: !tfr.tensor_list, + %arg0: !tfr.tensor, + %arg2: i32 {tfr.name = "A"}, + %arg3: vector<1xi32> {tfr.name = "C"}) -> + (!tfr.tensor, !tfr.tensor_list) { + tfr.return %arg0, %arg1 : !tfr.tensor, !tfr.tensor_list +} + +// ----- + +// expected-error@+1 {{tfr.tensor result should be before tfr.tensor_list result}} +tfr.func @Foo_invalid_result_order(%arg0: !tfr.tensor, + %arg1: !tfr.tensor_list, + %arg2: i32 {tfr.name = "A"}, + %arg3: vector<1xi32> {tfr.name = "C"}) -> + (!tfr.tensor_list, !tfr.tensor) { + tfr.return %arg1, %arg0 : !tfr.tensor_list, !tfr.tensor +} + +// ----- + +// expected-error@+1 {{More than one tfr.tensor_list argument isn't allowed}} +tfr.func @Foo_multiple_tensor_list_args(%arg0: !tfr.tensor, + %arg1: !tfr.tensor_list, + %arg2: !tfr.tensor_list, + %arg3: i32 {tfr.name = "A"}, + %arg4: vector<1xi32> {tfr.name = "C"}) -> + (!tfr.tensor, !tfr.tensor_list) { + tfr.return %arg0, %arg1 : !tfr.tensor, !tfr.tensor_list +} + +// ----- + +// expected-error@+1 {{More than one tfr.tensor_list result isn't allowed}} +tfr.func @Foo_multiple_tensor_list_results(%arg0: !tfr.tensor, + %arg1: !tfr.tensor_list, + %arg2: i32 {tfr.name = "A"}, + %arg3: vector<1xi32> {tfr.name = "C"}) -> + (!tfr.tensor_list, !tfr.tensor_list) { + tfr.return %arg1, %arg1 : !tfr.tensor_list, !tfr.tensor_list +} + +// ----- + +// expected-error@+1 {{None tfr.tensor/tfr.tensor_list results aren't allowed as a result}} +tfr.func @Foo_return_attr(%arg0: !tfr.tensor, + %arg1: !tfr.tensor_list, + %arg2: i32 {tfr.name = "A"}, + %arg3: vector<1xi32> {tfr.name = "C"}) -> i32 { + tfr.return %arg2 : i32 +} diff --git a/tensorflow/compiler/mlir/tfr/tests/raise_to_tf.mlir b/tensorflow/compiler/mlir/tfr/tests/raise_to_tf.mlir new file mode 100644 index 00000000000..41d0ee6271d --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/tests/raise_to_tf.mlir @@ -0,0 +1,76 @@ +// RUN: tfr-opt %s -tfr-raise-to-tf -verify-diagnostics -split-input-file | FileCheck %s + +tfr.func @tf__risc_same_(!tfr.tensor) -> !tfr.tensor attributes {T} +tfr.func @tf__risc_concat_(!tfr.tensor_list) -> !tfr.tensor attributes {T, N} +tfr.func @tf__risc_split_(!tfr.tensor, i32 {tfr.name="N"}) -> !tfr.tensor_list attributes {T, N} +tfr.func @tf__risc_cast_(!tfr.tensor, !tfr.attr {tfr.name="K"}) -> !tfr.tensor attributes {T, K} + +// CHECK-LABEL: decompose_tf_same +func @decompose_tf_same(%arg0: tensor<1x2x3x4x!tf.string>) -> tensor<1x2x3x4x!tf.string> { + %0 = "tfr.cast"(%arg0) : (tensor<1x2x3x4x!tf.string>) -> !tfr.tensor + %1 = tfr.call @tf__risc_same(%0) : (!tfr.tensor) -> !tfr.tensor + %2 = "tfr.cast"(%1) : (!tfr.tensor) -> tensor<1x2x3x4x!tf.string> + return %2 : tensor<1x2x3x4x!tf.string> + +// CHECK: %[[id:.*]] = "tf.RiscSame"(%arg0) : (tensor<1x2x3x4x!tf.string>) -> tensor<*x!tf.string> +// CHECK: %[[es:.*]] = "tf.EnsureShape"(%[[id]]) {shape = #tf.shape<1x2x3x4>} : (tensor<*x!tf.string>) -> tensor<1x2x3x4x!tf.string> +// CHECK: return %[[es]] : tensor<1x2x3x4x!tf.string> +} + +// CHECK-LABEL: decompose_tf_consecutive +func @decompose_tf_consecutive(%arg0: tensor<1x2x3x4x!tf.string>, %arg1: tensor, %arg2: tensor) -> tensor { + %0 = "tfr.cast"(%arg0) : (tensor<1x2x3x4x!tf.string>) -> !tfr.tensor + %1 = "tfr.cast"(%arg2) : (tensor) -> !tfr.tensor + %2 = tfr.call @tf__risc_same(%0) : (!tfr.tensor) -> !tfr.tensor + %3 = tfr.call @tf__risc_same(%1) : (!tfr.tensor) -> !tfr.tensor + %4 = "tfr.cast"(%3) : (!tfr.tensor) -> tensor + return %4 : tensor + +// CHECK: %[[id0:.*]] = "tf.RiscSame"(%arg0) : (tensor<1x2x3x4x!tf.string>) -> tensor<*x!tf.string> +// CHECK: %[[id2:.*]] = "tf.RiscSame"(%arg2) : (tensor) -> tensor<*xf32> +// CHECK: %[[es:.*]] = "tf.EnsureShape"(%[[id2]]) {shape = #tf.shape<>} : (tensor<*xf32>) -> tensor +// CHECK: return %[[es]] : tensor +} + +// CHECK-LABEL: decompose_tf_concat_n +func @decompose_tf_concat_n(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor<3xf32> { + %0 = "tfr.cast"(%arg0) : (tensor) -> !tfr.tensor + %1 = "tfr.cast"(%arg1) : (tensor) -> !tfr.tensor + %2 = "tfr.cast"(%arg2) : (tensor) -> !tfr.tensor + %3 = "tfr.build_list"(%0, %1, %2) : (!tfr.tensor, !tfr.tensor, !tfr.tensor) -> !tfr.tensor_list + %concat = tfr.call @tf__risc_concat(%3) : (!tfr.tensor_list) -> !tfr.tensor + %4 = "tfr.cast"(%concat) : (!tfr.tensor) -> tensor<3xf32> + return %4 : tensor<3xf32> + +// CHECK: %[[concat:.*]] = "tf.RiscConcat"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor<*xf32> +// CHECK: %[[es:.*]] = "tf.EnsureShape"(%[[concat]]) {shape = #tf.shape<3>} : (tensor<*xf32>) -> tensor<3xf32> +// CHECK: return %[[es]] : tensor<3xf32> +} + +// CHECK-LABEL: decompose_tf_split +func @decompose_tf_split(%arg0: tensor<3xf32>) -> (tensor) { + %0 = "tfr.cast"(%arg0) : (tensor<3xf32>) -> !tfr.tensor + %n = std.constant 3: i32 + %split = tfr.call @tf__risc_split(%0, %n) : (!tfr.tensor, i32) -> !tfr.tensor_list + %i0 = std.constant 0: index + %s0 = tfr.get_element %split[%i0] : (!tfr.tensor_list, index) -> !tfr.tensor + %4 = "tfr.cast"(%s0) : (!tfr.tensor) -> tensor + return %4 : tensor + +// CHECK: %[[split:.*]]:3 = "tf.RiscSplit"(%arg0) {N = 3 : i32} : (tensor<3xf32>) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) +// CHECK: %[[es:.*]] = "tf.EnsureShape"(%[[split]]#0) {shape = #tf.shape<>} : (tensor<*xf32>) -> tensor +// CHECK: return %[[es]] : tensor +} + +// CHECK-LABEL: decompose_tf_cast +func @decompose_tf_cast(%arg0: tensor) -> tensor { + %0 = "tfr.cast"(%arg0) : (tensor) -> !tfr.tensor + %t = tfr.constant i32 -> !tfr.attr + %concat = tfr.call @tf__risc_cast(%0, %t) : (!tfr.tensor, !tfr.attr) -> !tfr.tensor + %4 = "tfr.cast"(%concat) : (!tfr.tensor) -> tensor + return %4 : tensor + +// CHECK: %[[tfcast:.*]] = "tf.RiscCast"(%arg0) {K = i32} : (tensor) -> tensor<*xi32> +// CHECK: %[[es:.*]] = "tf.EnsureShape"(%[[tfcast]]) {shape = #tf.shape<>} : (tensor<*xi32>) -> tensor +// CHECK: return %[[es]] : tensor +} diff --git a/tensorflow/compiler/mlir/tfr/utils/utils.cc b/tensorflow/compiler/mlir/tfr/utils/utils.cc new file mode 100644 index 00000000000..6c08b682cb0 --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/utils/utils.cc @@ -0,0 +1,78 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tfr/utils/utils.h" + +#include "llvm/ADT/StringRef.h" +#include "mlir/Support/LLVM.h" // from @llvm-project + +namespace mlir { +namespace TFR { + +std::string GetComposeFuncName(StringRef tf_op_name) { + std::string compose_func_name; + for (int i = 0; i < tf_op_name.size(); ++i) { + if (tf_op_name[i] == '_') { + // The field name must not contain "_"s. "_Arg" and "_RetVal" are special + // op names and we can return empty string to skip the decomposition. + return {}; + } + if (tf_op_name[i] == '.') { + compose_func_name.push_back('_'); + } else if (tf_op_name[i] >= 'A' && tf_op_name[i] <= 'Z') { + compose_func_name.push_back('_'); + compose_func_name.push_back(tf_op_name[i] + 'a' - 'A'); + } else { + compose_func_name.push_back(tf_op_name[i]); + } + } + return compose_func_name; +} + +std::string GetTFOpName(StringRef compose_func_name) { + std::string tf_op_name; + bool after_underscore = false; + for (int i = 0; i < compose_func_name.size(); ++i) { + if (compose_func_name[i] >= 'A' && compose_func_name[i] <= 'Z') { + // The field name must not contain uppercase letters. + return {}; + } + if (after_underscore) { + if (compose_func_name[i] >= 'a' && compose_func_name[i] <= 'z') { + tf_op_name.push_back(compose_func_name[i] + 'A' - 'a'); + after_underscore = false; + } else { + // The character after a "_" must be a lowercase letter. + return {}; + } + } else if (compose_func_name[i] == '_') { // first time visit '_' + if (i + 1 < compose_func_name.size() && compose_func_name[i + 1] == '_') { + tf_op_name.push_back('.'); + i++; + } + after_underscore = true; + } else { + tf_op_name.push_back(compose_func_name[i]); + } + } + if (after_underscore) { + // Trailing "_". + return {}; + } + return tf_op_name; +} + +} // namespace TFR +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tfr/utils/utils.h b/tensorflow/compiler/mlir/tfr/utils/utils.h new file mode 100644 index 00000000000..26c7250d95a --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/utils/utils.h @@ -0,0 +1,42 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_UTILS_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_UTILS_UTILS_H_ + +#include + +#include "mlir/Support/LLVM.h" // from @llvm-project + +namespace mlir { +namespace TFR { + +// This is a hardcoded rule for mapping a TF op name to the corresponding +// TFR function name. Examples: +// tf.Pack => tf__pack +// tf.ConcatV2 => tf__concat_v2 +// TODO(fengliuai): move to an util file. +std::string GetComposeFuncName(StringRef tf_op_name); + +// This is a hardcoded rule for mapping a TFR function op name to the +// corresponding TF opname. Examples: +// tf__pack -> tf.Pack +// tf__concat_v2 => tf.ConcatV2 +std::string GetTFOpName(StringRef compose_func_name); + +} // namespace TFR +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_UTILS_UTILS_H_ diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD index 3c88318064b..b4cbf765c79 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD @@ -1,5 +1,17 @@ -load("//tensorflow:tensorflow.bzl", "tf_cc_binary") -load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") +load( + "//tensorflow:tensorflow.bzl", + "get_compatible_with_cloud", + "tf_cc_binary", +) +load( + "//tensorflow/core/platform/default:cuda_build_defs.bzl", + "if_cuda_is_configured", +) +load( + "@local_config_rocm//rocm:build_defs.bzl", + "if_rocm_is_configured", +) package( default_visibility = [":friends"], @@ -9,61 +21,118 @@ package( package_group( name = "friends", includes = ["//third_party/mlir:subpackages"], - packages = ["//tensorflow/compiler/mlir/..."], + packages = [ + "//tensorflow/compiler/mlir/...", + "//tensorflow/core/kernels/mlir_generated/...", + ], ) cc_library( - name = "cubin_creator", - srcs = ["cubin_creator.cc"], - hdrs = ["cubin_creator.h"], - copts = if_cuda(["-DGOOGLE_CUDA=1"]), + name = "kernel_creator", + srcs = ["kernel_creator.cc"], + hdrs = ["kernel_creator.h"], + copts = if_cuda_is_configured(["-DGOOGLE_CUDA=1"]) + if_rocm_is_configured(["-DTENSORFLOW_USE_ROCM=1"]), deps = [ - "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", - "@llvm-project//mlir:GPUDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:LLVMDialect", - "@llvm-project//mlir:Parser", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:StandardOps", - "@llvm-project//mlir:TargetNVVMIR", - "@llvm-project//mlir:Transforms", - "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/hlo", + "//tensorflow/compiler/mlir/hlo:all_passes", + "//tensorflow/compiler/mlir/hlo:hlo_legalize_to_lhlo", + "//tensorflow/compiler/mlir/hlo:legalize_to_linalg", + "//tensorflow/compiler/mlir/hlo:legalize_trigonometric_to_approximation", "//tensorflow/compiler/mlir/hlo:lhlo", + "//tensorflow/compiler/mlir/hlo:lhlo_fuse_linalg", + "//tensorflow/compiler/mlir/hlo:lhlo_legalize_to_affine", + "//tensorflow/compiler/mlir/hlo:lhlo_legalize_to_gpu", + "//tensorflow/compiler/mlir/hlo:transform_unranked_hlo", # buildcleaner: keep + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes", + "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", + "//tensorflow/compiler/mlir/tools/kernel_gen/transforms:passes", "//tensorflow/compiler/mlir/xla:xla_legalize_tf", - "//tensorflow/compiler/mlir/hlo:materialize_broadcasts", # buildcleaner: keep - "//tensorflow/compiler/mlir/hlo:unfuse_batch_norm", # buildcleaner: keep "//tensorflow/compiler/xla:debug_options_flags", + "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service/gpu:stream_executor_util", "//tensorflow/compiler/xla/service/gpu:target_constants", "//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend", "//tensorflow/compiler/xla/service/mlir_gpu:kernel_lowering", - "//tensorflow/core:cuda_libdevice_path", + "//tensorflow/compiler/xla/service/mlir_gpu:passes", "//tensorflow/core:lib", - ] + if_cuda(["//tensorflow/stream_executor/gpu:asm_compiler"]), + "//tensorflow/core/platform:cuda_libdevice_path", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AffineToStandard", + "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:GPUToGPURuntimeTransforms", + "@llvm-project//mlir:GPUToNVVMTransforms", + "@llvm-project//mlir:GPUTransforms", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:LinalgOps", + "@llvm-project//mlir:LinalgTransforms", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:SCFToGPUPass", + "@llvm-project//mlir:SCFToStandard", + "@llvm-project//mlir:SCFTransforms", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", + ], ) tf_cc_binary( - name = "tf_to_cubin", - srcs = ["tf_to_cubin.cc"], - visibility = ["//tensorflow/core/kernels/mlir_generated:__pkg__"], + name = "tf_to_gpu_binary", + srcs = ["tf_to_gpu_binary.cc"], + visibility = [ + "//tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_gpu_binary:__pkg__", + "//tensorflow/core/kernels/mlir_generated:__pkg__", + ], deps = [ - ":cubin_creator", + ":kernel_creator", "//tensorflow/compiler/mlir:init_mlir", + "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/core:lib", + "//tensorflow/stream_executor/lib", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", + "@llvm-project//mlir:Pass", + ], +) + +tf_cc_binary( + name = "tf_to_kernel", + srcs = ["tf_to_kernel.cc"], + visibility = [ + "//tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_kernel:__pkg__", + "//tensorflow/core/kernels/mlir_generated:__pkg__", + ], + deps = [ + ":kernel_creator", + "//tensorflow/compiler/mlir:init_mlir", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + "//tensorflow/stream_executor/lib", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Analysis", + "@llvm-project//llvm:CodeGen", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:Target", + "@llvm-project//llvm:X86CodeGen", # fixdeps: keep + "@llvm-project//llvm:X86Disassembler", # fixdeps: keep + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:TargetLLVMIR", ], ) tf_cc_binary( name = "kernel-gen-opt", srcs = ["tools/kernel-gen-opt/kernel-gen-opt.cc"], - visibility = ["//tensorflow/compiler/mlir/tools/kernel_gen/tests:__pkg__"], + visibility = ["//tensorflow/compiler/mlir/tools/kernel_gen/tests:__subpackages__"], deps = [ "//tensorflow/compiler/mlir/hlo:all_passes", "//tensorflow/compiler/mlir/hlo:hlo_dialect_registration", @@ -90,3 +159,16 @@ cc_library( "@llvm-project//mlir:mlir_runner_utils", ], ) + +cc_library( + name = "tf_cuda_runtime_wrappers", + srcs = ["tf_cuda_runtime_wrappers.cc"], + compatible_with = get_compatible_with_cloud(), + copts = if_cuda_is_configured(["-DGOOGLE_CUDA=1"]), + deps = [ + "//tensorflow/core/platform/default/build_config:stream_executor_cuda", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:mlir_c_runner_utils", + "@local_config_cuda//cuda:cuda_headers", + ], +) diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc b/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc deleted file mode 100644 index 3b6af7f699c..00000000000 --- a/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc +++ /dev/null @@ -1,306 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -//===- cubin_creator.cc -----------------------------------------*- C++ -*-===// -// -// This file implements the function to compile a TF kernel function to a cubin. -// -//===----------------------------------------------------------------------===// -#include "tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.h" - -#include -#include -#include - -#include "absl/memory/memory.h" -#include "absl/strings/escaping.h" -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/None.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/Support/Debug.h" -#include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project -#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project -#include "mlir/IR/Dialect.h" // from @llvm-project -#include "mlir/IR/Function.h" // from @llvm-project -#include "mlir/IR/Operation.h" // from @llvm-project -#include "mlir/IR/StandardTypes.h" // from @llvm-project -#include "mlir/IR/Value.h" // from @llvm-project -#include "mlir/Parser.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Pass/PassManager.h" // from @llvm-project -#include "mlir/Target/NVVMIR.h" // from @llvm-project -#include "mlir/Transforms/DialectConversion.h" // from @llvm-project -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" -#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" -#include "tensorflow/compiler/mlir/xla/transforms/passes.h" -#include "tensorflow/compiler/xla/debug_options_flags.h" -#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" -#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" -#include "tensorflow/compiler/xla/service/gpu/target_constants.h" -#include "tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.h" -#include "tensorflow/core/platform/cuda_libdevice_path.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/path.h" -#if GOOGLE_CUDA -#include "tensorflow/stream_executor/gpu/asm_compiler.h" -#endif - -namespace { -using tensorflow::Status; -using xla::InternalError; -using xla::StatusOr; - -StatusOr GetLibdeviceDir( - const xla::HloModuleConfig& hlo_module_config) { - for (const std::string& cuda_root : tensorflow::CandidateCudaRoots( - hlo_module_config.debug_options().xla_gpu_cuda_data_dir())) { - std::string libdevice_dir = - tensorflow::io::JoinPath(cuda_root, "nvvm", "libdevice"); - VLOG(2) << "Looking for libdevice at " << libdevice_dir; - if (tensorflow::Env::Default()->IsDirectory(libdevice_dir).ok()) { - VLOG(2) << "Found libdevice dir " << libdevice_dir; - return libdevice_dir; - } - } - return InternalError( - "Can't find libdevice directory ${CUDA_DIR}/nvvm/libdevice"); -} - -struct MaterializeBroadcastsPass - : public mlir::PassWrapper { - void runOnFunction() override { - mlir::ConversionTarget conversionTarget(getContext()); - mlir::OwningRewritePatternList conversionPatterns; - - // Consider the mhlo dialect legal for tests. - conversionTarget.addLegalDialect(); - // The conversion uses helpers from the Standard dialect. - conversionTarget.addLegalDialect(); - - mlir::mhlo::SetupMaterializeBroadcastsLegality(&getContext(), - &conversionTarget); - mlir::mhlo::PopulateMaterializeBroadcastsPatterns(&getContext(), - &conversionPatterns); - - if (failed(applyPartialConversion(getFunction(), conversionTarget, - conversionPatterns))) { - return signalPassFailure(); - } - } -}; - -struct UnfuseBatchNormPass - : public mlir::PassWrapper { - void runOnFunction() override { - mlir::OwningRewritePatternList patterns; - mlir::mhlo::PopulateUnfuseBatchNormPatterns(&getContext(), &patterns); - mlir::applyPatternsAndFoldGreedily(getOperation(), patterns); - } -}; - -Status LowerTfOpToLhloWithDynamicShapes(mlir::ModuleOp module) { - mlir::PassManager pm(module.getContext()); - auto enable_if_vlog_is_on = [](mlir::Pass* pass, mlir::Operation* op) { - return VLOG_IS_ON(1); - }; - pm.enableIRPrinting(/*shouldPrintBeforePass=*/{}, - /*shouldPrintAfterPass=*/enable_if_vlog_is_on, - /*printModuleScope=*/false, - /*printAfterOnlyOnChange=*/false, llvm::dbgs()); - pm.addNestedPass(mlir::mhlo::createLegalizeTFPass(false)); - pm.addNestedPass( - absl::make_unique()); - pm.addNestedPass(absl::make_unique()); - pm.addPass(mlir::mhlo::createLegalizeToLhloPass( - /*results_escape_functions=*/true)); - pm.addNestedPass(mlir::lmhlo::createLhloCopyRemovalPass()); - - if (failed(pm.run(module))) { - return InternalError("Lowering TF to LHLO failed."); - } - return Status::OK(); -} - -struct PropagateTensorFlowABIKnowledge - : public mlir::PassWrapper> { - explicit PropagateTensorFlowABIKnowledge(mlir::FunctionType type, - llvm::ArrayRef same_shape_) - : func_type(type), same_shape(same_shape_) {} - - void runOnOperation() override { - // We know due to tensorflow ABI that the offset is always 0 and that the - // innermost stride is always 1. To make this visible to the compiler, - // we insert constants into the code and replace usages accordingly. - // We do not change the signature so that we keep a somewhat stable ABI - // that is easy to undertand by tools. - // We also know that tensorflow aligns all allocated pointers by 16, so - // we pass this on. Furthermore, we know that arguments never alias. More - // precicely, they may only alias (due to reuse) if the kernel does not - // read from a position it previously has written to. We express this with - // the noalias attribute. - mlir::LLVM::LLVMFuncOp func = getOperation(); - - // This only works if the function is local and we can rewrite it. - if (func.isExternal()) return; - - mlir::OpBuilder b(func.getBody()); - // Steal the LLVM representation of the index type from the third argument. - auto index_type = func.getArgument(3).getType(); - mlir::Value one = b.create( - func.getLoc(), index_type, b.getIntegerAttr(b.getIndexType(), 1)); - mlir::Value zero = b.create( - func.getLoc(), index_type, b.getIntegerAttr(b.getIndexType(), 0)); - uint32_t arg_pos = 0; - std::vector positions; - // Collect the agument and return types of the surrounding function. - auto arg_types = llvm::to_vector<4>(llvm::concat( - func_type.getInputs(), func_type.getResults())); - for (mlir::Type arg_type : arg_types) { - if (!arg_type.isa()) { - func.emitError() << "argument of surrounding func is not ranked memref"; - signalPassFailure(); - return; - } - positions.push_back(arg_pos); - // Set alignment and aliasing on the pointers. - func.setArgAttr(arg_pos + 1, "llvm.noalias", b.getBoolAttr(true)); - func.setArgAttr(arg_pos + 1, "llvm.align", b.getIndexAttr(16)); - // Replace the offset with zero. Offset is argument number 3. - func.getArgument(arg_pos + 2).replaceAllUsesWith(zero); - // Forward over base_ptr, aligned_ptr, offset, size and stride arguments. - arg_pos += 3 + arg_type.cast().getRank() * 2; - // Replace the last stride with constant 1. - func.getArgument(arg_pos - 1).replaceAllUsesWith(one); - } - - // If we have knowledge that some arguments have the same shape, we - // can use that here. Simply replace usages of the shape parameters within - // the function body to a single shape parameter. - if (!same_shape.empty()) { - auto first = same_shape.front(); - auto first_offset = positions.at(first); - auto first_type = arg_types[first].cast(); - uint32_t rank = first_type.getRank(); - for (auto same : same_shape.drop_front(1)) { - uint32_t same_offset = positions.at(same); - auto same_type = arg_types[same].cast(); - if (same_type.getRank() != rank) { - func.emitOpError() << "same shape constraints on arguments with " - "non-matching shapes: #" - << first << " and #" << same; - signalPassFailure(); - continue; - } - - for (uint32_t i = 0; i < 2 * rank; ++i) { - // Replace uses for second arg data with first arg. - auto same_arg = func.getArgument(same_offset + 3 + i); - auto first_arg = func.getArgument(first_offset + 3 + i); - same_arg.replaceAllUsesWith(first_arg); - } - } - } - } - - mlir::FunctionType func_type; - llvm::ArrayRef same_shape; -}; - -Status PropagateTensorFlowABIKnowledgeToKernel( - mlir::ModuleOp module, llvm::ArrayRef same_shape) { - // Grab the original signature from the single function. - auto func = *module.getBody()->op_begin(); - - mlir::PassManager pm(module.getContext()); - auto enable_if_vlog_is_on = [](mlir::Pass*, mlir::Operation*) { - return VLOG_IS_ON(1); - }; - pm.enableIRPrinting(/*shouldPrintBeforePass=*/{}, - /*shouldPrintAfterPass=*/enable_if_vlog_is_on, - /*printModuleScope=*/false, - /*printAfterOnlyOnChange=*/false, llvm::dbgs()); - auto& kernel_pm = pm.nest<::mlir::gpu::GPUModuleOp>(); - kernel_pm.addNestedPass( - absl::make_unique(func.getType(), - same_shape)); - - if (failed(pm.run(module))) { - return InternalError("Static knowledge propagation failed."); - } - return Status::OK(); -} - -} // namespace - -StatusOr> tensorflow::kernel_gen::GenerateCubinForTfCode( - llvm::StringRef tf_code, std::pair compute_capability, - llvm::ArrayRef tile_sizes, llvm::ArrayRef same_shape, - llvm::ArrayRef unroll_factors) { - mlir::MLIRContext context; - mlir::RegisterAllTensorFlowDialects(context.getDialectRegistry()); - mlir::OwningModuleRef module = mlir::parseSourceString(tf_code, &context); - - TF_RETURN_IF_ERROR(LowerTfOpToLhloWithDynamicShapes(module.get())); - { - xla::mlir_gpu::LowerLHLOToGPUOptions options; - options.tile_sizes = tile_sizes; - options.unroll_factors = unroll_factors; - options.collapse_parallel_loops = false; - options.use_approximations = true; - TF_RETURN_IF_ERROR(xla::mlir_gpu::LowerLHLOToGPU(module.get(), options)); - } - TF_RETURN_IF_ERROR(xla::mlir_gpu::LowerKernelBodiesToNVVM(module.get())); - TF_RETURN_IF_ERROR( - PropagateTensorFlowABIKnowledgeToKernel(module.get(), same_shape)); - - mlir::OwningModuleRef kernel_module = - xla::mlir_gpu::ExtractKernelModule(*module).ValueOrDie(); - llvm::LLVMContext llvmContext; - auto llvmModule = mlir::translateModuleToNVVMIR(*kernel_module, llvmContext); - if (!llvmModule) { - return InternalError("Could not translate MLIR module to NVVM"); - } - - llvmModule->setModuleIdentifier("acme"); - llvmModule->setDataLayout(xla::gpu::nvptx::kDataLayout); - - xla::HloModuleConfig config; - config.set_debug_options(xla::GetDebugOptionsFromFlags()); - - auto enable_fusion = [](llvm::TargetMachine* target) { - target->Options.AllowFPOpFusion = llvm::FPOpFusion::FPOpFusionMode::Fast; - }; - - TF_ASSIGN_OR_RETURN(std::string libdevice_dir, GetLibdeviceDir(config)); - TF_ASSIGN_OR_RETURN( - std::string ptx, - xla::gpu::nvptx::CompileToPtx(llvmModule.get(), compute_capability, - config, libdevice_dir, enable_fusion)); - VLOG(1) << ptx; - -#if GOOGLE_CUDA - return tensorflow::se::CompileGpuAsm( - std::get<0>(compute_capability), std::get<1>(compute_capability), - ptx.c_str(), xla::gpu::PtxOptsFromConfig(config)); -#else - return InternalError( - "GOOGLE_CUDA not defined. Did you specify --config=cuda ?"); -#endif -} diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD index 29939f227db..2630f97f825 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD @@ -1,4 +1,6 @@ +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//third_party/mlir:tblgen.bzl", "gentbl") +load("//tensorflow:tensorflow.bzl", "get_compatible_with_cloud") package( default_visibility = ["//tensorflow/compiler/mlir/tools/kernel_gen:friends"], @@ -7,6 +9,7 @@ package( gentbl( name = "tf_framework_ops_inc_gen", + compatible_with = get_compatible_with_cloud(), tbl_outs = [ ("-gen-op-decls", "tf_framework_ops.h.inc"), ("-gen-op-defs", "tf_framework_ops.cc.inc"), diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc index 8c02a734f1d..b3d92773be4 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc @@ -77,9 +77,9 @@ LogicalResult Verify(AllocRawOp op) { return success(); } -#define GET_OP_CLASSES -#include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc.inc" - } // namespace tf_framework } // namespace kernel_gen } // namespace mlir + +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc.inc" diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h index d2612a38799..aab090cc5e0 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h +++ b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h @@ -38,12 +38,12 @@ class OpKernelContextType using Base::Base; }; -#define GET_OP_CLASSES -#include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_dialect.h.inc" -#include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h.inc" - } // namespace tf_framework } // namespace kernel_gen } // namespace mlir +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_dialect.h.inc" +#include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h.inc" + #endif // TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_IR_TF_FRAMEWORK_OPS_H_ diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.td b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.td index bc390a5aaa5..e6e29bcbdc2 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.td +++ b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.td @@ -29,7 +29,7 @@ def TFFramework_Dialect : Dialect { This dialect contains operations and types for that correspond to TensorFlow C++ Framework. }]; - let cppNamespace = "kernel_gen::tf_framework"; + let cppNamespace = "::mlir::kernel_gen::tf_framework"; } def TFFramework_OpKernelContextType : DialectType tile_sizes, + llvm::ArrayRef unroll_factors) { + mlir::PassManager pm(module.getContext()); + applyTensorflowAndCLOptions(pm); + + if (gpu_binary_only) { + pm.addPass(mlir::mhlo::createLegalizeTFPass( + /*allow_partial_conversion=*/false, /*legalize_chlo=*/true)); + pm.addNestedPass( + mlir::kernel_gen::transforms::CreateMaterializeBroadcastsPass()); + pm.addNestedPass( + mlir::kernel_gen::transforms::CreateUnfuseBatchNormPass()); + pm.addPass(mlir::mhlo::createLegalizeToLhloPass( + /*results_escape_functions=*/true)); + // Moving `AllocOp`s and inserting missing `DeallocOp`s + pm.addPass(::mlir::createBufferPlacementPass()); + pm.addNestedPass(mlir::createCopyRemovalPass()); + pm.addPass(mlir::kernel_gen::transforms::CreateShapeToDescriptorsPass()); + } else { + pm.addPass(mlir::mhlo::createLegalizeTFPass( + /*allow_partial_conversion=*/false, /*legalize_chlo=*/false)); + pm.addPass(mlir::mhlo::createChloLegalizeToHloPass()); + pm.addPass(mlir::createTransformUnrankedHloPass()); + pm.addPass(mlir::kernel_gen::transforms::CreateShapeToDescriptorsPass()); + // Clean up the IR created above. In particular, operations on descriptors + // are simplified here. + pm.addPass(mlir::createCanonicalizerPass()); + pm.addPass(mlir::kernel_gen::transforms::CreateBufferizePass()); + pm.addPass(mlir::kernel_gen::transforms::CreateParallelLoopsToSequential()); + } + + // Clean up the IR for further processing. + pm.addPass(mlir::createCanonicalizerPass()); + // We have to anticipate later unrolling in tiling to make sure that we get + // the requested tiling after unrolling. Compute the new tiling here if + // needed. + llvm::SmallVector tiling_for_unrolling; + llvm::SmallVector as_int64; + if (!unroll_factors.empty()) { + tiling_for_unrolling.reserve(tile_sizes.size()); + for (auto pair : llvm::zip(tile_sizes, unroll_factors)) { + tiling_for_unrolling.push_back(std::get<0>(pair) * std::get<1>(pair)); + as_int64.push_back(std::get<1>(pair)); + } + } else { + tiling_for_unrolling.append(tile_sizes.begin(), tile_sizes.end()); + } + // Transform LHLO operations to LinAlg. + pm.addPass(::mlir::lmhlo::createLegalizeLhloToLinalgPass()); + // Fuse linalg operations. + pm.addPass(::mlir::lmhlo::createLhloFuseLinalgPass( + /*use_parallel_loops=*/true, tiling_for_unrolling)); + // Transform the Linalg operations inside of the loop nest into parallel + // loops. + pm.addPass(::mlir::createConvertLinalgToParallelLoopsPass()); + // Canonicalize the code to simplify index computations. This is needed so + // that loop bounds have the same value. + pm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass()); + pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass()); + // Fuse the inner-most loops. + pm.addPass(xla::mlir_gpu::createFuseInnerParallelLoopsPass()); + // Run CSE to ensure that loads and stores to the same subview get + // recognized as such. + pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass()); + // Forward stores to buffers to loads. + pm.addPass(xla::mlir_gpu::createStoreForwardingPass()); + // Remove now unused temporary buffers. + pm.addPass(xla::mlir_gpu::createDeadTempBufferRemovalPass()); + if (!unroll_factors.empty()) { + pm.addPass(::mlir::createParallelLoopTilingPass(as_int64)); + } + // Some basic cleanup. + pm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass()); + pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass()); + // Greedily map the remaining loop to GPU hardware dimensions. + pm.addPass(xla::mlir_gpu::createMapParallelLoopsPass()); + // Apply the mapping. + pm.addPass(mlir::createParallelLoopToGpuPass()); + + // Embed TF Framework ops. + if (!gpu_binary_only) { + pm.addPass(mlir::kernel_gen::tf_framework::CreateEmbedTFFrameworkPass()); + } + + // Some basic cleanup. + pm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass()); + pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass()); + // Make loops with min bounds into a conditional plus static bounds. + // Only do this if we unrolled in the first place. + if (!unroll_factors.empty()) { + pm.addNestedPass<::mlir::FuncOp>(mlir::createForLoopSpecializationPass()); + } + // Approximate Tanh using standard operations. + pm.addNestedPass<::mlir::FuncOp>( + ::mlir::mhlo::createLegalizeTrigonometricToApproximationPass()); + // Take launches to launches with kernels. + pm.addPass(::mlir::createGpuKernelOutliningPass()); + + if (gpu_binary_only) { + // Make kernel signature deterministic so that we can call it externally. + pm.addPass(xla::mlir_gpu::createRewriteKernelSignaturePass()); + } + pm.addPass(::mlir::createLowerAffinePass()); + pm.addPass(::mlir::createLowerToCFGPass()); + if (failed(pm.run(module))) { + return InternalError("Lowering to GPU kernels failed."); + } + return Status::OK(); +} + +Status LowerGPUToLLVM(mlir::ModuleOp module, bool gpu_binary_only, + llvm::ArrayRef same_shape, + llvm::StringRef gpu_binary_attr_name, + llvm::ArrayRef architectures, + bool generate_fatbin) { + mlir::PassManager pm(module.getContext()); + applyTensorflowAndCLOptions(pm); + + auto& kernel_pm = pm.nest(); + if (gpu_binary_only) { + // Grab the original signature from the single function. + kernel_pm.addNestedPass( + mlir::kernel_gen::transforms::CreatePropagateTensorFlowABIKnowledgePass( + same_shape)); + } + kernel_pm.addPass(mlir::createStripDebugInfoPass()); + kernel_pm.addPass(mlir::kernel_gen::transforms::CreateGpuKernelToBlobPass( + gpu_binary_attr_name, architectures, generate_fatbin)); + + if (!gpu_binary_only) { + pm.addPass(mlir::kernel_gen::transforms::CreateTFKernelToLLVMPass()); + pm.addPass(mlir::createCanonicalizerPass()); + pm.addPass(mlir::createCSEPass()); + } + return failed(pm.run(module)) ? InternalError("Lowering to LLVM IR failed.") + : Status::OK(); +} + +} // namespace + +StatusOr GenerateKernelForTfCode( + mlir::MLIRContext& context, llvm::StringRef tf_code, bool gpu_binary_only, + llvm::ArrayRef architectures, + llvm::ArrayRef tile_sizes, llvm::ArrayRef same_shape, + llvm::ArrayRef unroll_factors, bool generate_fatbin) { + mlir::RegisterAllTensorFlowDialects(context.getDialectRegistry()); + mlir::OwningModuleRef module = mlir::parseSourceString(tf_code, &context); + TF_RETURN_IF_ERROR( + LowerTFtoGPU(module.get(), gpu_binary_only, tile_sizes, unroll_factors)); +#if !defined(TENSORFLOW_USE_ROCM) && !defined(GOOGLE_CUDA) + return InternalError( + "Neither TENSORFLOW_USE_ROCM nor GOOGLE_CUDA are defined." + " Did you specify either --config=rocm or --config=cuda ?"); +#endif + +#if TENSORFLOW_USE_ROCM + TF_RETURN_IF_ERROR(xla::mlir_gpu::LowerKernelBodiesToROCDL(module.get())); +#elif GOOGLE_CUDA + TF_RETURN_IF_ERROR(xla::mlir_gpu::LowerKernelBodiesToNVVM(module.get())); +#endif + TF_RETURN_IF_ERROR(LowerGPUToLLVM(module.get(), gpu_binary_only, same_shape, + kGpuBinaryAttrName, architectures, + generate_fatbin)); + return module; +} + +StatusOr ExtractGpuBinary(mlir::ModuleOp module) { + auto gpu_modules = module.getOps(); + if (std::distance(gpu_modules.begin(), gpu_modules.end()) != 1) { + return InternalError("There should be exactly one GPU Module"); + } + mlir::gpu::GPUModuleOp gpu_mod = *gpu_modules.begin(); + auto blob = gpu_mod.getAttrOfType(kGpuBinaryAttrName); + if (blob == nullptr) { + return InternalError("No binary blob found in the module"); + } + return blob.getValue().str(); +} + +} // namespace kernel_gen +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.h b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h similarity index 52% rename from tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.h rename to tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h index 47626ba9d0d..6767944d539 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.h +++ b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h @@ -13,30 +13,40 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -//===- cubin_creator.h ------------------------------------------*- C++ -*-===// +//===- kernel_creator.h -----------------------------------------*- C++ -*-===// // -// This file declares the function to compile a TF kernel function to a cubin. +// This file declares the function to compile a TF kernel function to gpu +// binary (hsaco for AMD, cubin for NVIDIA) or to a gpu binary with host side. // //===----------------------------------------------------------------------===// -#ifndef TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_CUBIN_CREATOR_H_ -#define TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_CUBIN_CREATOR_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_KERNEL_CREATOR_H_ +#define TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_KERNEL_CREATOR_H_ #include -#include #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/StringRef.h" +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project #include "tensorflow/compiler/xla/statusor.h" namespace tensorflow { namespace kernel_gen { -xla::StatusOr> GenerateCubinForTfCode( - llvm::StringRef tf_code, - std::pair compute_capability = {7, 5}, + +// Converts TF code to LLVM/NVVM. If `gpu_binary_only` is true, then the +// conversion stops after gpu_binary blob is generated. If `gpu_binary_only` is +// false, lowers the host side to LLVM Dialect. +xla::StatusOr GenerateKernelForTfCode( + mlir::MLIRContext& context, llvm::StringRef tf_code, bool gpu_binary_only, + llvm::ArrayRef architectures = {"sm_75"}, llvm::ArrayRef tile_sizes = {16, 64}, llvm::ArrayRef same_shape = {}, - llvm::ArrayRef unroll_factors = {}); + llvm::ArrayRef unroll_factors = {}, bool generate_fatbin = true); + +// Extracts gpu_binary from the converted module. +xla::StatusOr ExtractGpuBinary(mlir::ModuleOp module); + } // namespace kernel_gen } // namespace tensorflow -#endif // TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_CUBIN_CREATOR_H_ +#endif // TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_KERNEL_CREATOR_H_ diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/tests/BUILD new file mode 100644 index 00000000000..25568398442 --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/BUILD @@ -0,0 +1,25 @@ +load("//tensorflow:tensorflow.bzl", "filegroup") +load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") + +package(licenses = ["notice"]) + +glob_lit_tests( + data = [ + ":test_utilities", + "@llvm-project//mlir:run_lit.sh", + ], + driver = "//tensorflow/compiler/mlir:run_lit.sh", + test_file_exts = ["mlir"], +) + +# Bundle together all of the test utilities that are used by tests. +filegroup( + name = "test_utilities", + testonly = True, + data = [ + "//tensorflow/compiler/mlir:tf-opt", + "//tensorflow/compiler/mlir/hlo:mlir-hlo-opt", + "//tensorflow/compiler/mlir/tools/kernel_gen:kernel-gen-opt", + "@llvm-project//llvm:FileCheck", + ], +) diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/bufferize.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/bufferize.mlir new file mode 100644 index 00000000000..1a278365464 --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/bufferize.mlir @@ -0,0 +1,78 @@ +// RUN: kernel-gen-opt %s --bufferize | FileCheck %s + +// CHECK-LABEL: @extract_element +// CHECK-SAME: (%[[ARG:.*]]: memref) -> f32 +func @extract_element(%arg : tensor) -> f32 { + // CHECK: %[[C0:.*]] = constant 0 : index + // CHECK: %[[RESULT:.*]] = load %[[ARG]][%[[C0]]] + // CHECK: return %[[RESULT]] + %c0 = constant 0 : index + %result = extract_element %arg[%c0] : tensor + return %result : f32 +} + +// CHECK-LABEL: @tensor_load +// CHECK-SAME: (%[[ARG:.*]]: memref) -> memref +func @tensor_load(%arg : memref) -> tensor { + // CHECK: return %[[ARG]] : memref + %result = tensor_load %arg : memref + return %result : tensor +} + +// CHECK-LABEL: @tensor_from_elements +// CHECK-SAME: (%[[A:.*]]: f32) -> memref<3xf32> +func @tensor_from_elements(%a : f32) -> tensor<3xf32> { + // CHECK: %[[B:.*]] = constant 1.2 + // CHECK: %[[C:.*]] = constant 2.3 + // CHECK: %[[MEM:.*]] = alloca() : memref<3xf32> + // CHECK: %[[C0:.*]] = constant 0 : index + // CHECK: store %[[A]], %[[MEM]][%[[C0]]] : memref<3xf32> + // CHECK: %[[C1:.*]] = constant 1 : index + // CHECK: store %[[B]], %[[MEM]][%[[C1]]] : memref<3xf32> + // CHECK: %[[C2:.*]] = constant 2 : index + // CHECK: store %[[C]], %[[MEM]][%[[C2]]] : memref<3xf32> + // CHECK: return %[[MEM]] : memref<3xf32> + %b = constant 1.2 : f32 + %c = constant 2.3 : f32 + %result = tensor_from_elements %a, %b, %c : tensor<3xf32> + return %result : tensor<3xf32> +} + +// CHECK-LABEL: @dynamic_tensor_from_elements +// CHECK-SAME: (%[[ARG:.*]]: memref<*xf32>) -> memref +func @dynamic_tensor_from_elements(%arg : tensor<*xf32>) -> tensor { + // CHECK: %[[C3:.*]] = constant 3 : index + // CHECK: %[[MEM:.*]] = alloca(%c3) : memref + // CHECK: %[[C0:.*]] = constant 0 : index + // CHECK: %[[C1:.*]] = constant 1 : index + // CHECK: scf.parallel (%[[I:.*]]) = (%[[C0]]) to (%[[C3]]) step (%[[C1]]) { + // CHECK: %[[ELEM:.*]] = dim %[[ARG]], %[[I]] : memref<*xf32> + // CHECK: store %[[ELEM]], %[[MEM]][%[[I]]] : memref + // CHECK: scf.yield + // CHECK: } + // CHECK: return %[[MEM]] : memref + %c3 = constant 3 : index + %result = dynamic_tensor_from_elements %c3 { + ^bb0(%i : index): + %elem = dim %arg, %i : tensor<*xf32> + yield %elem : index + } : tensor + return %result : tensor +} + +// CHECK-LABEL: @assuming +// CHECK-SAME: (%[[WITNESS:.*]]: !shape.witness, %[[ARG:.*]]: memref) +// CHECK-SAME: -> memref +func @assuming(%witness: !shape.witness, %arg : memref) + -> tensor { + // CHECK-NEXT: %[[ASSUMING_RESULT:.*]] = shape.assuming %[[WITNESS]] + // CHECK-SAME: -> (memref) { + // CHECK-NEXT: shape.assuming_yield %[[ARG]] : memref + // CHECK-NEXT: } + // CHECK-NEXT: return %[[ASSUMING_RESULT]] : memref + %assuming_result = shape.assuming %witness -> (tensor) { + %result = tensor_load %arg : memref + shape.assuming_yield %result : tensor + } + return %assuming_result : tensor +} diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/embed_tf_framework.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/embed_tf_framework.mlir new file mode 100644 index 00000000000..bb0f1926cda --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/embed_tf_framework.mlir @@ -0,0 +1,37 @@ +// RUN: kernel-gen-opt %s -embed-tf-framework -split-input-file | FileCheck %s + +// CHECK-LABEL: func @tf_entry( +// CHECK-SAME: [[CTX:%.*]]: !tf_framework.op_kernel_context, +// CHECK-SAME: [[SIZE_0:%.*]]: index, +// CHECK-SAME: [[SIZE_2:%.*]]: index) -> index attributes {tf_entry} { +func @tf_entry(%size_0 : index , %size_2 : index) -> index + attributes {tf_entry} { + %buf = alloc(%size_0, %size_2)[] : memref + dealloc %buf : memref + std.return %size_0 : index +} +// CHECK-NEXT: [[VAL_3:%.*]] = tf_framework.alloc_raw +// CHECK-SAME: ([[CTX]], [[SIZE_0]], [[SIZE_2]]) : memref +// CHECK-NEXT: tf_framework.dealloc_raw([[CTX]], [[VAL_3]]) : memref +// CHECK-NEXT: return [[SIZE_0]] : index + +// ----- + +// CHECK-LABEL: func @non_tf_entry( +// CHECK-SAME: [[SIZE_0:%.*]]: index, [[SIZE_2:%.*]]: index) -> index +func @non_tf_entry(%size_0 : index , %size_2 : index) -> index { + std.return %size_0 : index +} + +// ----- + +// CHECK-LABEL: func @tf_entry( +func @tf_entry(%size : index) attributes {tf_entry} { + %buf = alloc()[%size] : memref<64xf32, affine_map<(d0)[s0] -> (d0 + s0)>> + dealloc %buf : memref<64xf32, affine_map<(d0)[s0] -> (d0 + s0)>> + std.return +} +// CHECK_NOT: alloc_raw +// CHECK: alloc() +// CHECK_NOT: dealloc_raw +// CHECK: dealloc % diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/invalid.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/invalid.mlir new file mode 100644 index 00000000000..1d1b3319515 --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/invalid.mlir @@ -0,0 +1,7 @@ +// RUN: kernel-gen-opt %s -split-input-file -verify-diagnostics + +func @alloc_raw(%ctx: !tf_framework.op_kernel_context, %size : index) { + // expected-error @+1 {{`dyn_sizes` count 1 does not match dynamic dimensions}} + %buf = tf_framework.alloc_raw(%ctx, %size) : memref + return +} diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/ops.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/ops.mlir new file mode 100644 index 00000000000..fc8e7c97ec8 --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/ops.mlir @@ -0,0 +1,25 @@ +// RUN: kernel-gen-opt %s | FileCheck %s +// Verify the printed output can be parsed. +// RUN: kernel-gen-opt %s | kernel-gen-opt | FileCheck %s +// Verify the generic form can be parsed. +// RUN: kernel-gen-opt -mlir-print-op-generic %s | kernel-gen-opt | FileCheck %s + +// CHECK-LABEL: func @alloc_raw +func @alloc_raw(%ctx: !tf_framework.op_kernel_context, + %size_0 : index , %size_2 : index) { + %buf_0 = tf_framework.alloc_raw(%ctx) : memref<10xi8> + %buf_1 = tf_framework.alloc_raw(%ctx, %size_0, %size_2) : memref + return +} + +// CHECK-LABEL: func @dealloc_raw +func @dealloc_raw(%ctx: !tf_framework.op_kernel_context, %memref : memref) { + tf_framework.dealloc_raw(%ctx, %memref) : memref + return +} + +// CHECK-LABEL: func @null_context +func @null_context() { + tf_framework.null_context() : !tf_framework.op_kernel_context + return +} diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/parallel_loops_to_sequential.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/parallel_loops_to_sequential.mlir new file mode 100644 index 00000000000..df059759ecc --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/parallel_loops_to_sequential.mlir @@ -0,0 +1,17 @@ +// RUN: kernel-gen-opt %s --parallel-loops-to-sequential | FileCheck %s + +// CHECK-LABEL: @parallel_loop +func @parallel_loop(%lb_0 : index, %lb_1 : index, + %ub_0 : index, %ub_1 : index, + %s_0 : index, %s_1 : index, + %buf: memref) { + scf.parallel (%i0, %i1) = (%lb_0, %lb_1) to (%ub_0, %ub_1) step (%s_0, %s_1) { + %sum_elem = addi %i0, %i1 : index + store %sum_elem, %buf[%i0, %i1] : memref + } + return +} +// CHECK: scf.for [[I_0:%.*]] = [[LB_0:%.*]] to [[UB_0:%.*]] step [[S_0:%.*]] +// CHECK: scf.for [[I_1:%.*]] = [[LB_1:%.*]] to [[UB_1:%.*]] step [[S_1:%.*]] +// CHECK: [[SUM:%.*]] = addi [[I_0]], [[I_1]] : index +// CHECK: store [[SUM]], {{%.*}}{{\[}}[[I_0]], [[I_1]]] : memref diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tanh.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tanh.mlir new file mode 100644 index 00000000000..53d02322c55 --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tanh.mlir @@ -0,0 +1,20 @@ +// RUN: tf-opt %s --xla-legalize-tf | mlir-hlo-opt --transform-unranked-hlo | kernel-gen-opt -allow-unregistered-dialect --shape-to-descriptors --canonicalize --bufferize | FileCheck %s + +// Test whether all shape computations required for tanh can be lowered to +// the standard dialect, scf and descriptors. We check for a sparse pattern here, +// as each lowering pattern is already tested and we just care for the +// integration. +// TODO: Expand this pattern once things have stabilized. +// CHECK-LABEL: @tanh +func @tanh(%arg0: tensor<*xf32>) -> tensor<*xf32> { + // CHECK: alloca + // CHECK: scf.parallel + // CHECK-NOT: tensor_load + // CHECK: scf.for + // CHECK-NOT: tensor_from_elements + // CHECK: mhlo.reshape_memref_cast + // CHECK: lmhlo.tanh + // CHECK: mhlo.reshape_memref_cast + %0 = "tf.Tanh"(%arg0) { } : (tensor<*xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf-legalize-to-mlhlo.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf-legalize-to-mlhlo.mlir new file mode 100644 index 00000000000..2fc585d9e9d --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf-legalize-to-mlhlo.mlir @@ -0,0 +1,26 @@ +// RUN: tf-opt %s --xla-legalize-tf='legalize-chlo=false' | mlir-hlo-opt --transform-unranked-hlo --chlo-legalize-to-hlo | kernel-gen-opt --shape-to-descriptors --canonicalize --bufferize + +func @acos(%arg0: tensor<*xf32>) -> tensor<*xf32> { + %0 = "tf.Acos"(%arg0) { } : (tensor<*xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +func @tan(%arg0: tensor<*xf32>) -> tensor<*xf32> { + %0 = "tf.Tan"(%arg0) { } : (tensor<*xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +func @tanh(%arg0: tensor<*xf32>) -> tensor<*xf32> { + %0 = "tf.Tanh"(%arg0) { } : (tensor<*xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +func @sin(%arg0: tensor<*xf32>) -> tensor<*xf32> { + %0 = "tf.Sin"(%arg0) { } : (tensor<*xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +func @sinh(%arg0: tensor<*xf32>) -> tensor<*xf32> { + %0 = "tf.Sinh"(%arg0) { } : (tensor<*xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_framework_legalize_to_llvm.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_framework_legalize_to_llvm.mlir new file mode 100644 index 00000000000..b943321e95b --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_framework_legalize_to_llvm.mlir @@ -0,0 +1,75 @@ +// RUN: kernel-gen-opt %s -tf-kernel-to-llvm -split-input-file | FileCheck %s + +// CHECK: llvm.func @_mlir_ciface_tf_alloc_raw +// CHECK-SAME: (!llvm.ptr, !llvm.i64) -> !llvm.ptr + +// CHECK-LABEL: llvm.func @alloc_raw( +// CHECK-SAME: [[TF_CTX:%.*]]: !llvm.ptr, +// CHECK-SAME: [[SIZE_0:%.*]]: !llvm.i64, +// CHECK-SAME: [[SIZE_2:%.*]]: !llvm.i64) -> [[DESC_TY:!.*]] { +func @alloc_raw(%ctx: !tf_framework.op_kernel_context, + %size_0 : index , %size_2 : index) -> memref { + %buf = tf_framework.alloc_raw(%ctx, %size_0, %size_2) : memref + std.return %buf : memref +} +// Compute number of elements. +// CHECK: [[SIZE_1:%.*]] = llvm.mlir.constant(10 : index) : !llvm.i64 +// CHECK: [[NUM_ELEM_0:%.*]] = llvm.mul [[SIZE_0]], [[SIZE_1]] : !llvm.i64 +// CHECK: [[NUM_ELEM_1:%.*]] = llvm.mul [[NUM_ELEM_0]], [[SIZE_2]] : !llvm.i64 + +// Compute the size of an individual element. +// CHECK: [[NULL:%.*]] = llvm.mlir.null : !llvm.ptr +// CHECK: [[C1:%.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK: [[GEP:%.*]] = llvm.getelementptr [[NULL]]{{\[}}[[C1]]] +// CHECK-SAME: (!llvm.ptr, !llvm.i64) -> !llvm.ptr +// CHECK: [[SIZE_OF_FLOAT:%.*]] = llvm.ptrtoint [[GEP]] +// CHECK-SAME: !llvm.ptr to !llvm.i64 + +// Allocate memory. +// CHECK: [[NUM_BYTES:%.*]] = llvm.mul [[NUM_ELEM_1]], [[SIZE_OF_FLOAT]] +// CHECK: [[BYTES_PTR:%.*]] = llvm.call @{{.*}}([[TF_CTX]], [[NUM_BYTES]]) +// CHECK-SAME: (!llvm.ptr, !llvm.i64) -> !llvm.ptr + +// Build memref descriptor. +// CHECK: [[DESC_0:%.*]] = llvm.mlir.undef : [[DESC_TY]] + +// Set pointers and offset. +// CHECK: [[FLOAT_PTR:%.*]] = llvm.bitcast [[BYTES_PTR]] +// CHECK-SAME: !llvm.ptr to !llvm.ptr +// CHECK: [[DESC_1:%.*]] = llvm.insertvalue [[FLOAT_PTR]], [[DESC_0]][0] +// CHECK: [[DESC_2:%.*]] = llvm.insertvalue [[FLOAT_PTR]], [[DESC_1]][1] +// CHECK: [[C0:%.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK: [[DESC_3:%.*]] = llvm.insertvalue [[C0]], [[DESC_2]][2] : [[DESC_TY]] + +// Set sizes and strides. +// CHECK: [[STRIDE_2:%.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK: [[DESC_4:%.*]] = llvm.insertvalue [[SIZE_2]], [[DESC_3]][3, 2] +// CHECK: [[DESC_5:%.*]] = llvm.insertvalue [[STRIDE_2]], [[DESC_4]][4, 2] +// CHECK: [[STRIDE_1:%.*]] = llvm.mul [[STRIDE_2]], [[SIZE_2]] : !llvm.i64 +// CHECK: [[DESC_6:%.*]] = llvm.insertvalue [[SIZE_1]], [[DESC_5]][3, 1] +// CHECK: [[DESC_7:%.*]] = llvm.insertvalue [[STRIDE_1]], [[DESC_6]][4, 1] +// CHECK: [[STRIDE_0:%.*]] = llvm.mul [[STRIDE_1]], [[SIZE_1]] : !llvm.i64 +// CHECK: [[DESC_8:%.*]] = llvm.insertvalue [[SIZE_0]], [[DESC_7]][3, 0] +// CHECK: [[DESC_9:%.*]] = llvm.insertvalue [[STRIDE_0]], [[DESC_8]][4, 0] +// CHECK: llvm.return [[DESC_9]] : [[DESC_TY]] + +// ----- + +// CHECK: llvm.func @_mlir_ciface_tf_dealloc_raw(!llvm.ptr, !llvm.ptr) + +// CHECK-LABEL: llvm.func @dealloc_raw( +// CHECK-SAME: [[TF_CTX:%.*]]: !llvm.ptr, +func @dealloc_raw(%ctx: !tf_framework.op_kernel_context, + %memref : memref) { + tf_framework.dealloc_raw(%ctx, %memref) : memref + return +} +// Extract allocated ptr from the memref descriptor. +// CHECK: %{{.*}} = llvm.mlir.undef : [[DESC_TY:!.*]] +// CHECK: [[FLOAT_PTR:%.*]] = llvm.extractvalue %{{.*}}[0] : [[DESC_TY]] +// CHECK-NEXT: [[VOID_PTR:%.*]] = llvm.bitcast [[FLOAT_PTR]] +// CHECK-SAME: !llvm.ptr to !llvm.ptr + +// Deallocate. +// CHECK: llvm.call @_mlir_ciface_tf_dealloc_raw( +// CHECK-SAME: [[TF_CTX]], [[VOID_PTR]]) : (!llvm.ptr, !llvm.ptr) -> () diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_gpu_binary/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_gpu_binary/BUILD new file mode 100644 index 00000000000..6aef5c05fe9 --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_gpu_binary/BUILD @@ -0,0 +1,17 @@ +load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") + +package(licenses = ["notice"]) + +glob_lit_tests( + data = [ + "//tensorflow/compiler/mlir/tools/kernel_gen:tf_to_gpu_binary", + "@llvm-project//mlir:run_lit.sh", + ], + default_tags = [ + # We need access to the CUDA SDK. + "gpu", + "no_rocm", + ], + driver = "//tensorflow/compiler/mlir:run_lit.sh", + test_file_exts = ["mlir"], +) diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_gpu_binary/abs.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_gpu_binary/abs.mlir new file mode 100644 index 00000000000..edb023e5fe7 --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_gpu_binary/abs.mlir @@ -0,0 +1,6 @@ +// RUN: tf_to_gpu_binary --input=%s --output=%t --same_shape=0,1 --unroll_factors=4 --tile_sizes=256 --arch=sm_70 +func @abs(%arg0: tensor) -> tensor { + %0 = "tf.Abs"(%arg0) { } + : (tensor) -> tensor + return %0 : tensor +} diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_gpu_binary/ceil.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_gpu_binary/ceil.mlir new file mode 100644 index 00000000000..25b79c47f4e --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_gpu_binary/ceil.mlir @@ -0,0 +1,6 @@ +// RUN: tf_to_gpu_binary --input=%s --output=%t --same_shape=0,1 --unroll_factors=4 --tile_sizes=256 --arch=sm_70 +func @ceil(%arg0: tensor) -> tensor { + %0 = "tf.Ceil"(%arg0) { } + : (tensor) -> tensor + return %0 : tensor +} diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_gpu_binary/tanh.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_gpu_binary/tanh.mlir new file mode 100644 index 00000000000..69632f498a9 --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_gpu_binary/tanh.mlir @@ -0,0 +1,5 @@ +// RUN: tf_to_gpu_binary --input=%s --output=%t --same_shape=0,1 --unroll_factors=4 --tile_sizes=256 --arch=sm_70 +func @tanh(%arg0: tensor) -> tensor { + %0 = "tf.Tanh"(%arg0) : (tensor) -> tensor + return %0 : tensor +} diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_kernel/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_kernel/BUILD new file mode 100644 index 00000000000..24e288c246c --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_kernel/BUILD @@ -0,0 +1,17 @@ +load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") + +package(licenses = ["notice"]) + +glob_lit_tests( + data = [ + "//tensorflow/compiler/mlir/tools/kernel_gen:tf_to_kernel", + "@llvm-project//mlir:run_lit.sh", + ], + default_tags = [ + # We need access to the CUDA SDK. + "gpu", + "no_rocm", + ], + driver = "//tensorflow/compiler/mlir:run_lit.sh", + test_file_exts = ["mlir"], +) diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_kernel/tanh.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_kernel/tanh.mlir new file mode 100644 index 00000000000..85bea1795a5 --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_kernel/tanh.mlir @@ -0,0 +1,6 @@ +// RUN: tf_to_kernel --input=%s --output=%t --same_shape=0,1 --unroll_factors=4 --tile_sizes=256 --arch=sm_70,compute_75 + +func @tanh(%arg: tensor<*xf32>) -> tensor<*xf32> { + %0 = "tf.Tanh"(%arg) : (tensor<*xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tf_cuda_runtime_wrappers.cc b/tensorflow/compiler/mlir/tools/kernel_gen/tf_cuda_runtime_wrappers.cc new file mode 100644 index 00000000000..06d613e0599 --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tf_cuda_runtime_wrappers.cc @@ -0,0 +1,113 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Implements C wrappers around the CUDA library for easy linking in ORC jit. +// Also adds some debugging helpers that are helpful when writing MLIR code to +// run on GPUs. + +#include +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/ExecutionEngine/CRunnerUtils.h" // from @llvm-project + +#if GOOGLE_CUDA +#include "third_party/gpus/cuda/include/cuda.h" + +#define CUDA_REPORT_IF_ERROR(expr) \ + [](CUresult result) { \ + if (!result) \ + return; \ + const char *name = nullptr; \ + cuGetErrorName(result, &name); \ + if (!name) \ + name = ""; \ + llvm::errs() << "'" << #expr << "' failed with '" << name << "'\n"; \ + }(expr) + +extern "C" CUmodule mgpuModuleLoad(void *data) { + CUmodule module = nullptr; + CUDA_REPORT_IF_ERROR(cuModuleLoadData(&module, data)); + return module; +} + +extern "C" CUfunction mgpuModuleGetFunction(CUmodule module, const char *name) { + CUfunction function = nullptr; + CUDA_REPORT_IF_ERROR(cuModuleGetFunction(&function, module, name)); + return function; +} + +// The wrapper uses intptr_t instead of CUDA's unsigned int to match +// the type of MLIR's index type. This avoids the need for casts in the +// generated MLIR code. +extern "C" void mgpuLaunchKernel(CUfunction function, intptr_t gridX, + intptr_t gridY, intptr_t gridZ, + intptr_t blockX, intptr_t blockY, + intptr_t blockZ, int32_t smem, CUstream stream, + void **params, void **extra) { + CUDA_REPORT_IF_ERROR(cuLaunchKernel(function, gridX, gridY, gridZ, blockX, + blockY, blockZ, smem, stream, params, + extra)); +} + +extern "C" CUstream mgpuStreamCreate() { + static CUstream stream = []() { + // TODO(b/170649852): This is neither thread-safe nor handles + // creation/descruction of one stream per context. + CUstream stream = nullptr; + CUDA_REPORT_IF_ERROR(cuStreamCreate(&stream, CU_STREAM_NON_BLOCKING)); + return stream; + }(); + return stream; +} + +extern "C" void mgpuStreamSynchronize(CUstream stream) { + CUDA_REPORT_IF_ERROR(cuStreamSynchronize(stream)); +} + +/// Helper functions for writing mlir example code + +// Allows to register byte array with the CUDA runtime. Helpful until we have +// transfer functions implemented. +extern "C" void mgpuMemHostRegister(void *ptr, uint64_t sizeBytes) { + CUDA_REPORT_IF_ERROR(cuMemHostRegister(ptr, sizeBytes, /*flags=*/0)); +} + +// Allows to register a MemRef with the CUDA runtime. Helpful until we have +// transfer functions implemented. +extern "C" void +mgpuMemHostRegisterMemRef(int64_t rank, StridedMemRefType *descriptor, + int64_t elementSizeBytes) { + + llvm::SmallVector denseStrides(rank); + llvm::ArrayRef sizes(descriptor->sizes, rank); + llvm::ArrayRef strides(sizes.end(), rank); + + std::partial_sum(sizes.rbegin(), sizes.rend(), denseStrides.rbegin(), + std::multiplies()); + auto sizeBytes = denseStrides.front() * elementSizeBytes; + + // Only densely packed tensors are currently supported. + std::rotate(denseStrides.begin(), denseStrides.begin() + 1, + denseStrides.end()); + denseStrides.back() = 1; + assert(strides == llvm::makeArrayRef(denseStrides)); + + auto ptr = descriptor->data + descriptor->offset * elementSizeBytes; + mgpuMemHostRegister(ptr, sizeBytes); +} + +#endif diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_cubin.cc b/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_gpu_binary.cc similarity index 56% rename from tensorflow/compiler/mlir/tools/kernel_gen/tf_to_cubin.cc rename to tensorflow/compiler/mlir/tools/kernel_gen/tf_to_gpu_binary.cc index 96831689600..84c2bf46b55 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_cubin.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_gpu_binary.cc @@ -12,9 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -//===- tf_to_cubin.cc -------------------------------------------*- C++ -*-===// +//===- tf_to_gpu_binary.cc --------------------------------------*- C++ -*-===// // -// This file implements the entry point to compile a tf op to a cubin file. +// This file implements the entry point to compile a tf op to a gpu binary // //===----------------------------------------------------------------------===// #include @@ -23,10 +23,44 @@ #include "absl/strings/string_view.h" #include "llvm/Support/CommandLine.h" +#include "mlir/Pass/PassManager.h" // from @llvm-project #include "tensorflow/compiler/mlir/init_mlir.h" -#include "tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.h" +#include "tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/stream_executor/lib/statusor.h" + +namespace tensorflow { +namespace kernel_gen { +namespace { + +xla::Status Run(llvm::StringRef input_file, llvm::StringRef output_file, + std::string architecture, llvm::ArrayRef tile_sizes, + llvm::ArrayRef same_shape, + llvm::ArrayRef unroll_factors) { + // Read TF code. + std::string tf_code; + TF_RETURN_IF_ERROR( + ReadFileToString(Env::Default(), input_file.str(), &tf_code)); + // Compile. + mlir::MLIRContext context; + TF_ASSIGN_OR_RETURN( + mlir::OwningModuleRef module, + GenerateKernelForTfCode(context, tf_code, /*gpu_binary_only=*/true, + architecture, tile_sizes, same_shape, + unroll_factors, /*generate_fatbin=*/false)); + // Extract gpu_binary. + TF_ASSIGN_OR_RETURN(std::string gpu_binary, ExtractGpuBinary(*module)); + + // Write gpu_binary blob. + TF_RETURN_IF_ERROR( + WriteStringToFile(Env::Default(), output_file.str(), gpu_binary)); + return xla::Status::OK(); +} + +} // namespace +} // namespace kernel_gen +} // namespace tensorflow int main(int argc, char** argv) { llvm::cl::opt input_file("input", llvm::cl::desc("input file"), @@ -35,9 +69,9 @@ int main(int argc, char** argv) { llvm::cl::opt output_file( "output", llvm::cl::desc("output file"), llvm::cl::value_desc("filename"), llvm::cl::init("foo.bin")); - llvm::cl::opt architecture( - "arch", llvm::cl::desc("target architecture (e.g. 50 for sm_50)"), - llvm::cl::init(50)); + llvm::cl::opt architecture( + "arch", llvm::cl::desc("target architecture (e.g. sm_50)"), + llvm::cl::init("sm_50")); llvm::cl::list tile_sizes( "tile_sizes", llvm::cl::desc("tile sizes to use"), llvm::cl::ZeroOrMore, llvm::cl::CommaSeparated); @@ -51,38 +85,15 @@ int main(int argc, char** argv) { llvm::cl::ZeroOrMore, llvm::cl::CommaSeparated); tensorflow::InitMlir y(&argc, &argv); + mlir::registerPassManagerCLOptions(); llvm::cl::ParseCommandLineOptions(argc, argv, "TF op GPU kernel generator\n"); - std::pair compute_capability(architecture / 10, - architecture % 10); - - std::string tf_code; - auto read_status = tensorflow::ReadFileToString(tensorflow::Env::Default(), - input_file, &tf_code); - if (!read_status.ok()) { - LOG(ERROR) << read_status; - return 1; - } - - auto cubin = tensorflow::kernel_gen::GenerateCubinForTfCode( - tf_code, compute_capability, tile_sizes, same_shape, unroll_factors); - - if (!cubin.ok()) { - LOG(ERROR) << cubin.status(); - return 1; - } - - std::vector cubin_data = cubin.ConsumeValueOrDie(); - - auto status = tensorflow::WriteStringToFile( - tensorflow::Env::Default(), output_file, - absl::string_view{reinterpret_cast(cubin_data.data()), - cubin_data.size()}); - + auto status = + tensorflow::kernel_gen::Run(input_file, output_file, architecture, + tile_sizes, same_shape, unroll_factors); if (!status.ok()) { LOG(ERROR) << status; return 1; } - return 0; } diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_kernel.cc b/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_kernel.cc new file mode 100644 index 00000000000..87c8e57804b --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_kernel.cc @@ -0,0 +1,162 @@ +// Copyright 2020 The TensorFlow Runtime Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//===- tf_to_kernel.cc ------------------------------------------*- C++ -*-===// +// +// This file implements the entry point to compile a tf op to a kernel. +// +//===----------------------------------------------------------------------===// +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/CodeGen/CommandFlags.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Host.h" +#include "llvm/Support/TargetRegistry.h" +#include "llvm/Support/TargetSelect.h" +#include "llvm/Target/TargetMachine.h" +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Target/LLVMIR.h" // from @llvm-project +#include "tensorflow/compiler/mlir/init_mlir.h" +#include "tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/stream_executor/lib/statusor.h" + +namespace tensorflow { +namespace kernel_gen { +namespace { + +static llvm::codegen::RegisterCodeGenFlags CGF; + +std::unique_ptr GetTargetMachine(llvm::Module* module) { + llvm::Triple triple(module->getTargetTriple()); + if (triple.getTriple().empty()) { + triple = llvm::Triple(llvm::sys::getDefaultTargetTriple()); + module->setTargetTriple(triple.getTriple()); + } + + std::string error; + const llvm::Target* target = + llvm::TargetRegistry::lookupTarget("", triple, error); + if (!target) { + return nullptr; + } + + llvm::TargetOptions target_options = + llvm::codegen::InitTargetOptionsFromCodeGenFlags(llvm::Triple()); + return std::unique_ptr(target->createTargetMachine( + triple.str(), "generic", "", target_options, llvm::Reloc::Model::PIC_)); +} + +// Compiles the given MLIR module via LLVM into an executable binary format. +xla::StatusOr EmitToBinary(mlir::ModuleOp module) { + // Translate the module. + llvm::LLVMContext llvm_context; + std::unique_ptr llvm_module = + mlir::translateModuleToLLVMIR(module, llvm_context); + + // Set up the output stream. + llvm::SmallString<8> outstr; + llvm::raw_svector_ostream ostream(outstr); + ostream.SetUnbuffered(); + + llvm::legacy::PassManager codegen_passes; + codegen_passes.add(new llvm::TargetLibraryInfoWrapperPass( + llvm::Triple(llvm_module->getTargetTriple()))); + + // TODO(b/163818770): Apply optimizations before dumping .a file. + auto target_machine = GetTargetMachine(llvm_module.get()); + llvm_module->setDataLayout(target_machine->createDataLayout()); + if (target_machine->addPassesToEmitFile(codegen_passes, ostream, nullptr, + llvm::CGFT_ObjectFile, false)) { + return xla::InternalError("Failed add passes to emit file"); + } + codegen_passes.run(*llvm_module); + return ostream.str().str(); +} + +xla::Status Run(llvm::StringRef input_file, llvm::StringRef output_file, + llvm::ArrayRef architectures, + llvm::ArrayRef tile_sizes, + llvm::ArrayRef same_shape, + llvm::ArrayRef unroll_factors) { + // Read TF code. + std::string tf_code; + TF_RETURN_IF_ERROR( + ReadFileToString(Env::Default(), input_file.str(), &tf_code)); + // Compile. + mlir::MLIRContext context; + TF_ASSIGN_OR_RETURN( + mlir::OwningModuleRef module, + GenerateKernelForTfCode(context, tf_code, /*gpu_binary_only=*/false, + architectures, tile_sizes, same_shape, + unroll_factors)); + // Get binary. + TF_ASSIGN_OR_RETURN(std::string binary, EmitToBinary(*module)); + + // Write .a file. + TF_RETURN_IF_ERROR( + WriteStringToFile(Env::Default(), output_file.str(), binary)); + return xla::Status::OK(); +} + +} // namespace +} // namespace kernel_gen +} // namespace tensorflow + +int main(int argc, char** argv) { + llvm::cl::opt input_file("input", llvm::cl::desc("input file"), + llvm::cl::value_desc("filename"), + llvm::cl::init("foo.mlir")); + llvm::cl::opt output_file( + "output", llvm::cl::desc("output file"), llvm::cl::value_desc("filename"), + llvm::cl::init("foo.bin")); + llvm::cl::list architectures( + "arch", llvm::cl::desc("target architectures (e.g. sm_70 or compute_75)"), + llvm::cl::OneOrMore, llvm::cl::CommaSeparated); + llvm::cl::list tile_sizes( + "tile_sizes", llvm::cl::desc("tile sizes to use"), llvm::cl::ZeroOrMore, + llvm::cl::CommaSeparated); + llvm::cl::list unroll_factors( + "unroll_factors", + llvm::cl::desc("factors to unroll by, separated by commas"), + llvm::cl::ZeroOrMore, llvm::cl::CommaSeparated); + llvm::cl::list same_shape( + "same_shape", + llvm::cl::desc("arguments with same shape, separated by commas"), + llvm::cl::ZeroOrMore, llvm::cl::CommaSeparated); + + tensorflow::InitMlir y(&argc, &argv); + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + mlir::registerPassManagerCLOptions(); + llvm::cl::ParseCommandLineOptions(argc, argv, "TF op GPU kernel generator\n"); + + auto status = + tensorflow::kernel_gen::Run(input_file, output_file, architectures, + tile_sizes, same_shape, unroll_factors); + if (!status.ok()) { + LOG(ERROR) << status; + return 1; + } + return 0; +} diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD index b0f22b40f5b..b2595d2ad3a 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD @@ -1,4 +1,14 @@ +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//third_party/mlir:tblgen.bzl", "gentbl") +load( + "//tensorflow/core/platform/default:cuda_build_defs.bzl", + "if_cuda_is_configured", +) +load( + "@local_config_rocm//rocm:build_defs.bzl", + "if_rocm_is_configured", +) +load("//tensorflow:tensorflow.bzl", "get_compatible_with_cloud") package( default_visibility = ["//tensorflow/compiler/mlir/tools/kernel_gen:friends"], @@ -28,6 +38,7 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms", @@ -51,6 +62,7 @@ cc_library( gentbl( name = "kernel_gen_passes_inc_gen", + compatible_with = get_compatible_with_cloud(), tbl_outs = [("-gen-pass-decls -name KernelGen", "kernel_gen_passes.h.inc")], tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes.td", @@ -62,32 +74,59 @@ cc_library( srcs = [ "bufferize_pass.cc", "embed_tf_framework_pass.cc", + "gpu_kernel_to_blob_pass.cc", + "materialize_broadcasts_pass.cc", + "parallel_loops_to_sequential.cc", + "propagate_tf_abi_knowledge_pass.cc", "shape_to_descriptors_pass.cc", - "tf_framework_legalize_to_llvm_pass.cc", + "tf_kernel_to_llvm_pass.cc", + "unfuse_batch_norm_pass.cc", ], hdrs = ["passes.h"], + copts = if_cuda_is_configured(["-DGOOGLE_CUDA=1"]) + if_rocm_is_configured(["-DTENSORFLOW_USE_ROCM=1"]), deps = [ + "//tensorflow/compiler/mlir/hlo:materialize_broadcasts", # buildcleaner: keep + "//tensorflow/compiler/mlir/hlo:unfuse_batch_norm", # buildcleaner: keep + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/compiler/xla:debug_options_flags", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla/service/gpu:target_constants", + "//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend", + "//tensorflow/core/platform:cuda_libdevice_path", + "//tensorflow/core:lib", ":bufferize", ":embed_tf_framework", ":kernel_gen_passes_inc_gen", ":tf_framework_legalize_to_llvm", - "//tensorflow/compiler/mlir/hlo", - "//tensorflow/compiler/mlir/hlo:hlo_legalize_to_lhlo", - "//tensorflow/compiler/mlir/hlo:lhlo", - "//tensorflow/compiler/mlir/hlo:lhlo_legalize_to_llvm", - "//tensorflow/compiler/mlir/tools/kernel_gen/ir:tf_framework_ops", "@llvm-project//llvm:Support", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:GPUToGPURuntimeTransforms", "@llvm-project//mlir:IR", "@llvm-project//mlir:LLVMDialect", "@llvm-project//mlir:LLVMTransforms", "@llvm-project//mlir:Pass", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:Shape", - "@llvm-project//mlir:ShapeToSCF", + "@llvm-project//mlir:TargetNVVMIR", + "@llvm-project//mlir:TargetROCDLIR", "@llvm-project//mlir:ShapeToStandard", + "@llvm-project//mlir:SCFToStandard", "@llvm-project//mlir:ShapeTransforms", "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms", - ], + "@llvm-project//llvm:TransformUtils", + "//tensorflow/compiler/mlir/hlo", + "//tensorflow/compiler/mlir/hlo:hlo_legalize_to_lhlo", + "//tensorflow/compiler/mlir/hlo:lhlo", + "//tensorflow/compiler/xla/service/gpu:stream_executor_util", + "//tensorflow/compiler/mlir/hlo:lhlo_legalize_to_llvm", + "//tensorflow/compiler/mlir/tools/kernel_gen/ir:tf_framework_ops", + ] + if_cuda_is_configured([ + "//tensorflow/stream_executor/gpu:asm_compiler", + ]) + if_rocm_is_configured([ + "//tensorflow/core/platform:rocm_rocdl_path", + ]), ) diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize.cc index 3d5c820e6dd..f2b5e14bd30 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize.cc @@ -15,17 +15,22 @@ limitations under the License. // This file implements logic for translating mixed IR to buffer form. +#include "mlir/Transforms/Bufferize.h" // from @llvm-project + #include #include #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project #include "mlir/IR/Function.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Transforms/BufferPlacement.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project namespace mlir { @@ -35,10 +40,10 @@ namespace transforms { namespace { class TensorFromElementsOpConverter - : public BufferAssignmentOpConversionPattern { + : public BufferizeOpConversionPattern { public: - using BufferAssignmentOpConversionPattern< - TensorFromElementsOp>::BufferAssignmentOpConversionPattern; + using BufferizeOpConversionPattern< + TensorFromElementsOp>::BufferizeOpConversionPattern; LogicalResult matchAndRewrite( TensorFromElementsOp op, ArrayRef operands, @@ -58,11 +63,63 @@ class TensorFromElementsOpConverter } }; -class TensorLoadOpConversion - : public BufferAssignmentOpConversionPattern { +class DynamicTensorFromElementsOpConverter + : public BufferizeOpConversionPattern { public: - using BufferAssignmentOpConversionPattern< - TensorLoadOp>::BufferAssignmentOpConversionPattern; + using BufferizeOpConversionPattern< + DynamicTensorFromElementsOp>::BufferizeOpConversionPattern; + + LogicalResult matchAndRewrite( + DynamicTensorFromElementsOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + // Allocate memory on stack. + Location loc = op.getLoc(); + DynamicTensorFromElementsOp::Adaptor transformed(operands); + RankedTensorType tensor_ty = op.getType().cast(); + MemRefType memref_type = + MemRefType::get(tensor_ty.getShape(), tensor_ty.getElementType()); + Value result = rewriter.create(loc, memref_type, + transformed.dynamicExtents()); + + // Collect loop bounds. + int64_t rank = tensor_ty.getRank(); + Value zero = rewriter.create(loc, 0); + Value one = rewriter.create(loc, 1); + SmallVector lower_bounds(rank, zero); + SmallVector steps(rank, one); + SmallVector upper_bounds; + int next_dynamic_index = 0; + for (int i = 0; i < rank; i++) { + Value ub = tensor_ty.isDynamicDim(i) + ? transformed.dynamicExtents()[next_dynamic_index++] + : rewriter.create( + loc, memref_type.getDimSize(i)); + upper_bounds.push_back(ub); + } + + // Generate tensor elements. + rewriter.create( + loc, lower_bounds, upper_bounds, steps, + [&](OpBuilder &b, Location loc, ValueRange ivs) { + BlockAndValueMapping mapping; + mapping.map(op.body().getArguments(), ivs); + for (auto &nested_op : op.getBody()->without_terminator()) + b.clone(nested_op, mapping); + auto yield_op = llvm::cast(op.getBody()->getTerminator()); + b.create(loc, mapping.lookup(yield_op.value()), result, ivs); + b.create(loc); + }); + + rewriter.replaceOp(op, {result}); + return success(); + } +}; + +class TensorLoadOpConversion + : public BufferizeOpConversionPattern { + public: + using BufferizeOpConversionPattern< + TensorLoadOp>::BufferizeOpConversionPattern; LogicalResult matchAndRewrite( TensorLoadOp op, ArrayRef operands, @@ -74,17 +131,17 @@ class TensorLoadOpConversion }; class ExtractElementOpConversion - : public BufferAssignmentOpConversionPattern { + : public BufferizeOpConversionPattern { public: - using BufferAssignmentOpConversionPattern< - ExtractElementOp>::BufferAssignmentOpConversionPattern; + using BufferizeOpConversionPattern< + ExtractElementOp>::BufferizeOpConversionPattern; LogicalResult matchAndRewrite( ExtractElementOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { ExtractElementOpAdaptor adaptor(operands); - if (!adaptor.aggregate().getType().isa()) { + if (!adaptor.aggregate().getType().isa()) { return failure(); } @@ -94,15 +151,49 @@ class ExtractElementOpConversion } }; +template +class SimpleOpResultConversion : public BufferizeOpConversionPattern { + public: + using BufferizeOpConversionPattern::BufferizeOpConversionPattern; + using BufferizeOpConversionPattern::converter; + + LogicalResult matchAndRewrite( + OpTy op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + rewriter.replaceOpWithNewOp(op, converter.convertType(op.getType()), + operands); + return success(); + } +}; + +class TensorCastOpConverter + : public BufferizeOpConversionPattern { + public: + using BufferizeOpConversionPattern< + TensorCastOp>::BufferizeOpConversionPattern; + + LogicalResult matchAndRewrite( + TensorCastOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + Value arg = operands.front(); + if (!arg.getType().isa()) return failure(); + + auto result_ty = converter.convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, arg, result_ty); + + return success(); + } +}; + } // namespace void populateStandardBufferizePattern(MLIRContext *context, - BufferAssignmentPlacer *bufferAssignment, - TypeConverter *converter, + BufferizeTypeConverter *converter, OwningRewritePatternList *patterns) { patterns->insert(context, bufferAssignment, - converter); + DynamicTensorFromElementsOpConverter, + SimpleOpResultConversion, TensorLoadOpConversion, + TensorCastOpConverter>(context, *converter); } } // namespace transforms diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc index ef07c801bc4..9a531515012 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc @@ -19,6 +19,8 @@ limitations under the License. #include #include "mlir/Dialect/SCF/SCF.h" // from @llvm-project +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project +#include "mlir/Dialect/Shape/Transforms/Passes.h" // from @llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/Function.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project @@ -26,8 +28,7 @@ limitations under the License. #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Transforms/BufferPlacement.h" // from @llvm-project +#include "mlir/Transforms/Bufferize.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" @@ -67,25 +68,25 @@ class UnrankedTensorStoreTestOnlyPattern }; struct BufferizePass : public BufferizePassBase { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + public: void runOnOperation() override { - OwningRewritePatternList patterns; auto& context = getContext(); ConversionTarget target(context); - target.addLegalDialect(); - target.addLegalDialect(); - target.addLegalDialect(); - target.addLegalOp(); - target.addLegalOp(); + target.addLegalDialect(); + target.addLegalOp(); target.addIllegalDialect(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); + target.addIllegalOp(); target.addDynamicallyLegalOp([&](TensorStoreOp op) { return !op.tensor().getType().isa(); }); - BufferAssignmentTypeConverter converter; + BufferizeTypeConverter converter; auto typesAreLegal = [&converter](Operation* op) { return converter.isLegal(op->getOperandTypes()) && converter.isLegal(op->getResultTypes()); @@ -96,26 +97,20 @@ struct BufferizePass : public BufferizePassBase { return converter.isLegal(inputs) && converter.isLegal(results) && converter.isLegal(&op.getBody()); }); - target.addDynamicallyLegalOp(typesAreLegal); - target.addDynamicallyLegalOp(typesAreLegal); + target.addDynamicallyLegalOp( + typesAreLegal); + + OwningRewritePatternList patterns; + mhlo::populateHLOToLHLOConversionPattern(&context, &converter, &patterns); + populateWithBufferizeOpConversionPatterns( + &context, converter, patterns); + populateStandardBufferizePattern(&context, &converter, &patterns); + populateShapeTypeConversionPatterns(&context, converter, patterns); + patterns.insert(&context); auto module = getOperation(); - WalkResult result = module.walk([&](FuncOp func) -> WalkResult { - BufferAssignmentPlacer bufferAssignment(func); - OwningRewritePatternList patterns; - mhlo::populateHLOToLHLOConversionPattern( - func.getContext(), &bufferAssignment, &converter, &patterns); - populateWithBufferAssignmentOpConversionPatterns< - ReturnOp, ReturnOp, lmhlo::CopyOp, - /*allowMemrefFunctionResults=*/true>(&context, &bufferAssignment, - &converter, &patterns); - populateStandardBufferizePattern(func.getContext(), &bufferAssignment, - &converter, &patterns); - patterns.insert(func.getContext()); - - return applyPartialConversion(func, target, patterns); - }); - if (result.wasInterrupted()) { + if (failed(applyPartialConversion(module, target, patterns))) { signalPassFailure(); } } diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_tf_framework_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_tf_framework_pass.cc index a0cfcae65d1..6aea4d9c619 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_tf_framework_pass.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_tf_framework_pass.cc @@ -36,6 +36,10 @@ static constexpr StringRef kTFEntry = "tf_entry"; // * std.dealloc becomes tf_framework.dealloc_raw. class EmbedTFFrameworkPass : public EmbedTFFrameworkPassBase { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + public: void runOnOperation() override { ModuleOp m = getOperation(); @@ -68,7 +72,7 @@ class EmbedTFFrameworkPass } // namespace -std::unique_ptr > createEmbedTFFrameworkPass() { +std::unique_ptr > CreateEmbedTFFrameworkPass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc new file mode 100644 index 00000000000..46bf13b7d20 --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc @@ -0,0 +1,227 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "llvm/Transforms/Utils/Cloning.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/Target/NVVMIR.h" // from @llvm-project +#include "mlir/Target/ROCDLIR.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" +#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" +#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" +#include "tensorflow/compiler/xla/service/gpu/target_constants.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/platform/cuda_libdevice_path.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/path.h" + +#if GOOGLE_CUDA +#include "tensorflow/stream_executor/gpu/asm_compiler.h" +#elif TENSORFLOW_USE_ROCM +#include "tensorflow/core/platform/rocm_rocdl_path.h" +#endif + +namespace mlir { +namespace kernel_gen { +namespace transforms { +namespace { + +#define GEN_PASS_CLASSES +#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc" + +using xla::InternalError; + +class GpuKernelToBlobPass + : public GpuKernelToBlobPassBase { + public: + GpuKernelToBlobPass(mlir::StringRef blob_annotation, + llvm::ArrayRef architectures, + bool generate_fatbin) { + blob_annotation_ = blob_annotation.str(); + architectures_ = architectures; + generate_fatbin_ = generate_fatbin; + } + + void runOnOperation() override { + mlir::gpu::GPUModuleOp gpu_module = getOperation(); + auto blob_or = GetGpuBinaryBlob(gpu_module); + if (blob_or.ok()) { + const auto& blob = blob_or.ValueOrDie(); + std::string blob_string(blob.begin(), blob.end()); + gpu_module.setAttr(blob_annotation_, + mlir::StringAttr::get(blob_string, &getContext())); + return; + } + return signalPassFailure(); + } + + xla::StatusOr> GetGpuBinaryBlob( + mlir::gpu::GPUModuleOp gpu_module) { + if (architectures_.empty()) { + return InternalError("Expected at least one GPU architecture."); + } + if (!generate_fatbin_ && architectures_.size() > 1) { + return InternalError( + "Can only generate machine code for more than one architecture as a " + "fatbin."); + } + + llvm::LLVMContext llvmContext; + +#if TENSORFLOW_USE_ROCM + auto llvmModule = mlir::translateModuleToROCDLIR(gpu_module, llvmContext); + if (!llvmModule) { + return InternalError("Could not translate MLIR module to ROCDL IR"); + } + + llvmModule->setModuleIdentifier("acme"); + + xla::HloModuleConfig config; + config.set_debug_options(xla::GetDebugOptionsFromFlags()); + + // TODO(b/169066682): Support fatbin on ROCm. + if (generate_fatbin_) { + return InternalError("Fatbins are not yet supported for ROCm."); + } + + // Parse ROCm architecture. + absl::string_view consumable_arch(architectures_.front()); + if (!absl::ConsumePrefix(&consumable_arch, "gfx")) { + return InternalError( + "Could not parse ROCm architecture prefix (expected gfx)"); + } + uint32_t arch; + if (!absl::SimpleAtoi(consumable_arch, &arch)) { + return InternalError("Could not parse ROCm architecture number"); + } + + std::string libdevice_dir = tensorflow::RocdlRoot(); + return xla::gpu::amdgpu::CompileToHsaco(llvmModule.get(), arch, config, + libdevice_dir); + +#elif GOOGLE_CUDA + auto llvmModule = mlir::translateModuleToNVVMIR(gpu_module, llvmContext); + if (!llvmModule) { + return InternalError("Could not translate MLIR module to NVVM"); + } + + llvmModule->setModuleIdentifier("acme"); + llvmModule->setDataLayout(xla::gpu::nvptx::kDataLayout); + + xla::HloModuleConfig config; + config.set_debug_options(xla::GetDebugOptionsFromFlags()); + + auto enable_fusion = [](llvm::TargetMachine* target) { + target->Options.AllowFPOpFusion = llvm::FPOpFusion::FPOpFusionMode::Fast; + }; + + // Compile and collect requested cubin and PTX images. + std::vector images; + TF_ASSIGN_OR_RETURN(std::string libdevice_dir, GetLibdeviceDir(config)); + auto gpu_asm_opts = xla::gpu::PtxOptsFromConfig(config); + for (const std::string& arch_str : architectures_) { + // Parse CUDA architecture. + absl::string_view consumable_arch(arch_str); + bool is_compute_profile; + if (absl::ConsumePrefix(&consumable_arch, "compute_")) { + is_compute_profile = true; + } else if (absl::ConsumePrefix(&consumable_arch, "sm_")) { + is_compute_profile = false; + } else { + return InternalError( + "Could not parse cuda architecture prefix (expected sm_ or " + "compute_)"); + } + uint32_t arch; + if (!absl::SimpleAtoi(consumable_arch, &arch)) { + return InternalError("Could not parse cuda architecture number"); + } + + uint32_t cc_major = arch / 10; + uint32_t cc_minor = arch % 10; + // Module may be changed by CompileToPtx. + auto llvm_module_copy = llvm::CloneModule(*llvmModule); + TF_ASSIGN_OR_RETURN( + std::string ptx, + xla::gpu::nvptx::CompileToPtx(llvm_module_copy.get(), + std::make_pair(cc_major, cc_minor), + config, libdevice_dir, enable_fusion)); + VLOG(1) << ptx; + TF_ASSIGN_OR_RETURN(std::vector gpu_asm, + tensorflow::se::CompileGpuAsm( + cc_major, cc_minor, ptx.c_str(), gpu_asm_opts)); + + if (!generate_fatbin_) { + // Skip fatbin generation and return the first and only GPU machine + // code. This is currently only used for `tf_to_gpu_binary` and will + // eventually disappear. + return gpu_asm; + } + + // Collect cubin (and ptx image if requested). + images.push_back({absl::StrCat("sm_", arch), std::move(gpu_asm)}); + if (is_compute_profile) { + std::vector ptx_bytes; + std::copy(ptx.begin(), ptx.end(), std::back_inserter(ptx_bytes)); + images.push_back( + {absl::StrCat("compute_", arch), std::move(ptx_bytes)}); + } + } + + // TODO(b/169870789): Revisit the use of fatbins. + // Bundle cubin and PTX images into a single fatbin. + return tensorflow::se::BundleGpuAsm(images, + gpu_asm_opts.preferred_cuda_dir); +#endif + + return InternalError( + "Neither TENSORFLOW_USE_ROCM nor GOOGLE_CUDA are defined." + " Did you specify either --config=rocm or --config=cuda ?"); + } + + private: + xla::StatusOr GetLibdeviceDir( + const xla::HloModuleConfig& hlo_module_config) { + for (const std::string& cuda_root : tensorflow::CandidateCudaRoots( + hlo_module_config.debug_options().xla_gpu_cuda_data_dir())) { + std::string libdevice_dir = + tensorflow::io::JoinPath(cuda_root, "nvvm", "libdevice"); + VLOG(2) << "Looking for libdevice at " << libdevice_dir; + if (tensorflow::Env::Default()->IsDirectory(libdevice_dir).ok()) { + VLOG(2) << "Found libdevice dir " << libdevice_dir; + return libdevice_dir; + } + } + return InternalError( + "Can't find libdevice directory ${CUDA_DIR}/nvvm/libdevice"); + } +}; + +} // namespace + +std::unique_ptr> CreateGpuKernelToBlobPass( + mlir::StringRef blob_annotation, ArrayRef architectures, + bool generate_fatbin) { + return std::make_unique(blob_annotation, architectures, + generate_fatbin); +} + +} // namespace transforms +} // namespace kernel_gen +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/materialize_broadcasts_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/materialize_broadcasts_pass.cc new file mode 100644 index 00000000000..dd3f32e2b3c --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/materialize_broadcasts_pass.cc @@ -0,0 +1,61 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h" + +namespace mlir { +namespace kernel_gen { +namespace transforms { +namespace { + +#define GEN_PASS_CLASSES +#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc" + +struct MaterializeBroadcastsPass + : public MaterializeBroadcastsPassBase { + void runOnFunction() override { + mlir::ConversionTarget conversionTarget(getContext()); + mlir::OwningRewritePatternList conversionPatterns; + + // Consider the mhlo dialect legal for tests. + conversionTarget.addLegalDialect(); + // The conversion uses helpers from the Standard dialect. + conversionTarget.addLegalDialect(); + + mlir::mhlo::SetupMaterializeBroadcastsLegality(&getContext(), + &conversionTarget); + mlir::mhlo::PopulateMaterializeBroadcastsPatterns(&getContext(), + &conversionPatterns); + + if (failed(applyPartialConversion(getFunction(), conversionTarget, + conversionPatterns))) { + return signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr CreateMaterializeBroadcastsPass() { + return std::make_unique(); +} + +} // namespace transforms +} // namespace kernel_gen +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/parallel_loops_to_sequential.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/parallel_loops_to_sequential.cc new file mode 100644 index 00000000000..7981dbe5534 --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/parallel_loops_to_sequential.cc @@ -0,0 +1,52 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/Conversion/SCFToStandard/SCFToStandard.h" // from @llvm-project +#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h" + +namespace mlir { +namespace kernel_gen { +namespace transforms { +namespace { + +#define GEN_PASS_CLASSES +#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc" + +struct ParallelLoopsToSequentialPass + : public ParallelLoopsToSequentialBase { + void runOnFunction() override { + mlir::OwningRewritePatternList patterns; + mlir::populateLoopToStdConversionPatterns(patterns, &getContext()); + + mlir::ConversionTarget target(getContext()); + target.addIllegalOp(); + target.addLegalOp(); + target.markUnknownOpDynamicallyLegal([](mlir::Operation*) { return true; }); + if (failed(applyPartialConversion(getOperation(), target, patterns))) + signalPassFailure(); + } +}; + +} // namespace + +std::unique_ptr CreateParallelLoopsToSequential() { + return std::make_unique(); +} + +} // namespace transforms +} // namespace kernel_gen +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h index e65d8402fb2..5fd4091b2c0 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h @@ -18,6 +18,8 @@ limitations under the License. #include +#include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project #include "mlir/IR/Module.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project @@ -25,20 +27,19 @@ namespace mlir { namespace kernel_gen { namespace tf_framework { -// Test pass for applying TF Framework -> LLVM patterns. -std::unique_ptr > -createTestTFFrameworkLegalizeToLLVMPass(); - // Pass to replace some of the Standard ops with TF Framework ops. // * adds tf_framework::OpKernelContextType argument to the function // * std.alloc becomes tf_framework.alloc_raw // * std.dealloc becomes tf_framework.dealloc_raw -std::unique_ptr > createEmbedTFFrameworkPass(); +std::unique_ptr > CreateEmbedTFFrameworkPass(); } // namespace tf_framework namespace transforms { +// Pass for applying LLVM legalization patterns. +std::unique_ptr > CreateTFKernelToLLVMPass(); + // Pass to tranform shape computations in shape dialect to standard and scf // using memref descriptors. std::unique_ptr > CreateShapeToDescriptorsPass(); @@ -47,6 +48,25 @@ std::unique_ptr > CreateShapeToDescriptorsPass(); // buffers. std::unique_ptr > CreateBufferizePass(); +// Pass to materialize broadcasts. +std::unique_ptr CreateMaterializeBroadcastsPass(); + +// Pass to convert scf::ParallelOp to scf::ForOp. +std::unique_ptr CreateParallelLoopsToSequential(); + +// Pass to propagate TF ABI knowledge, e.g. offsets, alignment. +std::unique_ptr> +CreatePropagateTensorFlowABIKnowledgePass( + llvm::ArrayRef same_shape = {}); + +// Pass to annotate GPU Module with its PTX. +std::unique_ptr> CreateGpuKernelToBlobPass( + mlir::StringRef blob_annotation = "", + ArrayRef architectures = {}, bool generate_fatbin = true); + +// Pass to unfuse batch norm. +std::unique_ptr CreateUnfuseBatchNormPass(); + } // namespace transforms #define GEN_PASS_REGISTRATION diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td index 6a0e328f212..a8b2506bd1c 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td @@ -13,30 +13,67 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TF_FRAMEWORK_PASSES -#define TF_FRAMEWORK_PASSES +#ifndef TF_KERNEL_GEN_PASSES +#define TF_KERNEL_GEN_PASSES include "mlir/Pass/PassBase.td" -def TestTFFrameworkLegalizeToLLVMPass - : Pass<"test-tf-framework-legalize-to-llvm", "ModuleOp"> { - let summary = "Test pass for applying TF Framework -> LLVM patterns."; - let constructor = "tf_framework::createTestTFFrameworkLegalizeToLLVMPass()"; +def TFKernelToLLVMPass : Pass<"tf-kernel-to-llvm", "ModuleOp"> { + let summary = "Pass for applying LLVM legalization patterns."; + let constructor = "transforms::CreateTFKernelToLLVMPass()"; } def EmbedTFFrameworkPass : Pass<"embed-tf-framework", "ModuleOp"> { let summary = "Pass to embed TF Framework for allocation and error reporting"; - let constructor = "tf_framework::createEmbedTFFrameworkPass()"; + let constructor = "tf_framework::CreateEmbedTFFrameworkPass()"; } -def ShapeToDescriptorsPass : Pass<"test-shape-to-descriptors", "ModuleOp"> { +def ShapeToDescriptorsPass : Pass<"shape-to-descriptors", "ModuleOp"> { let summary = "Pass to transform shape computations to descriptors"; let constructor = "transforms::CreateShapeToDescriptorsPass()"; } -def BufferizePass : Pass<"test-bufferize", "ModuleOp"> { +def BufferizePass : Pass<"bufferize", "ModuleOp"> { let summary = "Pass to transform operations on values to buffer based ones"; let constructor = "transforms::CreateBufferizePass()"; } -#endif // TF_FRAMEWORK_PASSES +def MaterializeBroadcastsPass : FunctionPass<"materialize-broadcast"> { + let summary = "Pass to materialize broadcasts"; + let constructor = "transforms::CreateMaterializeBroadcastsPass()"; +} + +def UnfuseBatchNormPass : FunctionPass<"unfuse-batch-norm"> { + let summary = "Pass to unfuse batch norm"; + let constructor = "transforms::CreateUnfuseBatchNormPass()"; +} + +def GpuKernelToBlobPass : Pass<"gpu-kernel-to-blob", "gpu::GPUModuleOp"> { + let summary = "Pass to annotate GPU Module with its PTX"; + let options = [ + Option<"blob_annotation_", "blob-annotation", "std::string", + /*default=*/"", "Blob attribute name">, + ListOption<"architectures_", "arch", "std::string", "GPU architectures">, + Option<"generate_fatbin_", "generate-fatbin", "bool", /*default=*/"true", + "Bundle machine code for the different architectures in one " + "fatbin.">, + ]; + let constructor = "transforms::CreateGpuKernelToBlobPass()"; +} + +def ParallelLoopsToSequential : FunctionPass<"parallel-loops-to-sequential"> { + let summary = "Pass to convert scf::ParallelOp to scf::ForOp"; + let constructor = "transforms::CreateParallelLoopsToSequential()"; +} + +def PropagateTensorFlowABIKnowledgePass + : Pass<"propagate-tf-abi-knowledge", "LLVM::LLVMFuncOp"> { + let summary = "Pass to propagate TF ABI knowledge, e.g. offsets, alignment"; + let options = [ + ListOption<"same_shape_", "same-shape", "uint32_t", + "List of same shape args">, + ]; + let constructor = "transforms::CreatePropagateTensorFlowABIKnowledgePass()"; +} + +#endif // TF_KERNEL_GEN_PASSES diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/propagate_tf_abi_knowledge_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/propagate_tf_abi_knowledge_pass.cc new file mode 100644 index 00000000000..3b568f5f25f --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/propagate_tf_abi_knowledge_pass.cc @@ -0,0 +1,137 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h" + +namespace mlir { +namespace kernel_gen { +namespace transforms { +namespace { + +#define GEN_PASS_CLASSES +#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc" + +struct PropagateTensorFlowABIKnowledgePass + : public PropagateTensorFlowABIKnowledgePassBase< + PropagateTensorFlowABIKnowledgePass> { + explicit PropagateTensorFlowABIKnowledgePass( + llvm::ArrayRef same_shape) { + same_shape_ = same_shape; + } + + void runOnOperation() override { + // We know due to tensorflow ABI that the offset is always 0 and that the + // innermost stride is always 1. To make this visible to the compiler, + // we insert constants into the code and replace usages accordingly. + // We do not change the signature so that we keep a somewhat stable ABI + // that is easy to undertand by tools. + // We also know that tensorflow aligns all allocated pointers by 16, so + // we pass this on. Furthermore, we know that arguments never alias. More + // precicely, they may only alias (due to reuse) if the kernel does not + // read from a position it previously has written to. We express this with + // the noalias attribute. + mlir::LLVM::LLVMFuncOp func = getOperation(); + + // This only works if the function is local and we can rewrite it. + if (func.isExternal()) return; + + auto function_list = + func.getParentOfType().getOps(); + if (function_list.empty()) { + func.emitError() << "No possible kernel function found"; + return signalPassFailure(); + } + auto func_iterator = function_list.begin(); + if (std::next(func_iterator) != function_list.end()) { + func.emitError() << "More than one possible kernel function detected"; + return signalPassFailure(); + } + // Note that this dereference is necessary to prevent a + // stack-use-after-return error. + auto func_type = (*func_iterator).getType(); + + mlir::OpBuilder b(func.getBody()); + // Steal the LLVM representation of the index type from the third argument. + auto index_type = func.getArgument(3).getType(); + mlir::Value one = b.create( + func.getLoc(), index_type, b.getIntegerAttr(b.getIndexType(), 1)); + mlir::Value zero = b.create( + func.getLoc(), index_type, b.getIntegerAttr(b.getIndexType(), 0)); + uint32_t arg_pos = 0; + std::vector positions; + // Collect the agument and return types of the surrounding function. + auto arg_types = llvm::to_vector<4>(llvm::concat( + func_type.getInputs(), func_type.getResults())); + for (mlir::Type arg_type : arg_types) { + if (!arg_type.isa()) { + func.emitError() << "argument of surrounding func is not ranked memref"; + return signalPassFailure(); + } + positions.push_back(arg_pos); + // Set alignment and aliasing on the pointers. + func.setArgAttr(arg_pos + 1, "llvm.noalias", b.getBoolAttr(true)); + func.setArgAttr(arg_pos + 1, "llvm.align", b.getIndexAttr(16)); + // Replace the offset with zero. Offset is argument number 3. + func.getArgument(arg_pos + 2).replaceAllUsesWith(zero); + // Forward over base_ptr, aligned_ptr, offset, size and stride arguments. + arg_pos += 3 + arg_type.cast().getRank() * 2; + // Replace the last stride with constant 1. + func.getArgument(arg_pos - 1).replaceAllUsesWith(one); + } + + // If we have knowledge that some arguments have the same shape, we + // can use that here. Simply replace usages of the shape parameters within + // the function body to a single shape parameter. + if (same_shape_.empty()) { + return; + } + auto first = same_shape_.front(); + auto first_offset = positions.at(first); + auto first_type = arg_types[first].cast(); + uint32_t rank = first_type.getRank(); + for (int i = 1, e = same_shape_.size(); i < e; ++i) { + uint32_t same = same_shape_[i]; + uint32_t same_offset = positions.at(same); + auto same_type = arg_types[same].cast(); + if (same_type.getRank() != rank) { + func.emitOpError() << "same shape constraints on arguments with " + "non-matching shapes: #" + << first << " and #" << same; + return signalPassFailure(); + } + + for (uint32_t i = 0; i < 2 * rank; ++i) { + // Replace uses for second arg data with first arg. + auto same_arg = func.getArgument(same_offset + 3 + i); + auto first_arg = func.getArgument(first_offset + 3 + i); + same_arg.replaceAllUsesWith(first_arg); + } + } + } +}; + +} // namespace + +std::unique_ptr> +CreatePropagateTensorFlowABIKnowledgePass(llvm::ArrayRef same_shape) { + return std::make_unique(same_shape); +} + +} // namespace transforms +} // namespace kernel_gen +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h index 4efc1e95bc8..f73a14b9be0 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h @@ -21,6 +21,7 @@ limitations under the License. namespace mlir { class BufferAssignmentPlacer; +class BufferizeTypeConverter; class LLVMTypeConverter; class MLIRContext; class OwningRewritePatternList; @@ -44,8 +45,7 @@ namespace transforms { /// Collects a set of patterns that bufferize operations from the standard /// dialect. void populateStandardBufferizePattern(MLIRContext *context, - BufferAssignmentPlacer *bufferAssignment, - TypeConverter *converter, + BufferizeTypeConverter *converter, OwningRewritePatternList *patterns); } // namespace transforms } // namespace kernel_gen diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/shape_to_descriptors_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/shape_to_descriptors_pass.cc index 28d3647bb63..f5d01808c1b 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/shape_to_descriptors_pass.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/shape_to_descriptors_pass.cc @@ -16,7 +16,6 @@ limitations under the License. // This file combines patterns for lowering shape dialect to standard ops, // structured control flow and descriptors. -#include "mlir/Conversion/ShapeToSCF/ShapeToSCF.h" // from @llvm-project #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" // from @llvm-project #include "mlir/Dialect/SCF/SCF.h" // from @llvm-project #include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project @@ -24,7 +23,6 @@ limitations under the License. #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h" @@ -38,6 +36,10 @@ namespace { struct ShapeToDescriptorsPass : public ShapeToDescriptorsPassBase { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + public: void runOnOperation() override { MLIRContext &ctx = getContext(); @@ -47,12 +49,15 @@ struct ShapeToDescriptorsPass target.addIllegalDialect(); target.addLegalDialect(); target.addLegalDialect(); + // Don't mark the primary Cstr/Assuming ops as illegal, so they can be + // lowered at a later time to assertions. + target.addLegalOp(); // Setup conversion patterns. OwningRewritePatternList patterns; populateShapeRewritePatterns(&ctx, patterns); populateShapeToStandardConversionPatterns(patterns, &ctx); - populateShapeToSCFConversionPatterns(patterns, &ctx); // Apply conversion. auto module = getOperation(); diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc index 3ce111ff3ff..431919c2de7 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc @@ -171,7 +171,8 @@ class DeallocRawOpConverter : public ConvertToLLVMCallOpPattern { protected: StringRef GetFuncName() const override { return kCInterfaceDealloc; } LLVMType GetFuncType() const override { - return LLVM::LLVMType::getFunctionTy(getVoidType(), getVoidPtrType(), + return LLVM::LLVMType::getFunctionTy(getVoidType(), + {getVoidPtrType(), getVoidPtrType()}, /*isVarArg=*/false); } }; diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_kernel_to_llvm_pass.cc similarity index 70% rename from tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm_pass.cc rename to tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_kernel_to_llvm_pass.cc index 42e89433dff..b2fcc424a50 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm_pass.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_kernel_to_llvm_pass.cc @@ -13,11 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "mlir/Conversion/GPUCommon/GPUCommonPass.h" // from @llvm-project #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" // from @llvm-project #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" // from @llvm-project +#include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project #include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h" #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h" @@ -25,14 +26,17 @@ limitations under the License. namespace mlir { namespace kernel_gen { -namespace tf_framework { +namespace transforms { namespace { #define GEN_PASS_CLASSES #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc" -class TestTFFrameworkToLLVMPass - : public TestTFFrameworkLegalizeToLLVMPassBase { +class TFKernelToLLVMPass : public TFKernelToLLVMPassBase { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + public: void runOnOperation() override { ModuleOp m = getOperation(); @@ -46,16 +50,19 @@ class TestTFFrameworkToLLVMPass // Populate patterns. OwningRewritePatternList patterns; populateStdToLLVMConversionPatterns(type_converter, patterns); - PopulateTFFrameworkToLLVMConversionPatterns(&type_converter, &patterns); + tf_framework::PopulateTFFrameworkToLLVMConversionPatterns(&type_converter, + &patterns); + populateGpuToLLVMConversionPatterns(type_converter, patterns, "gpu.binary"); lmhlo::PopulateLhloToLLVMConversionPatterns(&type_converter, &patterns); // Set target. ConversionTarget target(getContext()); target.addLegalDialect(); - target.addIllegalDialect(); - target.addLegalOp(); + target + .addIllegalDialect(); + target.addIllegalOp(); - if (failed(applyFullConversion(m, target, patterns))) { + if (failed(applyPartialConversion(m, target, patterns))) { signalPassFailure(); } } @@ -63,11 +70,10 @@ class TestTFFrameworkToLLVMPass } // namespace -std::unique_ptr > -createTestTFFrameworkLegalizeToLLVMPass() { - return std::make_unique(); +std::unique_ptr > CreateTFKernelToLLVMPass() { + return std::make_unique(); } -} // namespace tf_framework +} // namespace transforms } // namespace kernel_gen } // namespace mlir diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/unfuse_batch_norm_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/unfuse_batch_norm_pass.cc new file mode 100644 index 00000000000..d2773d91b07 --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/unfuse_batch_norm_pass.cc @@ -0,0 +1,44 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h" + +namespace mlir { +namespace kernel_gen { +namespace transforms { +namespace { + +#define GEN_PASS_CLASSES +#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc" + +struct UnfuseBatchNormPass + : public UnfuseBatchNormPassBase { + void runOnFunction() override { + mlir::OwningRewritePatternList patterns; + mlir::mhlo::PopulateUnfuseBatchNormPatterns(&getContext(), &patterns); + mlir::applyPatternsAndFoldGreedily(getOperation(), patterns); + } +}; + +} // namespace + +std::unique_ptr CreateUnfuseBatchNormPass() { + return std::make_unique(); +} + +} // namespace transforms +} // namespace kernel_gen +} // namespace mlir diff --git a/tensorflow/compiler/mlir/utils/array_container_utils.h b/tensorflow/compiler/mlir/utils/array_container_utils.h new file mode 100644 index 00000000000..c1a898185d9 --- /dev/null +++ b/tensorflow/compiler/mlir/utils/array_container_utils.h @@ -0,0 +1,51 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_UTILS_ARRAY_CONTAINER_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_UTILS_ARRAY_CONTAINER_UTILS_H_ + +#include "absl/types/span.h" +#include "llvm/ADT/ArrayRef.h" + +namespace mlir { + +template +inline llvm::ArrayRef SpanToArrayRef(absl::Span span) { + return llvm::ArrayRef(span.data(), span.size()); +} + +template +inline llvm::ArrayRef SpanToArrayRef(absl::Span span) { + return llvm::ArrayRef(span.data(), span.size()); +} + +template +inline llvm::MutableArrayRef SpanToMutableArrayRef(absl::Span span) { + return llvm::MutableArrayRef(span.data(), span.size()); +} + +template +inline absl::Span ArrayRefToSpan(llvm::ArrayRef ref) { + return absl::Span(ref.data(), ref.size()); +} + +template +inline absl::Span MutableArrayRefToSpan(llvm::MutableArrayRef ref) { + return absl::Span(ref.data(), ref.size()); +} + +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_UTILS_ARRAY_CONTAINER_UTILS_H_ diff --git a/tensorflow/compiler/mlir/utils/name_utils.cc b/tensorflow/compiler/mlir/utils/name_utils.cc new file mode 100644 index 00000000000..bc4e80f5aa1 --- /dev/null +++ b/tensorflow/compiler/mlir/utils/name_utils.cc @@ -0,0 +1,99 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/utils/name_utils.h" + +#include + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringExtras.h" +#include "mlir/IR/Identifier.h" // from @llvm-project + +namespace mlir { + +namespace { +// Checks if a character is legal for a TensorFlow node name, with special +// handling if a character is at the beginning. +bool IsLegalChar(char c, bool first_char) { + if (isalpha(c)) return true; + if (isdigit(c)) return true; + if (c == '.') return true; + if (c == '_') return true; + + // First character of a node name can only be a letter, digit, dot or + // underscore. + if (first_char) return false; + + if (c == '/') return true; + if (c == '-') return true; + + return false; +} +} // anonymous namespace + +void LegalizeNodeName(std::string& name) { + if (name.empty()) return; + + if (!IsLegalChar(name[0], /*first_char=*/true)) name[0] = '.'; + + for (char& c : llvm::drop_begin(name, 1)) + if (!IsLegalChar(c, /*first_char=*/false)) c = '.'; +} + +std::string GetNameFromLoc(Location loc) { + llvm::SmallVector loc_names; + llvm::SmallVector locs; + locs.push_back(loc); + bool names_is_nonempty = false; + + while (!locs.empty()) { + Location curr_loc = locs.pop_back_val(); + + if (auto name_loc = curr_loc.dyn_cast()) { + // Add name in NameLoc. For NameLoc we also account for names due to ops + // in functions where the op's name is first. + auto name = name_loc.getName().strref().split('@').first; + loc_names.push_back(name); + if (!name.empty()) names_is_nonempty = true; + continue; + } else if (auto call_loc = curr_loc.dyn_cast()) { + // Add name if CallSiteLoc's callee has a NameLoc (as should be the + // case if imported with DebugInfo). + if (auto name_loc = call_loc.getCallee().dyn_cast()) { + auto name = name_loc.getName().strref().split('@').first; + loc_names.push_back(name); + if (!name.empty()) names_is_nonempty = true; + continue; + } + } else if (auto fused_loc = curr_loc.dyn_cast()) { + // Push all locations in FusedLoc in reverse order, so locations are + // visited based on order in FusedLoc. + auto reversed_fused_locs = llvm::reverse(fused_loc.getLocations()); + locs.append(reversed_fused_locs.begin(), reversed_fused_locs.end()); + continue; + } + + // Location is not a supported, so an empty StringRef is added. + loc_names.push_back(llvm::StringRef()); + } + + if (names_is_nonempty) + return llvm::join(loc_names.begin(), loc_names.end(), ";"); + + return ""; +} + +} // namespace mlir diff --git a/tensorflow/compiler/mlir/utils/name_utils.h b/tensorflow/compiler/mlir/utils/name_utils.h new file mode 100644 index 00000000000..4b08a41feec --- /dev/null +++ b/tensorflow/compiler/mlir/utils/name_utils.h @@ -0,0 +1,35 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_UTILS_NAME_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_UTILS_NAME_UTILS_H_ + +#include + +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/Location.h" // from @llvm-project + +namespace mlir { + +// Converts characters in name that are considered illegal in TensorFlow Node +// name to '.'. +void LegalizeNodeName(std::string& name); + +// Creates a TensorFlow node name from a location. +std::string GetNameFromLoc(Location loc); + +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_UTILS_NAME_UTILS_H_ diff --git a/tensorflow/compiler/mlir/utils/string_container_utils.h b/tensorflow/compiler/mlir/utils/string_container_utils.h new file mode 100644 index 00000000000..fb2fa06ca4d --- /dev/null +++ b/tensorflow/compiler/mlir/utils/string_container_utils.h @@ -0,0 +1,34 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_UTILS_STRING_CONTAINER_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_UTILS_STRING_CONTAINER_UTILS_H_ + +#include "absl/strings/string_view.h" +#include "llvm/ADT/StringRef.h" + +namespace mlir { + +inline absl::string_view StringRefToView(llvm::StringRef ref) { + return absl::string_view(ref.data(), ref.size()); +} + +inline llvm::StringRef StringViewToRef(absl::string_view view) { + return llvm::StringRef(view.data(), view.size()); +} + +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_UTILS_STRING_CONTAINER_UTILS_H_ diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index 32a2ed1c272..1919446a365 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -1,5 +1,7 @@ +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//third_party/mlir:tblgen.bzl", "gentbl") load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test", "tf_native_cc_binary") +load("//tensorflow:tensorflow.bzl", "get_compatible_with_cloud") package( default_visibility = [":friends"], @@ -15,6 +17,7 @@ package_group( "//learning/brain/experimental/mlir/...", "//learning/brain/google/xla/kernels/...", "//learning/brain/google/xla/mlir/...", + "//learning/deepmind/partir/...", "//learning/pathways/data_parallel/tf2xla/...", "//platforms/xla/...", "//tensorflow/compiler/mlir/...", @@ -27,6 +30,7 @@ package_group( gentbl( name = "xla_legalize_tf_inc_gen", + compatible_with = get_compatible_with_cloud(), tbl_outs = [ ("-gen-rewriters", "transforms/generated_legalize_tf.inc"), ], @@ -132,6 +136,7 @@ cc_library( ":hlo_module_importer", ":hlo_utils", ":mlir_hlo_to_hlo", + ":translate_cl_options", "//tensorflow/compiler/mlir/hlo", "//tensorflow/compiler/mlir/hlo:lhlo", "//tensorflow/compiler/xla:debug_options_flags", @@ -237,8 +242,8 @@ cc_library( hdrs = ["mlir_hlo_to_hlo.h"], deps = [ ":type_to_shape", + "//tensorflow/compiler/mlir:name_utils", "//tensorflow/compiler/mlir/hlo", - "//tensorflow/compiler/mlir/hlo:hlo_dialect_force_registration", "//tensorflow/compiler/mlir/tensorflow:convert_type", "//tensorflow/compiler/mlir/tensorflow:error_util", "//tensorflow/compiler/tf2xla:common", @@ -323,6 +328,16 @@ cc_library( ], ) +cc_library( + name = "translate_cl_options", + srcs = ["xla_mlir_translate_cl.cc"], + hdrs = ["xla_mlir_translate_cl.h"], + deps = [ + "@llvm-project//llvm:Support", + ], + alwayslink = 1, +) + cc_library( name = "xla_mlir_translate", srcs = ["xla_mlir_translate.cc"], @@ -331,6 +346,7 @@ cc_library( ":hlo_to_mlir_hlo", ":mhlo_to_lhlo_with_xla", ":mlir_hlo_to_hlo", + ":translate_cl_options", "//tensorflow/compiler/jit:xla_cpu_jit", "//tensorflow/compiler/jit:xla_gpu_jit", "//tensorflow/compiler/mlir/hlo", @@ -361,6 +377,7 @@ tf_native_cc_binary( gentbl( name = "operator_writer_inc", + compatible_with = get_compatible_with_cloud(), tbl_outs = [("", "operator_writers.inc")], tblgen = ":operator_writer_gen", td_file = "//tensorflow/compiler/mlir/hlo:include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td", @@ -389,14 +406,12 @@ cc_library( ":xla_legalize_tf_with_tf2xla", "//tensorflow/compiler/mlir/hlo", "//tensorflow/compiler/mlir/hlo:chlo_legalize_to_hlo", - "//tensorflow/compiler/mlir/hlo:hlo_dialect_force_registration", "//tensorflow/compiler/mlir/hlo:hlo_legalize_to_lhlo", "//tensorflow/compiler/mlir/hlo:legalize_control_flow", - "//tensorflow/compiler/mlir/hlo:legalize_tanh_to_approximation", "//tensorflow/compiler/mlir/hlo:legalize_to_linalg", "//tensorflow/compiler/mlir/hlo:legalize_to_standard", + "//tensorflow/compiler/mlir/hlo:legalize_trigonometric_to_approximation", "//tensorflow/compiler/mlir/hlo:lhlo", - "//tensorflow/compiler/mlir/hlo:lhlo_copy_removal", "//tensorflow/compiler/mlir/hlo:lhlo_fuse_linalg", "//tensorflow/compiler/mlir/hlo:lhlo_legalize_to_affine", "//tensorflow/compiler/mlir/hlo:lhlo_legalize_to_gpu", diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc index a63fc12c285..253156b44a5 100644 --- a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc +++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc @@ -46,13 +46,11 @@ limitations under the License. using llvm::APInt; using llvm::makeArrayRef; -using mlir::DenseElementsAttr; using mlir::DenseIntElementsAttr; using mlir::FuncOp; using mlir::NamedAttribute; using mlir::Operation; using mlir::RankedTensorType; -using mlir::ShapedType; using mlir::Type; using mlir::Value; @@ -142,31 +140,42 @@ tensorflow::Status HloFunctionImporter::ImportAsRegion( return ImportInstructions(computation, block); } -tensorflow::Status HloFunctionImporter::ImportInstructions( - const HloComputation& computation, mlir::Block* block) { +StatusOr HloFunctionImporter::ImportInstructionsImpl( + const xla::HloComputation& computation, + const llvm::SmallVectorImpl& arguments, mlir::OpBuilder* builder) { // Setup the input parameters. const int num_parameters = computation.num_parameters(); + + if (arguments.size() != num_parameters) + return InvalidArgument("Caller vs callee argument sizes do not match"); + for (int i = 0; i < num_parameters; i++) { auto hlo_parameter = computation.parameter_instruction(i); - instruction_value_map_[hlo_parameter] = block->getArgument(i); + instruction_value_map_[hlo_parameter] = arguments[i]; } - mlir::OpBuilder builder = mlir::OpBuilder::atBlockEnd(block); for (auto instruction : computation.MakeInstructionPostOrder()) { TF_ASSIGN_OR_RETURN(auto new_operation, - ImportInstruction(instruction, &builder)); + ImportInstruction(instruction, builder)); if (new_operation) { instruction_value_map_[instruction] = new_operation->getResult(0); } } + // Setup the return type (HLO only supports a single return value). + return GetMlirValue(computation.root_instruction()); +} + +Status HloFunctionImporter::ImportInstructions( + const HloComputation& computation, mlir::Block* block) { + llvm::SmallVector arguments(block->args_begin(), block->args_end()); + mlir::OpBuilder builder = mlir::OpBuilder::atBlockEnd(block); + TF_ASSIGN_OR_RETURN(Value result, + ImportInstructionsImpl(computation, arguments, &builder)); + // TODO(suderman): Add location tracking details. mlir::Location loc = builder.getUnknownLoc(); - // Setup the return type (HLO only supports a single return value). - TF_ASSIGN_OR_RETURN(auto result, - GetMlirValue(computation.root_instruction())); - // Create terminator op depending on the parent op of this region. if (llvm::isa(block->getParentOp())) { builder.create(loc, result); @@ -176,15 +185,29 @@ tensorflow::Status HloFunctionImporter::ImportInstructions( return tensorflow::Status::OK(); } -StatusOr HloFunctionImporter::ImportInstruction( +StatusOr HloFunctionImporter::ImportInstructions( + const xla::HloComputation& computation, + const llvm::SmallVectorImpl& arguments, mlir::OpBuilder* builder) { + mlir::Block* block = builder->getBlock(); + if (block == nullptr) + return InvalidArgument( + "ImportInstructions requires a valid block in the builder"); + + HloFunctionImporter importer( + block->getParent()->getParentOfType(), {}, builder); + return importer.ImportInstructionsImpl(computation, arguments, builder); +} + +StatusOr HloFunctionImporter::ImportInstructionImpl( HloInstruction* instruction, mlir::OpBuilder* func_builder) { TF_ASSIGN_OR_RETURN(auto operands, GetOperands(instruction)); TF_ASSIGN_OR_RETURN(auto result_type, ConvertShapeToType( instruction->shape(), *builder_)); - llvm::SmallVector attributes = {builder_->getNamedAttr( - "name", builder_->getStringAttr(instruction->name()))}; - mlir::Location loc = func_builder->getUnknownLoc(); + mlir::Location loc = + mlir::NameLoc::get(func_builder->getIdentifier(instruction->name()), + func_builder->getContext()); + llvm::SmallVector attributes; switch (instruction->opcode()) { case HloOpcode::kParameter: { return nullptr; @@ -216,8 +239,8 @@ StatusOr HloFunctionImporter::ImportInstruction( return new_operation; \ } case HloOpcode::kBroadcast: { - // Note that the HLO broadcast is more powerful than the XLA broadcast op. - // BroadcastInDim offers a superset of the HLO op's functionality. + // Note that the HLO broadcast is more powerful than the XLA broadcast + // op. BroadcastInDim offers a superset of the HLO op's functionality. attributes.push_back( builder_->getNamedAttr("broadcast_dimensions", ConvertDimensions(instruction->dimensions()))); @@ -419,13 +442,27 @@ StatusOr HloFunctionImporter::ImportInstruction( } case HloOpcode::kSort: { auto sort_instruction = Cast(instruction); + + llvm::SmallVector return_types = {result_type}; + if (mlir::TupleType tuple_ty = result_type.dyn_cast()) { + return_types = llvm::to_vector<6>(tuple_ty.getTypes()); + } + auto sort_op = func_builder->create( - loc, result_type, operands, + loc, return_types, operands, builder_->getI64IntegerAttr(sort_instruction->sort_dimension()), builder_->getBoolAttr(sort_instruction->is_stable())); TF_RETURN_IF_ERROR( ImportAsRegion(*sort_instruction->to_apply(), &sort_op.comparator())); - return sort_op.getOperation(); + + // Check if the output needs to be tupled. + if (return_types.size() == 1 && return_types.front() == result_type) { + return sort_op.getOperation(); + } + + return func_builder + ->create(loc, result_type, sort_op.getResults()) + .getOperation(); } case HloOpcode::kConditional: { llvm::SmallVector rets; @@ -446,7 +483,8 @@ StatusOr HloFunctionImporter::ImportInstruction( return op.getOperation(); } - // Otherwise, it is a indexed conditional and should be mapped to Case op. + // Otherwise, it is a indexed conditional and should be mapped to Case + // op. TF_RETURN_IF_ERROR(GetMlirTypes( {instruction->branch_computation(0)->root_instruction()}, &rets)); @@ -462,8 +500,8 @@ StatusOr HloFunctionImporter::ImportInstruction( return op.getOperation(); } case HloOpcode::kConcatenate: { - // TODO(b/132057942): Support taking an uint64_t instead of an IntegerAttr - // for concatenate dimension. + // TODO(b/132057942): Support taking an uint64_t instead of an + // IntegerAttr for concatenate dimension. return func_builder ->create( loc, result_type, operands, @@ -667,6 +705,7 @@ StatusOr HloFunctionImporter::ImportInstruction( NoAttributeCase(kAnd, AndOp); NoAttributeCase(kAtan2, Atan2Op); NoAttributeCase(kBitcastConvert, BitcastConvertOp); + NoAttributeCase(kCbrt, CbrtOp); NoAttributeCase(kConvert, ConvertOp); NoAttributeCase(kCeil, CeilOp); NoAttributeCase(kClamp, ClampOp); @@ -691,9 +730,9 @@ StatusOr HloFunctionImporter::ImportInstruction( NoAttributeCase(kReal, RealOp); NoAttributeCase(kRemainder, RemOp); NoAttributeCase(kReplicaId, ReplicaIdOp); - // The dimensions attribute is not present on the HLO Reshape instruction. - // If dimensions are non-default, the XLA builder implements it as a - // separate transpose. + // The dimensions attribute is not present on the HLO Reshape + // instruction. If dimensions are non-default, the XLA builder + // implements it as a separate transpose. NoAttributeCase(kReshape, ReshapeOp); NoAttributeCase(kRoundNearestAfz, RoundOp); NoAttributeCase(kRsqrt, RsqrtOp); @@ -708,9 +747,9 @@ StatusOr HloFunctionImporter::ImportInstruction( NoAttributeCase(kTanh, TanhOp); NoAttributeCase(kTuple, TupleOp); NoAttributeCase(kXor, XorOp); - // TODO(b/129422361) Copy needs special handling because it is not defined - // in tensorflow/compiler/xla/client/xla_builder.h. - // See operation semantics in + // TODO(b/129422361) Copy needs special handling because it is not + // defined in tensorflow/compiler/xla/client/xla_builder.h. See + // operation semantics in // g3doc/platforms/xla/g3doc/internal/hlo_semantics#copy NoAttributeCase(kCopy, CopyOp); #undef NoAttributeCase @@ -724,6 +763,20 @@ StatusOr HloFunctionImporter::ImportInstruction( &fusion.fused_computation())); return fusion.getOperation(); } + case HloOpcode::kBitcast: + return func_builder + ->create(loc, result_type, operands, + attributes) + .getOperation(); + case HloOpcode::kReducePrecision: { + auto op = func_builder->create( + loc, result_type, operands[0], attributes); + op.exponent_bitsAttr(func_builder->getIntegerAttr( + func_builder->getI32Type(), instruction->exponent_bits())); + op.mantissa_bitsAttr(func_builder->getIntegerAttr( + func_builder->getI32Type(), instruction->mantissa_bits())); + return op.getOperation(); + } case HloOpcode::kAddDependency: // Arbitrary op code that I suspect we will not implement for quite a // while and allows testing handling of unknown ops. Selected because it @@ -742,6 +795,28 @@ StatusOr HloFunctionImporter::ImportInstruction( } } +StatusOr HloFunctionImporter::ImportInstruction( + HloInstruction* instruction, mlir::OpBuilder* func_builder) { + TF_ASSIGN_OR_RETURN(mlir::Operation * op, + ImportInstructionImpl(instruction, func_builder)); + if (op == nullptr) return op; + + // See MlirToHloConversionOptions for more about layouts. + // + // Minor-to-major is a permutation of [0, rank), presenting tensor dimensions + // in physical minor-to-major order. + if (instruction->shape().IsArray() && + instruction->shape().layout() != + LayoutUtil::MakeDescendingLayout( + instruction->shape().dimensions().size())) { + llvm::SmallVector minor_to_major( + instruction->shape().layout().minor_to_major().begin(), + instruction->shape().layout().minor_to_major().end()); + op->setAttr("minor_to_major", builder_->getIndexTensorAttr(minor_to_major)); + } + return op; +} + StatusOr> HloFunctionImporter::GetOperands( HloInstruction* instruction) { llvm::SmallVector operands; diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.h b/tensorflow/compiler/mlir/xla/hlo_function_importer.h index e0cc89004cf..4a75b079d76 100644 --- a/tensorflow/compiler/mlir/xla/hlo_function_importer.h +++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.h @@ -55,6 +55,13 @@ class HloFunctionImporter { static Status ImportAsRegion(const xla::HloComputation& computation, mlir::Region* region, mlir::Builder* builder); + // Imports the given computation to the given place specified by `builder`. + // `arguments` contains values for all parameters. + static StatusOr ImportInstructions( + const xla::HloComputation& computation, + const llvm::SmallVectorImpl& arguments, + mlir::OpBuilder* builder); + private: HloFunctionImporter(mlir::ModuleOp module, std::unordered_map ImportInstructionsImpl( + const xla::HloComputation& computation, + const llvm::SmallVectorImpl& arguments, + mlir::OpBuilder* builder); // Imports an instruction. StatusOr ImportInstruction(xla::HloInstruction* instruction, mlir::OpBuilder* func_builder); + StatusOr ImportInstructionImpl( + HloInstruction* instruction, mlir::OpBuilder* func_builder); // Gets the MLIR operand values from an HLO Instruction. StatusOr> GetOperands( diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc index ac5e01a0abf..daea2d9b8f6 100644 --- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc +++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc @@ -135,12 +135,16 @@ StatusOr MlirHloBuilder::CustomCallInternal( const string& call_target_name, absl::Span operands, const Shape& shape, const string& opaque, absl::optional> operand_shapes_with_layout, - bool has_side_effect) { + bool has_side_effect, + absl::Span>> + output_operand_aliasing) { if (operand_shapes_with_layout.has_value()) return Unimplemented( "CustomCall doesn't support operands shapes with layout"); TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType( shape, builder_)); + TF_RET_CHECK(output_operand_aliasing.empty()) + << "MLIR CustomCallOp does not support output_operand_aliasing yet"; auto op = builder_.create( loc_, ty, GetValues(operands), builder_.getStringAttr(call_target_name), /*has_side_effect=*/builder_.getBoolAttr(has_side_effect), @@ -239,11 +243,22 @@ StatusOr MlirHloBuilder::SortInternal(const Shape& shape, int64 dimension, bool is_stable) { TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType( shape, builder_)); + llvm::SmallVector sort_types = {ty}; + if (auto tuple_ty = ty.dyn_cast()) { + sort_types = llvm::to_vector<6>(tuple_ty.getTypes()); + } + auto op = builder_.create( - loc_, ty, GetValues(operands), builder_.getI64IntegerAttr(dimension), - builder_.getBoolAttr(is_stable)); + loc_, sort_types, GetValues(operands), + builder_.getI64IntegerAttr(dimension), builder_.getBoolAttr(is_stable)); TF_RETURN_IF_ERROR(ImportComputation(comparator.proto(), &op.comparator())); - return MakeXlaOp(op); + + if (ty.isa()) { + auto tuple = builder_.create(loc_, op.getResults()); + return MakeXlaOp(tuple); + } + + return MakeXlaOp(op.getResult(0)); } StatusOr MlirHloBuilder::WhileInternal(const Shape& shape, diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h index 00b7aa4d0b0..59b4bc7b1e0 100644 --- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h +++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h @@ -135,7 +135,9 @@ class MlirHloBuilder : public XlaBuilder { const string& call_target_name, absl::Span operands, const Shape& shape, const string& opaque, absl::optional> operand_shapes_with_layout, - bool has_side_effect) override; + bool has_side_effect, + absl::Span>> + output_operand_aliasing) override; StatusOr ReduceInternal( const Shape& shape, absl::Span all_operands, diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc index 5398cd70777..ccfcebab60e 100644 --- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc @@ -41,6 +41,7 @@ limitations under the License. #include "mlir/IR/UseDefLists.h" // from @llvm-project #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" +#include "tensorflow/compiler/mlir/utils/name_utils.h" #include "tensorflow/compiler/mlir/xla/type_to_shape.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/xla/client/lib/matrix.h" @@ -105,6 +106,9 @@ static mlir::LogicalResult GetXlaOp( // TODO(hpucha): This should be consolidated into a general place. static int ConvertAPInt(llvm::APInt i) { return i.getSExtValue(); } +static uint32_t Convertuint32_t(uint32_t i) { return i; } +static uint64_t Convertuint64_t(uint64_t i) { return i; } + // Convert APFloat to double. static double ConvertAPFloat(llvm::APFloat value) { const auto& semantics = value.getSemantics(); @@ -430,6 +434,27 @@ static xla::FrontendAttributes CreateOpFrontendAttributesFromAttribute( return frontend_attributes; } +// Returns a OpMetadata proto based on the location of the op. If the location +// is unknown, an empty proto is returned. `op_name` are populated with the op +// location (converted). FileLineColLoc locations are populated by taking the +// file name and line number, and populating `source_file` and `source_line` +// respectively. +static xla::OpMetadata CreateOpMetadataFromLocation(mlir::Operation* op) { + xla::OpMetadata metadata; + if (op->getLoc().isa()) return metadata; + + std::string name = mlir::GetNameFromLoc(op->getLoc()); + mlir::LegalizeNodeName(name); + metadata.set_op_name(name); + + if (auto file_line_col_loc = op->getLoc().dyn_cast()) { + metadata.set_source_file(file_line_col_loc.getFilename().str()); + metadata.set_source_line(file_line_col_loc.getLine()); + } + + return metadata; +} + // Checks if all shardings are set. static bool AllOptionalShardingsAreSet( llvm::ArrayRef> shardings) { @@ -474,12 +499,14 @@ class ConvertToHloModule { // single value. explicit ConvertToHloModule( mlir::ModuleOp module, bool use_tuple_args, bool return_tuple, - tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn) + tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn, + MlirToHloConversionOptions options) : module_(module), module_builder_("main"), use_tuple_args_(use_tuple_args), return_tuple_(return_tuple), - shape_representation_fn_(shape_representation_fn) { + shape_representation_fn_(shape_representation_fn), + options_(options) { if (!shape_representation_fn_) shape_representation_fn_ = tensorflow::IdentityShapeRepresentationFn(); } @@ -560,6 +587,8 @@ class ConvertToHloModule { // Unique suffix to give to the name of the next lowered region. size_t region_id_ = 0; + + MlirToHloConversionOptions options_; }; } // namespace @@ -761,7 +790,7 @@ LogicalResult ExportXlaOp(InfeedOp op, OpLoweringContext ctx) { LogicalResult ExportXlaOp(IotaOp op, OpLoweringContext ctx) { auto& value_map = *ctx.values; value_map[op] = xla::Iota(ctx.builder, xla::TypeToShape(op.getType()), - op.iota_dimension().getSExtValue()); + op.iota_dimension()); return success(); } @@ -887,8 +916,8 @@ LogicalResult ExportXlaOp(RngBitGeneratorOp op, OpLoweringContext ctx) { auto result = op.getResult(); auto xla_arg_1 = value_map[*op.getODSOperands(0).begin()]; auto xla_result = xla::RngBitGenerator( - static_cast(op.rng_algorithm().getSExtValue()), - Unwrap(xla_arg_1), xla::TypeToShape(result.getType()).tuple_shapes(1)); + static_cast(op.rng_algorithm()), Unwrap(xla_arg_1), + xla::TypeToShape(result.getType()).tuple_shapes(1)); value_map[result] = xla_result; return mlir::success(); } @@ -983,9 +1012,14 @@ LogicalResult ExportXlaOp(SortOp op, OpLoweringContext ctx) { &comparator))) return failure(); + auto tupled = xla::Sort(GetTuple(op.operands(), ctx), comparator, + op.dimension(), op.is_stable()); + auto& value_map = *ctx.values; - value_map[op] = xla::Sort(GetTuple(op.operands(), ctx), comparator, - op.dimension().getSExtValue(), op.is_stable()); + // MLIR's sort supports multiple returns, untuple all the results of XLA's. + for (auto it : llvm::enumerate(op.getResults())) { + value_map[it.value()] = xla::GetTupleElement(tupled, it.index()); + } return success(); } @@ -1034,7 +1068,7 @@ LogicalResult ExportXlaOp(FusionOp op, OpLoweringContext ctx) { llvm::SmallVector operands; for (auto operand : op.operands()) operands.push_back(values[operand]); - xla::XlaOp fusion = xla::internal::XlaBuilderBuildFusion( + xla::XlaOp fusion = xla::internal::XlaBuilderFriend::BuildFusion( ctx.builder, operands, absl::string_view(op.fusion_kind()->data(), op.fusion_kind()->size()), fused_computation); @@ -1048,6 +1082,15 @@ LogicalResult ExportXlaOp(FusionOp op, OpLoweringContext ctx) { return success(); } +LogicalResult ExportXlaOp(BitcastOp op, OpLoweringContext ctx) { + auto& value_map = *ctx.values; + xla::XlaOp operand; + if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure(); + value_map[op] = xla::internal::XlaBuilderFriend::BuildBitcast( + ctx.builder, operand, xla::TypeToShape(op.getType())); + return success(); +} + } // namespace } // namespace mhlo } // namespace mlir @@ -1057,18 +1100,19 @@ LogicalResult ExportXlaOp(FusionOp op, OpLoweringContext ctx) { namespace mlir { namespace { -StatusOr CreateLiteralFromAttr(ElementsAttr attr) { +StatusOr CreateArrayLiteralFromAttr(ElementsAttr attr, + xla::Layout layout) { if (attr.isa()) return tensorflow::errors::Unimplemented( "Opaque elements attr not supported"); xla::Shape shape = xla::TypeToShape(attr.getType()); -#define ELEMENTS_ATTR_TO_LITERAL(xla_type, cpp_type) \ - case xla_type: { \ - xla::Array source_data(shape.dimensions()); \ - source_data.SetValues(attr.getValues()); \ - return xla::LiteralUtil::CreateFromArray(source_data); \ +#define ELEMENTS_ATTR_TO_LITERAL(xla_type, cpp_type) \ + case xla_type: { \ + xla::Array source_data(shape.dimensions()); \ + source_data.SetValues(attr.getValues()); \ + return xla::LiteralUtil::CreateFromArrayWithLayout(source_data, layout); \ } switch (shape.element_type()) { @@ -1098,7 +1142,7 @@ StatusOr CreateLiteralFromAttr(ElementsAttr attr) { } xla::Array source_data(shape.dimensions()); source_data.SetValues(values); - return xla::LiteralUtil::CreateFromArray(source_data); + return xla::LiteralUtil::CreateFromArrayWithLayout(source_data, layout); } case xla::PrimitiveType::BF16: { xla::Array source_data(shape.dimensions()); @@ -1115,7 +1159,7 @@ StatusOr CreateLiteralFromAttr(ElementsAttr attr) { } source_data.SetValues(values_double); return xla::LiteralUtil::ConvertF64ToBF16( - xla::LiteralUtil::CreateFromArray(source_data)); + xla::LiteralUtil::CreateFromArrayWithLayout(source_data, layout)); } default: return tensorflow::errors::Internal(absl::StrCat( @@ -1124,13 +1168,46 @@ StatusOr CreateLiteralFromAttr(ElementsAttr attr) { #undef ELEMENTS_ATTR_TO_LITERAL } +xla::Layout ExtractLayout(mlir::Operation* op, int rank) { + if (auto attr = + op->getAttrOfType("minor_to_major")) { + llvm::SmallVector minor_to_major; + minor_to_major.reserve(attr.size()); + for (const llvm::APInt& i : attr) { + minor_to_major.push_back(i.getZExtValue()); + } + return xla::LayoutUtil::MakeLayout(minor_to_major); + } + return xla::LayoutUtil::MakeDescendingLayout(rank); +} + LogicalResult ConvertToHloModule::Lower( mlir::Operation* inst, bool is_entry_function, llvm::ArrayRef> ret_shardings, xla::XlaBuilder* builder, ConvertToHloModule::ValueLoweringMap* value_lowering, xla::XlaComputation* result) { + // See MlirToHloConversionOptions for more about layouts. + auto propagate_layouts = [this](mlir::Operation* inst, xla::XlaOp xla_op) { + if (options_.propagate_layouts) { + auto* shape = xla::internal::XlaBuilderFriend::GetInstruction(xla_op) + ->mutable_shape(); + if (shape->tuple_shapes().empty()) + *shape->mutable_layout() = + ExtractLayout(inst, shape->dimensions().size()).ToProto(); + } + }; + if (succeeded(ExportXlaOperator(inst, {value_lowering, this, builder}))) { + if (inst->getNumResults() == 1) { + auto iter = value_lowering->find(inst->getResult(0)); + if (iter == value_lowering->end()) { + inst->emitOpError( + "inst has a result, but it's not found in value_lowering"); + return failure(); + } + propagate_layouts(inst, iter->second); + } return success(); } @@ -1156,16 +1233,19 @@ LogicalResult ConvertToHloModule::Lower( if (failed(GetXlaOp(operand, value_map, &xla_operand, op))) return failure(); value_map[op.getResult()] = xla_operand; + propagate_layouts(inst, xla_operand); return success(); } - // TODO(jpienaar): This doesn't support layouts yet. if (matchPattern(inst, m_Constant(&const_attr))) { - auto literal_or = CreateLiteralFromAttr(const_attr); + xla::Layout layout; + layout = ExtractLayout(inst, const_attr.getType().getRank()); + auto literal_or = CreateArrayLiteralFromAttr(const_attr, layout); if (!literal_or.ok()) return inst->emitError(literal_or.status().ToString()); - value_map[inst->getResult(0)] = - xla::ConstantLiteral(builder, literal_or.ValueOrDie()); + auto constant = xla::ConstantLiteral(builder, literal_or.ValueOrDie()); + value_map[inst->getResult(0)] = constant; + return success(); } @@ -1618,22 +1698,24 @@ LogicalResult AddDynamicParameterBindings(mlir::ModuleOp module, } // namespace Status ConvertRegionToComputation(mlir::Region* region, - xla::XlaComputation* func) { + xla::XlaComputation* func, + MlirToHloConversionOptions options) { mlir::ModuleOp module; - ConvertToHloModule converter(module, true, true, {}); + ConvertToHloModule converter(module, true, true, {}, options); if (failed(converter.LowerRegionAsComputation(region, func))) return tensorflow::errors::Internal( "failed to convert region to computation"); return Status::OK(); } -Status ConvertMlirHloToHlo(mlir::ModuleOp module, xla::HloProto* hlo_proto, - bool use_tuple_args, bool return_tuple, - const tensorflow::XlaHelpers::ShapeRepresentationFn - shape_representation_fn) { +Status ConvertMlirHloToHlo( + mlir::ModuleOp module, xla::HloProto* hlo_proto, bool use_tuple_args, + bool return_tuple, + const tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn, + MlirToHloConversionOptions options) { mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext()); ConvertToHloModule converter(module, use_tuple_args, return_tuple, - shape_representation_fn); + shape_representation_fn, options); if (failed(converter.Run())) return diag_handler.ConsumeStatus(); auto hlo_module = converter.ConsumeMainProto(); hlo_proto->mutable_hlo_module()->Swap(&hlo_module); diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h index 6f2b5a6db95..4ca3e586128 100644 --- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h @@ -25,6 +25,18 @@ limitations under the License. namespace mlir { +struct MlirToHloConversionOptions { + // Best-effort propagation of the layouts. These layouts serve as performance + // hints to the backend. + // + // Note that non-array shapes are not carrying layouts, and users have to + // figure out the proper layouts of them through context. This is one of the + // reasons why the attribute-based solution is temporary. + // + // TODO(timshen): Investigate the necessity of having layouts in MHLO. + bool propagate_layouts = false; +}; + // Converts a MLIR module in HLO dialect into a HloModuleProto. If // use_tuple_args is set, then the entry computations's arguments are converted // to a tuple and passed as a single parameter. @@ -32,15 +44,19 @@ namespace mlir { // are converted to a tuple even when there is only a single return value. // Multiple return values are always converted to a tuple and returned as a // single value. +// +// TODO(timshen): move other options into `options`. Status ConvertMlirHloToHlo(mlir::ModuleOp module, ::xla::HloProto* hlo_proto, bool use_tuple_args, bool return_tuple, const tensorflow::XlaHelpers::ShapeRepresentationFn - shape_representation_fn = nullptr); + shape_representation_fn = nullptr, + MlirToHloConversionOptions options = {}); // Converts a region to a computation. It returns a standalone module that // contains the converted region as the entry computation. Status ConvertRegionToComputation(mlir::Region* region, - ::xla::XlaComputation* func); + ::xla::XlaComputation* func, + MlirToHloConversionOptions options = {}); // Creates XlaOp equivalent of a given MLIR operation using the operand info // from `value_lowering` map. diff --git a/tensorflow/compiler/mlir/xla/operator_writer_gen.cc b/tensorflow/compiler/mlir/xla/operator_writer_gen.cc index 407a7d3da38..801c04496f0 100644 --- a/tensorflow/compiler/mlir/xla/operator_writer_gen.cc +++ b/tensorflow/compiler/mlir/xla/operator_writer_gen.cc @@ -165,6 +165,11 @@ static bool OperatorWritersMain(raw_ostream& os, RecordKeeper& records) { "frontend_attributes(lowering_context.builder, " "CreateOpFrontendAttributesFromAttribute(op));\n\n"; + // Create a scoped object to assign op metadata to generated XLA ops. + os << " xla::XlaScopedOpMetadataAssignment " + "op_metadata(lowering_context.builder, " + "CreateOpMetadataFromLocation(op));\n\n"; + // Retrieve all the definitions derived from HLO_Op and sort by record name. for (const auto* def : records.getAllDerivedDefinitions("HLO_Op")) { // Skip operations that have a custom exporter. diff --git a/tensorflow/compiler/mlir/xla/tests/BUILD b/tensorflow/compiler/mlir/xla/tests/BUILD index 2631e2b6757..754b14f4b13 100644 --- a/tensorflow/compiler/mlir/xla/tests/BUILD +++ b/tensorflow/compiler/mlir/xla/tests/BUILD @@ -1,3 +1,4 @@ +load("//tensorflow:tensorflow.bzl", "filegroup") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") load("//tensorflow:tensorflow.bzl", "tf_cc_test") diff --git a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/fusion_layouts.hlotxt b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/fusion_layouts.hlotxt new file mode 100644 index 00000000000..781e203510b --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/fusion_layouts.hlotxt @@ -0,0 +1,16 @@ +// RUN: tf-mlir-translate -hlo-text-to-lhlo -optimize-xla-hlo=false %s | FileCheck %s + +HloModule TestModule + +// CHECK: func @TestComputation + +FusedComputation { + // CHECK: tensor_load %arg0 {minor_to_major = dense<[0, 1]> : tensor<2xindex>} + x = f32[3, 2]{0,1} parameter(0) + ROOT y = f32[3, 2]{0,1} add(x, x) +} + +ENTRY TestComputation { + x = f32[3, 2]{0,1} parameter(0) + ROOT y = f32[3, 2]{0,1} fusion(x), kind=kLoop, calls=FusedComputation +} diff --git a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/ops.mlir b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/ops.mlir index 2e1b63b0db7..e7312e2114c 100644 --- a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/ops.mlir +++ b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/ops.mlir @@ -316,12 +316,61 @@ func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: %[[VIEW0:.*]] = std.view %[[ARG2]]{{.*}} : memref<100xi8> to memref<5x5xi32> // CHECK: %[[VIEW1:.*]] = std.view %[[ARG3]]{{.*}} : memref<100xi8> to memref<5x5xf32> // CHECK: "lmhlo.sort"(%[[ARG0]], %[[ARG1]], %[[VIEW0]], %[[VIEW1]]) -func @main(%key: tensor<5x5xi32>, %value: tensor<5x5xf32>) -> tuple, tensor<5x5xf32>> { - %res = "mhlo.sort"(%key, %value) ({ +func @main(%key: tensor<5x5xi32>, %value: tensor<5x5xf32>) -> (tensor<5x5xi32>, tensor<5x5xf32>) { + %res:2 = "mhlo.sort"(%key, %value) ({ ^bb0(%a: tensor, %b: tensor, %c: tensor, %d: tensor): %ret = "mhlo.compare"(%c, %d) {comparison_direction = "GT"} : (tensor, tensor) -> tensor "mhlo.return"(%ret) : (tensor) -> () - }) {dimension = 1 : i64, is_stable = true}: (tensor<5x5xi32>, tensor<5x5xf32>) -> tuple, tensor<5x5xf32>> + }) {dimension = 1 : i64, is_stable = true}: (tensor<5x5xi32>, tensor<5x5xf32>) -> (tensor<5x5xi32>, tensor<5x5xf32>) - return %res : tuple, tensor<5x5xf32>> + return %res#0, %res#1 : tensor<5x5xi32>, tensor<5x5xf32> +} + +// ----- + +// CHECK-LABEL: func @main +// CHECK-SAME: %[[ARG0:.*]]: memref {{.*}}lmhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref {{.*}}lmhlo.params = 1 +// CHECK-SAME: %[[ARG2:.*]]: memref<4xi8> +// CHECK: "lmhlo.fusion"() ( { +// CHECK: %[[VAR0:.*]] = tensor_load %[[ARG0]] : memref +// CHECK: %[[VAR1:.*]] = tensor_load %[[ARG1]] : memref +// CHECK: %[[VAR2:.*]] = mhlo.add %[[VAR0]], %[[VAR1]] : tensor +// CHECK: tensor_store %[[VAR2]], %[[MEMREF:.*]] : memref +// CHECK: "lmhlo.terminator"() : () -> () +// CHECK: }) : () -> () +func @main(%arg0: tensor, %arg1: tensor) -> tensor { + %result = "mhlo.fusion"(%arg0, %arg1) ( { + ^bb0(%arg2: tensor, %arg3: tensor): + %result = "mhlo.add"(%arg2, %arg3): (tensor, tensor) -> tensor + "mhlo.return"(%result) : (tensor) -> () + }) { fusion_kind = "kLoop" } : (tensor, tensor) -> tensor + + return %result : tensor +} + +// ----- + +// CHECK-LABEL: func @main +// CHECK: "lmhlo.fusion"() ( { +// CHECK: %[[VAL0:.*]] = tensor_load %{{.*}} : memref +// CHECK: %[[VAL1:.*]] = tensor_load %{{.*}} : memref +// CHECK: %[[VAL2:.*]] = tensor_load %{{.*}} : memref +// CHECK: tensor_store %[[VAL0]], %{{.*}} : memref +// CHECK: tensor_store %[[VAL1]], %{{.*}} : memref +// CHECK: tensor_store %[[VAL2]], %{{.*}} : memref +// CHECK: "lmhlo.terminator"() : () -> () +// CHECK: }) : () -> () +func @main(%arg0: tuple>, tensor>, %arg1: tuple>) -> tuple, tensor, tensor> { + %result = "mhlo.fusion"(%arg0, %arg1) ( { + ^bb0(%arg2: tuple>, tensor>, %arg3: tuple>): + %0 = "mhlo.get_tuple_element"(%arg2) {index = 0 : i32} : (tuple>, tensor>) -> tuple> + %1 = "mhlo.get_tuple_element"(%0) {index = 0 : i32} : (tuple>) -> tensor + %2 = "mhlo.get_tuple_element"(%arg2) {index = 1 : i32} : (tuple>, tensor>) -> tensor + %3 = "mhlo.get_tuple_element"(%arg3) {index = 0 : i32} : (tuple>) -> tensor + %4 = "mhlo.tuple"(%1, %2, %3) : (tensor, tensor, tensor) -> tuple, tensor, tensor> + "mhlo.return"(%4) : (tuple, tensor, tensor>) -> () + }) { fusion_kind = "kLoop" } : (tuple>, tensor>, tuple>) -> tuple, tensor, tensor> + + return %result : tuple, tensor, tensor> } diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-BatchMatMulV2.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-BatchMatMulV2.mlir index cffb15022b0..5a07d9303f0 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-BatchMatMulV2.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-BatchMatMulV2.mlir @@ -60,20 +60,20 @@ func @batchmatmulv2_dynamic(%arg0: tensor, %arg1: tensor) return %0 : tensor } -func @batchmatmulv2_adj_real(%arg0: tensor<5x2xf32>, %arg1: tensor<2x4xf32>) -> tensor<5x4xf32> { +func @batchmatmulv2_adj_real(%arg0: tensor<2x5xf32>, %arg1: tensor<4x2xf32>) -> tensor<5x4xf32> { // CHECK-LABEL: func @batchmatmulv2_adj_real // CHECK: "mhlo.dot_general"({{.*}}, {{.*}}) {dot_dimension_numbers = { // CHECK-SAME: lhs_batching_dimensions = dense<> : tensor<0xi64>, // CHECK-SAME: lhs_contracting_dimensions = dense<0> : tensor<1xi64>, // CHECK-SAME: rhs_batching_dimensions = dense<> : tensor<0xi64>, // CHECK-SAME: rhs_contracting_dimensions = dense<1> : tensor<1xi64>}} - %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = true, adj_y = true, device = ""} : (tensor<5x2xf32>, tensor<2x4xf32>) -> tensor<5x4xf32> + %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = true, adj_y = true, device = ""} : (tensor<2x5xf32>, tensor<4x2xf32>) -> tensor<5x4xf32> return %0 : tensor<5x4xf32> } -func @batchmatmulv2_adj_complex(%arg0: tensor<5x2xcomplex>, %arg1: tensor<2x4xcomplex>) -> tensor<5x4xcomplex> { +func @batchmatmulv2_adj_complex(%arg0: tensor<2x5xcomplex>, %arg1: tensor<4x2xcomplex>) -> tensor<5x4xcomplex> { // CHECK-LABEL: func @batchmatmulv2_adj_complex( -// CHECK-SAME: [[LHS:%.*]]: tensor<5x2xcomplex>, [[RHS:%.*]]: tensor<2x4xcomplex>) -> tensor<5x4xcomplex> { +// CHECK-SAME: [[LHS:%.*]]: tensor<2x5xcomplex>, [[RHS:%.*]]: tensor<4x2xcomplex>) -> tensor<5x4xcomplex> { // CHECK: [[LHSRE:%.*]] = "mhlo.real"([[LHS]]) // CHECK: [[LHSIM:%.*]] = "mhlo.imag"([[LHS]]) // CHECK: [[LHSIMNEG:%.*]] = "mhlo.negate"([[LHSIM]]) @@ -84,6 +84,6 @@ func @batchmatmulv2_adj_complex(%arg0: tensor<5x2xcomplex>, %arg1: tensor<2 // CHECK: [[RHSCONJ:%.*]] = "mhlo.complex"([[RHSRE]], [[RHSIMNEG]]) // CHECK: shape.shape_of [[LHSCONJ]] // CHECK: shape.shape_of [[RHSCONJ]] - %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = true, adj_y = true, device = ""} : (tensor<5x2xcomplex>, tensor<2x4xcomplex>) -> tensor<5x4xcomplex> + %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = true, adj_y = true, device = ""} : (tensor<2x5xcomplex>, tensor<4x2xcomplex>) -> tensor<5x4xcomplex> return %0 : tensor<5x4xcomplex> } diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-binary-elementwise.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-binary-elementwise.mlir index 5f3e40f923f..7f37dbb0479 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-binary-elementwise.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-binary-elementwise.mlir @@ -2,6 +2,7 @@ // (unlike the rest), since this is the primary use case for such ops and // verification of shapes and broadcasts is desired. // RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion legalize-chlo=true" -canonicalize %s | FileCheck %s +// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion legalize-chlo=false" %s | FileCheck --check-prefix CHLO %s //===----------------------------------------------------------------------===// // Binary op legalizations. @@ -58,6 +59,15 @@ func @add_dynamic(%arg0: tensor, %arg1: tensor) -> tensor } +// CHECK-LABEL: func @broadcast_add_unranked +// CHLO-LABEL: func @broadcast_add_unranked +func @broadcast_add_unranked(%arg0: tensor<1xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> { + // CHECK: tf.Add + // CHLO: chlo.broadcast_add %arg0, %arg1 + %0 = "tf.Add"(%arg0, %arg1) : (tensor<1xi32>, tensor<*xi32>) -> tensor<*xi32> + return %0: tensor<*xi32> +} + // CHECK-LABEL: func @div func @div(%arg0: tensor<2xi32>) -> tensor<2xi32> { // CHECK-NEXT: %0 = mhlo.divide %arg0, %arg0 : tensor<2xi32> @@ -139,9 +149,9 @@ func @broadcast_shift_right_unsigned(%arg0: tensor<4xui8>, %arg1: tensor<2x4xui8 } // CHECK-LABEL: func @and -func @and(%arg0: tensor<2xi1>) -> tensor<2xi1> { +func @and(%arg0: tensor<2xi1>, %arg1: tensor<2xi1>) -> tensor<2xi1> { // CHECK-NEXT: mhlo.and - %0 = "tf.LogicalAnd"(%arg0, %arg0) : (tensor<2xi1>, tensor<2xi1>) -> tensor<2xi1> + %0 = "tf.LogicalAnd"(%arg0, %arg1) : (tensor<2xi1>, tensor<2xi1>) -> tensor<2xi1> return %0: tensor<2xi1> } @@ -153,9 +163,9 @@ func @and_unranked(%arg0: tensor<*xi1>, %arg1: tensor<*xi1>) -> tensor<*xi1> { } // CHECK-LABEL: func @or -func @or(%arg0: tensor<2xi1>) -> tensor<2xi1> { +func @or(%arg0: tensor<2xi1>, %arg1: tensor<2xi1>) -> tensor<2xi1> { // CHECK-NEXT: mhlo.or - %0 = "tf.LogicalOr"(%arg0, %arg0) : (tensor<2xi1>, tensor<2xi1>) -> tensor<2xi1> + %0 = "tf.LogicalOr"(%arg0, %arg1) : (tensor<2xi1>, tensor<2xi1>) -> tensor<2xi1> return %0: tensor<2xi1> } @@ -187,9 +197,9 @@ func @pow(%arg0: tensor<2xf32>) -> tensor<2xf32> { //===----------------------------------------------------------------------===// // CHECK-LABEL: func @equal -func @equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { - // CHECK-NEXT: "mhlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} - %0 = "tf.Equal"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> +func @equal(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> { + // CHECK-NEXT: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} + %0 = "tf.Equal"(%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> return %0: tensor<2xi1> } @@ -255,9 +265,9 @@ func @equal_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi1> } // CHECK-LABEL: func @notequal -func @notequal(%arg0: tensor<2xi32>) -> tensor<2xi1> { - // CHECK-NEXT: "mhlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} - %0 = "tf.NotEqual"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> +func @notequal(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> { + // CHECK-NEXT: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "NE"} + %0 = "tf.NotEqual"(%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> return %0: tensor<2xi1> } @@ -268,9 +278,9 @@ func @notequal(%arg0: tensor<2xi32>) -> tensor<2xi1> { //===----------------------------------------------------------------------===// // CHECK-LABEL: func @greater -func @greater(%arg0: tensor<2xi32>) -> tensor<2xi1> { - // CHECK: "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} - %0 = "tf.Greater"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> +func @greater(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> { + // CHECK: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} + %0 = "tf.Greater"(%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> return %0: tensor<2xi1> } @@ -300,29 +310,29 @@ func @greater_dynamic(%arg0: tensor, %arg1: tensor) -> tensor) -> tensor<*xi1> { +func @greater_uranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi1> { // CHECK: "tf.Greater" - %0 = "tf.Greater"(%arg0, %arg0) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi1> + %0 = "tf.Greater"(%arg0, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi1> return %0: tensor<*xi1> } // CHECK-LABEL: func @greater_equal -func @greater_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { - // CHECK-NEXT: "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} - %0 = "tf.GreaterEqual"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> +func @greater_equal(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> { + // CHECK-NEXT: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GE"} + %0 = "tf.GreaterEqual"(%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> return %0: tensor<2xi1> } // CHECK-LABEL: func @less -func @less(%arg0: tensor<2xi32>) -> tensor<2xi1> { - // CHECK-NEXT: "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} - %0 = "tf.Less"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> +func @less(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> { + // CHECK-NEXT: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "LT"} + %0 = "tf.Less"(%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> return %0: tensor<2xi1> } // CHECK-LABEL: func @less_equal -func @less_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { - // CHECK-NEXT: "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} - %0 = "tf.LessEqual"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> +func @less_equal(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> { + // CHECK-NEXT: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "LE"} + %0 = "tf.LessEqual"(%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> return %0: tensor<2xi1> } diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-control-flow.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-control-flow.mlir index 93eac3821b2..767e0be8d6a 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-control-flow.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-control-flow.mlir @@ -1,27 +1,27 @@ // RUN: tf-opt -xla-legalize-tf-control-flow %s | FileCheck %s // CHECK-LABEL: @if -func @if(%arg0: tensor, %arg1: tensor) -> (tensor) -attributes {tf._input_shapes = ["tfshape$", "tfshape$"]} { - // CHECK: [[VAL0:%.+]] = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor +// CHECK-SAME: ([[ARG0:%.+]]: tensor, [[ARG1:%.+]]: tensor) +func @if(%arg0: tensor, %arg1: tensor) -> (tensor) { + // CHECK: [[VAL0:%.+]] = "mhlo.compare"([[ARG0]], [[ARG1]]) {comparison_direction = "GT"} : (tensor, tensor) -> tensor %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor - // CHECK: [[VAL1:%.+]] = "mhlo.tuple"(%arg0, %arg1) + // CHECK: [[VAL1:%.+]] = "mhlo.tuple"([[ARG0]], [[ARG1]]) // CHECK: [[VAL2:%.+]] = "mhlo.if"([[VAL0]], [[VAL1]], [[VAL1]]) ( { - // CHECK: ^bb0(%arg2: tuple, tensor>): - // CHECK: [[VAL4:%.+]] = "mhlo.get_tuple_element"(%arg2) {index = 0 : i32} - // CHECK: [[VAL5:%.+]] = "mhlo.get_tuple_element"(%arg2) {index = 1 : i32} + // CHECK: ^bb0([[THEN_ARG:%.+]]: tuple, tensor>): + // CHECK: [[VAL4:%.+]] = "mhlo.get_tuple_element"([[THEN_ARG]]) {index = 0 : i32} + // CHECK: [[VAL5:%.+]] = "mhlo.get_tuple_element"([[THEN_ARG]]) {index = 1 : i32} // CHECK: [[VAL6:%.+]] = call @cond_true([[VAL4]], [[VAL5]]) // CHECK: [[VAL7:%.+]] = "mhlo.tuple"([[VAL6]]) // CHECK: "mhlo.return"([[VAL7]]) : (tuple>) -> () // CHECK: }, { - // CHECK: ^bb0(%arg2: tuple, tensor>) - // CHECK: [[VAL4:%.+]] = "mhlo.get_tuple_element"(%arg2) {index = 0 : i32} - // CHECK: [[VAL5:%.+]] = "mhlo.get_tuple_element"(%arg2) {index = 1 : i32} + // CHECK: ^bb0([[ELSE_ARG:%.+]]: tuple, tensor>) + // CHECK: [[VAL4:%.+]] = "mhlo.get_tuple_element"([[ELSE_ARG]]) {index = 0 : i32} + // CHECK: [[VAL5:%.+]] = "mhlo.get_tuple_element"([[ELSE_ARG]]) {index = 1 : i32} // CHECK: [[VAL6:%.+]] = call @cond_false([[VAL4]], [[VAL5]]) // CHECK: [[VAL7:%.+]] = "mhlo.tuple"([[VAL6]]) - // CHECK: "mhlo.return"([[VAL7]]) : (tuple>) -> () + // CHECK: "mhlo.return"([[VAL7]]) : (tuple>) -> () // CHECK: }) - %1 = "tf.If"(%0, %arg0, %arg1) {Tcond = "tfdtype$DT_BOOL", Tin = ["tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT"], _lower_using_switch_merge = true, _output_shapes = ["tfshape$"], device = "", else_branch = @cond_false, is_stateless = true, name = "cond", output_shapes = [#tf.shape<>], then_branch = @cond_true} : (tensor, tensor, tensor) -> tensor + %1 = "tf.If"(%0, %arg0, %arg1) {else_branch = @cond_false, is_stateless = true, then_branch = @cond_true} : (tensor, tensor, tensor) -> tensor // CHECK: [[VAL3:%.+]] = "mhlo.get_tuple_element"([[VAL2]]) {index = 0 : i32} // CHECK: return [[VAL3]] @@ -41,6 +41,38 @@ attributes {tf._input_shapes = ["tfshape$", "tfshape$"]} { } +// CHECK-LABEL: @ifRegion +// CHECK-SAME: ([[ARG0:%.+]]: tensor, [[ARG1:%.+]]: tensor) +func @ifRegion(%arg0: tensor, %arg1: tensor) -> (tensor) { + // CHECK: [[VAL0:%.+]] = "mhlo.compare"([[ARG0]], [[ARG1]]) {comparison_direction = "GT"} + %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + // CHECK: [[VAL1:%.+]] = "mhlo.tuple"([[ARG0]]) + // CHECK: [[VAL2:%.+]] = "mhlo.tuple"([[ARG1]]) + // CHECK: [[VAL3:%.+]] = "mhlo.if"([[VAL0]], [[VAL1]], [[VAL2]]) ( { + %1 = "tf.IfRegion"(%0) ( { + // CHECK: ^{{[a-z0-9]+}}([[TRUE_ARG:%.+]]: tuple>): + // CHECK: [[VAL5:%.+]] = "mhlo.get_tuple_element"([[TRUE_ARG]]) {index = 0 : i32} + // CHECK: [[VAL6:%.+]] = "mhlo.log"([[VAL5]]) + %2 = "mhlo.log"(%arg0) : (tensor) -> tensor + // CHECK: [[VAL7:%.+]] = "mhlo.tuple"([[VAL6]]) + // CHECK: "mhlo.return"([[VAL7]]) + "tf.Yield"(%2) : (tensor) -> () + }, { + // CHECK: ^{{[a-z0-9]+}}([[FALSE_ARG:%.+]]: tuple>): + // CHECK: [[VAL5:%.+]] = "mhlo.get_tuple_element"([[FALSE_ARG]]) {index = 0 : i32} + // CHECK: [[VAL6:%.+]] = "mhlo.exponential"([[VAL5]]) + %2 = "mhlo.exponential"(%arg1) : (tensor) -> tensor + // CHECK: [[VAL7:%.+]] = "mhlo.tuple"([[VAL6]]) + // CHECK: "mhlo.return"([[VAL7]]) + "tf.Yield"(%2) : (tensor) -> () + // CHECK: }) : (tensor, tuple>, tuple>) -> tuple> + }) {is_stateless = true} : (tensor) -> tensor + // CHECK: [[VAL4:%.+]] = "mhlo.get_tuple_element"([[VAL3]]) {index = 0 : i32} + // CHECK: return [[VAL4]] + return %1 : tensor +} + + // CHECK-LABEL: func @case // CHECK-SAME: %[[BRANCH_INDEX:.*]]: tensor, %[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) -> (tensor, tensor) func @case(%index: tensor, %arg0: tensor, %arg1: tensor) -> (tensor, tensor) { @@ -85,26 +117,62 @@ func @floor(%arg0: tensor, %arg1: tensor) -> (tensor, tensor } +// CHECK-LABEL: func @caseRegion +// CHECK-SAME: ([[BRANCH_INDEX:%.+]]: tensor, [[ARG0:.+]]: tensor, [[ARG1:%.+]]: tensor) +func @caseRegion(%index: tensor, %arg0: tensor, %arg1: tensor) -> (tensor, tensor) { + // CHECK: [[VAL0:%.+]] = "mhlo.tuple"([[ARG1]]) + // CHECK: [[VAL1:%.+]] = "mhlo.tuple"([[ARG0]], [[ARG1]]) + // CHECK: [[VAL2:%.+]] = "mhlo.tuple"([[ARG0]], [[ARG1]]) + // CHECK: [[VAL3:%.+]]:2 = "mhlo.case"([[BRANCH_INDEX]], [[VAL0]], [[VAL1]], [[VAL2]]) ( { + %0:2 = "tf.CaseRegion"(%index) ( { + // CHECK: ^{{[a-z0-9]+}}([[BRANCH0_ARG:%.+]]: tuple>): + // CHECK: [[VAL4:%.+]] = "mhlo.get_tuple_element"([[BRANCH0_ARG]]) {index = 0 : i32} + // CHECK: [[VAL5:%.+]] = "mhlo.exponential"([[VAL4]]) + %1 = "mhlo.exponential"(%arg1) : (tensor) -> tensor + // CHECK: "mhlo.return"([[VAL5]], [[VAL4]]) + "tf.Yield"(%1, %arg1) : (tensor, tensor) -> () + }, { + // CHECK: ^{{[a-z0-9]+}}([[BRANCH1_ARG:%.+]]: tuple, tensor>): + // CHECK: [[VAL4:%.+]] = "mhlo.get_tuple_element"([[BRANCH1_ARG]]) {index = 0 : i32} + // CHECK: [[VAL5:%.+]] = "mhlo.get_tuple_element"([[BRANCH1_ARG]]) {index = 1 : i32} + // CHECK: [[VAL6:%.+]] = "mhlo.log"([[VAL4]]) + %1 = "mhlo.log"(%arg0) : (tensor) -> tensor + // CHECK: "mhlo.return"([[VAL6]], [[VAL5]]) + "tf.Yield"(%1, %arg1) : (tensor, tensor) -> () + }, { + // CHECK: ^{{[a-z0-9]+}}([[BRANCH2_ARG:%.+]]: tuple, tensor>): + // CHECK: [[VAL4:%.+]] = "mhlo.get_tuple_element"([[BRANCH2_ARG]]) {index = 0 : i32} + // CHECK: [[VAL5:%.+]] = "mhlo.get_tuple_element"([[BRANCH2_ARG]]) {index = 1 : i32} + // CHECK: [[VAL6:%.+]] = "mhlo.floor"([[VAL4]]) + %1 = "mhlo.floor"(%arg0) : (tensor) -> tensor + // CHECK: "mhlo.return"([[VAL6]], [[VAL5]]) + "tf.Yield"(%1, %arg1) : (tensor, tensor) -> () + // CHECK: }) : (tensor, tuple>, tuple, tensor>, tuple, tensor>) -> (tensor, tensor) + }) {is_stateless = true} : (tensor) -> (tensor, tensor) + // CHECK: return [[VAL3]]#0, [[VAL3]]#1 : tensor, tensor + return %0#0, %0#1 : tensor, tensor +} + + // CHECK-LABEL: func @while -func @while(%arg0: tensor {tf_saved_model.index_path = [0]}) -> (tensor {tf_saved_model.index_path = []}) -attributes {tf._input_shapes = ["tfshape$"]} { +func @while() -> tensor { // CHECK: [[VAL0:%.+]] = mhlo.constant dense<0> // CHECK: [[VAL1:%.+]] = mhlo.constant dense<-1> %0 = mhlo.constant dense<0> : tensor %1 = mhlo.constant dense<-1> : tensor // CHECK: [[VAL2:%.+]] = "mhlo.tuple"([[VAL0]], [[VAL1]], [[VAL0]]) // CHECK: [[VAL3:%.+]] = "mhlo.while"([[VAL2]]) ( { - // CHECK: ^bb0(%arg1: tuple, tensor, tensor>): - // CHECK: [[VAL7:%.+]] = "mhlo.get_tuple_element"(%arg1) {index = 0 : i32} - // CHECK: [[VAL8:%.+]] = "mhlo.get_tuple_element"(%arg1) {index = 1 : i32} - // CHECK: [[VAL9:%.+]] = "mhlo.get_tuple_element"(%arg1) {index = 2 : i32} + // CHECK: ^bb0([[COND_ARG:%.+]]: tuple, tensor, tensor>): + // CHECK: [[VAL7:%.+]] = "mhlo.get_tuple_element"([[COND_ARG]]) {index = 0 : i32} + // CHECK: [[VAL8:%.+]] = "mhlo.get_tuple_element"([[COND_ARG]]) {index = 1 : i32} + // CHECK: [[VAL9:%.+]] = "mhlo.get_tuple_element"([[COND_ARG]]) {index = 2 : i32} // CHECK: [[VAL10:%.+]] = call @while_cond([[VAL7]], [[VAL8]], [[VAL9]]) // CHECK: "mhlo.return"([[VAL10]]) // CHECK: }, { - // CHECK: ^bb0(%arg1: tuple, tensor, tensor>): - // CHECK: [[VAL7:%.+]] = "mhlo.get_tuple_element"(%arg1) {index = 0 : i32} - // CHECK: [[VAL8:%.+]] = "mhlo.get_tuple_element"(%arg1) {index = 1 : i32} - // CHECK: [[VAL9:%.+]] = "mhlo.get_tuple_element"(%arg1) {index = 2 : i32} + // CHECK: ^bb0([[BODY_ARG:%.+]]: tuple, tensor, tensor>): + // CHECK: [[VAL7:%.+]] = "mhlo.get_tuple_element"([[BODY_ARG]]) {index = 0 : i32} + // CHECK: [[VAL8:%.+]] = "mhlo.get_tuple_element"([[BODY_ARG]]) {index = 1 : i32} + // CHECK: [[VAL9:%.+]] = "mhlo.get_tuple_element"([[BODY_ARG]]) {index = 2 : i32} // CHECK: [[VAL10:%.+]]:3 = call @while_body([[VAL7]], [[VAL8]], [[VAL9]]) // CHECK: [[VAL11:%.+]] = "mhlo.tuple"([[VAL10]]#0, [[VAL10]]#1, [[VAL10]]#2) // CHECK: "mhlo.return"([[VAL11]]) @@ -113,19 +181,134 @@ attributes {tf._input_shapes = ["tfshape$"]} { // CHECK: [[VAL5:%.+]] = "mhlo.get_tuple_element"([[VAL3]]) {index = 1 : i32} // CHECK: [[VAL6:%.+]] = "mhlo.get_tuple_element"([[VAL3]]) {index = 2 : i32} // CHECK: return [[VAL6]] - %2:3 = "tf.While"(%0, %1, %0) {T = ["tfdtype$DT_INT32", "tfdtype$DT_INT32", "tfdtype$DT_INT32"], _lower_using_switch_merge = true, _num_original_outputs = 3 : i64, _output_shapes = ["tfshape$", "tfshape$", "tfshape$"], body = @while_body, cond = @while_cond, device = "", is_stateless = true, name = "while", output_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], parallel_iterations = 10 : i64} : (tensor, tensor, tensor) -> (tensor, tensor, tensor) + %2:3 = "tf.While"(%0, %1, %0) {body = @while_body, cond = @while_cond, is_stateless = true, parallel_iterations = 10 : i64} : (tensor, tensor, tensor) -> (tensor, tensor, tensor) return %2#2 : tensor } -func @while_cond(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor -attributes {tf._input_shapes = ["tfshape$", "tfshape$", "tfshape$"]} { +func @while_cond(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { %0 = mhlo.constant dense<10> : tensor %1 = "mhlo.compare"(%arg2, %0) {comparison_direction = "LT"} : (tensor, tensor) -> tensor return %1 : tensor } -func @while_body(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> (tensor, tensor, tensor) -attributes {tf._input_shapes = ["tfshape$", "tfshape$", "tfshape$"]} { +func @while_body(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> (tensor, tensor, tensor) { %0 = mhlo.constant dense<1> : tensor %1 = mhlo.add %arg2, %0 : tensor %2 = mhlo.add %arg0, %0 : tensor return %2, %arg1, %1 : tensor, tensor, tensor } + + +// CHECK-LABEL: func @whileRegion +func @whileRegion() -> tensor { + // CHECK: [[VAL0:%.+]] = mhlo.constant dense<0> + %0 = mhlo.constant dense<0> : tensor + // CHECK: [[VAL1:%.+]] = mhlo.constant dense<-1> + %1 = mhlo.constant dense<-1> : tensor + // CHECK: [[VAL2:%.+]] = "mhlo.tuple"([[VAL0]], [[VAL1]], [[VAL0]]) + // CHECK: [[VAL3:%.+]] = "mhlo.while"([[VAL2]]) ( { + %2:3 = "tf.WhileRegion"(%0, %1, %0) ( { + // CHECK: ^bb0([[COND_ARG:%.+]]: tuple, tensor, tensor>): + ^cond(%carg0: tensor, %carg1: tensor, %carg2: tensor): + // CHECK: [[VAL7:%.+]] = "mhlo.get_tuple_element"([[COND_ARG]]) {index = 0 : i32} + // CHECK: [[VAL8:%.+]] = "mhlo.get_tuple_element"([[COND_ARG]]) {index = 1 : i32} + // CHECK: [[VAL9:%.+]] = "mhlo.get_tuple_element"([[COND_ARG]]) {index = 2 : i32} + // CHECK: [[VAL10:%.+]] = mhlo.constant dense<10> + %3 = mhlo.constant dense<10> : tensor + // CHECK: [[VAL11:%.+]] = "mhlo.compare"([[VAL9]], [[VAL10]]) {comparison_direction = "LT"} + %4 = "mhlo.compare"(%carg2, %3) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + // CHECK: "mhlo.return"([[VAL11]]) + "tf.Yield"(%4) : (tensor) -> () + }, { + // CHECK: ^bb0([[BODY_ARG:%.+]]: tuple, tensor, tensor>): + ^body(%barg0: tensor, %barg1: tensor, %barg2: tensor): + // CHECK: [[VAL7:%.+]] = "mhlo.get_tuple_element"([[BODY_ARG]]) {index = 0 : i32} + // CHECK: [[VAL8:%.+]] = "mhlo.get_tuple_element"([[BODY_ARG]]) {index = 1 : i32} + // CHECK: [[VAL9:%.+]] = "mhlo.get_tuple_element"([[BODY_ARG]]) {index = 2 : i32} + // CHECK: [[VAL10:%.+]] = mhlo.constant dense<1> + %5 = mhlo.constant dense<1> : tensor + // CHECK: [[VAL11:%.+]] = mhlo.add [[VAL9]], [[VAL10]] + %6 = mhlo.add %barg2, %5 : tensor + // CHECK: [[VAL12:%.+]] = mhlo.add [[VAL7]], [[VAL10]] + %7 = mhlo.add %barg0, %5 : tensor + // CHECK: [[VAL13:%.+]] = "mhlo.tuple"([[VAL12]], [[VAL8]], [[VAL11]]) + // CHECK: "mhlo.return"([[VAL13]]) + "tf.Yield"(%7, %barg1, %6) : (tensor, tensor, tensor) -> () + }) {is_stateless = true, parallel_iterations = 10 : i64} : (tensor, tensor, tensor) -> (tensor, tensor, tensor) + // CHECK: }) : (tuple, tensor, tensor>) -> tuple, tensor, tensor> + // CHECK: [[VAL4:%.+]] = "mhlo.get_tuple_element"([[VAL3]]) {index = 0 : i32} + // CHECK: [[VAL5:%.+]] = "mhlo.get_tuple_element"([[VAL3]]) {index = 1 : i32} + // CHECK: [[VAL6:%.+]] = "mhlo.get_tuple_element"([[VAL3]]) {index = 2 : i32} + // CHECK: return [[VAL6]] + return %2#2 : tensor +} + + +// CHECK-LABEL: func @whileRegionImplicitInputs +// CHECK-SAME: ([[ARG0:%.+]]: tensor) +func @whileRegionImplicitInputs(%arg0: tensor) -> tensor { + // CHECK: [[VAL0:%.+]] = mhlo.constant dense<0> + %0 = mhlo.constant dense<0> : tensor + // CHECK: [[VAL1:%.+]] = mhlo.constant dense<-1> + %1 = mhlo.constant dense<-1> : tensor + // CHECK: [[VAL2:%.+]] = "mhlo.tuple"([[ARG0]], [[VAL0]], [[VAL1]]) + // CHECK: [[VAL3:%.+]] = "mhlo.while"([[VAL2]]) ( { + %2 = "tf.WhileRegion"(%arg0) ( { + // CHECK: ^bb0([[COND_ARG:%.+]]: tuple, tensor, tensor>): + ^cond(%carg0: tensor): + // CHECK: [[VAL5:%.+]] = "mhlo.get_tuple_element"([[COND_ARG]]) {index = 0 : i32} + // CHECK: [[VAL6:%.+]] = "mhlo.get_tuple_element"([[COND_ARG]]) {index = 1 : i32} + // CHECK: [[VAL7:%.+]] = "mhlo.get_tuple_element"([[COND_ARG]]) {index = 2 : i32} + // CHECK: [[VAL8:%.+]] = "mhlo.compare"([[VAL5]], [[VAL6]]) {comparison_direction = "LT"} + %3 = "mhlo.compare"(%carg0, %0) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + // CHECK: "mhlo.return"([[VAL8]]) + "tf.Yield"(%3) : (tensor) -> () + }, { + // CHECK: ^bb0([[BODY_ARG:%.+]]: tuple, tensor, tensor>): + ^body(%barg0: tensor): + // CHECK: [[VAL5:%.+]] = "mhlo.get_tuple_element"([[BODY_ARG]]) {index = 0 : i32} + // CHECK: [[VAL6:%.+]] = "mhlo.get_tuple_element"([[BODY_ARG]]) {index = 1 : i32} + // CHECK: [[VAL7:%.+]] = "mhlo.get_tuple_element"([[BODY_ARG]]) {index = 2 : i32} + // CHECK: [[VAL8:%.+]] = mhlo.add [[VAL5]], [[VAL7]] + %3 = mhlo.add %barg0, %1 : tensor + // CHECK: [[VAL9:%.+]] = mhlo.add [[VAL5]], [[VAL8]] + %4 = mhlo.add %barg0, %3 : tensor + // CHECK: [[VAL10:%.+]] = "mhlo.tuple"([[VAL9]], [[VAL6]], [[VAL7]]) + // CHECK: "mhlo.return"([[VAL10]]) + "tf.Yield"(%4) : (tensor) -> () + }) {is_stateless = true, parallel_iterations = 10 : i64} : (tensor) -> tensor + // CHECK: }) : (tuple, tensor, tensor>) -> tuple, tensor, tensor> + // CHECK: [[VAL4:%.+]] = "mhlo.get_tuple_element"([[VAL3]]) {index = 0 : i32} + // CHECK: return [[VAL4]] + return %2 : tensor +} + + +// CHECK-LABEL: func @whileRegionMultipleImplicitInputs +func @whileRegionMultipleImplicitInputs() { + // CHECK: [[VAL0:%.+]] = mhlo.constant dense<0> + %0 = mhlo.constant dense<0> : tensor + // CHECK: [[VAL1:%.+]] = mhlo.constant dense<-1> + %1 = mhlo.constant dense<-1> : tensor + // CHECK: [[VAL2:%.+]] = "mhlo.tuple"([[VAL0]], [[VAL1]]) + // CHECK: [[VAL3:%.+]] = "mhlo.while"([[VAL2]]) ( { + "tf.WhileRegion"() ( { + // CHECK: ^bb0([[COND_ARG:%.+]]: tuple, tensor>): + // CHECK: [[VAL4:%.+]] = "mhlo.get_tuple_element"([[COND_ARG]]) {index = 0 : i32} + // CHECK: [[VAL5:%.+]] = "mhlo.get_tuple_element"([[COND_ARG]]) {index = 1 : i32} + // CHECK: [[VAL6:%.+]] = "mhlo.compare"([[VAL4]], [[VAL5]]) {comparison_direction = "LT"} + %2 = "mhlo.compare"(%0, %1) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + // CHECK: "mhlo.return"([[VAL6]]) + "tf.Yield"(%2) : (tensor) -> () + }, { + // CHECK: ^bb0([[BODY_ARG:%.+]]: tuple, tensor>): + // CHECK: [[VAL4:%.+]] = "mhlo.get_tuple_element"([[COND_ARG]]) {index = 0 : i32} + // CHECK: [[VAL5:%.+]] = "mhlo.get_tuple_element"([[COND_ARG]]) {index = 1 : i32} + // CHECK: [[VAL6:%.+]] = mhlo.add [[VAL4]], [[VAL5]] + %2 = mhlo.add %0, %1 : tensor + // CHECK: [[VAL7:%.+]] = "mhlo.tuple"([[VAL4]], [[VAL5]]) + // CHECK: "mhlo.return"([[VAL7]]) + "tf.Yield"() : () -> () + }) {is_stateless = true, parallel_iterations = 10 : i64} : () -> () + // CHECK: }) : (tuple, tensor>) -> tuple, tensor> + // CHECK: return + return +} diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir index df4f0303a84..a21a78cf7f4 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir @@ -125,9 +125,9 @@ func @constant(%arg0: tensor<2xf32>) -> tensor<2xf32> { } // CHECK-LABEL: func @greater -func @greater(%arg0: tensor<2xi32>) -> tensor<2xi1> { - // CHECK-NEXT: "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} - %0 = "tf.Greater"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> +func @greater(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> { + // CHECK-NEXT: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} + %0 = "tf.Greater"(%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> return %0: tensor<2xi1> } @@ -220,13 +220,6 @@ func @sparse_to_dense(%arg0: tensor<3x2xi32>, %arg1: tensor<3xf32>, %arg2: tenso return %0 : tensor<3x3xf32> } -// CHECK-LABEL: fft -func @fft(%arg0: tensor<3x5x8xcomplex>) -> tensor<3x5x8xcomplex> { - // CHECK: "mhlo.fft"(%arg0) - %0 = "tf.FFT"(%arg0) : (tensor<3x5x8xcomplex>) -> tensor<3x5x8xcomplex> - return %0 : tensor<3x5x8xcomplex> -} - // CHECK-LABEL: reverse_sequence func @reverse_sequence(%arg0: tensor<4x2x3x1x1xi32>, %arg1: tensor<3xi32>) -> tensor<4x2x3x1x1xi32> { // CHECK-NOT: tf.ReverseSequence diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index 56d4236c0a0..23137eff774 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -342,7 +342,7 @@ func @fusedBatchNormGradV3_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor< } // CHECK-LABEL: fusedBatchNormGradV3_Training -func @fusedBatchNormGradV3_Training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { +func @fusedBatchNormGradV3_Training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<0xf32>, tensor<*xf32>) { // CHECK-NEXT: %[[grad:.*]] = "mhlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> // CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> // CHECK-NEXT: %[[training:.*]] = "mhlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> tuple, tensor<8xf32>, tensor<8xf32>> @@ -350,10 +350,11 @@ func @fusedBatchNormGradV3_Training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x // CHECK-NEXT: %[[scale_backprop:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 1 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> // CHECK-NEXT: %[[offset_backprop:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 2 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> // CHECK-NEXT: %[[x_backprop:.*]] = "mhlo.convert"(%[[tact]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32> + // CHECK: return %[[x_backprop]] + // CHECK-SAME: tensor<8x8x8x8xf32> - %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) - return %0#0 : tensor<8x8x8x8xf32> + %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<0xf32>, tensor<*xf32>) + return %0#0, %0#3, %0#4 : tensor<8x8x8x8xf32>, tensor<0xf32>, tensor<*xf32> } // CHECK-LABEL: fusedBatchNormGradV3_noTraining_mixed_precision @@ -439,6 +440,17 @@ func @fusedBatchNormGradV3_Training_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: tens // Bias op legalizations. //===----------------------------------------------------------------------===// +// CHECK-LABEL: func @biasAdd_default +func @biasAdd_default(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { + // CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0 + // CHECK: %[[ARG0_EXTENTS:.+]] = shape.to_extent_tensor %[[ARG0_SHAPE]] + // CHECK: %[[ARG1_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]]) + // CHECK-SAME: {broadcast_dimensions = dense<3> : tensor<1xi64>} + // CHECK: %[[RESULT:.+]] = mhlo.add %arg0, %[[ARG1_BCAST]] + %0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT"} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> + return %0 : tensor<1x32x10x32xi32> +} + // CHECK-LABEL: func @biasAdd_NHWC func @biasAdd_NHWC(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { // CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0 @@ -472,6 +484,57 @@ func @biasAdd_dynamic(%arg0: tensor, %arg1: tensor) -> tenso return %0 : tensor } + +//===----------------------------------------------------------------------===// +// ClipByValue +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @clip +func @clip(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> tensor { + // CHECK: [[VAL:%.+]] = "mhlo.clamp"(%arg1, %arg0, %arg2) + + %0 = "tf.ClipByValue"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor + // CHECK: return [[VAL]] + return %0 : tensor +} + +// CHECK-LABEL: @clip_dynamic +func @clip_dynamic(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> tensor { + // CHECK-DAG: [[CLAMP:%.+]] = "mhlo.clamp"(%arg1, %arg0, %arg2) + %0 = "tf.ClipByValue"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor + + // CHECK: return [[CLAMP]] + return %0 : tensor +} + +// CHECK-LABEL: @clip_static_broadcast +func @clip_static_broadcast(%arg0 : tensor<5xf32>, %arg1 : tensor, %arg2 : tensor) -> tensor<5xf32> { + // CHECK-DAG: [[SHP:%.+]] = mhlo.constant dense<5> + // CHECK-DAG: [[SHPIDX:%.+]] = tensor_cast [[SHP]] + // CHECK-DAG: [[BROADCAST_MIN:%.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, [[SHPIDX]]) {broadcast_dimensions = dense<> : tensor<0xi64>} + // CHECK-DAG: [[BROADCAST_MAX:%.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg2, [[SHPIDX]]) {broadcast_dimensions = dense<> : tensor<0xi64>} + // CHECK-DAG: [[CLAMP:%.+]] = "mhlo.clamp"([[BROADCAST_MIN]], %arg0, [[BROADCAST_MAX]]) + %0 = "tf.ClipByValue"(%arg0, %arg1, %arg2) : (tensor<5xf32>, tensor, tensor) -> tensor<5xf32> + + // CHECK: return [[CLAMP]] + return %0 : tensor<5xf32> +} + + +// CHECK-LABEL: @clip_dynamic_broadcast +func @clip_dynamic_broadcast(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> tensor { + // CHECK-DAG: [[SHP:%.+]] = shape.shape_of %arg0 + // CHECK-DAG: [[EXT:%.+]] = shape.to_extent_tensor [[SHP]] + // CHECK-DAG: [[SHPIDX:%.+]] = index_cast %1 + // CHECK-DAG: [[BROADCAST_MIN:%.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, [[SHPIDX]]) {broadcast_dimensions = dense<> : tensor<0xi64>} + // CHECK-DAG: [[BROADCAST_MAX:%.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg2, [[SHPIDX]]) {broadcast_dimensions = dense<> : tensor<0xi64>} + // CHECK-DAG: [[CLAMP:%.+]] = "mhlo.clamp"([[BROADCAST_MIN]], %arg0, [[BROADCAST_MAX]]) + %0 = "tf.ClipByValue"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor + + // CHECK: return [[CLAMP]] + return %0 : tensor +} + //===----------------------------------------------------------------------===// // DiagPart //===----------------------------------------------------------------------===// @@ -1269,6 +1332,15 @@ func @maxpool_3d_same_padding(%arg0: tensor<2x8x13x25x7xf32>) -> tensor<2x8x4x7x return %0 : tensor<2x8x4x7x7xf32> } +// CHECK-LABEL: maxpool_explicit_padding +func @maxpool_explicit_padding(%arg0: tensor<2x12x20x7xi32>) -> tensor<2x3x5x7xi32> { + // CHECK: tf.MaxPool + // TODO(b/165938852): need to support explicit padding in max_pool. + + %0 = "tf.MaxPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "EXPLICIT", strides = [1, 4, 4, 1]} : (tensor<2x12x20x7xi32>) -> tensor<2x3x5x7xi32> + return %0 : tensor<2x3x5x7xi32> +} + //===----------------------------------------------------------------------===// // MaxPoolGrad op legalizations. //===----------------------------------------------------------------------===// @@ -1755,6 +1827,20 @@ func @simple_logsoftmax(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { // Fast Fourier Transform op legalization. //===----------------------------------------------------------------------===// +// CHECK-LABEL: func @fft_1D +func @fft_1D(%arg0: tensor<8xcomplex>) -> tensor<8xcomplex> { + // CHECK: "mhlo.fft"(%arg0) {fft_length = dense<8> : tensor<1xi64>, fft_type = "FFT"} : (tensor<8xcomplex> + %0 = "tf.FFT"(%arg0) : (tensor<8xcomplex>) -> tensor<8xcomplex> + return %0 : tensor<8xcomplex> +} + +// CHECK-LABEL: func @ifft_1D +func @ifft_1D(%arg0: tensor<8xcomplex>) -> tensor<8xcomplex> { + // CHECK: "mhlo.fft"(%arg0) {fft_length = dense<8> : tensor<1xi64>, fft_type = "IFFT"} : (tensor<8xcomplex> + %0 = "tf.IFFT"(%arg0) : (tensor<8xcomplex>) -> tensor<8xcomplex> + return %0 : tensor<8xcomplex> +} + // CHECK-LABEL: func @rfft_1D func @rfft_1D(%arg0: tensor<8xf32>) -> tensor<8xcomplex> { %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>) @@ -1763,6 +1849,48 @@ func @rfft_1D(%arg0: tensor<8xf32>) -> tensor<8xcomplex> { return %0 : tensor<8xcomplex> } +// CHECK-LABEL: func @rfft_1D_padded +func @rfft_1D_padded(%arg0: tensor<7xf32>) -> tensor<8xcomplex> { + %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>) + // CHECK: %[[PADDED:.*]] = "mhlo.pad"(%arg0, %2) {edge_padding_high = dense<1> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>} : (tensor<7xf32>, tensor) -> tensor<8xf32> + // CHECK: "mhlo.fft"(%[[PADDED]]) {fft_length = dense<8> : tensor<1xi64>, fft_type = "RFFT"} : (tensor<8xf32> + %0 = "tf.RFFT"(%arg0, %fftlength) : (tensor<7xf32>, tensor<1xi32>) -> tensor<8xcomplex> + return %0 : tensor<8xcomplex> +} + +// CHECK-LABEL: func @rfft_1D_sliced +func @rfft_1D_sliced(%arg0: tensor<2x9xf32>) -> tensor<2x8xcomplex> { + %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>) + // CHECK: %[[SLICED:.*]] = "mhlo.slice"(%arg0) {limit_indices = dense<[2, 8]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x9xf32>) -> tensor<2x8xf32> + // CHECK: "mhlo.fft"(%[[SLICED]]) {fft_length = dense<8> : tensor<1xi64>, fft_type = "RFFT"} : (tensor<2x8xf32> + %0 = "tf.RFFT"(%arg0, %fftlength) : (tensor<2x9xf32>, tensor<1xi32>) -> tensor<2x8xcomplex> + return %0 : tensor<2x8xcomplex> +} + +// CHECK-LABEL: func @irfft_1D +func @irfft_1D(%arg0: tensor<8xcomplex>) -> tensor<5xf32> { + %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>) + // CHECK: %[[SLICED:.*]] = "mhlo.slice"(%arg0) {limit_indices = dense<5> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<8xcomplex>) -> tensor<5xcomplex> + // CHECK: "mhlo.fft"(%[[SLICED]]) {fft_length = dense<5> : tensor<1xi64>, fft_type = "IRFFT"} : (tensor<5xcomplex> + %0 = "tf.IRFFT"(%arg0, %fftlength) : (tensor<8xcomplex>, tensor<1xi32>) -> tensor<5xf32> + return %0 : tensor<5xf32> +} + +// CHECK-LABEL: fft_1D_dynamic +func @fft_1D_dynamic(%arg0: tensor>) -> tensor<8xcomplex> { + // CHECK: "tf.FFT" + %0 = "tf.FFT"(%arg0) : (tensor>) -> tensor<8xcomplex> + return %0 : tensor<8xcomplex> +} + +// CHECK-LABEL: rfft_1D_dynamic +func @rfft_1D_dynamic(%arg0: tensor) -> tensor<8xcomplex> { + %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>) + // CHECK: "tf.RFFT" + %0 = "tf.RFFT"(%arg0, %fftlength) : (tensor, tensor<1xi32>) -> tensor<8xcomplex> + return %0 : tensor<8xcomplex> +} + //===----------------------------------------------------------------------===// // Shape op legalization. //===----------------------------------------------------------------------===// @@ -1881,7 +2009,7 @@ func @abs_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { // CHECK-LABEL: @acos // CHLO-LABEL: @acos func @acos(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK: "chlo.acos"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + // CHECK: chlo.acos %arg0 : tensor<2xf32> // CHLO: %[[VAL_1:.*]] = "mhlo.compare"({{.*}}) {comparison_direction = "NE"} // CHLO: %[[VAL_5:.*]] = mhlo.multiply %arg0, %arg0 // CHLO: %[[VAL_4:.*]] = mhlo.constant dense<1.000000e+00> @@ -1902,24 +2030,41 @@ func @acos(%arg0: tensor<2xf32>) -> tensor<2xf32> { // CHECK-LABEL: @acos_dynamic // CHLO-LABEL: @acos_dynamic func @acos_dynamic(%arg0: tensor<*xf32>) -> tensor<*xf32> { - // CHECK: "chlo.acos"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> -// CHLO: %[[VAL_1:.*]] = "mhlo.compare"({{.*}}) {comparison_direction = "NE"} -// CHLO: %[[VAL_5:.*]] = mhlo.multiply %arg0, %arg0 -// CHLO: %[[VAL_4:.*]] = "chlo.constant_like"(%arg0) {value = 1.000000e+00 : f32} -// CHLO: %[[VAL_6:.*]] = mhlo.subtract %[[VAL_4]], %[[VAL_5]] -// CHLO: %[[VAL_7:.*]] = "mhlo.sqrt"(%[[VAL_6]]) -// CHLO: %[[VAL_8:.*]] = "chlo.constant_like"(%arg0) {value = 1.000000e+00 : f32} -// CHLO: %[[VAL_9:.*]] = mhlo.add %[[VAL_8]], %arg0 -// CHLO: %[[VAL_10:.*]] = mhlo.atan2 %[[VAL_7]], %[[VAL_9]] -// CHLO: %[[VAL_3:.*]] = "chlo.constant_like"(%arg0) {value = 2.000000e+00 : f32} -// CHLO: %[[VAL_11:.*]] = mhlo.multiply %[[VAL_3]], %[[VAL_10]] -// CHLO: %[[VAL_12:.*]] = "chlo.constant_like"(%arg0) {value = 3.14159274 : f32} -// CHLO: %[[VAL_13:.*]] = "mhlo.select"(%[[VAL_1]], %[[VAL_11]], %[[VAL_12]]) -// CHLO: return %[[VAL_13]] + // CHECK: chlo.acos %arg0 : tensor<*xf32> + // `tf.Acos` is lowered to `chlo.constant_like` operations which can only be + // lowered further on ranked tensors. Unranked CHLO must be transformed to + // ranked code before further lowering. + // CHLO: "tf.Acos" %0 = "tf.Acos"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } +// CHECK-LABEL: @tan +// CHECK-SAME: (%[[ARG:.*]]: tensor<2xf32>) -> tensor<2xf32> +// CHLO-LABEL: @tan +// CHLO-SAME: (%[[ARG:.*]]: tensor<2xf32>) -> tensor<2xf32> +func @tan(%arg : tensor<2xf32>) -> tensor<2xf32> { + // CHECK: chlo.tan %[[ARG]] : tensor<2xf32> + // CHLO: %[[SINE:.*]] = "mhlo.sine"(%[[ARG]]) + // CHLO %[[COSINE:.*]] = "mhlo.cosine"(%[[ARG]]) + // CHLO %[[RESULT:.*]] = "mhlo.divide"(%[[SINE]], %[[COSINE]]) + %result = "tf.Tan"(%arg) : (tensor<2xf32>) -> tensor<2xf32> + return %result : tensor<2xf32> +} + +// CHECK-LABEL: @tan_unranked +// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) -> tensor<*xf32> +// CHLO-LABEL: @tan_unranked +// CHLO-SAME: (%[[ARG:.*]]: tensor<*xf32>) -> tensor<*xf32> +func @tan_unranked(%arg : tensor<*xf32>) -> tensor<*xf32> { + // CHECK: chlo.tan %[[ARG]] : tensor<*xf32> + // CHLO: %[[SINE:.*]] = "mhlo.sine"(%[[ARG]]) + // CHLO %[[COSINE:.*]] = "mhlo.cosine"(%[[ARG]]) + // CHLO %[[RESULT:.*]] = "mhlo.divide"(%[[SINE]], %[[COSINE]]) + %result = "tf.Tan"(%arg) : (tensor<*xf32>) -> tensor<*xf32> + return %result : tensor<*xf32> +} + // CHECK-LABEL: func @cast_dynamic_i2f func @cast_dynamic_i2f(%arg0: tensor) -> tensor { // CHECK: "mhlo.convert"(%arg0) : (tensor) -> tensor @@ -2032,6 +2177,13 @@ func @floor_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { return %0 : tensor<*xf32> } +// CHECK-LABEL: func @invert_op_unranked +func @invert_op_unranked(%arg0: tensor<*xi32>) -> tensor<*xi32> { + // CHECK: "mhlo.not"(%arg0) : (tensor<*xi32>) -> tensor<*xi32> + %0 = "tf.Invert"(%arg0) : (tensor<*xi32>) -> tensor<*xi32> + return %0 : tensor<*xi32> +} + // CHECK-LABEL: @is_finite func @is_finite(%arg0: tensor<2xf32>) -> tensor<2xi1> { // CHECK: "mhlo.is_finite"(%arg0) : (tensor<2xf32>) -> tensor<2xi1> @@ -2316,10 +2468,10 @@ func @reshape(%arg0: tensor<2xf32>, %arg1: tensor<2xi32>) -> tensor<2x1xf32> { } // CHECK-LABEL: reshape_dynamic -func @reshape_dynamic(%arg0: tensor, %arg1: tensor<2xi32>) -> tensor<1x1xf32> { - // CHECK: "mhlo.reshape" - %0 = "tf.Reshape"(%arg0, %arg1) : (tensor, tensor<2xi32>) -> tensor<1x1xf32> - return %0 : tensor<1x1xf32> +func @reshape_dynamic(%arg0: tensor, %arg1: tensor<2xi32>) -> tensor { + // CHECK: "mhlo.dynamic_reshape" + %0 = "tf.Reshape"(%arg0, %arg1) : (tensor, tensor<2xi32>) -> tensor + return %0 : tensor } // CHECK-LABEL: reshape_unranked @@ -2350,6 +2502,25 @@ func @expand_dims(%arg0: tensor<2xf32>, %axis: tensor) -> tensor<1x2xf32> { return %0 : tensor<1x2xf32> } +// CHECK-LABEL: expand_dims_dynamic +func @expand_dims_dynamic(%arg0: tensor) -> tensor { + %axis = "tf.Const"() {value = dense<1> : tensor} : () -> (tensor) + + // CHECK-DAG: [[SHAPEOF:%.+]] = shape.shape_of %arg0 + // CHECK-DAG: [[CST0:%.+]] = constant 0 + // CHECK-DAG: [[CST1:%.+]] = constant 1 + // CHECK-DAG: [[GETEXTENT0:%.+]] = shape.get_extent [[SHAPEOF]], [[CST0]] + // CHECK-DAG: [[CST1_0:%.+]] = constant 1 + // CHECK-DAG: [[GETEXTENT1:%.+]] = shape.get_extent [[SHAPEOF]], [[CST1_0]] + // CHECK-DAG: [[FROMEXTENTS:%.+]] = shape.from_extents [[GETEXTENT0]], [[CST1]], [[GETEXTENT1]] + // CHECK-DAG: [[TOEXTENTS:%.+]] = shape.to_extent_tensor [[FROMEXTENTS]] + // CHECK-DAG: [[RESHAPE:%.+]] = "mhlo.dynamic_reshape"(%arg0, [[TOEXTENTS]]) + %0 = "tf.ExpandDims"(%arg0, %axis) : (tensor, tensor) -> tensor + + // CHECK: return [[RESHAPE]] + return %0 : tensor +} + // CHECK-LABEL: func @sign // CHECK-SAME: [[ARG:%arg.*]]: tensor<1x2x3x4xf32> func @sign(%arg0: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> { @@ -2942,6 +3113,15 @@ func @max(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> { return %0 : tensor<4x1xf16> } +// CHECK-LABEL: func @max_qint +// Regression test to ensure we don't crash getting the initial value for +// tf.Max when using quantized integer types. +func @max_qint(%arg0: tensor<4x8x!tf.qint8>) -> tensor<4x1x!tf.qint8> { + %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> + %0 = "tf.Max"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8x!tf.qint8>, tensor<1xi64>) -> tensor<4x1x!tf.qint8> + return %0 : tensor<4x1x!tf.qint8> +} + // CHECK-LABEL: func @max_dynamic func @max_dynamic(%arg0: tensor<4x?xf16>) -> tensor<4x1xf16> { // CHECK: %[[CAST:.*]] = "mhlo.convert"(%arg0) : (tensor<4x?xf16>) -> tensor<4x?xf16> @@ -2976,6 +3156,15 @@ func @min(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> { return %0 : tensor<4x1xf16> } +// CHECK-LABEL: func @min_qint +// Regression test to ensure we don't crash getting the initial value for +// tf.Min when using quantized integer types. +func @min_qint(%arg0: tensor<4x8x!tf.qint8>) -> tensor<4x1x!tf.qint8> { + %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> + %0 = "tf.Min"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8x!tf.qint8>, tensor<1xi64>) -> tensor<4x1x!tf.qint8> + return %0 : tensor<4x1x!tf.qint8> +} + // CHECK-LABEL: func @prod func @prod(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> { // CHECK: %[[CAST:.*]] = "mhlo.convert"(%arg0) : (tensor<4x8xf16>) -> tensor<4x8xf32> @@ -2993,6 +3182,15 @@ func @prod(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> { return %0 : tensor<4x1xf16> } +// CHECK-LABEL: func @prod_qint +// Regression test to ensure we don't crash getting the initial value for +// tf.Prod when using quantized integer types. +func @prod_qint(%arg0: tensor<4x8x!tf.qint8>) -> tensor<4x1x!tf.qint8> { + %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> + %0 = "tf.Prod"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8x!tf.qint8>, tensor<1xi64>) -> tensor<4x1x!tf.qint8> + return %0 : tensor<4x1x!tf.qint8> +} + // CHECK-LABEL: @all func @all(%input: tensor<4x8xi1>) -> tensor<4xi1> { %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> @@ -3685,15 +3883,13 @@ func @topk_v2(%input: tensor<16x16xf32>) -> (tensor<16x8xf32>, tensor<16x8xi32>) %k = "tf.Const"() {value = dense<8> : tensor} : () -> tensor // CHECK: %[[IOTA:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} - // CHECK-NEXT: %[[SORT:.*]] = "mhlo.sort"(%[[INPUT]], %[[IOTA]]) ( { + // CHECK-NEXT: %[[SORT:.*]]:2 = "mhlo.sort"(%[[INPUT]], %[[IOTA]]) ( { // CHECK-NEXT: ^{{.*}}(%[[LHS:.*]]: tensor, %[[RHS:.*]]: tensor, %{{.*}}: tensor, %{{.*}}: tensor): // CHECK-NEXT: %[[CMP:.*]] = "mhlo.compare"(%[[LHS]], %[[RHS]]) {comparison_direction = "GT"} // CHECK-NEXT: "mhlo.return"(%[[CMP]]) - // CHECK-NEXT: }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> - // CHECK-NEXT: %[[TUPL0:.*]] = "mhlo.get_tuple_element"(%[[SORT]]) {index = 0 : i32} - // CHECK-NEXT: %[[TUPL1:.*]] = "mhlo.get_tuple_element"(%[[SORT]]) {index = 1 : i32} - // CHECK-NEXT: %[[VAL:.*]] = "mhlo.slice"(%[[TUPL0]]) {limit_indices = dense<[16, 8]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} - // CHECK-NEXT: %[[IDX:.*]] = "mhlo.slice"(%[[TUPL1]]) {limit_indices = dense<[16, 8]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + // CHECK-NEXT: }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>) + // CHECK-NEXT: %[[VAL:.*]] = "mhlo.slice"(%[[SORT]]#0) {limit_indices = dense<[16, 8]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + // CHECK-NEXT: %[[IDX:.*]] = "mhlo.slice"(%[[SORT]]#1) {limit_indices = dense<[16, 8]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} // CHECK-NEXT: return %[[VAL]], %[[IDX]] %0:2 = "tf.TopKV2"(%input, %k): (tensor<16x16xf32>, tensor) -> (tensor<16x8xf32>, tensor<16x8xi32>) return %0#0, %0#1: tensor<16x8xf32>, tensor<16x8xi32> @@ -4060,12 +4256,11 @@ func @random_shuffle_1D_16(%input: tensor<16xf32>) -> tensor<16xf32> { // CHECK: [[LOWER:%.*]] = mhlo.constant dense<0> : tensor // CHECK: [[UPPER:%.*]] = mhlo.constant dense<-1> : tensor // CHECK: [[RNG:%.*]] = "mhlo.rng_uniform"([[LOWER]], [[UPPER]], [[SHAPE]]) - // CHECK: [[SORT:%.*]] = "mhlo.sort"([[RNG]], [[INPUT]]) ( { + // CHECK: [[SORT:%.*]]:2 = "mhlo.sort"([[RNG]], [[INPUT]]) ( { // CHECK: ^{{.*}}([[ARG1:%.*]]: tensor, [[ARG2:%.*]]: tensor, {{.*}}: tensor, {{.*}}: tensor): // CHECK: "mhlo.compare"([[ARG1]], [[ARG2]]) {comparison_direction = "LT"} - // CHECK: }) {dimension = -1 : i64, is_stable = true} : (tensor<16xi32>, tensor<16xf32>) -> tuple, tensor<16xf32>> - // CHECK: [[RES:%.*]] = "mhlo.get_tuple_element"([[SORT]]) {index = 1 : i32} - // CHECK: return [[RES]] + // CHECK: }) {dimension = -1 : i64, is_stable = true} : (tensor<16xi32>, tensor<16xf32>) -> (tensor<16xi32>, tensor<16xf32>) + // CHECK: return [[SORT]]#1 %0 = "tf.RandomShuffle"(%input) : (tensor<16xf32>) -> (tensor<16xf32>) return %0: tensor<16xf32> } @@ -4074,10 +4269,8 @@ func @random_shuffle_1D_16(%input: tensor<16xf32>) -> tensor<16xf32> { func @random_shuffle_1D_10240(%input: tensor<10240xf32>) -> tensor<10240xf32> { // CHECK: mhlo.rng_uniform // CHECK: mhlo.sort - // CHECK: mhlo.get_tuple_element // CHECK: mhlo.rng_uniform // CHECK: mhlo.sort - // CHECK: mhlo.get_tuple_element %0 = "tf.RandomShuffle"(%input) : (tensor<10240xf32>) -> (tensor<10240xf32>) return %0: tensor<10240xf32> } @@ -4859,3 +5052,20 @@ func @xla_gather_i32(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>) -> %0 = "tf.XlaGather"(%arg0, %arg1, %cst) {dimension_numbers = "\0A\01\01\12\01\00\1A\01\00 \01", indices_are_sorted = true} : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<3xi32>) -> tensor<10x1x300xf32> return %0 : tensor<10x1x300xf32> } + + +// CHECK: func @stridedslice_with_i32 +func @stridedslice_with_i32(%arg0: tensor) -> tensor<4xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "const_0_arg", outputs = "identity_0_retval_RetVal"}} { +// CHECK-NOT: tf.StridedSlice +// CHECK: [[DYNSLICE:%.*]] = "mhlo.dynamic-slice +// CHECK: [[RESHAPE:%.*]] = "mhlo.reshape"([[DYNSLICE]]) +// CHECK: return [[RESHAPE]] + %0 = "tf.Const"() {value = dense<[[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00, 7.000000e+00]]> : tensor<2x4xf32>} : () -> tensor<2x4xf32> + %1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + %3 = "tf.AddV2"(%arg0, %1) {_xla_inferred_shapes = [#tf.shape<>], device = ""} : (tensor, tensor) -> tensor + %4 = "tf.Pack"(%3) {_xla_inferred_shapes = [#tf.shape<1>], axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %5 = "tf.Pack"(%arg0) {_xla_inferred_shapes = [#tf.shape<1>], axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %6 = "tf.StridedSlice"(%0, %5, %4, %2) {_xla_inferred_shapes = [#tf.shape<4>], begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2x4xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xf32> + return %6 : tensor<4xf32> +} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/BUILD b/tensorflow/compiler/mlir/xla/tests/translate/BUILD index c4e747c90f3..7dc66edd9e1 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/BUILD +++ b/tensorflow/compiler/mlir/xla/tests/translate/BUILD @@ -1,3 +1,4 @@ +load("//tensorflow:tensorflow.bzl", "filegroup") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") package(licenses = ["notice"]) diff --git a/tensorflow/compiler/mlir/xla/tests/translate/case.mlir b/tensorflow/compiler/mlir/xla/tests/translate/case.mlir index 1032bb723c5..cea0599adb0 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/case.mlir +++ b/tensorflow/compiler/mlir/xla/tests/translate/case.mlir @@ -1,10 +1,10 @@ // RUN: tf-mlir-translate -split-input-file -mlir-hlo-to-hlo-text %s | FILECHECK_OPTS="" FileCheck %s func @main() -> tensor { - %cst = constant {name = "constant"} dense<1> : tensor - %cst_0 = constant {name = "constant.1"} dense<5.600000e+01> : tensor - %cst_1 = constant {name = "constant.2"} dense<1.200000e+01> : tensor - %cst_2 = constant {name = "constant.3"} dense<1.300000e+01> : tensor + %cst = constant dense<1> : tensor + %cst_0 = constant dense<5.600000e+01> : tensor + %cst_1 = constant dense<1.200000e+01> : tensor + %cst_2 = constant dense<1.300000e+01> : tensor %0 = "mhlo.case"(%cst, %cst_0, %cst_1, %cst_2) ( { ^bb0(%arg0: tensor): %1 = "mhlo.negate"(%arg0) : (tensor) -> tensor @@ -17,7 +17,7 @@ func @main() -> tensor { ^bb0(%arg0: tensor): %1 = "mhlo.floor"(%arg0) : (tensor) -> tensor "mhlo.return"(%1) : (tensor) -> () - }) {name = "conditional"} : (tensor, tensor, tensor, tensor) -> tensor + }) : (tensor, tensor, tensor, tensor) -> tensor return %0 : tensor } @@ -48,23 +48,23 @@ func @main() -> tensor { // ----- func @main() -> (tensor, tensor) { - %cst = constant {name = "constant"} dense<1> : tensor - %cst_0 = constant {name = "constant.1"} dense<5.600000e+01> : tensor - %cst_1 = constant {name = "constant.2"} dense<1.200000e+01> : tensor - %cst_2 = constant {name = "constant.3"} dense<1.300000e+01> : tensor + %cst = constant dense<1> : tensor + %cst_0 = constant dense<5.600000e+01> : tensor + %cst_1 = constant dense<1.200000e+01> : tensor + %cst_2 = constant dense<1.300000e+01> : tensor %0:2 = "mhlo.case"(%cst, %cst_0, %cst_1, %cst_2) ( { ^bb0(%arg0: tensor): - %1 = "mhlo.negate"(%arg0) {name = "negate"} : (tensor) -> tensor + %1 = "mhlo.negate"(%arg0) : (tensor) -> tensor "mhlo.return"(%1, %1) : (tensor, tensor) -> () }, { ^bb0(%arg0: tensor): - %1 = "mhlo.copy"(%arg0) {name = "copy"} : (tensor) -> tensor + %1 = "mhlo.copy"(%arg0) : (tensor) -> tensor "mhlo.return"(%1, %1) : (tensor, tensor) -> () }, { ^bb0(%arg0: tensor): - %1 = "mhlo.floor"(%arg0) {name = "floor"} : (tensor) -> tensor + %1 = "mhlo.floor"(%arg0) : (tensor) -> tensor "mhlo.return"(%1, %1) : (tensor, tensor) -> () - }) {name = "conditional"} : (tensor, tensor, tensor, tensor) -> (tensor, tensor) + }) : (tensor, tensor, tensor, tensor) -> (tensor, tensor) return %0#0, %0#1 : tensor, tensor } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/case_conditional.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/case_conditional.hlotxt index 62f0d7a59e4..1fa7367763e 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/case_conditional.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/case_conditional.hlotxt @@ -26,21 +26,21 @@ ENTRY %indexed_conditional () -> f32[] { } // CHECK-LABEL: func @main() -> tensor -// CHECK: %[[INDEX:.*]] = constant {name = "constant"} dense<1> : tensor -// CHECK: %[[OPERAND_1:.*]] = constant {name = "{{.*}}"} dense<5.600000e+01> : tensor -// CHECK: %[[OPERAND_2:.*]] = constant {name = "{{.*}}"} dense<1.200000e+01> : tensor -// CHECK: %[[OPERAND_3:.*]] = constant {name = "{{.*}}"} dense<1.300000e+01> : tensor +// CHECK: %[[INDEX:.*]] = constant dense<1> : tensor +// CHECK: %[[OPERAND_1:.*]] = constant dense<5.600000e+01> : tensor +// CHECK: %[[OPERAND_2:.*]] = constant dense<1.200000e+01> : tensor +// CHECK: %[[OPERAND_3:.*]] = constant dense<1.300000e+01> : tensor // CHECK: %[[RESULT:.*]] = "mhlo.case"(%[[INDEX]], %[[OPERAND_1]], %[[OPERAND_2]], %[[OPERAND_3]]) ( { // CHECK: ^bb0(%[[ARG_1:.*]]: tensor): -// CHECK: %[[RES_1:.*]] = "mhlo.negate"(%[[ARG_1]]) {name = "{{.*}}"} : (tensor) -> tensor +// CHECK: %[[RES_1:.*]] = "mhlo.negate"(%[[ARG_1]]) : (tensor) -> tensor // CHECK: "mhlo.return"(%[[RES_1]]) : (tensor) -> () // CHECK: }, { // CHECK: ^bb0(%[[ARG_2:.*]]: tensor): -// CHECK: %[[RES_2:.*]] = "mhlo.copy"(%[[ARG_2]]) {name = "{{.*}}"} : (tensor) -> tensor +// CHECK: %[[RES_2:.*]] = "mhlo.copy"(%[[ARG_2]]) : (tensor) -> tensor // CHECK: "mhlo.return"(%[[RES_2]]) : (tensor) -> () // CHECK: }, { // CHECK: ^bb0(%[[ARG_3:.*]]: tensor): -// CHECK: %[[RES_3:.*]] = "mhlo.floor"(%[[ARG_3]]) {name = "{{.*}}"} : (tensor) -> tensor +// CHECK: %[[RES_3:.*]] = "mhlo.floor"(%[[ARG_3]]) : (tensor) -> tensor // CHECK: "mhlo.return"(%[[RES_3]]) : (tensor) -> () -// CHECK: }) {name = "{{.*}}"} : (tensor, tensor, tensor, tensor) -> tensor +// CHECK: }) : (tensor, tensor, tensor, tensor) -> tensor // CHECK: return %[[RESULT]] : tensor diff --git a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir index 316eda4c4aa..c078191d170 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir +++ b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir @@ -362,7 +362,9 @@ func @main(%arg0: tensor<2x3xf32>, %arg1: tensor<5x5xf32>) -> tensor<1x2x3xf32> // CHECK: [[VAL_1:%.*]] = f32[2,3] parameter(0) // CHECK: [[VAL_2:%.*]] = f32[5,5] parameter(1) // CHECK: ROOT -// CHECK-SAME: f32[1,2,3] custom-call(f32[2,3] [[VAL_1]], f32[5,5] [[VAL_2]]), custom_call_target="foo", backend_config="bar" +// CHECK-SAME: f32[1,2,3] custom-call(f32[2,3] [[VAL_1]], f32[5,5] [[VAL_2]]) +// CHECK-SAME: custom_call_target="foo" +// CHECK-SAME: backend_config="bar" // ----- @@ -437,7 +439,7 @@ func @main(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>) -> tensor<10 // CHECK-SAME: index_vector_dim=1 // CHECK-SAME: slice_sizes={1,1,300} // CHECK-SAME: indices_are_sorted=true - %0 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = {collapsed_slice_dims = dense<[0, 1]> : tensor<2xi64>, index_vector_dim = 1 : i64, offset_dims = dense<1> : tensor<1xi64>, start_index_map = dense<[0, 1]> : tensor<2xi64>}, indices_are_sorted = true, name = "gather", slice_sizes = dense<[1, 1, 300]> : tensor<3xi64>} : (tensor<200x100x300xf32>, tensor<10x2xi32>) -> tensor<10x300xf32> + %0 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = {collapsed_slice_dims = dense<[0, 1]> : tensor<2xi64>, index_vector_dim = 1 : i64, offset_dims = dense<1> : tensor<1xi64>, start_index_map = dense<[0, 1]> : tensor<2xi64>}, indices_are_sorted = true, slice_sizes = dense<[1, 1, 300]> : tensor<3xi64>} : (tensor<200x100x300xf32>, tensor<10x2xi32>) -> tensor<10x300xf32> return %0 : tensor<10x300xf32> } @@ -500,7 +502,7 @@ func @main() -> tensor<1x10xf32> { func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { %0 = "mhlo.map"(%arg0, %arg1) ( { ^bb0(%arg2: tensor, %arg3: tensor): // no predecessors - %1 = mhlo.add %arg2, %arg3 {name = "add"} : tensor + %1 = mhlo.add %arg2, %arg3 : tensor "mhlo.return"(%1) : (tensor) -> () }) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> @@ -737,7 +739,7 @@ func @main(%arg0: tensor, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> // CHECK: %[[ARG2:.*]] = s32[2,3] parameter(2) // CHECK: ROOT %[[RES:.*]] = s32[2,3] select(pred[2,3] %[[COND]], s32[2,3] %[[ARG1]], s32[2,3] %[[ARG2]]) - %0 = "mhlo.select"(%arg0, %arg1, %arg2) {name = "select.4"} : (tensor, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + %0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> return %0 : tensor<2x3xi32> } @@ -946,19 +948,20 @@ func @main(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> { // CHECK: HloModule func @main(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { - %0 = "mhlo.sort"(%input0, %input1) ( { + %0:2 = "mhlo.sort"(%input0, %input1) ( { ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): %7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor "mhlo.return"(%7) : (tensor) -> () - }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> + }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>) return } // CHECK: %[[SORT_CMP:.*]] ([[ARG0:.*]]: f32[], [[ARG1:.*]]: f32[], {{.*}}: s32[], {{.*}}: s32[]) -> pred[] { // CHECK: ROOT %compare.8 = pred[] compare(f32[] %[[ARG0]], f32[] %[[ARG1]]), direction=GT -// CHECK: ENTRY %{{.*}} ([[MAIN_ARG0:.*]]: f32[16,16], [[MAIN_ARG1:.*]]: s32[16,16]) -> (f32[16,16], s32[16,16]) { -// CHECK: ROOT %{{.*}} = (f32[16,16], s32[16,16]) sort(f32[16,16] %[[MAIN_ARG0]], s32[16,16] %[[MAIN_ARG1]]), dimensions={1}, is_stable=true, to_apply=%[[SORT_CMP]] +// CHECK: [[SORT:%.+]] = (f32[16,16], s32[16,16]) sort(f32[16,16] %Arg_0.1, s32[16,16] %Arg_1.2), dimensions={1}, is_stable=true, to_apply=%[[SORT_CMP]] +// CHECK: [[GET0:%.+]] = f32[16,16] get-tuple-element((f32[16,16], s32[16,16]) [[SORT]]), index=0 +// CHECK: ROOT [[GET1:%.+]] = s32[16,16] get-tuple-element((f32[16,16], s32[16,16]) [[SORT]]), index=1 // ----- @@ -1099,3 +1102,33 @@ func @main(%arg: tensor<3xui64>) -> tuple, tensor<2x2xui32>> { %0 = "mhlo.rng_bit_generator"(%arg) {rng_algorithm = 2 : i32} : (tensor<3xui64>) -> tuple, tensor<2x2xui32>> return %0 : tuple, tensor<2x2xui32>> } + +// ----- + +// CHECK: HloModule +func @main(%arg: tensor<3x4xf32>) -> tensor<3x4xf32> { +// CHECK: %[[ARG0:.*]] = f32[3,4] parameter(0) +// CHECK: ROOT %[[RESULT:.*]] = f32[3,4] cbrt(f32[3,4] %[[ARG0]]) + %0 = "mhlo.cbrt"(%arg) : (tensor<3x4xf32>) -> tensor<3x4xf32> + return %0 : tensor<3x4xf32> +} + +// ----- + +// CHECK: HloModule +func @main(%arg: tensor<3x4xf32>) -> tensor<3x4xf32> { +// CHECK: %[[ARG0:.*]] = f32[3,4] parameter(0) +// CHECK: ROOT %[[RESULT:.*]] = f32[3,4] reduce-precision(f32[3,4] %[[ARG0]]), exponent_bits=8, mantissa_bits=10 + %0 = "mhlo.reduce_precision"(%arg) {exponent_bits = 8 : i32, mantissa_bits = 10 : i32} : (tensor<3x4xf32>) -> tensor<3x4xf32> + return %0 : tensor<3x4xf32> +} + +// ----- + +// CHECK: HloModule +func @main(%arg: tensor<3x4xf32>) -> tensor<3x4x1xf32> { +// CHECK: %[[ARG0:.*]] = f32[3,4] parameter(0) +// CHECK: ROOT %[[RESULT:.*]] = f32[3,4,1] bitcast(f32[3,4] %[[ARG0]]) + %0 = "mhlo.bitcast"(%arg) : (tensor<3x4xf32>) -> tensor<3x4x1xf32> + return %0 : tensor<3x4x1xf32> +} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/fully_connected_reference_model.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/fully_connected_reference_model.hlotxt index 86adcf0710f..4cc70be0965 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/fully_connected_reference_model.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/fully_connected_reference_model.hlotxt @@ -9,95 +9,95 @@ ENTRY %tfcompile.48 { %arg0.1 = f32[1,300] parameter(0) %arg1.2 = f32[1,300,3,1] parameter(1) - // CHECK-NEXT: %0 = "mhlo.reshape"(%arg0) {name = "reshape.3"} : (tensor<1x300xf32>) -> tensor<1x300xf32> + // CHECK-NEXT: %0 = "mhlo.reshape"(%arg0) : (tensor<1x300xf32>) -> tensor<1x300xf32> %reshape.3 = f32[1,300] reshape(%arg0.1) - // CHECK-NEXT: %1 = "mhlo.transpose"(%0) {name = "transpose.27", permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<1x300xf32>) -> tensor<300x1xf32> + // CHECK-NEXT: %1 = "mhlo.transpose"(%0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<1x300xf32>) -> tensor<300x1xf32> %transpose.27 = f32[300,1] transpose(%reshape.3), dimensions={1,0} - // CHECK-NEXT: %2 = "mhlo.reshape"(%1) {name = "reshape.28"} : (tensor<300x1xf32>) -> tensor<300x1x1xf32> + // CHECK-NEXT: %2 = "mhlo.reshape"(%1) : (tensor<300x1xf32>) -> tensor<300x1x1xf32> %reshape.28 = f32[300,1,1] reshape(%transpose.27) - // CHECK-NEXT: %3 = "mhlo.reshape"(%2) {name = "reshape.29"} : (tensor<300x1x1xf32>) -> tensor<300x1xf32> + // CHECK-NEXT: %3 = "mhlo.reshape"(%2) : (tensor<300x1x1xf32>) -> tensor<300x1xf32> %reshape.29 = f32[300,1] reshape(%reshape.28) - // CHECK-NEXT: %4 = "mhlo.broadcast_in_dim"(%3) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>, name = "broadcast.30"} : (tensor<300x1xf32>) -> tensor<300x1x5xf32> + // CHECK-NEXT: %4 = "mhlo.broadcast_in_dim"(%3) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<300x1xf32>) -> tensor<300x1x5xf32> %broadcast.30 = f32[300,1,5] broadcast(%reshape.29), dimensions={0,1} - // CHECK-NEXT: %cst = constant {name = "constant.8"} dense<1.000000e+00> : tensor + // CHECK-NEXT: %cst = constant dense<1.000000e+00> : tensor %constant.8 = f32[] constant(1) - // CHECK-NEXT: %5 = "mhlo.broadcast_in_dim"(%cst) {broadcast_dimensions = dense<> : tensor<0xi64>, name = "broadcast.9"} : (tensor) -> tensor<300x1x5xf32> + // CHECK-NEXT: %5 = "mhlo.broadcast_in_dim"(%cst) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<300x1x5xf32> %broadcast.9 = f32[300,1,5] broadcast(%constant.8), dimensions={} - // CHECK-NEXT: %6 = mhlo.multiply %4, %5 {name = "multiply.31"} : tensor<300x1x5xf32> + // CHECK-NEXT: %6 = mhlo.multiply %4, %5 : tensor<300x1x5xf32> %multiply.31 = f32[300,1,5] multiply(%broadcast.30, %broadcast.9) - // CHECK-NEXT: %cst_0 = constant {name = "constant.32"} dense<0.000000e+00> : tensor + // CHECK-NEXT: %cst_0 = constant dense<0.000000e+00> : tensor %constant.32 = f32[] constant(0) - // CHECK-NEXT: %7 = "mhlo.broadcast_in_dim"(%cst_0) {broadcast_dimensions = dense<> : tensor<0xi64>, name = "broadcast.33"} : (tensor) -> tensor<300x1x5xf32> + // CHECK-NEXT: %7 = "mhlo.broadcast_in_dim"(%cst_0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<300x1x5xf32> %broadcast.33 = f32[300,1,5] broadcast(%constant.32), dimensions={} - // CHECK-NEXT: %8 = "mhlo.compare"(%6, %7) {comparison_direction = "GT", name = "compare.34"} : (tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xi1> + // CHECK-NEXT: %8 = "mhlo.compare"(%6, %7) {comparison_direction = "GT"} : (tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xi1> %compare.34 = pred[300,1,5] compare(%multiply.31, %broadcast.33), direction=GT - // CHECK-NEXT: %cst_1 = constant {name = "constant.10"} dense<0.000000e+00> : tensor + // CHECK-NEXT: %cst_1 = constant dense<0.000000e+00> : tensor %constant.10 = f32[] constant(0) - // CHECK-NEXT: %9 = "mhlo.broadcast_in_dim"(%cst_1) {broadcast_dimensions = dense<> : tensor<0xi64>, name = "broadcast.11"} : (tensor) -> tensor<300x1x5xf32> + // CHECK-NEXT: %9 = "mhlo.broadcast_in_dim"(%cst_1) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<300x1x5xf32> %broadcast.11 = f32[300,1,5] broadcast(%constant.10), dimensions={} - // CHECK-NEXT: %cst_2 = constant {name = "constant.40"} dense<0.000000e+00> : tensor + // CHECK-NEXT: %cst_2 = constant dense<0.000000e+00> : tensor %constant.40 = f32[] constant(0) - // CHECK-NEXT: %10 = "mhlo.broadcast_in_dim"(%cst_2) {broadcast_dimensions = dense<> : tensor<0xi64>, name = "broadcast.41"} : (tensor) -> tensor<300x5xf32> + // CHECK-NEXT: %10 = "mhlo.broadcast_in_dim"(%cst_2) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<300x5xf32> %broadcast.41 = f32[300,5] broadcast(%constant.40), dimensions={} - // CHECK-NEXT: %11 = "mhlo.copy"(%arg1) {name = "copy.1"} : (tensor<1x300x3x1xf32>) -> tensor<1x300x3x1xf32> + // CHECK-NEXT: %11 = "mhlo.copy"(%arg1) : (tensor<1x300x3x1xf32>) -> tensor<1x300x3x1xf32> %copy.1 = f32[1,300,3,1] copy(%arg1.2) - // CHECK-NEXT: %12 = "mhlo.reshape"(%11) {name = "reshape.4"} : (tensor<1x300x3x1xf32>) -> tensor<1x300x3x1xf32> + // CHECK-NEXT: %12 = "mhlo.reshape"(%11) : (tensor<1x300x3x1xf32>) -> tensor<1x300x3x1xf32> %reshape.4 = f32[1,300,3,1] reshape(%copy.1) - // CHECK-NEXT: %13 = "mhlo.reshape"(%12) {name = "reshape.24"} : (tensor<1x300x3x1xf32>) -> tensor<1x300x3xf32> + // CHECK-NEXT: %13 = "mhlo.reshape"(%12) : (tensor<1x300x3x1xf32>) -> tensor<1x300x3xf32> %reshape.24 = f32[1,300,3] reshape(%reshape.4) - // CHECK-NEXT: %14 = "mhlo.transpose"(%13) {name = "transpose.25", permutation = dense<[1, 0, 2]> : tensor<3xi64>} : (tensor<1x300x3xf32>) -> tensor<300x1x3xf32> + // CHECK-NEXT: %14 = "mhlo.transpose"(%13) {permutation = dense<[1, 0, 2]> : tensor<3xi64>} : (tensor<1x300x3xf32>) -> tensor<300x1x3xf32> %transpose.25 = f32[300,1,3] transpose(%reshape.24), dimensions={1,0,2} - // CHECK-NEXT: %15 = "mhlo.reshape"(%14) {name = "reshape.26"} : (tensor<300x1x3xf32>) -> tensor<300x3xf32> + // CHECK-NEXT: %15 = "mhlo.reshape"(%14) : (tensor<300x1x3xf32>) -> tensor<300x3xf32> %reshape.26 = f32[300,3] reshape(%transpose.25) - // CHECK-NEXT: %cst_3 = constant {name = "constant.35"} dense<{{\[\[}}-1.060230e-01, 1.215050e-01, 8.002390e-01, -7.688850e-01, 0.0966112986], [6.890140e-01, -4.070560e-01, -0.797852993, 3.789250e-03, -2.088810e-01], [-6.085290e-01, 2.766170e-02, 2.685570e-01, 5.774010e-01, -4.284370e-01]]> : tensor<3x5xf32> + // CHECK-NEXT: %cst_3 = constant dense<{{\[\[}}-1.060230e-01, 1.215050e-01, 8.002390e-01, -7.688850e-01, 0.0966112986], [6.890140e-01, -4.070560e-01, -0.797852993, 3.789250e-03, -2.088810e-01], [-6.085290e-01, 2.766170e-02, 2.685570e-01, 5.774010e-01, -4.284370e-01]]> : tensor<3x5xf32> %constant.35 = f32[3,5] constant({ { -0.106023, 0.121505, 0.800239, -0.768885, 0.0966113 }, { 0.689014, -0.407056, -0.797853, 0.00378925, -0.208881 }, { -0.608529, 0.0276617, 0.268557, 0.577401, -0.428437 } }) // TODO(b/129709049) consider making this default precision config implied. - // CHECK-NEXT: %16 = "mhlo.dot"(%15, %cst_3) {name = "dot.36", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<300x3xf32>, tensor<3x5xf32>) -> tensor<300x5xf32> + // CHECK-NEXT: %16 = "mhlo.dot"(%15, %cst_3) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<300x3xf32>, tensor<3x5xf32>) -> tensor<300x5xf32> %dot.36 = f32[300,5] dot(%reshape.26, %constant.35), lhs_contracting_dims={1}, rhs_contracting_dims={0} - // CHECK-NEXT: %cst_4 = constant {name = "constant.37"} dense<0.000000e+00> : tensor<5xf32> + // CHECK-NEXT: %cst_4 = constant dense<0.000000e+00> : tensor<5xf32> %constant.37 = f32[5]{0} constant({0, 0, 0, 0, 0}) - // CHECK-NEXT: %17 = "mhlo.broadcast_in_dim"(%cst_4) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "broadcast.38"} : (tensor<5xf32>) -> tensor<300x5xf32> + // CHECK-NEXT: %17 = "mhlo.broadcast_in_dim"(%cst_4) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<5xf32>) -> tensor<300x5xf32> %broadcast.38 = f32[300,5] broadcast(%constant.37), dimensions={1} - // CHECK-NEXT: %18 = mhlo.add %16, %17 {name = "add.39"} : tensor<300x5xf32> + // CHECK-NEXT: %18 = mhlo.add %16, %17 : tensor<300x5xf32> %add.39 = f32[300,5] add(%dot.36, %broadcast.38) - // CHECK-NEXT: %19 = mhlo.maximum %10, %18 {name = "maximum.42"} : tensor<300x5xf32> + // CHECK-NEXT: %19 = mhlo.maximum %10, %18 : tensor<300x5xf32> %maximum.42 = f32[300,5] maximum(%broadcast.41, %add.39) - // CHECK-NEXT: %20 = "mhlo.reshape"(%19) {name = "reshape.44"} : (tensor<300x5xf32>) -> tensor<300x1x5xf32> + // CHECK-NEXT: %20 = "mhlo.reshape"(%19) : (tensor<300x5xf32>) -> tensor<300x1x5xf32> %reshape.44 = f32[300,1,5] reshape(%maximum.42) - // CHECK-NEXT: %21 = "mhlo.select"(%8, %9, %20) {name = "select.45"} : (tensor<300x1x5xi1>, tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xf32> + // CHECK-NEXT: %21 = "mhlo.select"(%8, %9, %20) : (tensor<300x1x5xi1>, tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xf32> %select.45 = f32[300,1,5] select(%compare.34, %broadcast.11, %reshape.44) - // CHECK-NEXT: %22 = "mhlo.reshape"(%21) {name = "reshape.46"} : (tensor<300x1x5xf32>) -> tensor<300x1x5xf32> + // CHECK-NEXT: %22 = "mhlo.reshape"(%21) : (tensor<300x1x5xf32>) -> tensor<300x1x5xf32> %reshape.46 = f32[300,1,5] reshape(%select.45) - // CHECK-NEXT: %23 = "mhlo.tuple"(%22) {name = "tuple.47"} : (tensor<300x1x5xf32>) -> tuple> + // CHECK-NEXT: %23 = "mhlo.tuple"(%22) : (tensor<300x1x5xf32>) -> tuple> // CHECK-NEXT: return %23 : tuple> ROOT %tuple.47 = (f32[300,1,5]) tuple(%reshape.46) } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt index 4d4e0213da8..cce49b16c6c 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt @@ -13,12 +13,12 @@ ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] { %Arg_0.1 = f32[4]{0} parameter(0) %Arg_1.2 = f32[4]{0} parameter(1) - // CHECK-NEXT: mhlo.add %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32> - %add.3 = f32[4]{0} add(f32[4]{0} %Arg_0.1, f32[4]{0} %Arg_1.2) + // CHECK-NEXT: mhlo.add %arg0, %arg1 : tensor<4xf32> + %add.42 = f32[4]{0} add(f32[4]{0} %Arg_0.1, f32[4]{0} %Arg_1.2) // TODO(b/129709049) consider making this default precision config inferred. - // CHECK-NEXT: "mhlo.dot"(%0, %arg1) {name = "{{.*}}", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<4xf32>, tensor<4xf32>) -> tensor - ROOT %dot.4 = f32[] dot(f32[4]{0} %add.3, f32[4]{0} %Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={0} + // CHECK-NEXT: "mhlo.dot"(%0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<4xf32>, tensor<4xf32>) -> tensor + ROOT %dot.4 = f32[] dot(f32[4]{0} %add.42, f32[4]{0} %Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={0} } // CHECK-LABEL: func @test_after_all @@ -26,7 +26,7 @@ ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] { %test_after_all (token0: token[], token1: token[] ) -> token[] { token0 = token[] parameter(0) token1 = token[] parameter(1) - // CHECK-NEXT: "mhlo.after_all"([[VAL_0]], [[VAL_1]]) {name = "{{.*}}"} : (!mhlo.token, !mhlo.token) -> !mhlo.token + // CHECK-NEXT: "mhlo.after_all"([[VAL_0]], [[VAL_1]]) : (!mhlo.token, !mhlo.token) -> !mhlo.token ROOT after-all = token[] after-all(token0, token1) } @@ -75,10 +75,10 @@ add { %test_broadcast_in_dim { %Arg_0.1 = f32[1, 2] parameter(0) - // CHECK-NEXT: "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>, name = "{{.*}}"} : (tensor<1x2xf32>) -> tensor<1x2x3xf32> + // CHECK-NEXT: "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x2xf32>) -> tensor<1x2x3xf32> %broadcast.2 = f32[1,2,3] broadcast(%Arg_0.1), dimensions={0,1} - // CHECK-NEXT: "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>, name = "{{.*}}"} : (tensor<1x2xf32>) -> tensor<3x1x2xf32> + // CHECK-NEXT: "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<1x2xf32>) -> tensor<3x1x2xf32> ROOT broadcast.4 = f32[3,1,2] broadcast(%Arg_0.1), dimensions={1, 2} } @@ -113,7 +113,7 @@ add { // CHECK-SAME: ([[ARG:%.*]]: tensor<1x291x291xf32>) -> tensor<1x291x291xf32> %test_cholesky (a: f32[1,291,291]) -> f32[1,291,291] { %a = f32[1,291,291] parameter(0) - // CHECK-NEXT: "mhlo.cholesky"([[ARG]]) {lower = true, name = {{.*}}} : (tensor<1x291x291xf32>) -> tensor<1x291x291xf32> + // CHECK-NEXT: "mhlo.cholesky"([[ARG]]) {lower = true} : (tensor<1x291x291xf32>) -> tensor<1x291x291xf32> ROOT %out = f32[1,291,291] cholesky(f32[1,291,291] %a), lower=true } @@ -124,16 +124,16 @@ add { %Arg_1.2 = f32[4] parameter(1) %Arg_2.3 = f32[] parameter(2) - // CHECK-NEXT: "mhlo.clamp"(%arg0, %arg1, %arg2) {name = "{{.*}}"} : (tensor, tensor<4xf32>, tensor) -> tensor<4xf32> + // CHECK-NEXT: "mhlo.clamp"(%arg0, %arg1, %arg2) : (tensor, tensor<4xf32>, tensor) -> tensor<4xf32> ROOT %clamp.3 = f32[4] clamp(f32[] %Arg_0.1, f32[4] %Arg_1.2, f32[] %Arg_2.3) } // CHECK-LABEL: func @test_collective_permute // CHECK-SAME: ([[ARG:%.*]]: tensor<128x32xf32>) -> tensor<128x32xf32> %test_collective_permute (input: f32[128,32]) -> f32[128,32] { - %input = f32[128,32]{0,1} parameter(0) - // CHECK-NEXT: "mhlo.collective_permute"([[ARG]]) {name = {{.*}}, source_target_pairs = dense<{{\[\[}}0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>} : (tensor<128x32xf32>) -> tensor<128x32xf32> - ROOT root = f32[128,32]{0,1} collective-permute(%input), source_target_pairs={{0,1},{1,2},{2,3}} + %input = f32[128,32]{1,0} parameter(0) + // CHECK-NEXT: "mhlo.collective_permute"([[ARG]]) {source_target_pairs = dense<{{\[\[}}0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>} : (tensor<128x32xf32>) -> tensor<128x32xf32> + ROOT root = f32[128,32]{1,0} collective-permute(%input), source_target_pairs={{0,1},{1,2},{2,3}} } @@ -143,14 +143,14 @@ add { %Arg_1.2 = f32[3] parameter(1) %Arg_2.3 = f32[3] parameter(2) - // CHECK-NEXT: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "EQ", name = "{{.*}}"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1> + // CHECK-NEXT: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1> %compare.4 = pred[3] compare(Arg_0.1, Arg_1.2), direction=EQ - // CHECK-NEXT: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "LE", name = "{{.*}}"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1> + // CHECK-NEXT: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "LE"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1> %compare.5 = pred[3] compare(Arg_0.1, Arg_1.2), direction=LE // Requires broadcast of compatible tensors. - // CHECK-NEXT: "mhlo.compare"(%arg0, %arg2) {comparison_direction = "GT", name = "{{.*}}"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1> + // CHECK-NEXT: "mhlo.compare"(%arg0, %arg2) {comparison_direction = "GT"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1> ROOT %compare.6 = pred[3] compare(Arg_0.1, Arg_2.3), direction=GT } @@ -159,7 +159,7 @@ add { %Arg_0.1 = f32[4] parameter(0) %Arg_1.2 = f32[4] parameter(1) - // CHECK-NEXT: "mhlo.complex"(%arg0, %arg1) {name = "{{.*}}"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex> + // CHECK-NEXT: "mhlo.complex"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex> ROOT %complex.3 = c64[4] complex(f32[4] %Arg_0.1, f32[4] %Arg_1.2) } @@ -176,12 +176,12 @@ add { %test_constant { // Scalar/0D tensor constant - // CHECK-NEXT: %cst = constant {name = "{{.*}}"} dense<1> : tensor + // CHECK-NEXT: %cst = constant dense<1> : tensor %constant.0 = s64[] constant(1) // Note that double brackets "[[" have to be escaped as they denote variables // in FileCheck. The only way to do so is to drop into regex with "{{" - // CHECK-NEXT: constant {name = "{{.*}}"} dense<{{\[\[\[\[}}1.000000e+00]], {{\[\[}}2.000000e+00]]], {{\[\[\[}}3.000000e+00]], {{\[\[}}4.000000e+00]]]]> : tensor<2x2x1x1xf32> + // CHECK-NEXT: constant dense<{{\[\[\[\[}}1.000000e+00]], {{\[\[}}2.000000e+00]]], {{\[\[\[}}3.000000e+00]], {{\[\[}}4.000000e+00]]]]> : tensor<2x2x1x1xf32> %constant.1 = f32[2,2,1,1]{3,2,1,0} constant({{{{1.0}},{{2.0}}},{{{3.0}},{{4.0}}}}), metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"} // CHECK: dense<[1, 2, 4, 8]> : tensor<4xui64> @@ -206,15 +206,15 @@ add { %test_conv { %arg0.1 = f32[256,32,32,6]{3,2,1,0} parameter(0), metadata={op_name="HLO_Args"} - // CHECK-NEXT: %0 = "mhlo.copy"(%arg0) {name = "{{.*}}"} : (tensor<256x32x32x6xf32>) -> tensor<256x32x32x6xf32> + // CHECK-NEXT: %0 = "mhlo.copy"(%arg0) {minor_to_major = dense<[2, 1, 3, 0]> : tensor<4xindex>} : (tensor<256x32x32x6xf32>) -> tensor<256x32x32x6xf32> %copy.1 = f32[256,32,32,6]{2,1,3,0} copy(%arg0.1), metadata={op_name="HLO_Args"} - // CHECK-NEXT: %1 = "mhlo.reshape"(%0) {name = "{{.*}}"} : (tensor<256x32x32x6xf32>) -> tensor<256x32x32x6xf32> + // CHECK-NEXT: %1 = "mhlo.reshape"(%0) {minor_to_major = dense<[2, 1, 3, 0]> : tensor<4xindex>} : (tensor<256x32x32x6xf32>) -> tensor<256x32x32x6xf32> %reshape.2 = f32[256,32,32,6]{2,1,3,0} reshape(%copy.1) // Note that double brackets "[[" have to be escaped as they denote variables // in FileCheck. The only way to do so is to drop into regex with "{{" - // CHECK-NEXT: %cst = constant {name = "{{.*}}"} dense<{{\[\[\[\[}}5.000000e-01]], {{\[\[}}-6.000000e-01]]], {{\[\[\[}}3.000000e-01]], {{\[\[}}-1.000000e-01]]]]> : tensor<2x2x1x1xf32> + // CHECK-NEXT: %cst = constant dense<{{\[\[\[\[}}5.000000e-01]], {{\[\[}}-6.000000e-01]]], {{\[\[\[}}3.000000e-01]], {{\[\[}}-1.000000e-01]]]]> : tensor<2x2x1x1xf32> %constant.3 = f32[2,2,1,1]{3,2,1,0} constant({{{{0.5}}, {{-0.6}}}, {{{0.3}}, {{-0.1}}}}), metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"} // CHECK-NEXT: %2 = "mhlo.convolution"(%1, %cst) { @@ -241,10 +241,10 @@ add { %convolution.4 = f32[16,30,30,256]{2,1,3,0} convolution(%reshape.2, %constant.3), window={size=3x3 stride=4x5 pad=44_45x60_60 rhs_dilate=2x3}, dim_labels=b01f_01io->f01b, metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"} - // CHECK-NEXT: %3 = "mhlo.reshape"(%2) {name = "{{.*}}"} : (tensor<16x30x30x256xf32>) -> tensor<256x30x30x16xf32> + // CHECK-NEXT: %3 = "mhlo.reshape"(%2) : (tensor<16x30x30x256xf32>) -> tensor<256x30x30x16xf32> %reshape.5 = f32[256,30,30,16]{3,2,1,0} reshape(%convolution.4), metadata={op_name="HLO_Retvals"} - // CHECK-NEXT: "mhlo.tuple"(%3) {name = "{{.*}}"} : (tensor<256x30x30x16xf32>) -> tuple> + // CHECK-NEXT: "mhlo.tuple"(%3) : (tensor<256x30x30x16xf32>) -> tuple> ROOT %tuple.6 = (f32[256,30,30,16]{3,2,1,0}) tuple(%reshape.5), metadata={op_name="HLO_Retvals"} } @@ -263,10 +263,10 @@ add { %Arg_0.1 = f32[4] parameter(0) %Arg_1.2 = f32[4] parameter(1) - // CHECK-NEXT: %0 = "mhlo.convert"(%arg0) {name = "{{.*}}"} : (tensor<4xf32>) -> tensor<4xf64> + // CHECK-NEXT: %0 = "mhlo.convert"(%arg0) : (tensor<4xf32>) -> tensor<4xf64> %convert.3 = f64[4] convert(f32[4] %Arg_0.1) - // CHECK-NEXT: %1 = "mhlo.convert"(%arg1) {name = "{{.*}}"} : (tensor<4xf32>) -> tensor<4xf64> + // CHECK-NEXT: %1 = "mhlo.convert"(%arg1) : (tensor<4xf32>) -> tensor<4xf64> %convert.4 = f64[4] convert(f32[4] %Arg_1.2) // CHECK-NEXT: mhlo.add %0, %1 @@ -277,7 +277,7 @@ add { %test_cosine (arg0.1: f32[1,16,16,3]) -> f32[1,16,16,3] { %arg0.1 = f32[1,16,16,3]{3,2,1,0} parameter(0), metadata={op_name="HLO_Args"} - // CHECK-NEXT: "mhlo.cosine"(%arg0) {name = "{{.*}}"} : (tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32> + // CHECK-NEXT: "mhlo.cosine"(%arg0) : (tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32> ROOT %cosine.3 = f32[1,16,16,3]{3,2,1,0} cosine(f32[1,16,16,3]{3,2,1,0} %arg0.1) } @@ -286,7 +286,7 @@ add { %test_custom_call (arg1: f32[2,3], arg2: f32[5,5]) -> f32[1,2,3] { %arg1 = f32[2,3] parameter(0) %arg2 = f32[5,5] parameter(1) -// CHECK: "mhlo.custom_call"([[ARG_0]], [[ARG_1]]) {backend_config = "bar", call_target_name = "foo", has_side_effect = true, name = {{.*}}} : (tensor<2x3xf32>, tensor<5x5xf32>) -> tensor<1x2x3xf32> +// CHECK: "mhlo.custom_call"([[ARG_0]], [[ARG_1]]) {backend_config = "bar", call_target_name = "foo", has_side_effect = true, minor_to_major = {{.*}}} : (tensor<2x3xf32>, tensor<5x5xf32>) -> tensor<1x2x3xf32> ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[2,3] %arg1, f32[5,5] %arg2), custom_call_target="foo", backend_config="bar", custom_call_has_side_effect=true } @@ -295,7 +295,7 @@ add { %Arg_0.1 = f32[4] parameter(0) %Arg_1.2 = f32[4] parameter(1) - // CHECK-NEXT: mhlo.divide %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32> + // CHECK-NEXT: mhlo.divide %arg0, %arg1 : tensor<4xf32> ROOT %divide.3 = f32[4] divide(f32[4] %Arg_0.1, f32[4] %Arg_1.2) } @@ -304,17 +304,17 @@ add { %Arg_0.1 = f32[1, 4] parameter(0) %Arg_1.2 = f32[4, 1] parameter(1) - // CHECK-NEXT: %0 = "mhlo.dot"(%arg0, %arg1) {name = "{{.*}}", precision_config = ["HIGH", "HIGHEST"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor + // CHECK-NEXT: %0 = "mhlo.dot"(%arg0, %arg1) {precision_config = ["HIGH", "HIGHEST"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor dot.3 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={high,highest} - // CHECK-NEXT: %1 = "mhlo.dot"(%arg0, %arg1) {name = "{{.*}}", precision_config = ["HIGHEST", "DEFAULT"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor + // CHECK-NEXT: %1 = "mhlo.dot"(%arg0, %arg1) {precision_config = ["HIGHEST", "DEFAULT"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor dot.4 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,default} - // CHECK-NEXT: %2 = "mhlo.dot"(%arg0, %arg1) {name = "{{.*}}", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor + // CHECK-NEXT: %2 = "mhlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor %dot.5 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={default,default} // TODO(b/129709049) consider making this default precision config inferred. - // CHECK-NEXT: "mhlo.dot"(%arg0, %arg1) {name = "{{.*}}", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor + // CHECK-NEXT: "mhlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor ROOT %dot.6 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0} } @@ -325,17 +325,17 @@ add { %Arg_0.1 = f32[4, 1] parameter(0) %Arg_1.2 = f32[1, 4] parameter(1) - // CHECK-NEXT: [[R0:%.+]] = "mhlo.dot_general"([[ARG0]], [[ARG1]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<> : tensor<0xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, name = "{{.*}}", precision_config = ["HIGH", "HIGHEST"]} + // CHECK-NEXT: [[R0:%.+]] = "mhlo.dot_general"([[ARG0]], [[ARG1]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<> : tensor<0xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, precision_config = ["HIGH", "HIGHEST"]} dot.3 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={1}, operand_precision={high,highest} - // CHECK-NEXT: [[R1:%.+]] = "mhlo.dot_general"([[ARG0]], [[ARG1]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<> : tensor<0xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, name = "{{.*}}", precision_config = ["HIGHEST", "DEFAULT"]} + // CHECK-NEXT: [[R1:%.+]] = "mhlo.dot_general"([[ARG0]], [[ARG1]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<> : tensor<0xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, precision_config = ["HIGHEST", "DEFAULT"]} dot.4 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={1}, operand_precision={highest,default} - // CHECK-NEXT: [[R2:%.+]] = "mhlo.dot_general"([[ARG0]], [[ARG1]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<> : tensor<0xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, name = "{{.*}}", precision_config = ["DEFAULT", "DEFAULT"]} + // CHECK-NEXT: [[R2:%.+]] = "mhlo.dot_general"([[ARG0]], [[ARG1]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<> : tensor<0xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} %dot.5 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={1}, operand_precision={default,default} // TODO(b/129709049) consider making this default precision config inferred. - // CHECK-NEXT: "mhlo.dot_general"([[ARG0]], [[ARG1]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<> : tensor<0xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, name = "{{.*}}", precision_config = ["DEFAULT", "DEFAULT"]} + // CHECK-NEXT: "mhlo.dot_general"([[ARG0]], [[ARG1]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<> : tensor<0xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} ROOT %dot.6 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={1} } @@ -376,7 +376,7 @@ add { %test_exponential (arg0.1: f32[16]) -> f32[16] { %arg0.1 = f32[16] parameter(0) - // CHECK-NEXT: "mhlo.exponential"(%arg0) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32> + // CHECK-NEXT: "mhlo.exponential"(%arg0) : (tensor<16xf32>) -> tensor<16xf32> ROOT %exp.2 = f32[16] exponential(f32[16] %arg0.1) } @@ -384,7 +384,7 @@ add { %test_expm1 (arg0.1: f32[16]) -> f32[16] { %arg0.1 = f32[16] parameter(0) - // CHECK: "mhlo.exponential_minus_one"(%arg0) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32> + // CHECK: "mhlo.exponential_minus_one"(%arg0) : (tensor<16xf32>) -> tensor<16xf32> ROOT %expm1.2 = f32[16] exponential-minus-one(f32[16] %arg0.1) } @@ -400,7 +400,7 @@ add { %test_floor (arg0.1: f32[16]) -> f32[16] { %arg0.1 = f32[16] parameter(0) - // CHECK-NEXT: "mhlo.floor"([[A0]]) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32> + // CHECK-NEXT: "mhlo.floor"([[A0]]) : (tensor<16xf32>) -> tensor<16xf32> ROOT %floor.2 = f32[16] floor(f32[16] %arg0.1) } @@ -430,7 +430,7 @@ add { // CHECK-SAME: ([[ARG:%.*]]: tensor<4x2xf32>) %test_get_dimension_size (Arg_0.1: f32[4,2]) -> s32[] { %Arg_0.1 = f32[4,2] parameter(0) - // CHECK-NEXT: "mhlo.get_dimension_size"([[ARG]]) {dimension = 1 : i32, name = "{{.*}}"} : (tensor<4x2xf32>) -> tensor + // CHECK-NEXT: "mhlo.get_dimension_size"([[ARG]]) {dimension = 1 : i32} : (tensor<4x2xf32>) -> tensor ROOT %get-dimension-size.2 = s32[] get-dimension-size(f32[4,2] %Arg_0.1), dimensions={1} } @@ -438,7 +438,7 @@ add { %test_imag (Arg_0.1: c64[4]) -> f32[4] { %Arg_0.1 = c64[4] parameter(0) - // CHECK-NEXT: "mhlo.imag"(%arg0) {name = "{{.*}}"} : (tensor<4xcomplex>) -> tensor<4xf32> + // CHECK-NEXT: "mhlo.imag"(%arg0) : (tensor<4xcomplex>) -> tensor<4xf32> ROOT %imag.3 = f32[4] imag(c64[4] %Arg_0.1) } @@ -468,7 +468,7 @@ add { %test_log (arg0.1: f32[16]) -> f32[16] { %arg0.1 = f32[16] parameter(0) - // CHECK-NEXT: "mhlo.log"(%arg0) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32> + // CHECK-NEXT: "mhlo.log"(%arg0) : (tensor<16xf32>) -> tensor<16xf32> ROOT %log.2 = f32[16] log(f32[16] %arg0.1) } @@ -476,7 +476,7 @@ add { %test_log1p (arg0.1: f32[16]) -> f32[16] { %arg0.1 = f32[16] parameter(0) - // CHECK: "mhlo.log_plus_one"(%arg0) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32> + // CHECK: "mhlo.log_plus_one"(%arg0) : (tensor<16xf32>) -> tensor<16xf32> ROOT %log1p.2 = f32[16] log-plus-one(f32[16] %arg0.1) } @@ -507,7 +507,7 @@ add { %Arg_0.1 = f32[4] parameter(0) %Arg_1.2 = f32[4] parameter(1) - // CHECK-NEXT: mhlo.maximum %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32> + // CHECK-NEXT: mhlo.maximum %arg0, %arg1 : tensor<4xf32> ROOT %maximum.3 = f32[4] maximum(f32[4] %Arg_0.1, f32[4] %Arg_1.2) } @@ -516,7 +516,7 @@ add { %Arg_0.1 = f32[4] parameter(0) %Arg_1.2 = f32[4] parameter(1) - // CHECK-NEXT: mhlo.minimum %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32> + // CHECK-NEXT: mhlo.minimum %arg0, %arg1 : tensor<4xf32> ROOT %minimum.3 = f32[4] minimum(f32[4] %Arg_0.1, f32[4] %Arg_1.2) } @@ -525,7 +525,7 @@ add { %Arg_0.1 = f32[4] parameter(0) %Arg_1.2 = f32[4] parameter(1) - // CHECK-NEXT: %0 = mhlo.multiply %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32> + // CHECK-NEXT: %0 = mhlo.multiply %arg0, %arg1 : tensor<4xf32> ROOT %multiply.3 = f32[4] multiply(f32[4] %Arg_0.1, f32[4] %Arg_1.2) } @@ -533,7 +533,7 @@ add { %test_negate (arg0.1: f32[16]) -> f32[16] { %arg0.1 = f32[16] parameter(0) - // CHECK-NEXT: "mhlo.negate"(%arg0) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32> + // CHECK-NEXT: "mhlo.negate"(%arg0) : (tensor<16xf32>) -> tensor<16xf32> ROOT %negate.2 = f32[16] negate(f32[16] %arg0.1) } @@ -541,7 +541,7 @@ add { %test_not (arg0.1: pred[16]) -> pred[16] { %arg0.1 = pred[16] parameter(0) - // CHECK: "mhlo.not"(%arg0) {name = "{{.*}}"} : (tensor<16xi1>) -> tensor<16xi1> + // CHECK: "mhlo.not"(%arg0) : (tensor<16xi1>) -> tensor<16xi1> ROOT %not.2 = pred[16] not(pred[16] %arg0.1) } @@ -595,7 +595,7 @@ add { %test_popcnt (arg0.1: s32[16]) -> s32[16] { %arg0.1 = s32[16] parameter(0) - // CHECK: "mhlo.popcnt"(%arg0) {name = "{{.*}}"} : (tensor<16xi32>) -> tensor<16xi32> + // CHECK: "mhlo.popcnt"(%arg0) : (tensor<16xi32>) -> tensor<16xi32> ROOT %popcnt.2 = s32[16] popcnt(s32[16] %arg0.1) } @@ -604,7 +604,7 @@ add { %Arg_0.1 = f32[4] parameter(0) %Arg_1.2 = f32[4] parameter(1) - // CHECK-NEXT: mhlo.power %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32> + // CHECK-NEXT: mhlo.power %arg0, %arg1 : tensor<4xf32> ROOT %power.3 = f32[4] power(f32[4] %Arg_0.1, f32[4] %Arg_1.2) } @@ -632,7 +632,7 @@ add { %test_real (Arg_0.1: c64[4]) -> f32[4] { %Arg_0.1 = c64[4] parameter(0) - // CHECK-NEXT: "mhlo.real"(%arg0) {name = "{{.*}}"} : (tensor<4xcomplex>) -> tensor<4xf32> + // CHECK-NEXT: "mhlo.real"(%arg0) : (tensor<4xcomplex>) -> tensor<4xf32> ROOT %real.3 = f32[4] real(c64[4] %Arg_0.1) } @@ -687,7 +687,7 @@ add { // CHECK: {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor) -> tensor %reduce.4 = f32[] reduce(%reduce.2, %Arg_2.3), dimensions={0}, to_apply=%reduce_helper.3 - // CHECK: %4 = mhlo.subtract [[VAL2]], [[VAL4]] {name = "{{.*}}"} : tensor + // CHECK: %4 = mhlo.subtract [[VAL2]], [[VAL4]] : tensor %sub.5 = f32[] subtract(%reduce.3, %reduce.4) ROOT %tuple.6 = ((f32[], f32[]), f32[]) tuple(%reduce.1, %sub.5) @@ -741,7 +741,7 @@ add { %test_rsqrt (arg0.1: f32[16]) -> f32[16] { %arg0.1 = f32[16] parameter(0) - // CHECK: "mhlo.rsqrt"([[ARG0]]) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32> + // CHECK: "mhlo.rsqrt"([[ARG0]]) : (tensor<16xf32>) -> tensor<16xf32> ROOT %rsqrt.2 = f32[16] rsqrt(f32[16] %arg0.1) } @@ -788,7 +788,7 @@ add { %Arg_1.2 = s32[2,3] parameter(1) %Arg_2.3 = s32[2,3] parameter(2) - // CHECK: "mhlo.select"(%arg0, %arg1, %arg2) {name = "{{.*}}"} : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + // CHECK: "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> ROOT %select.4 = s32[2,3] select(pred[2,3] %Arg_0.1, s32[2,3] %Arg_1.2, s32[2,3] %Arg_2.3) } @@ -835,7 +835,7 @@ add { %test_set_dimension_size (Arg_0.1: f32[4,4], Arg_1.2: s32[]) -> f32[4,<=4] { %Arg_0.1 = f32[4,4] parameter(0) %Arg_1.2 = s32[] parameter(1) - // CHECK-NEXT: "mhlo.set_dimension_size"([[ARG]], [[SIZE]]) {dimension = 1 : i32, name = "{{.*}}"} : (tensor<4x4xf32>, tensor) -> tensor<4x4xf32> + // CHECK-NEXT: "mhlo.set_dimension_size"([[ARG]], [[SIZE]]) {dimension = 1 : i32} : (tensor<4x4xf32>, tensor) -> tensor<4x4xf32> ROOT %set-dimension-size.2 = f32[4,<=4] set-dimension-size(f32[4,4] %Arg_0.1, s32[] %Arg_1.2), dimensions={1} } @@ -843,7 +843,7 @@ add { %test_sine (arg0.1: f32[1,16,16,3]) -> f32[1,16,16,3] { %arg0.1 = f32[1,16,16,3]{3,2,1,0} parameter(0), metadata={op_name="HLO_Args"} - // CHECK-NEXT: "mhlo.sine"(%arg0) {name = "{{.*}}"} : (tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32> + // CHECK-NEXT: "mhlo.sine"(%arg0) : (tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32> ROOT %sine.3 = f32[1,16,16,3]{3,2,1,0} sine(f32[1,16,16,3]{3,2,1,0} %arg0.1) } @@ -862,7 +862,7 @@ add { // CHECK-SAME: [[ARG:%.*]]: tensor<1024xf32>) -> tensor<1024xf32> // CHECK: "mhlo.sort"([[ARG]]) ( { // CHECK: ^bb0([[ARG0:%.*]]: tensor, [[ARG1:%.*]]: tensor): -// CHECK: [[CMP:%.*]] = "mhlo.compare"([[ARG0]], [[ARG1]]) {comparison_direction = "LT", name = "lt"} : (tensor, tensor) -> tensor +// CHECK: [[CMP:%.*]] = "mhlo.compare"([[ARG0]], [[ARG1]]) {comparison_direction = "LT"} : (tensor, tensor) -> tensor // CHECK: "mhlo.return"([[CMP]]) : (tensor) -> () // CHECK: }) {dimension = 0 : i64, is_stable = true} : (tensor<1024xf32>) -> tensor<1024xf32> @@ -871,7 +871,7 @@ add { %Arg_0.1 = f32[4] parameter(0) %Arg_1.2 = f32[4] parameter(1) - // CHECK-NEXT: mhlo.subtract %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32> + // CHECK-NEXT: mhlo.subtract %arg0, %arg1 : tensor<4xf32> ROOT %subtract.3 = f32[4] subtract(f32[4] %Arg_0.1, f32[4] %Arg_1.2) } @@ -879,7 +879,7 @@ add { %test_tanh (arg0.1: f32[1,16,16,3]) -> f32[1,16,16,3] { %arg0.1 = f32[1,16,16,3]{3,2,1,0} parameter(0), metadata={op_name="HLO_Args"} - // CHECK-NEXT: "mhlo.tanh"(%arg0) {name = "{{.*}}"} : (tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32> + // CHECK-NEXT: "mhlo.tanh"(%arg0) : (tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32> ROOT %tanh.3 = f32[1,16,16,3]{3,2,1,0} tanh(f32[1,16,16,3]{3,2,1,0} %arg0.1), metadata={op_type="Tanh" op_name="embedded_inference/tanh_model/Tanh"} } @@ -887,7 +887,7 @@ add { %test_transpose { %Arg_0.1 = s32[1,2,3,4] parameter(0) - // CHECK: "mhlo.transpose"(%arg0) {name = "{{.*}}", permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> + // CHECK: "mhlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> ROOT %transpose.2 = s32[2,1,4,3] transpose(s32[1,2,3,4] %Arg_0.1), dimensions={1,0,3,2} } @@ -909,10 +909,10 @@ add { %Arg_0.1 = s32[1] parameter(0) %Arg_1.2 = f32[1, 2] parameter(1) - // CHECK-NEXT: %0 = "mhlo.tuple"(%arg0) {name = "{{.*}}"} : (tensor<1xi32>) -> tuple> + // CHECK-NEXT: %0 = "mhlo.tuple"(%arg0) : (tensor<1xi32>) -> tuple> %tuple.3 = (s32[1]) tuple(%Arg_0.1) - // CHECK: "mhlo.tuple"(%arg0, %arg1) {name = "{{.*}}"} : (tensor<1xi32>, tensor<1x2xf32>) -> tuple, tensor<1x2xf32>> + // CHECK: "mhlo.tuple"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xf32>) -> tuple, tensor<1x2xf32>> ROOT %tuple.4 = (s32[1], f32[1,2]) tuple(%Arg_0.1, %Arg_1.2) } @@ -934,11 +934,11 @@ add { %arg0.1 = s64[] parameter(0), metadata={op_name="HLO_Args"} // CHECK-NEXT: "mhlo.while"(%arg0) ( { // CHECK-NEXT: ^bb0(%arg1: tensor): // no predecessors - // CHECK-NEXT: [[CMP:%.*]] = "mhlo.compare"(%arg1, %arg1) {comparison_direction = "LT", name = "{{.*}}"} : (tensor, tensor) -> tensor + // CHECK-NEXT: [[CMP:%.*]] = "mhlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor, tensor) -> tensor // CHECK-NEXT: "mhlo.return"([[CMP]]) : (tensor) -> () // CHECK-NEXT: }, { // CHECK-NEXT: ^bb0(%arg1: tensor): // no predecessors - // CHECK-NEXT: [[ADD:%.*]] = mhlo.add %arg1, %arg1 {name = "{{.*}}"} : tensor + // CHECK-NEXT: [[ADD:%.*]] = mhlo.add %arg1, %arg1 : tensor // CHECK-NEXT: "mhlo.return"([[ADD]]) : (tensor) -> () // CHECK-NEXT: }) : (tensor) -> tensor ROOT %while.2 = s64[] while(%arg0.1), body=%loop, condition=%cond @@ -992,8 +992,8 @@ add { %Arg_1.2 = c128[2] parameter(1) %abs.4 = f64[2] abs(c128[2] %Arg_1.2) - // CHECK: "mhlo.abs"(%[[ARG0]]) {name = "{{.*}}"} : (tensor<2xcomplex>) -> tensor<2xf32> - // CHECK: "mhlo.abs"(%[[ARG1]]) {name = "{{.*}}"} : (tensor<2xcomplex>) -> tensor<2xf64> + // CHECK: "mhlo.abs"(%[[ARG0]]) : (tensor<2xcomplex>) -> tensor<2xf32> + // CHECK: "mhlo.abs"(%[[ARG1]]) : (tensor<2xcomplex>) -> tensor<2xf64> ROOT %tuple.5 = (f32[2], f64[2]) tuple(f32[2] %abs.3, f64[2] %abs.4) } @@ -1002,7 +1002,7 @@ add { %unsigned_int(Arg_0.1: u16[4]) -> u16[4] { %Arg_0.1 = u16[4] parameter(0) - // CHECK: "mhlo.not"(%[[ARG0]]) {name = "{{.*}}"} : (tensor<4xui16>) -> tensor<4xui16> + // CHECK: "mhlo.not"(%[[ARG0]]) : (tensor<4xui16>) -> tensor<4xui16> ROOT %not.2 = u16[4] not(u16[4] %Arg_0.1) } @@ -1014,3 +1014,26 @@ add { ROOT %rng-bit-generator.2 = (u64[3], u32[2,2]) rng-bit-generator(u64[3] %Arg_0.1), algorithm=rng_philox } +// CHECK-LABEL: func @cbrt +// CHECK-SAME: (%[[ARG0:.*]]: tensor<3x4xf32>) +%cbrt (Arg_0.1: f32[3,4]) -> f32[3,4] { + %Arg_0.1 = f32[3,4] parameter(0) + // CHECK: "mhlo.cbrt"(%[[ARG0]]) : (tensor<3x4xf32>) -> tensor<3x4xf32> + ROOT %cbrt = f32[3,4] cbrt(f32[3,4] %Arg_0.1) +} + +// CHECK-LABEL: func @bitcast +// CHECK-SAME: (%[[ARG0:.*]]: tensor<3x4xf32>) -> tensor<3x4x1xf32> +%bitcast (Arg_0.1: f32[3,4]) -> f32[3,4,1] { + %Arg_0.1 = f32[3,4] parameter(0) + // CHECK: "mhlo.bitcast"(%[[ARG0]]) : (tensor<3x4xf32>) -> tensor<3x4x1xf32> + ROOT %bitcast = f32[3,4,1] bitcast(f32[3,4] %Arg_0.1) +} + +// CHECK-LABEL: func @reduce_precision +// CHECK-SAME: (%[[ARG0:.*]]: tensor<3x4xf32>) +%reduce_precision (Arg_0.1: f32[3,4]) -> f32[3,4] { + %Arg_0.1 = f32[3,4] parameter(0) + // CHECK: "mhlo.reduce_precision"(%[[ARG0]]) {exponent_bits = 8 : i32, mantissa_bits = 10 : i32} : (tensor<3x4xf32>) -> tensor<3x4xf32> + ROOT %reduce_precision = f32[3,4] reduce-precision(f32[3,4] %Arg_0.1), exponent_bits=8, mantissa_bits=10 +} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/layouts_and_names.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/layouts_and_names.hlotxt new file mode 100644 index 00000000000..da07dc0a76b --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/layouts_and_names.hlotxt @@ -0,0 +1,11 @@ +// RUN: tf-mlir-translate -mlir-print-debuginfo -hlo-text-to-mlir-hlo %s -o - | FileCheck %s + +HloModule Test + +// CHECK-LABEL: func @main +ENTRY A { + %input = f16[128,224,224,4] parameter(0) + %filter = f16[64,7,7,4] parameter(1) + // %0 = "mhlo.convolution"{{.*}}minor_to_major = dense<[1, 3, 2, 0]> : tensor<4xindex>{{.*}} loc("root.42") + ROOT %root.42 = f16[128,64,112,112]{1,3,2,0} convolution(%input, %filter), dim_labels=b01f_o01i->bf01, window={size=7x7 stride=2x2 pad=3_3x3_3} +} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/layouts_and_names.mlir b/tensorflow/compiler/mlir/xla/tests/translate/layouts_and_names.mlir new file mode 100644 index 00000000000..2ef0aaf3f50 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/layouts_and_names.mlir @@ -0,0 +1,34 @@ +// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text-with-layouts %s | FileCheck %s + +// Checks exporting layouts + +// CHECK: HloModule +func @main(%arg0: tensor<128x224x224x4xf16>, %arg1: tensor<64x7x7x4xf16>) -> tensor<128x64x112x112xf16> { + // CHECK: %convolution.{{.*}} = f16[128,64,112,112]{1,3,2,0} convolution{{.*}}op_name="root.42" + %0 = "mhlo.convolution"(%arg0, %arg1) { + batch_group_count = 1 : i64, + dimension_numbers = { + input_batch_dimension = 0 : i64, + input_feature_dimension = 3 : i64, + input_spatial_dimensions = dense<[ 1, 2 ]> : tensor<2xi64>, + kernel_input_feature_dimension = 3 : i64, + kernel_output_feature_dimension = 0 : i64, + kernel_spatial_dimensions = dense<[ 1, 2 ]> : tensor<2xi64>, + output_batch_dimension = 0 : i64, + output_feature_dimension = 1 : i64, + output_spatial_dimensions = dense<[ 2, 3 ]> : tensor<2xi64> + }, + feature_group_count = 1 : i64, + lhs_dilations = dense<1> : tensor<2xi64>, + minor_to_major = dense<[ 1, 3, 2, 0 ]> : tensor<4xindex>, + padding = dense<3> : tensor<2x2xi64>, + precision_config = [ "DEFAULT", "DEFAULT" ], + rhs_dilations = dense<1> : tensor<2xi64>, + window_strides = dense<2> : tensor<2xi64> + } : (tensor<128x224x224x4xf16>, tensor<64x7x7x4xf16>)-> tensor<128x64x112x112xf16> loc("root.42") + + // CHECK: s32[1,1]{0,1} constant({ {42} }) + %cst_1 = "std.constant"() {value = dense<[[42]]> : tensor<1x1xi32>, minor_to_major = dense<[0, 1]> : tensor<2xindex>} : () -> tensor<1x1xi32> + + return %0 : tensor<128x64x112x112xf16> +} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/location_to_op_metadata.mlir b/tensorflow/compiler/mlir/xla/tests/translate/location_to_op_metadata.mlir new file mode 100644 index 00000000000..2182ce6106d --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/location_to_op_metadata.mlir @@ -0,0 +1,43 @@ +// RUN: tf-mlir-translate -split-input-file -mlir-hlo-to-hlo-text %s | FileCheck %s --dump-input=always + +// CHECK-LABEL: %main +func @main(%arg0: !mhlo.token) -> !mhlo.token { + %0 = "mhlo.after_all"(%arg0) : (!mhlo.token) -> !mhlo.token loc(unknown) + return %0 : !mhlo.token +} + +// CHECK: after-all +// CHECK-NOT: metadata + +// ----- + +// CHECK-LABEL: %main +func @main(%arg0: !mhlo.token) -> !mhlo.token { + %0 = "mhlo.after_all"(%arg0) : (!mhlo.token) -> !mhlo.token loc("AfterAll") + return %0 : !mhlo.token +} + +// CHECK: after-all +// CHECK-SAME: metadata={op_name="AfterAll"} + +// ----- + +// CHECK-LABEL: %main +func @main(%arg0: !mhlo.token) -> !mhlo.token { + %0 = "mhlo.after_all"(%arg0) : (!mhlo.token) -> !mhlo.token loc("name@function") + return %0 : !mhlo.token +} + +// CHECK: after-all +// CHECK-SAME: metadata={op_name="name"} + +// ----- + +// CHECK-LABEL: %main +func @main(%arg0: !mhlo.token) -> !mhlo.token { + %0 = "mhlo.after_all"(%arg0) : (!mhlo.token) -> !mhlo.token loc("file_name":2:8) + return %0 : !mhlo.token +} + +// CHECK: after-all +// CHECK-SAME: metadata={source_file="file_name" source_line=2} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/simple.hlo b/tensorflow/compiler/mlir/xla/tests/translate/simple.hlo index d97c5150335..4c288aee956 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/simple.hlo +++ b/tensorflow/compiler/mlir/xla/tests/translate/simple.hlo @@ -139,8 +139,8 @@ dynamic_parameter_binding { } # CHECK-LABEL: func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor { -# CHECK-NEXT: %0 = mhlo.add %arg0, %arg1 {name = "add.3"} : tensor<4xf32> +# CHECK-NEXT: %0 = mhlo.add %arg0, %arg1 : tensor<4xf32> # TODO(b/129709049) consider making this default precision config inferred. -# CHECK-NEXT: %1 = "mhlo.dot"(%0, %arg1) {name = "dot.4", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<4xf32>, tensor<4xf32>) -> tensor +# CHECK-NEXT: %1 = "mhlo.dot"(%0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<4xf32>, tensor<4xf32>) -> tensor # CHECK-NEXT: return %1 : tensor # CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/types.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/types.hlotxt index 855b1c4bcd5..f7e1ba9ff15 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/types.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/types.hlotxt @@ -4,25 +4,25 @@ HloModule tfcompile.1 // CHECK-LABEL: func @main() -> tensor { ENTRY %tfcompile.1 { - // CHECK-NEXT: %cst = constant {name = "constant.0"} dense<1.000000e+00> : tensor + // CHECK-NEXT: %cst = constant dense<1.000000e+00> : tensor %constant.0 = f32[] constant(1) - // CHECK-NEXT: %cst_0 = constant {name = "constant.1"} dense<1.000000e+00> : tensor + // CHECK-NEXT: %cst_0 = constant dense<1.000000e+00> : tensor %constant.1 = f64[] constant(1) - // CHECK-NEXT: %cst_1 = constant {name = "constant.2"} dense<1> : tensor + // CHECK-NEXT: %cst_1 = constant dense<1> : tensor %constant.2 = s8[] constant(1) - // CHECK-NEXT: %cst_2 = constant {name = "constant.3"} dense<1> : tensor + // CHECK-NEXT: %cst_2 = constant dense<1> : tensor %constant.3 = s16[] constant(1) - // CHECK-NEXT: %cst_3 = constant {name = "constant.4"} dense<1> : tensor + // CHECK-NEXT: %cst_3 = constant dense<1> : tensor %constant.4 = s32[] constant(1) - // CHECK-NEXT: %cst_4 = constant {name = "constant.5"} dense<1> : tensor + // CHECK-NEXT: %cst_4 = constant dense<1> : tensor %constant.5 = s64[] constant(1) - // CHECK-NEXT: %cst_5 = constant {name = "constant.6"} dense : tensor + // CHECK-NEXT: %cst_5 = constant dense : tensor // CHECK-NEXT: return %cst_5 : tensor ROOT %constant.6 = pred[] constant(1) } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/while.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/while.hlotxt index 126bc88ec7a..f989104323a 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/while.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/while.hlotxt @@ -26,4 +26,4 @@ ENTRY %foo (arg0.1: s64[]) -> s64[] { // CHECK: "mhlo.return" // CHECK: }) : (tensor) -> tensor ROOT %while.2 = s64[] while(%arg0.1), body=%loop, condition=%cond -} \ No newline at end of file +} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/while.mlir b/tensorflow/compiler/mlir/xla/tests/translate/while.mlir index 61d7aadb23f..f852ef06421 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/while.mlir +++ b/tensorflow/compiler/mlir/xla/tests/translate/while.mlir @@ -10,11 +10,11 @@ module { // CHECK: %[[A0]] = s64[] parameter(0) // CHECK: ROOT %compare.7 = pred[] compare(s64[] %[[A0]], s64[] %[[A0]]), direction=LT ^bb0(%arg1: tensor): - %1 = "mhlo.compare"(%arg1, %arg1) {comparison_direction = "LT", name = "compare.2"} : (tensor, tensor) -> tensor + %1 = "mhlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor, tensor) -> tensor "mhlo.return"(%1) : (tensor) -> () }, { ^bb0(%arg1: tensor): - %1 = mhlo.add %arg1, %arg1 {name = "compare.0"} : tensor + %1 = mhlo.add %arg1, %arg1 : tensor "mhlo.return"(%1) : (tensor) -> () }) : (tensor) -> tensor diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index 2c733bb5ca2..9c85242dca8 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -15,6 +15,7 @@ limitations under the License. // This file implements logic for lowering TensorFlow dialect to XLA dialect. +#include #include #include #include @@ -42,6 +43,7 @@ limitations under the License. #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" @@ -70,7 +72,8 @@ constexpr char kShardingAttr[] = "mhlo.sharding"; class LegalizeTF : public PassWrapper { void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); + registry.insert(); } public: @@ -116,9 +119,9 @@ class LegalizeTF : public PassWrapper { static bool IsDefaultDataFormat(StringRef format) { return format == "NHWC"; } /// Returns the feature dimension for the given format and input type. -static size_t GetFeatureDimension(StringAttr format, +static size_t GetFeatureDimension(StringRef format, RankedTensorType inputType) { - return IsDefaultDataFormat(format.getValue()) ? inputType.getRank() - 1 : 1; + return IsDefaultDataFormat(format) ? inputType.getRank() - 1 : 1; } // Gets all integer values from the given attribute and push them to `values`. @@ -728,12 +731,33 @@ static void CreateWhile32(Location loc, int num_iterations, // BatchNorm op utilities. //===----------------------------------------------------------------------===// -static IntegerAttr getFeatureDimensionAttr(Builder &b, StringAttr format, +static IntegerAttr getFeatureDimensionAttr(Builder &b, StringRef format, Value input) { return b.getI64IntegerAttr( GetFeatureDimension(format, input.getType().cast())); } +//===----------------------------------------------------------------------===// +// FFT op utilities. +//===----------------------------------------------------------------------===// +// Returns the 1D i64 elements attribute populated with the inner-most dim of +// the value. +static DenseIntElementsAttr GetInnerDimFromValue(ShapedType type, + Builder *builder) { + if (type.getRank() == 0) { + return builder->getI64TensorAttr({}); + } + return builder->getI64TensorAttr(type.getShape().back()); +} + +// Returns True if the inner-most dim is static. +bool CheckInnerDimStatic(ShapedType type, Builder *builder) { + if (!type.hasRank()) { + return false; + } + return !type.isDynamicDim(type.getShape().size() - 1); +} + //===----------------------------------------------------------------------===// // MatMul op utilities. //===----------------------------------------------------------------------===// @@ -1104,7 +1128,7 @@ class ConvertBiasAddOp : public OpRewritePattern { PatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto feature_dim = GetFeatureDimension( - op.data_formatAttr(), op.value().getType().cast()); + op.data_format(), op.value().getType().cast()); auto bias_broadcast = Broadcast1DToFeatureDim(loc, op.value(), op.bias(), feature_dim, rewriter); rewriter.replaceOpWithNewOp(op, op.value(), bias_broadcast); @@ -1683,6 +1707,80 @@ class ConvertEinsumOp : public OpRewritePattern { } }; +template +class ConvertFFTOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + auto input_ty = op.input().getType().template cast(); + if (!input_ty.hasRank()) { + return failure(); + } + auto input_shape = input_ty.getShape(); + DenseIntElementsAttr fft_length_attr; + if (!matchPattern(op.fft_length(), m_Constant(&fft_length_attr))) { + return failure(); + } + int64_t fft_length; + if (fft_length_attr.getNumElements() != 0) { + fft_length = fft_length_attr.getValue(0).getInt(); + } else { + return failure(); + } + + std::string fft_string = "RFFT"; + if (typeid(OpTy) == typeid(TF::IRFFTOp)) { + fft_length = fft_length / 2 + 1; + fft_string = "IRFFT"; + } + auto loc = op.getLoc(); + + // The inner-most dim cannot be dynamic. + if (input_ty.isDynamicDim(input_shape.size() - 1)) { + return failure(); + } + + auto expected_shape = llvm::to_vector<4>(input_shape.drop_back()); + expected_shape.push_back(fft_length); + + // Zero pad or truncate the last axis + Value reshaped = op.input(); + SmallVector begin_indices(input_shape.size(), 0); + SmallVector strides(input_shape.size(), 1); + + // Last dim larger than fft_length, slice the input + if (input_shape.back() > fft_length) { + reshaped = rewriter.create( + op.getLoc(), + RankedTensorType::get(expected_shape, input_ty.getElementType()), + op.input(), GetI64ElementsAttr(begin_indices, &rewriter), + GetI64ElementsAttr(expected_shape, &rewriter), + GetI64ElementsAttr(strides, &rewriter)); + + // Last dim smaller than fft_length, zero-pad the input + } else if (input_ty.getShape().back() < fft_length) { + SmallVector no_padding(input_shape.size(), 0); + SmallVector padding(input_shape.size() - 1, 0); + padding.push_back(fft_length - input_shape.back()); + Value zero = + GetScalarConstOfType(input_ty.getElementType(), loc, 0, &rewriter); + reshaped = rewriter.create( + loc, RankedTensorType::get(expected_shape, input_ty.getElementType()), + op.input(), zero, GetI64ElementsAttr(no_padding, &rewriter), + GetI64ElementsAttr(padding, &rewriter), + GetI64ElementsAttr(no_padding, &rewriter)); + } + + rewriter.replaceOpWithNewOp(op, op.getType(), reshaped, fft_string, + rewriter.getI64TensorAttr(fft_length)); + return success(); + } +}; + +using ConvertRFFTOp = ConvertFFTOp; +using ConvertIRFFTOp = ConvertFFTOp; + // The base class to convert TensorFlow FusedBatchNormGrad*Op to HLO // BatchNormGradOp for training and a sequence of binary ops for inference. // TODO(b/145536565): move to legalize_tf_patterns.td if it applies. @@ -1716,7 +1814,7 @@ class ConvertFusedBatchNormGradBase act = rewriter.create(loc, act, kernel_type); auto feature_dim_attr = - getFeatureDimensionAttr(rewriter, op.data_formatAttr(), act); + getFeatureDimensionAttr(rewriter, op.data_format(), act); auto feature_dim = feature_dim_attr.getValue().getSExtValue(); // Gets the result values. @@ -1731,7 +1829,7 @@ class ConvertFusedBatchNormGradBase auto training_op = rewriter.create( loc, result_type, act, scale, mean, var, grad, op.epsilon(), - feature_dim_attr.getValue()); + feature_dim); x_backprop = rewriter.create(loc, training_op.getResult(), 0); @@ -1783,11 +1881,27 @@ class ConvertFusedBatchNormGradBase } x_backprop = rewriter.create(loc, x_backprop, act_ele_type); - // It doesn't matter what values we provide for the last 2 results. - rewriter.replaceOp(op, - {/*x_backprop=*/x_backprop, - /*scale_backprop=*/scale_backprop, - /*offset_backprop=*/offset_backprop, op.x(), op.x()}); + Value last_val[2]; + if (op.getResult(3).use_empty() && op.getResult(4).use_empty()) { + // It doesn't matter what values we provide for the last 2 results. + last_val[0] = last_val[1] = op.x(); + } else { + auto const_val = rewriter.create( + op.getLoc(), + DenseElementsAttr::get( + RankedTensorType::get({0}, getElementTypeOrSelf(op.getResult(3))), + 0.0)); + auto maybe_cast = [&](Value val, Type t) -> Value { + if (val.getType() == t) return val; + return rewriter.create(op.getLoc(), t, val); + }; + last_val[0] = maybe_cast(const_val, op.getResult(3).getType()); + last_val[1] = maybe_cast(const_val, op.getResult(4).getType()); + } + rewriter.replaceOp( + op, {/*x_backprop=*/x_backprop, + /*scale_backprop=*/scale_backprop, + /*offset_backprop=*/offset_backprop, last_val[0], last_val[1]}); return success(); } }; @@ -1810,7 +1924,7 @@ class ConvertFusedBatchNormBase : public OpRewritePattern { LogicalResult matchAndRewrite(FusedBatchNormOpT op, PatternRewriter &rewriter) const override { auto feature_dim = - getFeatureDimensionAttr(rewriter, op.data_formatAttr(), op.x()); + getFeatureDimensionAttr(rewriter, op.data_format(), op.x()); auto input_type_tensor = op.x().getType().template cast(); auto input_element_type = input_type_tensor.getElementType(); @@ -1851,7 +1965,7 @@ class ConvertFusedBatchNormBase : public OpRewritePattern { auto bn_train_op = rewriter.create( op.getLoc(), result_type, bn_train_input, op.scale(), op.offset(), - op.epsilon(), feature_dim.getValue()); + op.epsilon(), feature_dim.getInt()); // HLO op outputs a tuple of tensors. Extract those results. auto bn_train_op_result = bn_train_op.getResult(); Value y_out = rewriter.create( @@ -1938,7 +2052,7 @@ class ConvertFusedBatchNormBase : public OpRewritePattern { op.getLoc(), /*result_type=*/bn_train_input_type_tensor, bn_train_input, op.scale(), op.offset(), op.mean(), op.variance(), op.epsilon(), - feature_dim.getValue()); + feature_dim.getInt()); // Convert back to input type to stay aligned with expected output type // for TF op. @@ -2376,6 +2490,12 @@ class ConvertMaxPoolOp : public OpRewritePattern { Type element_type = op.input().getType().template cast().getElementType(); if (!element_type.isSignlessIntOrFloat()) return failure(); + tensorflow::Padding padding; + if (!GetPaddingFromString(op.padding().str(), &padding).ok()) + return failure(); + if (padding == tensorflow::Padding::EXPLICIT) { + return failure(); + } Location loc = op.getLoc(); ConstOp init = GetScalarLimitConstOfType(element_type, loc, hlo::kInfinityLowest, &rewriter); @@ -3087,7 +3207,7 @@ class ConvertStridedSliceOp : public OpRewritePattern { // axis. For instance, if there are 4 dims, we can support a // shrink_axis_mask of 0001 (1), 0011 (3), 0111 (7), or 1111 (15), but no // other. - bool shrink_axis_mask_ok = op.shrink_axis_mask().isMask(); + bool shrink_axis_mask_ok = llvm::isMask_64(op.shrink_axis_mask()); if (!shrink_axis_mask_ok) return rewriter.notifyMatchFailure( op, @@ -3096,27 +3216,27 @@ class ConvertStridedSliceOp : public OpRewritePattern { // When begin/end values are dynamic, the ellipsis mask, if set, must refer // to the last dimension. - int ellipsis_mask = op.ellipsis_mask().getZExtValue(); + int ellipsis_mask = op.ellipsis_mask(); if (!(ellipsis_mask == 0 || ellipsis_mask == (1 << last_dim))) return rewriter.notifyMatchFailure( op, "requires that ellipsis_mask, if set, refer to the last dimension of " "input (when begin/end values are dynamic)"); - APInt begin_mask = op.begin_mask(); - if (!begin_mask.isNullValue()) + uint64_t begin_mask = op.begin_mask(); + if (begin_mask) return rewriter.notifyMatchFailure( op, "requires that begin_mask is either set to 0 or not set when " "begin/end values are dynamic"); - APInt end_mask = op.end_mask(); - if (!end_mask.isNullValue()) + uint64_t end_mask = op.end_mask(); + if (end_mask) return rewriter.notifyMatchFailure( op, "requires that end_mask is either set to 0 or not set when begin/end " "values are dynamic"); - APInt new_axis_mask = op.new_axis_mask(); - if (!new_axis_mask.isNullValue()) + uint64_t new_axis_mask = op.new_axis_mask(); + if (new_axis_mask) return rewriter.notifyMatchFailure( op, "requires that new_axis_mask is either set to 0 or not set when " @@ -3148,11 +3268,12 @@ class ConvertStridedSliceOp : public OpRewritePattern { SmallVector slice_begin_indices; // For the dimensions that are to be sliced, all have slice sizes of 1. SmallVector slice_sizes(slicing_dim_size, 1); - auto input_element_ty = input_ty.getElementType(); + auto begin_element_ty = + op.begin().getType().cast().getElementType(); // Scalar tensor type. - TensorType type = RankedTensorType::get(/*shape=*/{}, input_element_ty); + TensorType type = RankedTensorType::get(/*shape=*/{}, begin_element_ty); Location loc = op.getLoc(); - auto zero = GetScalarConstOfType(input_element_ty, loc, 0, &rewriter); + auto zero = GetScalarConstOfType(begin_element_ty, loc, 0, &rewriter); for (int d = 0; d < slicing_dim_size; ++d) { auto index = rewriter.create( loc, op.begin(), GetI64ElementsAttr({d}, &rewriter), @@ -3163,7 +3284,7 @@ class ConvertStridedSliceOp : public OpRewritePattern { // If the index is negative, wrap it around with dimension size. auto index_negative = rewriter.create(loc, reshaped_index, zero); - auto input_val = GetScalarConstOfType(input_element_ty, loc, + auto input_val = GetScalarConstOfType(begin_element_ty, loc, input_shape[d], &rewriter); auto wrapped_index = rewriter.create(loc, input_val, reshaped_index); @@ -3502,6 +3623,13 @@ class ConvertLinSpaceOp : public OpRewritePattern { /// `is_accumulation` controls whether it uses higher precision for the actual /// reduction. This is set to false for ops like max where there is no precision /// concerns. +// +// The Derived class should have a static method to return the initial value to +// use for reduction: +// static Value GetInitialValue(Type reduce_element_type, Location loc, +// PatternRewriter *rewriter); +// The reduce_element_type is guaranteed to be a float, int, or complex type +// suitable for use with GetScalarConstOfType or GetScalarLimitConstOfType. template class GenericConvertReductionOp : public OpRewritePattern { @@ -3535,6 +3663,14 @@ class GenericConvertReductionOp : public OpRewritePattern { Location loc = op.getLoc(); Type element_type = input_ty.getElementType(); + + // Only float, int, and complex types are currently supported. + if (!element_type.isa() && !element_type.isa() && + !element_type.isa()) { + return rewriter.notifyMatchFailure( + op, "element type must be float, int, or complex type"); + } + // Convert to an accumulation type to not lose precision when doing // repeated arithmetic operations. Type reduce_element_type = @@ -4372,7 +4508,7 @@ class ConvertOneHotOp : public OpRewritePattern { } int64_t depth = depth_attr.getValue({}).getSExtValue(); - int64_t axis = op.axis().getSExtValue(); + int64_t axis = op.axis(); if (axis == -1) axis = indices_shape.size(); llvm::SmallVector broadcast_dims(indices_shape.size()); @@ -4602,10 +4738,8 @@ class ConvertTopKV2Op : public OpRewritePattern { &rewriter); // Get the sorted input and index tuple element. - auto tuple_first_element = - rewriter.create(op.getLoc(), sort_op, 0); - auto tuple_second_element = - rewriter.create(op.getLoc(), sort_op, 1); + auto tuple_first_element = sort_op.getResult(0); + auto tuple_second_element = sort_op.getResult(1); SmallVector begin_indices(input_rank, 0); auto end_indices = llvm::to_vector<4>(input_type.getShape()); @@ -4648,7 +4782,7 @@ class ConvertUnpackOp : public OpRewritePattern { if (!value_type) return failure(); int64_t value_rank = value_type.getRank(); - int64_t axis = op.axis().getSExtValue(); + int64_t axis = op.axis(); if (axis < 0) axis += value_rank; // Parameters for constructing each slice. @@ -4891,8 +5025,7 @@ class ConvertRandomShuffleOp : public OpRewritePattern { BuildSortComparisonBody({i32_type, input_type.getElementType()}, /*direction=*/"LT", &sorted.comparator(), &rewriter); - current = rewriter.create(op.getLoc(), - sorted.getResult(), 1); + current = sorted.getResult(1); } rewriter.replaceOp(op, current); return success(); @@ -5090,6 +5223,46 @@ class ConvertXlaDynamicUpdateSliceOp } }; +// Converts ClipByValue to XLA's clamp operation. Includes the broadcasting +// semantics for static and dynamic cases. +class ConvertClipByValueOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::ClipByValueOp op, + PatternRewriter &rewriter) const override { + Value input = op.t(); + Value min = op.clip_value_min(); + Value max = op.clip_value_max(); + + auto input_ty = input.getType().cast(); + auto min_ty = min.getType().cast(); + auto max_ty = max.getType().cast(); + + if (!input_ty.hasRank() || !min_ty.hasRank() || !max_ty.hasRank()) { + return failure(); + } + + auto shape = rewriter.create( + op.getLoc(), + RankedTensorType::get({input_ty.getRank()}, rewriter.getI32Type()), + input); + + if (min_ty != input_ty) { + min = + rewriter.create(op.getLoc(), input_ty, min, shape); + } + + if (max_ty != input_ty) { + max = + rewriter.create(op.getLoc(), input_ty, max, shape); + } + + rewriter.replaceOpWithNewOp(op, input_ty, min, input, max); + return success(); + } +}; + // Converts the Cumsum or Cumprod TensorFlow op to the HLO ReduceWindow op by // setting appropriate window dimensions, with the given aggregation op as the // reduction function. The input tensor needs to have a static shape, and 'axis' @@ -5229,6 +5402,101 @@ class ConvertShapeOp : public OpRewritePattern { } }; +class ConvertDynamicReshapeOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::ReshapeOp op, + PatternRewriter &rewriter) const override { + auto tensor = op.tensor(); + auto shape = op.shape(); + + auto tensor_ty = tensor.getType().cast(); + auto shape_ty = shape.getType().cast(); + auto result_ty = op.getType().cast(); + + if (!result_ty.hasRank() || !tensor_ty.hasRank() || !shape_ty.hasRank()) { + return failure(); + } + + // Handle with the static case. + if (result_ty.hasStaticShape()) { + return failure(); + } + + rewriter.replaceOpWithNewOp(op, result_ty, tensor, + shape); + return success(); + } +}; + +class ConvertDynamicExpandDimsOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::ExpandDimsOp op, + PatternRewriter &rewriter) const override { + auto input = op.input(); + auto input_ty = input.getType().cast(); + auto result_ty = op.getType().cast(); + if (!result_ty.hasRank() || !input_ty.hasRank() || + result_ty.hasStaticShape()) { + return failure(); + } + + DenseIntElementsAttr expand_dims_attr; + if (!matchPattern(op.dim(), m_Constant(&expand_dims_attr))) { + return failure(); + } + + auto shape = rewriter.create( + op.getLoc(), + RankedTensorType::get({input_ty.getRank()}, rewriter.getIndexType()), + input); + auto expand_dims = llvm::to_vector<6>(expand_dims_attr.getIntValues()); + + llvm::SmallVector dims; + dims.resize(result_ty.getRank()); + + auto inserted_dim = expand_dims_attr.getValue({}) + .cast() + .getValue() + .getSExtValue(); + + // Handle the negative value use case. + if (inserted_dim < 0) { + inserted_dim += result_ty.getRank(); + // This means the value is completely incorrect, just return. + if (inserted_dim < 0) { + return failure(); + } + } + + dims[inserted_dim] = rewriter.create(op.getLoc(), 1); + + for (int i = 0; i < dims.size() - 1; i++) { + // Add the extracted dim. + auto index = rewriter.create(op.getLoc(), i); + auto dim = rewriter.create( + op.getLoc(), rewriter.getIndexType(), shape, index); + + dims[i >= inserted_dim ? i + 1 : i] = dim; + } + + auto from_extents = rewriter.create( + op.getLoc(), shape::ShapeType::get(op.getContext()), dims); + + auto to_extent_tensor = rewriter.create( + op.getLoc(), + RankedTensorType::get({result_ty.getRank()}, rewriter.getIndexType()), + from_extents); + + rewriter.replaceOpWithNewOp(op, result_ty, input, + to_extent_tensor); + return success(); + } +}; + // Converts a TF QR op to HLO. class ConvertQrOp : public OpRewritePattern { public: @@ -5728,7 +5996,7 @@ class ConvertQrOp : public OpRewritePattern { void EmitLegalizationErrors(Operation *op, const DenseSet &nonlegalized_ops) { // Track the legalization failures by mapping op name to information about - // that failure: the number of unlegalized occurances of the op, and one + // that failure: the number of unlegalized occurrences of the op, and one // example operation that failed. std::map> op_name_to_error_info; DenseSet error_ops; @@ -5823,12 +6091,6 @@ LogicalResult legalizeTF( ConversionTarget target(*context); if (legalize_chlo) { target.addIllegalDialect(); - - // Mark ConstantLikeOp as dynamically legal only when it doesn't have a - // static result type so that it gets canonicalized to MHLO constant. - target.addDynamicallyLegalOp([](Operation *op) { - return !op->getResultTypes().front().cast().hasStaticShape(); - }); } else { target.addLegalDialect(); } @@ -5858,14 +6120,16 @@ LogicalResult legalizeTF( void PopulateLegalizeTfPatterns(MLIRContext *context, OwningRewritePatternList *patterns) { - populateWithGenerated(context, patterns); + populateWithGenerated(context, *patterns); patterns->insert< ConvertAllOp, ConvertAnyOp, ConvertArgMaxOp, ConvertBatchMatMulV2Op, ConvertBiasAddOp, ConvertBroadcastToOp, ConvertBF16FloorDivOp, - ConvertConv2DOp, ConvertConv3DOp, ConvertDepthConv2DOp, - ConvertConv2DBackpropFilterOp, ConvertConv3DBackpropFilterOp, - ConvertConv2DBackpropInputOp, ConvertConv3DBackpropInputOp, - ConvertCumprodOp, ConvertCumsumOp, ConvertDiagPartOp, ConvertEinsumOp, + ConvertClipByValueOp, ConvertConv2DOp, ConvertConv3DOp, + ConvertDepthConv2DOp, ConvertConv2DBackpropFilterOp, + ConvertConv3DBackpropFilterOp, ConvertConv2DBackpropInputOp, + ConvertConv3DBackpropInputOp, ConvertCumprodOp, ConvertCumsumOp, + ConvertDiagPartOp, ConvertDynamicExpandDimsOp, ConvertDynamicReshapeOp, + ConvertEinsumOp, ConvertRFFTOp, ConvertIRFFTOp, ConvertFusedBatchNormGradOp, ConvertFusedBatchNormGradV2Op, ConvertFusedBatchNormGradV3Op, ConvertFusedBatchNormV2Op, ConvertFusedBatchNormV3Op, ConvertInfeedDequeueTupleOp, diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_communication.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_communication.cc index 1f884b1bdea..6320ad2032b 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_communication.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_communication.cc @@ -60,6 +60,10 @@ const char kXlaHostTransferOriginalTypeAttr[] = // ops other than certain control flow ops (`mhlo.if`, `mhlo.while`). class LegalizeTFCommunication : public PassWrapper> { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + public: void runOnOperation() override; }; diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc index 760252331e0..692b2af7cff 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc @@ -20,30 +20,24 @@ limitations under the License. #include #include #include +#include #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" -#include "llvm/ADT/SmallPtrSet.h" -#include "llvm/ADT/SmallSet.h" -#include "llvm/ADT/StringSet.h" -#include "llvm/ADT/iterator_range.h" +#include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project -#include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project #include "mlir/IR/Function.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Module.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project -#include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project -#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "mlir/Transforms/RegionUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" -#include "tensorflow/core/util/tensor_format.h" using mlir::PassRegistration; @@ -64,7 +58,7 @@ createLegalizeTFControlFlowPass() { namespace { -void Detuple(Value tuple, Operation::result_range replace, OpBuilder* builder) { +void Detuple(Value tuple, ValueRange replace, OpBuilder* builder) { // De-tuple the results of the xla hlo if result. for (auto result_it : llvm::enumerate(replace)) { auto get_tuple_value = builder->create( @@ -102,7 +96,7 @@ void ImportXlaRegion(mlir::FuncOp func, Region* dest_region, Location loc, } } -void LowerIf(TF::IfOp op, ModuleOp module) { +void LowerIf(TF::IfOp op) { Location loc = op.getLoc(); OpBuilder builder(op); @@ -111,7 +105,7 @@ void LowerIf(TF::IfOp op, ModuleOp module) { SmallVector inputs(op.input()); auto tuple_input = builder.create(loc, inputs); - // Create the new if op with tuple inputs. + // Create the new `mhlo.if` op with tuple inputs. auto result_type = builder.getTupleType(op.getResultTypes()); auto if_op = builder.create(loc, result_type, op.cond(), tuple_input, tuple_input); @@ -119,15 +113,15 @@ void LowerIf(TF::IfOp op, ModuleOp module) { // Import the regions for both the true and false cases. These regions // must be updated to tuple the return results together and use the xla hlo // return op. - ImportXlaRegion(op.then_func(), &if_op.true_branch(), loc); - ImportXlaRegion(op.else_func(), &if_op.false_branch(), loc); + ImportXlaRegion(op.then_function(), &if_op.true_branch(), loc); + ImportXlaRegion(op.else_function(), &if_op.false_branch(), loc); - // De-tuple the results of the xla hlo if result. + // De-tuple the results of the `mhlo.if`. Detuple(if_op.getResult(), op.getResults(), &builder); op.erase(); } -void LowerCase(TF::CaseOp op, ModuleOp module) { +void LowerCase(TF::CaseOp op) { Location loc = op.getLoc(); OpBuilder builder(op); @@ -137,17 +131,16 @@ void LowerCase(TF::CaseOp op, ModuleOp module) { auto tuple_input = builder.create(loc, inputs); // Create replica of input tuple for each branch - SmallVector n_tuple_inputs(op.branches().size(), tuple_input); + SmallVector n_tuple_inputs(op.num_branches(), tuple_input); - // Create the new case op with tuple inputs. + // Create the new `mhlo.case` op with tuple inputs. auto case_op = builder.create(loc, op.getResultTypes(), op.branch_index(), n_tuple_inputs, op.branches().size()); // Import the regions for all branches. - for (unsigned i = 0; i < op.branches().size(); ++i) { - mlir::FuncOp branch_func = module.lookupSymbol( - op.branches()[i].cast()); + for (unsigned i = 0; i < op.num_branches(); ++i) { + mlir::FuncOp branch_func = op.branch_function(i); ImportXlaRegion(branch_func, &case_op.branches()[i], loc, /*tuple_return=*/false); } @@ -156,7 +149,7 @@ void LowerCase(TF::CaseOp op, ModuleOp module) { op.erase(); } -void LowerWhile(TF::WhileOp op, ModuleOp module) { +void LowerWhile(TF::WhileOp op) { Location loc = op.getLoc(); OpBuilder builder(op); @@ -166,36 +159,238 @@ void LowerWhile(TF::WhileOp op, ModuleOp module) { builder.setInsertionPoint(op); Value tuple_input = builder.create(loc, inputs); - // Create the new while op with tuple inputs. + // Create the new `mhlo.while` op with tuple inputs. auto while_op = builder.create( loc, builder.getTupleType(op.getResultTypes()), tuple_input); // Import the regions for both the cond and body. These regions must be // updated to tuple the return results together and use the xla hlo return op. - ImportXlaRegion(op.body_func(), &while_op.body(), loc); - ImportXlaRegion(op.cond_func(), &while_op.cond(), loc, + ImportXlaRegion(op.body_function(), &while_op.body(), loc); + ImportXlaRegion(op.cond_function(), &while_op.cond(), loc, /*tuple_return=*/false); - // De-tuple the results of the xla hlo while. + // De-tuple the results of the `mhlo.while`. + Detuple(while_op.getResult(), op.getResults(), &builder); + op.erase(); +} + +// Replaces all block arguments of a block with a single block arg of Tuple +// type `tuple_type`. Single block arguments are removed and remapped to +// get_tuple_element(tuple_arg, index). +void ReplaceBlockArgs(Block* block, Type tuple_type, OpBuilder* builder) { + auto tuple_arg = block->addArgument(tuple_type); + Detuple(tuple_arg, block->getArguments().drop_back(1), builder); + for (int i = block->getNumArguments() - 2; i >= 0; --i) + block->eraseArgument(i); +} + +// Replaces implicitly captured value uses with tuple block argument. +// get_tuple_element's are created to extract specific values. Values from +// get_tuple_element's are returned in the order of `implicit_inputs`. +llvm::SmallVector ReplaceImplicitInputs( + Block* block, int offset, ArrayRef implicit_inputs, + OpBuilder* builder) { + llvm::SmallVector implicit_input_elements; + implicit_input_elements.reserve(implicit_inputs.size()); + + Region* region = block->getParent(); + assert(block->getNumArguments() == 1); + + BlockArgument tuple_arg = block->getArgument(0); + for (auto& implicit_input : llvm::enumerate(implicit_inputs)) { + Value implicit_input_value = implicit_input.value(); + auto get_tuple_element = builder->create( + implicit_input_value.getLoc(), tuple_arg, + implicit_input.index() + offset); + implicit_input_elements.emplace_back(get_tuple_element.getResult()); + for (auto& use : + llvm::make_early_inc_range(implicit_input_value.getUses())) { + if (!region->isAncestor(use.getOwner()->getParentRegion())) continue; + use.set(get_tuple_element.getResult()); + } + } + + return implicit_input_elements; +} + +// Finds and replaces implicitly captured value uses with tuple block argument. +// A tuple of implicitly captured values is also created and returned, for use +// as an operand to the associated mhlo control flow op. +Value TupleImplicitInputs(Region& region, Location loc, OpBuilder* builder) { + llvm::SetVector implicit_inputs; + getUsedValuesDefinedAbove(region, region, implicit_inputs); + llvm::ArrayRef implicit_inputs_ref = implicit_inputs.getArrayRef(); + Value tuple_input = builder->create(loc, implicit_inputs_ref); + Block& block = region.front(); + // `tf.CaseRegion`/`tf.IfRegion` are expected to have no block arguments and + // instead all inputs used by their branch regions are implicitly captured + // from above. + assert(block.getNumArguments() == 0); + block.addArgument(tuple_input.getType()); + builder->setInsertionPointToStart(&block); + ReplaceImplicitInputs(&block, /*offset=*/0, implicit_inputs_ref, builder); + return tuple_input; +} + +// Replaces block terminator (tf.Yield) with `mhlo.return`. Additional results +// can be returned if `extra_results` is not empty. If `tuple_return` is +// set, a tuple of the return values will be set as the terminator operand. +void ReplaceTerminator(Block* block, ArrayRef extra_results, + OpBuilder* builder, bool tuple_return = true) { + Operation* terminator = block->getTerminator(); + assert(isa(terminator)); + Location loc = terminator->getLoc(); + + builder->setInsertionPoint(terminator); + auto results = llvm::to_vector<4>(terminator->getOperands()); + results.append(extra_results.begin(), extra_results.end()); + if (tuple_return) { + auto tuple_results = builder->create(loc, results); + builder->create(loc, tuple_results.getResult()); + } else { + builder->create(loc, results); + } + + terminator->erase(); +} + +void LowerIfRegion(TF::IfRegionOp op) { + Location loc = op.getLoc(); + OpBuilder builder(op); + + // Tuple implicit inputs per region and update terminators to return tuples. + builder.setInsertionPoint(op); + Value then_input = TupleImplicitInputs(op.then_branch(), loc, &builder); + ReplaceTerminator(&op.then_branch().front(), /*extra_results=*/{}, &builder); + + builder.setInsertionPoint(op); + Value else_input = TupleImplicitInputs(op.else_branch(), loc, &builder); + ReplaceTerminator(&op.else_branch().front(), /*extra_results=*/{}, &builder); + + // Create the new `mhlo.if` op with tuple inputs and take ownership of regions + // from `tf.IfRegion` op. + builder.setInsertionPoint(op); + auto result_type = builder.getTupleType(op.getResultTypes()); + auto if_op = builder.create(loc, result_type, op.cond(), + then_input, else_input); + if_op.true_branch().takeBody(op.then_branch()); + if_op.false_branch().takeBody(op.else_branch()); + + // De-tuple the results of the `mhlo.if`. + Detuple(if_op.getResult(), op.getResults(), &builder); + op.erase(); +} + +void LowerCaseRegion(TF::CaseRegionOp op) { + Location loc = op.getLoc(); + OpBuilder builder(op); + + llvm::SmallVector branch_inputs; + branch_inputs.reserve(op.branches().size()); + // Tuple implicit inputs per region and update terminators. + for (Region& region : op.branches()) { + builder.setInsertionPoint(op); + Value branch_input = TupleImplicitInputs(region, loc, &builder); + branch_inputs.emplace_back(branch_input); + ReplaceTerminator(®ion.front(), /*extra_results=*/{}, &builder, + /*tuple_return=*/false); + } + + // Create the new `mhlo.case` op with tuple inputs and take ownership of + // regions from `tf.CaseRegion` op. + builder.setInsertionPoint(op); + auto case_op = + builder.create(loc, op.getResultTypes(), op.branch_index(), + branch_inputs, branch_inputs.size()); + for (auto region : llvm::zip(case_op.branches(), op.branches())) + std::get<0>(region).takeBody(std::get<1>(region)); + + op.replaceAllUsesWith(case_op.getResults()); + op.erase(); +} + +void LowerWhileRegion(TF::WhileRegionOp op) { + Location loc = op.getLoc(); + OpBuilder builder(op); + + // XLA prefers tuple arguments for control flow due to XLA not supporting + // multiple return values. + SmallVector inputs(op.input()); + const int inputs_size = inputs.size(); + llvm::SetVector implicit_inputs; + getUsedValuesDefinedAbove(op.getOperation()->getRegions(), implicit_inputs); + inputs.append(implicit_inputs.begin(), implicit_inputs.end()); + + builder.setInsertionPoint(op); + Value tuple_input = builder.create(loc, inputs); + + // Create the new `mhlo.while` op with tuple inputs. Implicit inputs are also + // returned. + auto while_result_types = llvm::to_vector<4>(op.getResultTypes()); + while_result_types.reserve(while_result_types.size() + + implicit_inputs.size()); + for (const auto& implicit_input : implicit_inputs) + while_result_types.emplace_back(implicit_input.getType()); + auto while_op = builder.create( + loc, builder.getTupleType(while_result_types), tuple_input); + + // Rewrite cond and associated block arguments and terminator. Ownership of + // cond region is transfered over from `tf.WhileRegion` to `mhlo.while`. + Region& cond = while_op.cond(); + cond.takeBody(op.cond()); + Block& cond_block = cond.front(); + builder.setInsertionPointToStart(&cond_block); + ReplaceBlockArgs(&cond_block, tuple_input.getType(), &builder); + ReplaceImplicitInputs(&cond_block, inputs_size, implicit_inputs.getArrayRef(), + &builder); + // Cond always returns a single result of bool type. + ReplaceTerminator(&cond_block, /*extra_results=*/{}, &builder, + /*tuple_return=*/false); + + // Rewrite body and associated block arguments and terminator. Ownership of + // body region is transfered over from `tf.WhileRegion` to `mhlo.while`. + Region& body = while_op.body(); + body.takeBody(op.body()); + Block& body_block = body.front(); + builder.setInsertionPointToStart(&body_block); + ReplaceBlockArgs(&body_block, tuple_input.getType(), &builder); + // Capture implicit inputs that were added as a tuple block arguments. These + // are to be returned by the body in addition to explicit inputs. + auto implicit_input_elements = ReplaceImplicitInputs( + &body_block, inputs_size, implicit_inputs.getArrayRef(), &builder); + ReplaceTerminator(&body_block, implicit_input_elements, &builder); + + // De-tuple the results of the `mhlo.while`. + builder.setInsertionPoint(op); Detuple(while_op.getResult(), op.getResults(), &builder); op.erase(); } } // namespace void LegalizeTFControlFlow::runOnOperation() { - auto module = getOperation(); - - module.walk([&](Operation* op) { + getOperation().walk([&](Operation* op) { if (auto while_op = dyn_cast(op)) { - LowerWhile(while_op, module); + LowerWhile(while_op); + return; + } + if (auto while_region_op = dyn_cast(op)) { + LowerWhileRegion(while_region_op); return; } if (auto if_op = dyn_cast(op)) { - LowerIf(if_op, module); + LowerIf(if_op); + return; + } + if (auto if_region_op = dyn_cast(op)) { + LowerIfRegion(if_region_op); return; } if (auto case_op = dyn_cast(op)) { - LowerCase(case_op, module); + LowerCase(case_op); + return; + } + if (auto case_region_op = dyn_cast(op)) { + LowerCaseRegion(case_region_op); return; } }); diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td index 73ce305091c..52bbbf6f9da 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td @@ -31,7 +31,7 @@ def IEEEFloatTensor : TensorOf<[F16, F32, F64]>; //===----------------------------------------------------------------------===// def FeatureDimension : NativeCodeCall< - "getFeatureDimensionAttr($_builder, $0, $1)">; + "getFeatureDimensionAttr($_builder, $0.getValue(), $1)">; def FalseBoolAttr : AttrConstraint>; def TrueBoolAttr : AttrConstraint>; @@ -86,7 +86,7 @@ def AreBroadcastCompatible : Constraint, "types must be broadcastable">; class DirectBinaryPat - : Pat<(FromOp AnyRankedTensor:$l, AnyRankedTensor:$r), + : Pat<(FromOp AnyTensor:$l, AnyTensor:$r), (ToOp $l, $r, (BinBroadcastDimensions $l, $r))>; foreach fromToBinPair = [[TF_AddOp, HLOClient_BroadcastAddOp], @@ -285,9 +285,19 @@ def : Pat<(TF_AllToAllOp AnyRankedTensor:$input, (TF_ConstOp $group_assignment), // FFT op patterns. //===----------------------------------------------------------------------===// -def : Pat<(TF_RFFTOp $input, (TF_ConstOp I32ElementsAttr:$fft_length)), - (HLO_FftOp $input, HLO_FFT_TYPE_RFFT, - (CastElementsToI64Elements $fft_length))>; +def GetInnerDimFromValue : NativeCodeCall< + "GetInnerDimFromValue($0.getType().cast(), &$_builder)">; + +def CheckInnerDimStatic + : Constraint(), &$_builder)">>; + +def : Pat<(TF_FFTOp:$res $input), + (HLO_FftOp $input, HLO_FFT_TYPE_FFT, (GetInnerDimFromValue $res)), + [(CheckInnerDimStatic $input)]>; + +def : Pat<(TF_IFFTOp:$res $input), + (HLO_FftOp $input, HLO_FFT_TYPE_IFFT, (GetInnerDimFromValue $res)), + [(CheckInnerDimStatic $input)]>; //===----------------------------------------------------------------------===// // GatherV2 op patterns. @@ -562,12 +572,14 @@ def : Pat<(TF_ReverseV2Op AnyRankedTensor:$values, (TF_ConstOp $axis)), foreach Mapping = [ [TF_AbsOp, HLO_AbsOp], [TF_AcosOp, HLOClient_AcosOp], + [TF_AtanOp, HLOClient_AtanOp], [TF_CeilOp, HLO_CeilOp], [TF_ComplexAbsOp, HLO_AbsOp], [TF_CosOp, HLO_CosOp], [TF_ExpOp, HLO_ExpOp], [TF_FloorOp, HLO_FloorOp], [TF_ImagOp, HLO_ImagOp], + [TF_InvertOp, HLO_NotOp], [TF_IsFiniteOp, HLO_IsFiniteOp], [TF_LogOp, HLO_LogOp], [TF_Log1pOp, HLO_Log1pOp], @@ -576,27 +588,16 @@ foreach Mapping = [ [TF_RealOp, HLO_RealOp], [TF_RsqrtOp, HLO_RsqrtOp], [TF_SigmoidOp, HLO_LogisticOp], + [TF_SinhOp, HLOClient_SinhOp], [TF_SinOp, HLO_SinOp], [TF_SqrtOp, HLO_SqrtOp], [TF_TanhOp, HLO_TanhOp], + [TF_TanOp, HLOClient_TanOp], ] in { def : Pat<(Mapping[0] HLO_Tensor:$input), (Mapping[1] $input)>; } -// Expand acos to MHLO dialect as follows: -// acos(x) = 2 * atan(sqrt(1 - x^2) / (1 + x)) if x != -1 -// = pi if x == -1 -def : Pat<(HLOClient_AcosOp $input), (HLO_SelectOp - (HLO_CompareOp $input, (HLO_ConstantLike<"0"> $input), - HLO_COMPARISON_DIRECTION_NE), - (HLO_MulOp (HLO_ConstantLike<"2.0f"> $input), - (HLO_Atan2Op - (HLO_SqrtOp (HLO_SubOp - (HLO_ConstantLike<"1"> $input), (HLO_MulOp $input, $input))), - (HLO_AddOp (HLO_ConstantLike<"1"> $input), $input))), - (HLO_ConstantLike<"M_PI"> $input))>; - // TODO(bixia): Lower Cast with a Complex type source operand or with // Truncate=True for floating point value conversions. def : Pat<(TF_CastOp HLO_Tensor:$arg, ConstBoolAttrFalse), diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc index 2f73d1a54df..b392e91e22f 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc @@ -146,13 +146,12 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), - TypeID::get(), TypeID::get(), TypeID::get(), - TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), @@ -162,7 +161,6 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), - TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), @@ -179,6 +177,7 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), @@ -200,6 +199,7 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), @@ -228,6 +228,7 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), @@ -240,6 +241,8 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get() diff --git a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc index ef362d95b97..33cd2c66c45 100644 --- a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc +++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc @@ -29,7 +29,9 @@ limitations under the License. #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project @@ -40,6 +42,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/xla/hlo_function_importer.h" #include "tensorflow/compiler/mlir/xla/hlo_utils.h" #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h" +#include "tensorflow/compiler/mlir/xla/xla_mlir_translate_cl.h" #include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" @@ -110,7 +113,7 @@ Status ConvertModule(std::unique_ptr hlo_module, ModuleOp module, // Run all HLO passes to produce an optimized module. auto result_or = backend->compiler()->RunHloPassesAndBufferAssignement( std::move(hlo_module), backend->default_stream_executor(), - backend->memory_allocator()); + backend->memory_allocator(), optimize_xla_hlo); TF_RETURN_WITH_CONTEXT_IF_ERROR(result_or.status(), "running XLA pass pipeline"); std::unique_ptr optimized_hlo_module = @@ -276,27 +279,137 @@ Status LhloDialectEmitter::HandleSort(HloInstruction* instr) { return EmitSortOp(instr).status(); } -Status LhloDialectEmitter::CreateView(const HloInstruction* instr, - const Shape& current_shape, - ::xla::ShapeIndex* current_shape_index, - SmallVectorImpl* values) { - if (current_shape.IsTuple()) { - for (int i = 0; i < current_shape.tuple_shapes().size(); i++) { - current_shape_index->push_back(i); - TF_RETURN_IF_ERROR(CreateView(instr, current_shape.tuple_shapes(i), - current_shape_index, values)); - current_shape_index->pop_back(); +// Walks MHLO::TupleOp recursively. +Status WalkTuplePostOrder(Value v, + const std::function& visitor) { + if (auto* op = v.getDefiningOp()) { + if (auto tuple = dyn_cast(op)) { + for (Value sub_v : tuple.val()) { + TF_RETURN_IF_ERROR(WalkTuplePostOrder(sub_v, visitor)); + } + return Status::OK(); } - return Status::OK(); } + return visitor(v); +} + +// This function removes all uses of a fused region argument, and rewire those +// uses to a `tensor_load %memref`, where %memref is caller argument. +// +// It also flattens all input/output tuples into more region arguments / +// results. +StatusOr LhloDialectEmitter::RewriteFusionOperand( + const HloInstruction* root, const Shape& shape, + ::xla::ShapeIndex* shape_index, OpBuilder* b, Location loc) { + if (shape.IsTuple()) { + llvm::SmallVector values; + for (int i = 0; i < shape.tuple_shapes_size(); i++) { + shape_index->push_back(i); + TF_ASSIGN_OR_RETURN( + auto v, RewriteFusionOperand(root, shape.tuple_shapes(i), shape_index, + b, loc)); + values.push_back(v); + shape_index->pop_back(); + } + return Value(b->create(loc, values)); + } + TF_ASSIGN_OR_RETURN(Value memref, + GetOrCreateArrayView(root, shape, *shape_index)); + auto load = b->create(loc, memref); + if (shape.layout() != + xla::LayoutUtil::MakeDescendingLayout(shape.dimensions().size())) { + llvm::SmallVector minor_to_major( + shape.layout().minor_to_major().begin(), + shape.layout().minor_to_major().end()); + load.setAttr("minor_to_major", b->getIndexTensorAttr(minor_to_major)); + } + return load.getResult(); +} + +StatusOr LhloDialectEmitter::EmitFusionOp( + HloInstruction* instr) { + Location loc = getLocation(instr); + + auto* fusion_instr = ::xla::Cast<::xla::HloFusionInstruction>(instr); + + auto fusion = builder_.create(getLocation(instr), + ArrayRef{}); + auto after_fusion = builder_.saveInsertionPoint(); + builder_ = mlir::OpBuilder(fusion); + + auto region_builder = OpBuilder::atBlockBegin(&fusion.region().front()); + + llvm::SmallVector arguments; + for (int i = 0; i < instr->operands().size(); i++) { + const HloInstruction* operand = instr->operand(i); + xla::ShapeIndex shape_index; + TF_ASSIGN_OR_RETURN( + auto arg, RewriteFusionOperand(operand, operand->shape(), &shape_index, + ®ion_builder, loc)); + arguments.push_back(arg); + } + + TF_ASSIGN_OR_RETURN(Value result, + ::xla::HloFunctionImporter::ImportInstructions( + *fusion_instr->fused_instructions_computation(), + arguments, ®ion_builder)); + + { + int i = 0; + llvm::SmallVector output; + TF_RETURN_IF_ERROR(GetOrCreateView(instr, &output)); + TF_RETURN_IF_ERROR(WalkTuplePostOrder(result, [&](Value v) mutable { + region_builder.create(loc, v, output[i++]); + return Status::OK(); + })); + if (i != output.size()) { + return ::xla::InternalError("output sizes don't match"); + } + } + + // Fold GTE/Tuple pairs. + // + // Since the fused region refers to values in its parent region, we can't + // call applyPatternAndFoldGreedily. We optimize it manually. + // + // Only walk once, because post-ordering is exactly what we need for GTE + // optimizations. + fusion.region().walk([](mhlo::GetTupleElementOp gte) { + SmallVector folded_values; + if (succeeded(OpBuilder(gte).tryFold(gte, folded_values))) { + gte.replaceAllUsesWith(folded_values[0]); + } + }); + + // Effectively a DCE on the region. + { + llvm::SmallVector ops; + fusion.region().walk([&](mlir::Operation* op) { ops.push_back(op); }); + // Visit the user first. + std::reverse(ops.begin(), ops.end()); + for (auto op : ops) { + if (isOpTriviallyDead(op)) op->erase(); + } + } + + builder_.restoreInsertionPoint(after_fusion); + return fusion; +} + +Status LhloDialectEmitter::HandleFusion(HloInstruction* instr) { + return EmitFusionOp(instr).status(); +} + +StatusOr LhloDialectEmitter::GetOrCreateArrayView( + const ::xla::HloInstruction* instr, const ::xla::Shape& current_shape, + const ::xla::ShapeIndex& shape_index) { TF_ASSIGN_OR_RETURN(Type out_type, ::xla::ConvertShapeToType( current_shape, builder_)); TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice, - assignment_.GetUniqueSlice(instr, *current_shape_index)); + assignment_.GetUniqueSlice(instr, shape_index)); Value alloc = allocations_[slice.allocation()]; - if (alloc.getType() == out_type) { - values->push_back(alloc); - return Status::OK(); + if (alloc.getType() == out_type && slice.offset() == 0) { + return alloc; } auto out_memref_type = out_type.dyn_cast(); @@ -304,6 +417,13 @@ Status LhloDialectEmitter::CreateView(const HloInstruction* instr, return tensorflow::errors::Internal( "Expected memref type when creating a view for leaf type of a tuple."); + // Cache generated ViewOp and StaticMemRefCastOp by (instruction, + // shape_index). + auto& cached_value = slices_[std::make_pair(instr, shape_index)]; + if (cached_value) { + return cached_value; + } + Value byte_shift = builder_.create(alloc.getLoc(), slice.offset()); @@ -327,7 +447,24 @@ Status LhloDialectEmitter::CreateView(const HloInstruction* instr, if (physical_out_type != out_type) result = builder_.create(loc, out_memref_type, result); - values->push_back(result); + return cached_value = result; +} + +Status LhloDialectEmitter::GetOrCreateViewImpl( + const HloInstruction* instr, const Shape& current_shape, + ::xla::ShapeIndex* current_shape_index, SmallVectorImpl* values) { + if (current_shape.IsTuple()) { + for (int i = 0; i < current_shape.tuple_shapes().size(); i++) { + current_shape_index->push_back(i); + TF_RETURN_IF_ERROR(GetOrCreateViewImpl( + instr, current_shape.tuple_shapes(i), current_shape_index, values)); + current_shape_index->pop_back(); + } + return Status::OK(); + } + TF_ASSIGN_OR_RETURN( + auto v, GetOrCreateArrayView(instr, current_shape, *current_shape_index)); + values->push_back(v); return Status::OK(); } @@ -336,25 +473,8 @@ Status LhloDialectEmitter::CreateView(const HloInstruction* instr, // create another view to adjust the slice for the shape of the instruction. Status LhloDialectEmitter::GetOrCreateView(const HloInstruction* instr, SmallVectorImpl* values) { - // Cache generated ViewOp and StaticMemRefCastOp by instruction. We could have - // gone fancier to do the following cacheing: - // %range = ViewOp(%allocation, %offset) : memref - // %typed_range = ViewOp(%range) : memref - // - // where %range is cached. This in theory gives easier time for alias - // analysis, since the identity of %range defines alias. However, - // %typed_range can't be cached, as different buffers with different types and - // shapes may still alias. Creating two ViewOps doesn't seem to worth the - // effort for a slightly easier aliasing, so we don't over optimize here. - auto result = slices_.try_emplace(instr, llvm::SmallVector{}); - llvm::SmallVectorImpl& new_values = result.first->second; - if (result.second) { - ::xla::ShapeIndex shape_index; - TF_RETURN_IF_ERROR( - CreateView(instr, instr->shape(), &shape_index, &new_values)); - } - values->insert(values->end(), new_values.begin(), new_values.end()); - return Status::OK(); + ::xla::ShapeIndex shape_index; + return GetOrCreateViewImpl(instr, instr->shape(), &shape_index, values); } Status LhloDialectEmitter::Initialize() { @@ -373,7 +493,7 @@ Status LhloDialectEmitter::Initialize() { if (computation_.IsEntryComputation()) { // Sort the rather arbitrarily ordered allocations to match the input/output - // parameters. Specifically We want to sort buffer allocations in the + // parameters. Specifically we want to sort buffer allocations in the // following order: // * Parameters always order before non-parameters. // * Different parameters order by parameter number. @@ -436,8 +556,8 @@ Status LhloDialectEmitter::Initialize() { } } - FunctionType function_type = builder_.getFunctionType( - llvm::to_vector<8>(block->getArgumentTypes()), {}); + FunctionType function_type = + builder_.getFunctionType(block->getArgumentTypes(), {}); func_op.setType(function_type); func_op.setAllArgAttrs(args_attrs); diff --git a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h index 89514116254..a57db3cb67e 100644 --- a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h +++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h @@ -43,6 +43,7 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault { i8_type_(builder_.getIntegerType(8)) {} ::xla::StatusOr EmitSortOp(::xla::HloInstruction* instr); + ::xla::StatusOr EmitFusionOp(::xla::HloInstruction* instr); private: template @@ -57,21 +58,31 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault { } tensorflow::Status HandleSort(::xla::HloInstruction* instr) final; + tensorflow::Status HandleFusion(::xla::HloInstruction* instr) final; // Helper function that recursively visits the tuple structure in // `current_shape`, and reconstruct a matching lmhlo::TupleOp. // Each leaf node is converted to an std.view op with corresponding offsets. // If no tuple presents, it simply returns a view of the buffer. - tensorflow::Status CreateView(const ::xla::HloInstruction* instr, - const ::xla::Shape& current_shape, - ::xla::ShapeIndex* current_shape_index, - SmallVectorImpl* values); + tensorflow::Status GetOrCreateViewImpl(const ::xla::HloInstruction* instr, + const ::xla::Shape& current_shape, + ::xla::ShapeIndex* current_shape_index, + SmallVectorImpl* values); // Helper function to create view/tuple of views to a buffer for a given // instruction result. tensorflow::Status GetOrCreateView(const ::xla::HloInstruction* instr, SmallVectorImpl* values); + ::xla::StatusOr GetOrCreateArrayView( + const ::xla::HloInstruction* instr, const ::xla::Shape& current_shape, + const ::xla::ShapeIndex& current_shape_index); + + ::xla::StatusOr RewriteFusionOperand(const ::xla::HloInstruction* root, + const ::xla::Shape& shape, + ::xla::ShapeIndex* shape_index, + OpBuilder* b, Location loc); + // Return an MLIR location for an HLO instruction. Location getLocation(::xla::HloInstruction* inst) { return NameLoc::get(builder_.getIdentifier(inst->name()), @@ -102,7 +113,8 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault { // // `slices_` is populated lazily in the `GetOrCreateView()` helper as we // process every instruction. - llvm::DenseMap> + absl::flat_hash_map, + Value> slices_; // The BufferAssignment computed by XLA ahead of time. diff --git a/tensorflow/compiler/mlir/xla/type_to_shape.cc b/tensorflow/compiler/mlir/xla/type_to_shape.cc index b725f56b455..3822e10089b 100644 --- a/tensorflow/compiler/mlir/xla/type_to_shape.cc +++ b/tensorflow/compiler/mlir/xla/type_to_shape.cc @@ -102,8 +102,7 @@ Shape TypeToShape(mlir::Type type) { if (ptype != PrimitiveType::PRIMITIVE_TYPE_INVALID) return ShapeUtil::MakeShape(ptype, {}); - if (type.isBF16() || type.isF32() || type.isF64() || - type.isa()) { + if (type.isIntOrFloat()) { auto* context = type.getContext(); mlir::emitError(mlir::UnknownLoc::get(context)) << "lowering should have been handled by primitive type lowering for " @@ -140,7 +139,8 @@ Shape TypeToShape(mlir::Type type) { for (const auto& e : llvm::enumerate(strides)) { strides_with_indices.push_back({e.value(), e.index()}); } - std::sort(strides_with_indices.begin(), strides_with_indices.end()); + std::stable_sort(strides_with_indices.begin(), + strides_with_indices.end()); llvm::SmallVector minor_to_major; int64_t stride = 1; @@ -149,7 +149,7 @@ Shape TypeToShape(mlir::Type type) { // Either the affine map is not perfectly strided, or the dimensions // recovered from strides don't match the actual dimensions in shapes. - if (stride != pr.first) return {}; + if (stride != pr.first && m.getShape()[pr.second] != 1) return {}; stride *= m.getShape()[pr.second]; } diff --git a/tensorflow/compiler/mlir/xla/type_to_shape_test.cc b/tensorflow/compiler/mlir/xla/type_to_shape_test.cc index a4a2bc42d99..97417748b64 100644 --- a/tensorflow/compiler/mlir/xla/type_to_shape_test.cc +++ b/tensorflow/compiler/mlir/xla/type_to_shape_test.cc @@ -196,5 +196,22 @@ TEST(TypeToShapeTest, ConvertMemRefToShape) { EXPECT_TRUE(ShapeUtil::Equal(converted, shape)); } +TEST(TypeToShapeTest, ConvertMemRefToShape2) { + Shape shape = ShapeUtil::MakeShapeWithLayout(PrimitiveType::C64, {2, 4, 3, 3}, + {2, 3, 1, 0}); + MLIRContext context; + mlir::Builder builder(&context); + + StatusOr mlir_type = + ConvertShapeToType(shape, builder); + ASSERT_TRUE(mlir_type.ok()); + mlir::Type type = mlir_type.ConsumeValueOrDie(); + Shape converted = TypeToShape(type); + EXPECT_TRUE(ShapeUtil::Equal( + converted, ShapeUtil::MakeShapeWithLayout(PrimitiveType::C64, + {2, 4, 3, 3}, {2, 3, 1, 0}))); + EXPECT_TRUE(ShapeUtil::Equal(converted, shape)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc b/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc index d5c598615b7..3ee70db1813 100644 --- a/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc +++ b/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.h" #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h" #include "tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h" +#include "tensorflow/compiler/mlir/xla/xla_mlir_translate_cl.h" #include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" @@ -33,19 +34,6 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/protobuf.h" -// NOLINTNEXTLINE -static llvm::cl::opt emit_use_tuple_arg( - "emit-use-tuple-args", - llvm::cl::desc( - "Emit HLO modules using tuples as args for the entry computation"), - llvm::cl::init(false)); - -// NOLINTNEXTLINE -static llvm::cl::opt emit_return_tuple( - "emit-return-tuple", - llvm::cl::desc("Emit HLO modules with entry computations returning tuple"), - llvm::cl::init(false)); - namespace xla { namespace { @@ -136,13 +124,16 @@ static StatusOr> HloModuleFromProto( return HloModule::CreateFromProto(module_proto, module_config); } -static mlir::LogicalResult MlirHloToHloTextTranslateFunction( - mlir::ModuleOp module, llvm::raw_ostream& output) { +static mlir::LogicalResult MlirHloToHloTextTranslateFunctionImpl( + mlir::ModuleOp module, llvm::raw_ostream& output, bool with_layouts) { if (!module) return mlir::failure(); HloProto hloProto; + mlir::MlirToHloConversionOptions options; + options.propagate_layouts = with_layouts; Status status = mlir::ConvertMlirHloToHlo( - module, &hloProto, emit_use_tuple_arg, emit_return_tuple); + module, &hloProto, emit_use_tuple_arg, emit_return_tuple, + /*shape_representation_fn=*/nullptr, options); if (!status.ok()) { LOG(ERROR) << "Module conversion failed: " << status; return mlir::failure(); @@ -158,9 +149,8 @@ static mlir::LogicalResult MlirHloToHloTextTranslateFunction( HloModule* hlo_module = statusOrHloModule.ValueOrDie().get(); - // We don't interpret or use layouts output << hlo_module->ToString( - HloPrintOptions().set_include_layout_in_shapes(false)); + HloPrintOptions().set_include_layout_in_shapes(with_layouts)); // Output alias information as comments in the HLO text. hlo_module->input_output_alias_config().ForEachAlias( @@ -174,6 +164,18 @@ static mlir::LogicalResult MlirHloToHloTextTranslateFunction( return mlir::success(); } +static mlir::LogicalResult MlirHloToHloTextTranslateFunction( + mlir::ModuleOp module, llvm::raw_ostream& output) { + return MlirHloToHloTextTranslateFunctionImpl(module, output, + /*with_layouts=*/false); +} + +static mlir::LogicalResult MlirHloToHloTextWithLayoutsTranslateFunction( + mlir::ModuleOp module, llvm::raw_ostream& output) { + return MlirHloToHloTextTranslateFunctionImpl(module, output, + /*with_layouts=*/true); +} + } // namespace xla static void RegisterInputDialects(mlir::DialectRegistry& registry) { @@ -188,6 +190,10 @@ static mlir::TranslateFromMLIRRegistration MlirHloToHloTextTranslate( "mlir-hlo-to-hlo-text", xla::MlirHloToHloTextTranslateFunction, RegisterInputDialects); +static mlir::TranslateFromMLIRRegistration MlirHloToHloTextWithLayoutsTranslate( + "mlir-hlo-to-hlo-text-with-layouts", + xla::MlirHloToHloTextWithLayoutsTranslateFunction, RegisterInputDialects); + static mlir::TranslateToMLIRRegistration HloToHloMlirTranslate( "hlo-to-mlir-hlo", xla::HloToMlirHloTranslateFunction); diff --git a/tensorflow/compiler/mlir/xla/xla_mlir_translate_cl.cc b/tensorflow/compiler/mlir/xla/xla_mlir_translate_cl.cc new file mode 100644 index 00000000000..7eb1fb40f5e --- /dev/null +++ b/tensorflow/compiler/mlir/xla/xla_mlir_translate_cl.cc @@ -0,0 +1,35 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/xla/xla_mlir_translate_cl.h" + +// NOLINTNEXTLINE +llvm::cl::opt emit_use_tuple_arg( + "emit-use-tuple-args", + llvm::cl::desc( + "Emit HLO modules using tuples as args for the entry computation"), + llvm::cl::init(false)); + +// NOLINTNEXTLINE +llvm::cl::opt emit_return_tuple( + "emit-return-tuple", + llvm::cl::desc("Emit HLO modules with entry computations returning tuple"), + llvm::cl::init(false)); + +// NOLINTNEXTLINE +llvm::cl::opt optimize_xla_hlo( + "optimize-xla-hlo", + llvm::cl::desc("Enable optimizations when translating XLA HLO -> LHLO"), + llvm::cl::init(true)); diff --git a/tensorflow/compiler/mlir/xla/xla_mlir_translate_cl.h b/tensorflow/compiler/mlir/xla/xla_mlir_translate_cl.h new file mode 100644 index 00000000000..14a2878dff8 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/xla_mlir_translate_cl.h @@ -0,0 +1,29 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_XLA_XLA_MLIR_TRANSLATE_CL_H_ +#define TENSORFLOW_COMPILER_MLIR_XLA_XLA_MLIR_TRANSLATE_CL_H_ + +#include "llvm/Support/CommandLine.h" + +// This file contains command-line options aimed to provide the parameters +// required by the MLIR module to XLA HLO conversion. It is only intended to be +// included by binaries. + +extern llvm::cl::opt emit_use_tuple_arg; +extern llvm::cl::opt emit_return_tuple; +extern llvm::cl::opt optimize_xla_hlo; + +#endif // TENSORFLOW_COMPILER_MLIR_XLA_XLA_MLIR_TRANSLATE_CL_H_ diff --git a/tensorflow/compiler/plugin/BUILD b/tensorflow/compiler/plugin/BUILD index c2ba5cb3ecd..dc1c2391e94 100644 --- a/tensorflow/compiler/plugin/BUILD +++ b/tensorflow/compiler/plugin/BUILD @@ -13,6 +13,8 @@ # limitations under the License. # ============================================================================== +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") + """Configuration file for an XLA plugin. please don't check in changes to this file. to prevent changes appearing diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 30b8a7e5561..eb0cde57591 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -1,3 +1,4 @@ +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow:tensorflow.bzl", "cuda_py_test") load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") # buildifier: disable=same-origin-load load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library") @@ -326,7 +327,6 @@ tf_xla_py_test( name = "self_adjoint_eig_op_test", size = "medium", srcs = ["self_adjoint_eig_op_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -392,7 +392,6 @@ tf_xla_py_test( size = "small", timeout = "moderate", srcs = ["matrix_inverse_op_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -415,7 +414,6 @@ tf_xla_py_test( size = "small", timeout = "moderate", srcs = ["matrix_solve_op_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -638,7 +636,6 @@ tf_xla_py_test( name = "extract_image_patches_op_test", size = "small", srcs = ["extract_image_patches_op_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -695,7 +692,6 @@ tf_xla_py_test( name = "fft_test", size = "medium", srcs = ["fft_test.py"], - enable_mlir_bridge = True, python_version = "PY3", shard_count = 6, tags = [ @@ -757,6 +753,7 @@ tf_xla_py_test( python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + "notsan", # TODO(b/171000704): data race ], deps = [ ":xla_test", @@ -1017,7 +1014,6 @@ tf_xla_py_test( "cpu", "cpu_ondemand", ], - enable_mlir_bridge = True, python_version = "PY3", shard_count = 5, tags = [ @@ -1280,7 +1276,9 @@ tf_xla_py_test( python_version = "PY3", shard_count = 10, tags = [ + "no_oss", # b/170479349 "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + "notap", # b/170479349 "optonly", ], deps = [ @@ -1350,7 +1348,6 @@ tf_xla_py_test( python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip - "notap", # b/162025277 ], deps = [ ":xla_test", @@ -1388,7 +1385,6 @@ tf_xla_py_test( name = "unary_ops_test", size = "medium", srcs = ["unary_ops_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1483,6 +1479,7 @@ tf_xla_py_test( "//tensorflow/compiler/tf2xla/python:xla", "//tensorflow/python:array_ops", "//tensorflow/python:framework", + "//tensorflow/python:image_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", ], diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index 07a41d67520..59c8c544347 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -474,7 +474,6 @@ class BinaryOpsTest(xla_test.XLATestCase): expected=np.array([1 << 32, 1 << 36, 1 << 32, 1 << 36], dtype=np.int64)) - @test_util.disable_mlir_bridge("Enable tf.NextAfter Compilation") def testNextAfter(self): for dtype in self.numeric_types: if dtype in [np.float32, np.float64]: diff --git a/tensorflow/compiler/tests/cholesky_op_test.py b/tensorflow/compiler/tests/cholesky_op_test.py index 4bd2dfd9244..41877d39381 100644 --- a/tensorflow/compiler/tests/cholesky_op_test.py +++ b/tensorflow/compiler/tests/cholesky_op_test.py @@ -28,7 +28,6 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import linalg_ops -from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -61,7 +60,7 @@ class CholeskyOpTest(xla_test.XLATestCase): dtypes.as_dtype(x.dtype), shape=x.shape) with self.test_scope(): chol = linalg_ops.cholesky(placeholder) - verification = math_ops.matmul(chol, chol, adjoint_b=True) + verification = test_util.matmul_without_tf32(chol, chol, adjoint_b=True) self._verifyCholeskyBase(sess, placeholder, x, chol, verification, atol) def testBasic(self): diff --git a/tensorflow/compiler/tests/data_format_ops_test.py b/tensorflow/compiler/tests/data_format_ops_test.py index 08d44256b50..ca833326a50 100644 --- a/tensorflow/compiler/tests/data_format_ops_test.py +++ b/tensorflow/compiler/tests/data_format_ops_test.py @@ -63,6 +63,22 @@ class XlaDataFormatDimMapTest(xla_test.XLATestCase): self._test([-4, -3, -2, -1, 0, 1, 2, 3], "qwer", "rewq", [3, 2, 1, 0, 3, 2, 1, 0]) + self._test(0, "NDHWC", "NCDHW", 0) + self._test(1, "NDHWC", "NCDHW", 2) + self._test(2, "NDHWC", "NCDHW", 3) + self._test(3, "NDHWC", "NCDHW", 4) + self._test(4, "NDHWC", "NCDHW", 1) + self._test([1, 4], "NDHWC", "NCDHW", [2, 1]) + self._test([1, 4, -2], "NDHWC", "NCDHW", [2, 1, 4]) + self._test([1, -3, -2], "NDHWC", "NCDHW", [2, 3, 4]) + self._test([[1, -4], [1, -1]], "NDHWC", "NCDHW", [[2, 2], [2, 1]]) + + self._test([1, -3, -2], "NDHWC", "NCDHW", [2, 3, 4]) + self._test([-5, -4, -3, -2, -1, 0, 1, 2, 3, 4], "NDHWC", "DHWNC", + [3, 0, 1, 2, 4, 3, 0, 1, 2, 4]) + self._test([-5, -4, -3, -2, -1, 0, 1, 2, 3, 4], "NDHWC", "WHDCN", + [4, 2, 1, 0, 3, 4, 2, 1, 0, 3]) + class XlaPermuteOpTest(xla_test.XLATestCase): diff --git a/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py b/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py index 9d278cfbb28..08aad66abe1 100644 --- a/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py +++ b/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py @@ -29,7 +29,6 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import linalg_ops -from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -65,7 +64,8 @@ class MatrixTriangularSolveOpTest(xla_test.XLATestCase): with self.test_scope(): x = linalg_ops.matrix_triangular_solve( placeholder_a, placeholder_b, lower=lower, adjoint=adjoint) - verification = math_ops.matmul(placeholder_ca, x, adjoint_a=adjoint) + verification = test_util.matmul_without_tf32( + placeholder_ca, x, adjoint_a=adjoint) self._VerifyTriangularSolveBase(sess, placeholder_a, placeholder_ca, placeholder_b, a, clean_a, b, verification, atol) diff --git a/tensorflow/compiler/tests/qr_op_test.py b/tensorflow/compiler/tests/qr_op_test.py index 5fcf254db82..de318e9dfde 100644 --- a/tensorflow/compiler/tests/qr_op_test.py +++ b/tensorflow/compiler/tests/qr_op_test.py @@ -24,12 +24,18 @@ from absl.testing import parameterized import numpy as np from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import test +@test_util.run_all_without_tensor_float_32( + "XLA QR op calls matmul. Also, matmul used for verification. Also with " + 'TensorFloat-32, mysterious "Unable to launch cuBLAS gemm" error ' + "occasionally occurs") +# TODO(b/165435566): Fix "Unable to launch cuBLAS gemm" error class QrOpTest(xla_test.XLATestCase, parameterized.TestCase): def AdjustedNorm(self, x): @@ -64,16 +70,26 @@ class QrOpTest(xla_test.XLATestCase, parameterized.TestCase): xx = math_ops.matmul(x, x, adjoint_a=True) identity = array_ops.matrix_band_part(array_ops.ones_like(xx), 0, 0) precision = self.AdjustedNorm(xx.eval() - self.evaluate(identity)) - self.assertTrue(np.all(precision < 5.0)) + self.assertTrue(np.all(precision < 6.0)) - def _test(self, dtype, shape, full_matrices): + def _random_matrix(self, dtype, shape): np.random.seed(1) - x_np = np.random.uniform( - low=-1.0, high=1.0, size=np.prod(shape)).reshape(shape).astype(dtype) + def rng(): + return np.random.uniform( + low=-1.0, high=1.0, size=np.prod(shape)).reshape(shape).astype(dtype) + + x_np = rng() + if np.issubdtype(dtype, np.complexfloating): + x_np += rng() * dtype(1j) + return x_np + + def _test(self, x_np, full_matrices, full_rank=True): + dtype = x_np.dtype + shape = x_np.shape with self.session() as sess: x_tf = array_ops.placeholder(dtype) - with self.test_scope(): + with self.device_scope(): q_tf, r_tf = linalg_ops.qr(x_tf, full_matrices=full_matrices) q_tf_val, r_tf_val = sess.run([q_tf, r_tf], feed_dict={x_tf: x_np}) @@ -91,24 +107,39 @@ class QrOpTest(xla_test.XLATestCase, parameterized.TestCase): np_q_reshape[i, :, :], _ = np.linalg.qr( x_reshape[i, :, :], mode="reduced") np_q = np.reshape(np_q_reshape, q_dims) - self.CompareOrthogonal(np_q, q_tf_val, min(shape[-2:])) + if full_rank: + # Q is unique up to sign/phase if the matrix is full-rank. + self.CompareOrthogonal(np_q, q_tf_val, min(shape[-2:])) self.CheckApproximation(x_np, q_tf_val, r_tf_val) self.CheckUnitary(q_tf_val) - SIZES = [1, 2, 5, 10, 32, 100, 300] - DTYPES = [np.float32] + SIZES = [1, 2, 5, 10, 32, 100, 300, 603] + DTYPES = [np.float32, np.complex64] PARAMS = itertools.product(SIZES, SIZES, DTYPES) @parameterized.parameters(*PARAMS) def testQR(self, rows, cols, dtype): - # TODO(b/111317468): Test other types. for full_matrices in [True, False]: # Only tests the (3, 2) case for small numbers of rows/columns. for batch_dims in [(), (3,)] + [(3, 2)] * (max(rows, cols) < 10): - self._test(dtype, batch_dims + (rows, cols), full_matrices) + x_np = self._random_matrix(dtype, batch_dims + (rows, cols)) + self._test(x_np, full_matrices) def testLarge2000x2000(self): - self._test(np.float32, (2000, 2000), full_matrices=True) + x_np = self._random_matrix(np.float32, (2000, 2000)) + self._test(x_np, full_matrices=True) + + @parameterized.parameters((23, 25), (513, 23)) + def testZeroColumn(self, rows, cols): + x_np = self._random_matrix(np.complex64, (rows, cols)) + x_np[:, 7] = 0. + self._test(x_np, full_matrices=True) + + @parameterized.parameters((4, 4), (514, 20)) + def testRepeatedColumn(self, rows, cols): + x_np = self._random_matrix(np.complex64, (rows, cols)) + x_np[:, 1] = x_np[:, 2] + self._test(x_np, full_matrices=True, full_rank=False) if __name__ == "__main__": diff --git a/tensorflow/compiler/tests/special_math_test.py b/tensorflow/compiler/tests/special_math_test.py index bd105bb5e95..5e7f8763743 100644 --- a/tensorflow/compiler/tests/special_math_test.py +++ b/tensorflow/compiler/tests/special_math_test.py @@ -31,6 +31,7 @@ import six from tensorflow.compiler.tests import xla_test from tensorflow.python.eager import def_function from tensorflow.python.framework import constant_op +from tensorflow.python.framework import test_util from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import gen_random_ops from tensorflow.python.ops import gradient_checker_v2 @@ -54,6 +55,16 @@ def _igammac(a, x): return math_ops.igammac(a, x) +@def_function.function(experimental_compile=True) +def _polygamma(n, x): + return math_ops.polygamma(n, x) + + +@def_function.function(experimental_compile=True) +def _zeta(a, q): + return math_ops.zeta(a, q) + + # This is df/da / df/dx, where f = igamma. def implicit_reparameterization_grad(a, x): log_prob = math_ops.xlogy(a - 1., x) - math_ops.lgamma(a) - x @@ -136,6 +147,208 @@ class Log1pTest(xla_test.XLATestCase, parameterized.TestCase): self._test_range(0., 3., dtype, rtol, atol, is_negative=False) +class ZetaTest(xla_test.XLATestCase, parameterized.TestCase): + + def setUp(self): + if flags.FLAGS.vary_seed: + entropy = os.urandom(64) + if six.PY2: + answer = int(entropy.encode('hex'), 16) + else: + answer = int.from_bytes(entropy, 'big') + np.random.seed(answer % (2**32 - 1)) + super(ZetaTest, self).setUp() + + def adjust_tolerance_for_tpu(self, dtype, rtol, atol): + if self.device not in ['TPU']: + return rtol, atol + + if dtype == np.float32: + return 2e-2, 1e-7 + return 2e-4, 1e-20 + + @test_util.disable_mlir_bridge('TODO(b/165736950): Add support in MLIR') + def testBadValues(self): + q = np.random.uniform(low=0.3, high=20., size=[10]) + with self.session() as sess: + with self.test_scope(): + y = _zeta(np.float64(1.), q) + actual = sess.run(y) + # When x == 1, this is the Harmonic series. + self.assertTrue(np.all(np.isinf(actual))) + + with self.session() as sess: + with self.test_scope(): + y = _zeta(np.float64(0.1), q) + actual = sess.run(y) + # When x < 1, this is undefined. + self.assertTrue(np.all(np.isnan(actual))) + + with self.session() as sess: + with self.test_scope(): + y = _zeta([1., 1.1], [-1.1, -1.]) + actual = sess.run(y) + + # When q is negative, zeta is not defined + # if q is an integer or x is not an integer. + self.assertTrue(np.all(np.isinf(actual))) + + @parameterized.parameters((np.float32, 1e-2, 1e-11), + (np.float64, 1e-4, 1e-30)) + @test_util.disable_mlir_bridge('TODO(b/165736950): Add support in MLIR') + def testLargeXSmallQ(self, dtype, rtol, atol): + rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol) + if self.device not in ['XLA_GPU', 'XLA_CPU'] and dtype == np.float64: + # TODO(b/165739664): Figure out why on TPU F64 Zeta sometimes returns + # infs. + self.skipTest( + 'Skipping test because some F64 operations are numerically ' + 'unstable on TPU.') + + x = np.random.uniform(low=100., high=200., size=[NUM_SAMPLES]).astype(dtype) + q = np.random.uniform(low=0.3, high=1., size=[NUM_SAMPLES]).astype(dtype) + + expected_values = sps.zeta(x, q) + with self.session() as sess: + with self.test_scope(): + y = _zeta(x, q) + actual = sess.run(y) + + self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol) + + @parameterized.parameters((np.float32, 1e-2, 1e-11), + (np.float64, 1e-4, 1e-30)) + @test_util.disable_mlir_bridge('TODO(b/165736950): Add support in MLIR') + def testSmallValues(self, dtype, rtol, atol): + rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol) + # Test values near zero. + x = np.random.uniform(low=1.1, high=10., size=[NUM_SAMPLES]).astype(dtype) + q = np.random.uniform( + low=np.finfo(dtype).tiny, high=1., size=[NUM_SAMPLES]).astype(dtype) + + expected_values = sps.zeta(x, q) + with self.session() as sess: + with self.test_scope(): + actual = sess.run(_zeta(x, q)) + self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol) + + @parameterized.parameters((np.float32, 1e-2, 1e-11), + (np.float64, 1e-4, 1e-30)) + @test_util.disable_mlir_bridge('TODO(b/165736950): Add support in MLIR') + def testMediumValues(self, dtype, rtol, atol): + rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol) + x = np.random.uniform(low=1.1, high=100., size=[NUM_SAMPLES]).astype(dtype) + q = np.random.uniform(low=1., high=1e1, size=[NUM_SAMPLES]).astype(dtype) + + expected_values = sps.zeta(x, q) + with self.session() as sess: + with self.test_scope(): + actual = sess.run(_zeta(x, q)) + self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol) + + @parameterized.parameters((np.float32, 2e-2, 1e-5), (np.float64, 1e-4, 1e-30)) + @test_util.disable_mlir_bridge('TODO(b/165736950): Add support in MLIR') + def testLargeValues(self, dtype, rtol, atol): + x = np.random.uniform( + low=100., high=int(1e3), size=[NUM_SAMPLES]).astype(dtype) + q = np.random.uniform( + low=1., high=int(1e1), size=[NUM_SAMPLES]).astype(dtype) + + expected_values = sps.zeta(x, q) + with self.session() as sess: + with self.test_scope(): + actual = sess.run(_zeta(x, q)) + self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol) + + +class PolygammaTest(xla_test.XLATestCase, parameterized.TestCase): + + def setUp(self): + if flags.FLAGS.vary_seed: + entropy = os.urandom(64) + if six.PY2: + answer = int(entropy.encode('hex'), 16) + else: + answer = int.from_bytes(entropy, 'big') + np.random.seed(answer % (2**32 - 1)) + super(PolygammaTest, self).setUp() + + def adjust_tolerance_for_tpu(self, dtype, rtol, atol): + if self.device not in ['TPU']: + return rtol, atol + + if dtype == np.float32: + return 2e-2, 1e-7 + return 2e-4, 1e-20 + + @test_util.disable_mlir_bridge('TODO(b/165736950): Add support in MLIR') + def testBadValues(self): + x = np.random.uniform(low=0.3, high=20., size=[10]) + with self.session() as sess: + with self.test_scope(): + y = _polygamma(np.float64(-1.), x) + actual = sess.run(y) + # Not defined for negative numbers. + self.assertTrue(np.all(np.isnan(actual))) + + with self.session() as sess: + with self.test_scope(): + y = _polygamma(np.float64(0.1), x) + actual = sess.run(y) + # Not defined for non-integers. + self.assertTrue(np.all(np.isnan(actual))) + + @parameterized.parameters((np.float32, 1e-2, 1e-11), + (np.float64, 1e-4, 1e-30)) + @test_util.disable_mlir_bridge('TODO(b/165736950): Add support in MLIR') + def testRecoverDigamma(self, dtype, rtol, atol): + rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol) + if self.device not in ['XLA_GPU', 'XLA_CPU'] and dtype == np.float64: + self.skipTest( + 'Skipping test because some F64 operations are ' + 'numerically unstable on TPU.' + ) + + x = np.random.uniform(low=0.1, high=50., size=[NUM_SAMPLES]).astype(dtype) + expected_values = sps.digamma(x) + with self.session() as sess: + with self.test_scope(): + y = _polygamma(dtype(0.), x) + actual = sess.run(y) + + self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol) + + @parameterized.parameters((np.float32, 1e-2, 1e-11), + (np.float64, 1e-4, 1e-30)) + @test_util.disable_mlir_bridge('TODO(b/165736950): Add support in MLIR') + def testSmallN(self, dtype, rtol, atol): + rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol) + # Test values near zero. + n = np.random.randint(low=1, high=5, size=[NUM_SAMPLES]).astype(dtype) + x = np.random.uniform( + low=np.finfo(dtype).tiny, high=1., size=[NUM_SAMPLES]).astype(dtype) + + expected_values = sps.polygamma(n, x) + with self.session() as sess: + with self.test_scope(): + actual = sess.run(_polygamma(n, x)) + self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol) + + @parameterized.parameters((np.float32, 1e-2, 1e-11), + (np.float64, 1e-4, 1e-30)) + @test_util.disable_mlir_bridge('TODO(b/165736950): Add support in MLIR') + def testMediumLargeN(self, dtype, rtol, atol): + rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol) + n = np.random.randint(low=5, high=10, size=[NUM_SAMPLES]).astype(dtype) + x = np.random.uniform(low=1., high=1e1, size=[NUM_SAMPLES]).astype(dtype) + + expected_values = sps.polygamma(n, x) + with self.session() as sess: + with self.test_scope(): + actual = sess.run(_polygamma(n, x)) + self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol) + + class IgammaTest(xla_test.XLATestCase, parameterized.TestCase): def setUp(self): diff --git a/tensorflow/compiler/tests/stateful_random_ops_test.py b/tensorflow/compiler/tests/stateful_random_ops_test.py index 343969c40d7..239b99de19e 100644 --- a/tensorflow/compiler/tests/stateful_random_ops_test.py +++ b/tensorflow/compiler/tests/stateful_random_ops_test.py @@ -25,7 +25,9 @@ import numpy as np from tensorflow.compiler.tests import xla_test from tensorflow.python.client import device_lib +from tensorflow.python.compat import compat from tensorflow.python.eager import def_function +from tensorflow.python.framework import config from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops @@ -156,6 +158,10 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase): def testNewStateThreeFry(self): """Tests that the new state is correct (for ThreeFry). """ + if compat.forward_compatible(2020, 10, 25): + self.skipTest("The expected values in this test is inconsistent with " + "CPU/GPU. testXLAEqualsCPU has the correct checks of the " + "new states for the new version.") with ops.device(xla_device_name()): counter = 57 key = 0x1234 @@ -171,6 +177,10 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase): def testNewStatePhilox(self): """Tests that the new state is correct (for Philox). """ + if compat.forward_compatible(2020, 10, 25): + self.skipTest("The expected values in this test is inconsistent with " + "CPU/GPU. testXLAEqualsCPU has the correct checks of the " + "new states for the new version.") with ops.device(xla_device_name()): counter_low = 57 counter_high = 283 @@ -204,13 +214,39 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase): """Tests that XLA and CPU kernels generate the same integers.""" seed = 1234 shape = [315, 49] - with ops.device("/device:CPU:0"): - cpu = (random.Generator.from_seed(seed=seed, alg=random.RNG_ALG_PHILOX) - .uniform_full_int(shape=shape, dtype=dtype)) - with ops.device(xla_device_name()): - xla = (random.Generator.from_seed(seed=seed, alg=random.RNG_ALG_PHILOX) - .uniform_full_int(shape=shape, dtype=dtype)) - self.assertAllEqual(cpu, xla) + if compat.forward_compatible(2020, 10, 25): + with ops.device("/device:CPU:0"): + cpu_gen = random.Generator.from_seed( + seed=seed, alg=random.RNG_ALG_PHILOX) + with ops.device(xla_device_name()): + xla_gen = random.Generator.from_seed( + seed=seed, alg=random.RNG_ALG_PHILOX) + # Repeat multiple times to make sure that the state after + # number-generation are the same between CPU and XLA. + for _ in range(5): + with ops.device("/device:CPU:0"): + # Test both number-generation and skip + cpu = cpu_gen.uniform_full_int(shape=shape, dtype=dtype) + cpu_gen.skip(100) + with ops.device(xla_device_name()): + xla = xla_gen.uniform_full_int(shape=shape, dtype=dtype) + xla_gen.skip(100) + self.assertAllEqual(cpu, xla) + self.assertAllEqual(cpu_gen.state, xla_gen.state) + else: + # The old version doesn't guarantee that CPU and XLA are in the same state + # after number-generation, which is a bug. + with ops.device("/device:CPU:0"): + cpu = ( + random.Generator.from_seed( + seed=seed, alg=random.RNG_ALG_PHILOX).uniform_full_int( + shape=shape, dtype=dtype)) + with ops.device(xla_device_name()): + xla = ( + random.Generator.from_seed( + seed=seed, alg=random.RNG_ALG_PHILOX).uniform_full_int( + shape=shape, dtype=dtype)) + self.assertAllEqual(cpu, xla) def _testRngIsNotConstant(self, rng, dtype): # Tests that 'rng' does not always return the same value. @@ -364,4 +400,5 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase): if __name__ == "__main__": ops.enable_eager_execution() + config.set_soft_device_placement(False) test.main() diff --git a/tensorflow/compiler/tests/stateless_random_ops_test.py b/tensorflow/compiler/tests/stateless_random_ops_test.py index f9d792806b0..23e827f18e8 100644 --- a/tensorflow/compiler/tests/stateless_random_ops_test.py +++ b/tensorflow/compiler/tests/stateless_random_ops_test.py @@ -21,7 +21,11 @@ from __future__ import print_function import numpy as np from tensorflow.compiler.tests import xla_test +from tensorflow.python.compiler.xla import xla +from tensorflow.python.eager import def_function +from tensorflow.python.framework import config from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util from tensorflow.python.kernel_tests.random import util as \ random_test_util from tensorflow.python.ops import array_ops @@ -39,6 +43,26 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): allowed_types.update({dtypes.int32, dtypes.int64}) return self.all_tf_types & allowed_types + @test_util.run_v2_only + def testForcedCompile(self): + """Tests whole-function forced-compilation. + + This test checks that stateless_random_* can be used in forced-compilation + scenarios (e.g. TPU). The new version of stateless_random_* requires the + intermediate tensor `alg` to be compile-time constant, so we need to check + that this requirement is met. We use xla.compile instead of tf.function's + experimental_compile because the latter doesn't throw an error even if the + compile-time-constant constraint is not met. + """ + if config.list_logical_devices('TPU'): + self.skipTest('To accommodate OSS, xla.compile support for TPU is not ' + 'linked in.') + @def_function.function + def f(x): + return xla.compile( + lambda x: stateless.stateless_random_normal([], seed=x), [x]) + f([1, 2]) + def testDeterminism(self): # Stateless values should be equal iff the seeds are equal (roughly) with self.session(), self.test_scope(): @@ -138,7 +162,7 @@ class StatelessRandomOpsBenchmark(test.Benchmark): def _benchmarkUniform(self, name, dtype, use_xla_jit): - def BuilderFn(): + def builder_fn(): shape = (10, 1000, 1000) seed_var = variables.Variable((312, 456), dtype=dtypes.int32, @@ -147,7 +171,7 @@ class StatelessRandomOpsBenchmark(test.Benchmark): shape, seed=seed_var, dtype=dtype) return '%s.shape%s' % (name, shape), [random_t] - xla_test.Benchmark(self, BuilderFn, use_xla_jit=use_xla_jit, device='cpu') + xla_test.Benchmark(self, builder_fn, use_xla_jit=use_xla_jit, device='cpu') def benchmarkUniformF32(self): self._benchmarkUniform( @@ -167,4 +191,5 @@ class StatelessRandomOpsBenchmark(test.Benchmark): if __name__ == '__main__': + config.set_soft_device_placement(False) test.main() diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index b5f82bcff12..f3f6fa8ae52 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -542,7 +542,7 @@ class UnaryOpsTest(xla_test.XLATestCase): for dtype in self.float_types: def quantize_and_dequantize_v2(x): - return array_ops.quantize_and_dequantize_v2( + return array_ops.quantize_and_dequantize( x, -127, 127, signed_input=True, num_bits=8) self._assertOpOutputMatchesExpected( @@ -551,7 +551,7 @@ class UnaryOpsTest(xla_test.XLATestCase): expected=np.array([-1., -0.5, 0., 0.296875], dtype=dtype)) def quantize_and_dequantize_v2_round_half_up(x): - return array_ops.quantize_and_dequantize_v2( + return array_ops.quantize_and_dequantize( x, -1, 1.0, @@ -575,7 +575,7 @@ class UnaryOpsTest(xla_test.XLATestCase): dtype=dtype)) def quantize_and_dequantize_v2_round_half_to_even(x): - return array_ops.quantize_and_dequantize_v2( + return array_ops.quantize_and_dequantize( x, -1.0, 1.0, diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py index 3e9f5e8c5dd..b80b6263992 100644 --- a/tensorflow/compiler/tests/xla_ops_test.py +++ b/tensorflow/compiler/tests/xla_ops_test.py @@ -18,12 +18,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools + from absl.testing import parameterized import numpy as np from tensorflow.compiler.tests import xla_test from tensorflow.compiler.tf2xla.python import xla from tensorflow.compiler.xla import xla_data_pb2 +from tensorflow.python.eager import def_function from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import function @@ -299,6 +302,78 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase): args=(np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]),), expected=np.array([0, 45, 120, 231], dtype=dtype)) + @test_util.disable_mlir_bridge('Not supported yet') + def testVariadicReduce(self): + for dtype in set(self.numeric_types).intersection( + set([np.float32, np.complex64])): + + @def_function.function + def kahan_sum_reducer(t0, t1): + (s0, c0), (s1, c1) = t0, t1 + s0minusc = s0 - (c0 + c1) + t = s1 + s0minusc + c = (t - s1) - s0minusc + s = t + return s, c + + def kahan_sum_reduction(dims, output_idx): + + def fn(x): + arg = array_ops.zeros([], dtype) # pylint: disable=cell-var-from-loop + reducer = kahan_sum_reducer.get_concrete_function( + (arg, arg), (arg, arg)) + + return xla.variadic_reduce( + (x, array_ops.zeros_like(x)), + init_value=(arg, arg), + dimensions_to_reduce=dims, + reducer=reducer)[output_idx] + + return fn + + xs = np.array([1e5, np.pi, -1e5, np.exp(1.)]) + xs = np.array([xs, xs[::-1] / 3, xs / 7], dtype) + self._assertOpOutputMatchesExpected( + kahan_sum_reduction(dims=[], output_idx=0), + args=(xs,), expected=xs) + self._assertOpOutputMatchesExpected( + kahan_sum_reduction(dims=[], output_idx=1), + args=(xs,), expected=np.zeros_like(xs)) + shuffle_indices = np.argsort(np.random.randn(xs.shape[0])) + self._assertOpOutputMatchesExpected( + kahan_sum_reduction(dims=[0], output_idx=0), + args=(xs[shuffle_indices],), + expected=np.array([np.exp(1) / 3 + 1e5 * 8 / 7, + np.pi * 8 / 7 - 1e5 / 3, + -1e5 * 8 / 7 + np.pi / 3, + np.exp(1) * 8 / 7 + 1e5 / 3], dtype=dtype)) + error_term_equality = functools.partial(self.assertAllClose, atol=.005) + self._assertOpOutputMatchesExpected( + kahan_sum_reduction(dims=[0], output_idx=1), + args=(xs[shuffle_indices],), expected=np.zeros_like(xs[0]), + equality_fn=error_term_equality) + shuffle_indices = np.argsort(np.random.randn(xs.shape[1])) + self._assertOpOutputMatchesExpected( + kahan_sum_reduction(dims=[1], output_idx=0), + args=(xs[:, shuffle_indices],), + expected=np.array([np.pi + np.exp(1.), + (np.pi + np.exp(1.)) / 3, + (np.pi + np.exp(1.)) / 7], dtype=dtype)) + self._assertOpOutputMatchesExpected( + kahan_sum_reduction(dims=[1], output_idx=1), + args=(xs[:, shuffle_indices],), expected=np.zeros_like(xs[:, 0]), + equality_fn=error_term_equality) + # Now, shuffle both dims. + xs = xs[np.argsort(np.random.randn(xs.shape[0]))] + xs = xs[:, np.argsort(np.random.randn(xs.shape[1]))] + self._assertOpOutputMatchesExpected( + kahan_sum_reduction(dims=[0, 1], output_idx=0), + args=(xs,), expected=dtype((np.pi + np.exp(1.)) * 31 / 21)) + self._assertOpOutputMatchesExpected( + kahan_sum_reduction(dims=[0, 1], output_idx=1), + args=(xs,), expected=dtype(0), + equality_fn=error_term_equality) + @test_util.disable_mlir_bridge('Not supported yet') def testSelectAndScatter(self): for dtype in set(self.numeric_types).intersection( diff --git a/tensorflow/compiler/tests/xla_test.py b/tensorflow/compiler/tests/xla_test.py index 8c31629c234..de97c6ff210 100644 --- a/tensorflow/compiler/tests/xla_test.py +++ b/tensorflow/compiler/tests/xla_test.py @@ -237,8 +237,8 @@ class XLATestCase(test.TestCase): 'test_session not supported on XLATestCase, please use session') @contextlib.contextmanager - def test_scope(self): - """Test scope that runs tests on `self.device`. + def device_scope(self): + """Scope that runs tests on `self.device`. Yields: A scope to apply to the operators under test. @@ -246,6 +246,15 @@ class XLATestCase(test.TestCase): with ops.device('device:{}:0'.format(self.device)): yield + def test_scope(self): + """Deprecated alias of `device_scope`. + + This should be avoided as the name starts with `test`, so test runners + treat it as a test. This interferes with class decorators that operate on + each test method. + """ + return self.device_scope() + def Benchmark(tf_bench, builder_fn, diff --git a/tensorflow/compiler/tf2tensorrt/BUILD b/tensorflow/compiler/tf2tensorrt/BUILD index 44fb5513886..a82c1c485b9 100644 --- a/tensorflow/compiler/tf2tensorrt/BUILD +++ b/tensorflow/compiler/tf2tensorrt/BUILD @@ -3,13 +3,13 @@ # and provide TensorRT operators and converter package. # APIs are meant to change over time. +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load( "//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_copts", "tf_cuda_library", "tf_custom_op_library_additional_deps", - "tf_gen_op_libs", "tf_gen_op_wrapper_py", ) @@ -21,6 +21,9 @@ load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") # buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "pybind_extension") + +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs") load( "//tensorflow/core/platform:build_config.bzl", "tf_additional_all_protos", @@ -33,8 +36,6 @@ package( licenses = ["notice"], # Apache 2.0 ) -exports_files(["LICENSE"]) - cc_library( name = "tensorrt_stub", srcs = if_tensorrt([ @@ -69,7 +70,7 @@ tf_cuda_cc_test( deps = [ "//tensorflow/core/common_runtime/gpu:gpu_init", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor", + "//tensorflow/core/platform:stream_executor", "//tensorflow/core:test", "//tensorflow/core:test_main", ] + if_tensorrt([ @@ -107,18 +108,20 @@ cc_library( ":common_utils", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", - "@local_config_cuda//cuda:cuda_headers", "//tensorflow/core:framework", "//tensorflow/core:gpu_headers_lib", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:lib_proto_parsing", - "//tensorflow/core:stream_executor", + "//tensorflow/core/platform:stream_executor", "//tensorflow/core:stream_executor_headers_lib", "//tensorflow/core/common_runtime:core_cpu_lib_no_ops", "//tensorflow/core/grappler/costs:graph_properties", "//tensorflow/stream_executor/lib", - ] + if_tensorrt([":tensorrt_lib"]) + tf_custom_op_library_additional_deps(), + ] + if_tensorrt([ + ":tensorrt_lib", + "@local_config_cuda//cuda:cuda_headers", + ]) + tf_custom_op_library_additional_deps(), alwayslink = 1, ) @@ -480,7 +483,7 @@ tf_cuda_cc_test( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", - "//tensorflow/core:tensor_testutil", + "//tensorflow/core/framework:tensor_testutil", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", @@ -492,15 +495,32 @@ tf_cuda_cc_test( # Library for the segmenting portion of TensorRT operation creation cc_library( - name = "segment", - srcs = ["segment/segment.cc"], + name = "union_find", + srcs = ["segment/union_find.cc"], hdrs = [ - "segment/segment.h", "segment/union_find.h", ], copts = tf_copts(), + deps = [ + ":utils", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:optional", + ], +) + +cc_library( + name = "segment", + srcs = ["segment/segment.cc"], + hdrs = [ + "segment/segment.h", + ], + copts = tf_copts(), deps = [ ":common_utils", + ":union_find", ":utils", "//tensorflow/core:graph", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc index c4fc3e4f5da..28c08cd2ddc 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc @@ -733,6 +733,8 @@ Status ConvertAfterShapes(const ConversionParams& params) { } segment_options.minimum_segment_size = params.minimum_segment_size; segment_options.use_implicit_batch = params.use_implicit_batch; + if (segment_options.use_implicit_batch) + segment_options.maximum_batch_size = params.max_batch_size; segment_options.allow_dynamic_non_batch_dim = AllowDynamicNonBatchDimension(params); @@ -753,13 +755,10 @@ Status ConvertAfterShapes(const ConversionParams& params) { // Get the EngineInfo for each segment. std::unordered_map node_map; TF_RETURN_IF_ERROR(BuildNodeMap(graph, &node_map)); - float total_num_nodes_in_segments = 0.; std::vector engine_segments; engine_segments.reserve(initial_segments.size()); std::vector reverse_topo_order; GetPostOrder(graph, &reverse_topo_order); - size_t total_engine_bytes_size = 0; - std::vector engine_bytes_size; segment::SegmentNodesVector converted_segments; converted_segments.reserve(initial_segments.size()); string engine_name_prefix = @@ -791,9 +790,6 @@ Status ConvertAfterShapes(const ConversionParams& params) { continue; } - engine_bytes_size.push_back(curr_engine.segment_graph_def.ByteSizeLong()); - total_engine_bytes_size += engine_bytes_size.back(); - total_num_nodes_in_segments += curr_segment.size(); engine_segments.push_back(std::move(curr_engine)); converted_segments.push_back(std::move(curr_segment)); @@ -832,13 +828,9 @@ Status ConvertAfterShapes(const ConversionParams& params) { engine_nodes.resize(engine_segments.size()); for (int i = 0; i < engine_segments.size(); ++i) { auto& engine = engine_segments.at(i); - // Partition the workspace size by the average of node ratio and segment - // graphdef size - engine.max_workspace_size_bytes = - params.max_workspace_size_bytes * - (engine_bytes_size.at(i) / total_engine_bytes_size + - converted_segments.at(i).size() / total_num_nodes_in_segments) / - 2.0; + // TODO(b/170762693): implement the heuristic to calculate + // max_workspace_size_bytes. + engine.max_workspace_size_bytes = params.max_workspace_size_bytes; VLOG(1) << "Assigned " << engine.max_workspace_size_bytes << " bytes to " << engine.engine_name; auto status = CreateTRTNode(params, engine_segments, i, diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_graph_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_graph_test.cc index 3b0553426c0..be3bb51dbed 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_graph_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph_test.cc @@ -151,7 +151,8 @@ TEST(ConvertGraphTest, GetDeviceAndAllocator) { class ConvertAfterShapesTest : public ::testing::Test { public: - Status RunConvertAfterShape(Scope s, GraphDef* output_graph_def) { + Status RunConvertAfterShape(Scope s, GraphDef* output_graph_def, + int maximum_batch_size = 1000) { // Create GraphProperties. grappler::GrapplerItem item; TF_EXPECT_OK(s.ToGraphDef(&item.graph)); @@ -162,6 +163,7 @@ class ConvertAfterShapesTest : public ::testing::Test { const std::vector output_names{"output"}; ConversionParams params; params.output_names = &output_names; + params.max_batch_size = maximum_batch_size; params.max_workspace_size_bytes = 8 << 20; params.output_graph_def = output_graph_def; params.minimum_segment_size = 1; diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc index c0c3f25177e..d09485c35c7 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc @@ -26,6 +26,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" @@ -429,11 +430,52 @@ Status GetTrtBroadcastShape(const TRT_TensorOrWeights& operand_l, return Status::OK(); } +std::string GetLayerNameSuffix(absl::string_view sub_op_name, + absl::optional sub_op_instance) { + std::string op_suffix(sub_op_name); + if (sub_op_instance.has_value()) { + op_suffix = + absl::StrCat(op_suffix, "_", std::to_string(sub_op_instance.value())); + } + return op_suffix; +} + +// Sets the name of an ILayer using the name of the node_def. If the operation +// represented by the ILayer is generated by the converter to support the +// conversion of node_def, callers need to specify a non-empty sub_op_name +// to be appended to the name of node_def to avoid layer name conflicts. If the +// operation is generated multiple times, callers also need to specify +// sub_op_instance to be appended to the name of the layers to avoid layer name +// conflicts. +void SetLayerName(nvinfer1::ILayer* layer, const NodeDef& node_def, + absl::string_view sub_op_name = "", + absl::optional sub_op_instance = absl::nullopt) { + std::string sub_op_suffix = GetLayerNameSuffix(sub_op_name, sub_op_instance); + if (sub_op_suffix.empty()) { + layer->setName(node_def.name().c_str()); + } else { + layer->setName(absl::StrCat(node_def.name(), "-", sub_op_suffix).c_str()); + } +} + +// Sets the name of an ILayer using the format of +// "main_op_name"_"sub_op_name"_"sub_op_instance". +void SetLayerName(nvinfer1::ILayer* layer, absl::string_view main_op_name, + absl::string_view sub_op_name, + absl::optional sub_op_instance = absl::nullopt) { + std::string layer_name_suffix = + GetLayerNameSuffix(sub_op_name, sub_op_instance); + layer->setName(absl::StrCat(main_op_name, "-", layer_name_suffix).c_str()); +} + nvinfer1::ITensor* Converter::CreateConstantLayer( const TRT_ShapedWeights& weights, const nvinfer1::Dims& dims) { nvinfer1::Weights trt_weights = weights.GetTrtWeights(); nvinfer1::IConstantLayer* layer = network()->addConstant(dims, trt_weights); if (!layer) return nullptr; + SetLayerName(layer, "_tftrt_constant_", + std::to_string(next_constant_layer_id_)); + next_constant_layer_id_++; nvinfer1::ITensor* trt_tensor = layer->getOutput(0); #if !IS_TRT_VERSION_GE(5, 1, 3, 0) // TODO(laigd): there is a bug in TensorRT 5.0 library that, if we don't set @@ -1313,6 +1355,7 @@ Status Converter::AddInputTensor(const string& name, nvinfer1::DataType dtype, Status Converter::RenameAndMarkOutputTensors( const std::vector& output_tensors) { + int output_index = 0; for (const auto& output : output_tensors) { TRT_TensorOrWeights tensor_or_weights; TF_RETURN_IF_ERROR( @@ -1341,6 +1384,7 @@ Status Converter::RenameAndMarkOutputTensors( nvinfer1::IShuffleLayer* layer = network()->addShuffle(*tensor); TFTRT_RETURN_ERROR_IF_NULLPTR( layer, StrCat("Output Copy for ", tensor->getName())); + SetLayerName(layer, tensor->getName(), "shuffle", output_index); MarkQuantizationRangesAsInferrable(tensor, layer->getOutput(0)); tensor = layer->getOutput(0); } @@ -1349,6 +1393,7 @@ Status Converter::RenameAndMarkOutputTensors( // Set type after marking as output. TRT only supports setType for engine // outputs and inputs (type is inferred otherwise). tensor->setType(output.trt_dtype); + output_index++; VLOG(1) << "Marking output TRT tensor " << output.source_tensor_name << " with data type " << DebugString(output.trt_dtype) << ", which feeds TF node " << output.dest_node_name; @@ -1475,8 +1520,9 @@ Status Converter::GetTensorOrWeights(const string& name, Status Converter::TransposeTensor(nvinfer1::ITensor* input_tensor, const std::vector& order_with_batch_dim, - absl::string_view name, - nvinfer1::ITensor** output_tensor) { + nvinfer1::ITensor** output_tensor, + const NodeDef& node_def, + absl::string_view sub_op_name) { const auto dims = input_tensor->getDimensions(); const int order_size = use_implicit_batch_ ? order_with_batch_dim.size() - 1 : order_with_batch_dim.size(); @@ -1491,7 +1537,8 @@ Status Converter::TransposeTensor(nvinfer1::ITensor* input_tensor, nvinfer1::IShuffleLayer* layer = this->network()->addShuffle(*input_tensor); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, "TF-TRT Internal Transpose"); - layer->setName(std::basic_string(name).c_str()); + SetLayerName(layer, node_def, sub_op_name); + MarkQuantizationRangesAsInferrable(input_tensor, layer->getOutput(0)); nvinfer1::Permutation permutation; @@ -1555,7 +1602,9 @@ Status Converter::GetWeightRange(const TRT_ShapedWeights& weights, Status Converter::PrepareTensorForShape(const TRT_TensorOrWeights& input, const nvinfer1::Dims& dims, const bool validation_only, - nvinfer1::ITensor** tensor) { + nvinfer1::ITensor** tensor, + const NodeDef& node_def, + absl::optional op_instance) { const nvinfer1::Dims input_dims = input.GetTrtDims(); // If one of input_dims and dims doesn't have static shape, it means some of // the dims are unknown or need to be inferred. And we don't do further checks @@ -1586,6 +1635,7 @@ Status Converter::PrepareTensorForShape(const TRT_TensorOrWeights& input, nvinfer1::IShuffleLayer* layer = this->network()->addShuffle(*input.tensor()); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, "TF-TRT Internal Reshape"); + SetLayerName(layer, node_def, "shuffle", op_instance); layer->setReshapeDimensions(dims); MarkQuantizationRangesAsInferrable(input.tensor(), layer->getOutput(0)); *tensor = layer->getOutput(0); @@ -2086,6 +2136,7 @@ Status Conv2DPaddingHelper(OpConverterParams* params, const TFAttrs& attrs, *tensor, nvinfer1::DimsHW((*padding)[0].first, (*padding)[1].first), nvinfer1::DimsHW((*padding)[0].second, (*padding)[1].second)); TFTRT_RETURN_ERROR_IF_NULLPTR(pad_layer, params->node_def.name()); + SetLayerName(pad_layer, params->node_def, "pad"); params->converter->MarkQuantizationRangesAsInferrable( tensor, pad_layer->getOutput(0)); *padding = {{0, 0}, {0, 0}}; @@ -2186,7 +2237,7 @@ Status ConvertConv2DHelper(OpConverterParams* params, int group, const bool need_transpose = (data_format == "NHWC"); if (need_transpose) { TF_RETURN_IF_ERROR(params->converter->TransposeTensor( - tensor, {0, 3, 1, 2}, StrCat(node_def.name(), "_to_NCHW"), &tensor)); + tensor, {0, 3, 1, 2}, &tensor, node_def, "to_NCHW")); } // Dimensions of transposed tensor. const auto tensor_dim = tensor->getDimensions(); @@ -2252,7 +2303,6 @@ Status ConvertConv2DHelper(OpConverterParams* params, int group, #else layer->setPadding(nvinfer1::DimsHW{padding[0].first, padding[1].first}); #endif - layer->setName(node_def.name().c_str()); layer->setNbGroups(num_groups); conv_layer = layer; } else { @@ -2269,11 +2319,11 @@ Status ConvertConv2DHelper(OpConverterParams* params, int group, #else layer->setPadding(nvinfer1::DimsHW{padding[0].first, padding[1].first}); #endif - layer->setName(node_def.name().c_str()); layer->setNbGroups(num_groups); layer->setDilation(dilation); conv_layer = layer; } + SetLayerName(conv_layer, node_def, "conv"); nvinfer1::ITensor* output_tensor = conv_layer->getOutput(0); // Add an extra padding for Deconv because TRT doesn't accept the // argument output_shape and thus the TRT output shape could be wrong @@ -2306,13 +2356,13 @@ Status ConvertConv2DHelper(OpConverterParams* params, int group, params->converter->network()->addPadding(*output_tensor, pre_padding, post_padding); output_tensor = padding_layer->getOutput(0); + SetLayerName(padding_layer, node_def, "pad"); } } // Restore transpose. if (need_transpose) { TF_RETURN_IF_ERROR(params->converter->TransposeTensor( - output_tensor, {0, 2, 3, 1}, StrCat(node_def.name(), "_to_NHWC"), - &output_tensor)); + output_tensor, {0, 2, 3, 1}, &output_tensor, node_def, "to_NHWC")); } params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); return Status::OK(); @@ -2370,7 +2420,7 @@ Status ConvertTranspose(OpConverterParams* params) { // Start conversion. nvinfer1::ITensor* output_tensor = nullptr; TF_RETURN_IF_ERROR(params->converter->TransposeTensor( - input_tensor, perm, params->node_def.name(), &output_tensor)); + input_tensor, perm, &output_tensor, params->node_def)); params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); return Status::OK(); } @@ -2401,6 +2451,7 @@ Status ConvertShape(OpConverterParams* params) { nvinfer1::IShapeLayer* shape_layer = params->converter->network()->addShape(*inputs.at(0).tensor()); TFTRT_RETURN_ERROR_IF_NULLPTR(shape_layer, params->node_def.name()); + SetLayerName(shape_layer, params->node_def, "shape"); params->outputs->push_back(TRT_TensorOrWeights(shape_layer->getOutput(0))); return Status::OK(); #else @@ -2471,7 +2522,7 @@ Status ConvertReshape(OpConverterParams* params) { nvinfer1::ITensor* output_tensor = nullptr; TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( input_tensor, output_nonbatch_dims, params->validation_only, - &output_tensor)); + &output_tensor, params->node_def)); if (params->validation_only) return Status::OK(); // Record the conversion result. @@ -2514,7 +2565,8 @@ Status ConvertExpandDims(OpConverterParams* params) { nvinfer1::Dims new_dims; TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(input_dims, &new_dims)); TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( - input_tensor, new_dims, /*validation_only=*/false, &output_tensor)); + input_tensor, new_dims, /*validation_only=*/false, &output_tensor, + params->node_def)); } params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); return Status::OK(); @@ -2524,7 +2576,8 @@ Status Converter::DynamicReshape(nvinfer1::ITensor* input, std::vector> slices, OpConverterParams* params, nvinfer1::ITensor** output, - std::vector size_for_added_dims) { + std::vector size_for_added_dims, + absl::optional op_instance) { *output = nullptr; // DynamicReshape relies on INetworkDefinition::addShape that was introduced // in TensorRT 6. @@ -2536,9 +2589,11 @@ Status Converter::DynamicReshape(nvinfer1::ITensor* input, nvinfer1::ITensor* shape = network()->addShape(*input)->getOutput(0); // Build new shape = shape[:trt_axis] + [1] + shape[trt_axis:] std::vector concat_inputs; - for (int i = 0; i < std::max(slices.size(), size_for_added_dims.size()); - i++) { + int max_num_slices = std::max(slices.size(), size_for_added_dims.size()); + int op_instance_value = op_instance.has_value() ? op_instance.value() : 0; + for (int i = 0; i < max_num_slices; i++) { nvinfer1::ITensor* tensor; + int slice_instance = i * max_num_slices + op_instance_value; // maybe_add_a_dimension(i); if (i < size_for_added_dims.size() && size_for_added_dims[i] >= 0) { TF_RETURN_IF_ERROR( @@ -2546,11 +2601,11 @@ Status Converter::DynamicReshape(nvinfer1::ITensor* input, concat_inputs.push_back(tensor); } if (i < slices.size()) { - concat_inputs.push_back( - network() - ->addSlice(*shape, {1, {slices[i].first}}, - {1, {slices[i].second - slices[i].first}}, {1, {1}}) - ->getOutput(0)); + nvinfer1::ISliceLayer* slice_layer = network()->addSlice( + *shape, {1, {slices[i].first}}, + {1, {slices[i].second - slices[i].first}}, {1, {1}}); + concat_inputs.push_back(slice_layer->getOutput(0)); + SetLayerName(slice_layer, params->node_def, "slice", slice_instance); } } nvinfer1::IConcatenationLayer* concat_layer = network()->addConcatenation( @@ -2560,6 +2615,7 @@ Status Converter::DynamicReshape(nvinfer1::ITensor* input, nvinfer1::ITensor* new_shape = concat_layer->getOutput(0); // Reshape input using new shape nvinfer1::IShuffleLayer* shuffle = network()->addShuffle(*input); + SetLayerName(shuffle, params->node_def, "shuffle", op_instance); shuffle->setInput(1, *new_shape); *output = shuffle->getOutput(0); return Status::OK(); @@ -2572,7 +2628,8 @@ Status Converter::DynamicReshape(nvinfer1::ITensor* input, Status Converter::DynamicExpandDims(nvinfer1::ITensor* input, const nvinfer1::Dims& dims, int axis, OpConverterParams* params, - nvinfer1::ITensor** output) { + nvinfer1::ITensor** output, + absl::optional op_instance) { if (params->validation_only) { *output = nullptr; return errors::Internal( @@ -2588,7 +2645,7 @@ Status Converter::DynamicExpandDims(nvinfer1::ITensor* input, if (axis != dims.nbDims) { slices.push_back(std::pair{axis, dims.nbDims}); } - return DynamicReshape(input, slices, params, output, extra_dims); + return DynamicReshape(input, slices, params, output, extra_dims, op_instance); } Status Converter::SqueezeTensor(nvinfer1::ITensor* input, @@ -2616,7 +2673,8 @@ Status Converter::SqueezeTensor(nvinfer1::ITensor* input, VLOG(2) << "input_dims" << input_dims; TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(*input_dims, &new_dims)); TF_RETURN_IF_ERROR(PrepareTensorForShape(TRT_TensorOrWeights(input), new_dims, - /*validation_only=*/false, output)); + /*validation_only=*/false, output, + params->node_def)); return Status::OK(); } @@ -2680,11 +2738,11 @@ Status ConvertSqueeze(OpConverterParams* params) { } template -Status ConvertStridedSliceHelper(OpConverterParams* params, - const TRT_TensorOrWeights& input, - Container begin, Container size, - const Container& stride, - const nvinfer1::Dims* final_shape = nullptr) { +Status ConvertStridedSliceHelper( + OpConverterParams* params, const TRT_TensorOrWeights& input, + Container begin, Container size, const Container& stride, + const nvinfer1::Dims* final_shape = nullptr, + absl::optional op_instance = absl::nullopt) { const auto& node_def = params->node_def; // Get input dims. nvinfer1::Dims dims = input.GetTrtDims(); @@ -2709,6 +2767,7 @@ Status ConvertStridedSliceHelper(OpConverterParams* params, node_def.op(), ", at ", node_def.name()); } } + // TRT 5.1 adds ISliceLayer. For older versions, we attempt to use the // padding layer with negative padding. #if IS_TRT_VERSION_GE(5, 1, 3, 1) @@ -2723,12 +2782,13 @@ Status ConvertStridedSliceHelper(OpConverterParams* params, nvinfer1::ISliceLayer* layer = params->converter->network()->addSlice( *input.tensor(), begin_dims, size_dims, stride_dims); + SetLayerName(layer, params->node_def, "slice", op_instance); nvinfer1::ITensor* tensor = layer->getOutput(0); // Reshape for shrink_axis. if (final_shape) { TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( TRT_TensorOrWeights(tensor), *final_shape, /*validation_only=*/false, - &tensor)); + &tensor, node_def, op_instance)); } params->outputs->push_back(TRT_TensorOrWeights(tensor)); return Status::OK(); @@ -2782,6 +2842,7 @@ Status ConvertStridedSliceHelper(OpConverterParams* params, if (params->validation_only) return Status::OK(); nvinfer1::IShuffleLayer* layer = params->converter->network()->addShuffle(*input.tensor()); + SetLayerName(layer, params->node_def, "shuffle", op_instance); params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0))); return Status::OK(); } else if (pad_dims.size() == 1) { @@ -2830,30 +2891,32 @@ Status ConvertStridedSliceHelper(OpConverterParams* params, nvinfer1::ITensor* tensor = input.tensor(); if (need_reshape) { TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( - input, reshape_dims, /*validation_only=*/false, &tensor)); + input, reshape_dims, /*validation_only=*/false, &tensor, node_def, + op_instance)); } if (need_transpose) { TF_RETURN_IF_ERROR(params->converter->TransposeTensor( - tensor, transpose_order, StrCat(node_def.name(), "_for_pad"), &tensor)); + tensor, transpose_order, &tensor, node_def, "for_pad", op_instance)); } // Add padding layer nvinfer1::IPaddingLayer* layer = params->converter->network()->addPadding( *tensor, pre_padding, post_padding); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + SetLayerName(layer, params->node_def, "pad"); params->converter->MarkQuantizationRangesAsInferrable(tensor, layer->getOutput(0)); tensor = layer->getOutput(0); // Restore transpose if (need_transpose) { - TF_RETURN_IF_ERROR(params->converter->TransposeTensor( - tensor, inv_transpose_order, StrCat(node_def.name(), "_after_pad"), - &tensor)); + TF_RETURN_IF_ERROR( + params->converter->TransposeTensor(tensor, inv_transpose_order, &tensor, + node_def, "after_pad", op_instance)); } // Reshape for shrink_axis. if (final_shape) { TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( TRT_TensorOrWeights(tensor), *final_shape, /*validation_only=*/false, - &tensor)); + &tensor, node_def, op_instance)); } else if (need_reshape) { // Restore reshape. // Calculate output dimensions @@ -2876,7 +2939,7 @@ Status ConvertStridedSliceHelper(OpConverterParams* params, /*ignore_first_dim=*/true)); TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( TRT_TensorOrWeights(tensor), new_dims, /*validation_only=*/false, - &tensor)); + &tensor, node_def, op_instance)); } params->outputs->push_back(TRT_TensorOrWeights(tensor)); @@ -3166,8 +3229,7 @@ Status ConvertConv3DHelper(OpConverterParams* params, int group, const bool need_transpose = is_ndhwc; if (need_transpose) { TF_RETURN_IF_ERROR(params->converter->TransposeTensor( - tensor, {0, 4, 1, 2, 3}, StrCat(node_def.name(), "_to_NCDHW"), - &tensor)); + tensor, {0, 4, 1, 2, 3}, &tensor, node_def, "to_NCDHW")); } // group == 0 signifies that this is a depthwise convolution, so set @@ -3206,7 +3268,6 @@ Status ConvertConv3DHelper(OpConverterParams* params, int group, layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER); } - layer->setName(node_def.name().c_str()); layer->setNbGroups(num_groups); conv_layer = layer; } else { @@ -3222,18 +3283,17 @@ Status ConvertConv3DHelper(OpConverterParams* params, int group, layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER); } - layer->setName(node_def.name().c_str()); layer->setNbGroups(num_groups); layer->setDilationNd(dilation_dhw); conv_layer = layer; } + SetLayerName(conv_layer, node_def, "conv"); nvinfer1::ITensor* output_tensor = conv_layer->getOutput(0); // Restore transpose. if (need_transpose) { TF_RETURN_IF_ERROR(params->converter->TransposeTensor( - output_tensor, {0, 2, 3, 4, 1}, StrCat(node_def.name(), "_to_NDHWC"), - &output_tensor)); + output_tensor, {0, 2, 3, 4, 1}, &output_tensor, node_def, "to_NDHWC")); } params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); return Status::OK(); @@ -3302,8 +3362,7 @@ Status ConvertPool3D(OpConverterParams* params) { if (data_format == "NDHWC") { // NDHWC => NCDHW TF_RETURN_IF_ERROR(params->converter->TransposeTensor( - tensor, {0, 4, 1, 2, 3}, StrCat(node_def.name(), "_to_NCDHW"), - &tensor)); + tensor, {0, 4, 1, 2, 3}, &tensor, node_def, "to_NCDHW")); } const nvinfer1::Dims3 stride(tf_stride[d_index], tf_stride[h_index], @@ -3324,14 +3383,13 @@ Status ConvertPool3D(OpConverterParams* params) { // SAME_UPPER means that post padding is preferred. layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER); } - layer->setName(node_def.name().c_str()); + SetLayerName(layer, node_def, "pooling"); nvinfer1::ITensor* output_tensor = layer->getOutput(0); if (data_format == "NDHWC") { // NCDHW => NDHWC TF_RETURN_IF_ERROR(params->converter->TransposeTensor( - output_tensor, {0, 2, 3, 4, 1}, StrCat(node_def.name(), "_to_NDHWC"), - &output_tensor)); + output_tensor, {0, 2, 3, 4, 1}, &output_tensor, node_def, "to_NDHWC")); } params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); @@ -3426,7 +3484,7 @@ Status ConvertFusedConv2DBiasActivation(OpConverterParams* params) { const bool need_transpose = (data_format == "NHWC"); if (need_transpose) { TF_RETURN_IF_ERROR(params->converter->TransposeTensor( - tensor, {0, 3, 1, 2}, StrCat(node_def.name(), "_to_NCHW"), &tensor)); + tensor, {0, 3, 1, 2}, &tensor, node_def, "to_NCHW")); } nvinfer1::DimsHW kernel_size; @@ -3482,7 +3540,7 @@ Status ConvertFusedConv2DBiasActivation(OpConverterParams* params) { #else conv_layer->setPadding(nvinfer1::DimsHW{padding[0].first, padding[1].first}); #endif - conv_layer->setName(node_def.name().c_str()); + SetLayerName(conv_layer, node_def, "conv"); conv_layer->setNbGroups(1); conv_layer->setDilation(dilation); nvinfer1::ITensor* output_tensor = conv_layer->getOutput(0); @@ -3493,13 +3551,13 @@ Status ConvertFusedConv2DBiasActivation(OpConverterParams* params) { params->converter->network()->addActivation(*output_tensor, op_pair->second); TFTRT_RETURN_ERROR_IF_NULLPTR(activation_layer, node_def.name()); + SetLayerName(activation_layer, node_def, "activation"); output_tensor = activation_layer->getOutput(0); } // Restore transpose. if (need_transpose) { TF_RETURN_IF_ERROR(params->converter->TransposeTensor( - output_tensor, {0, 2, 3, 1}, StrCat(node_def.name(), "_to_NHWC"), - &output_tensor)); + output_tensor, {0, 2, 3, 1}, &output_tensor, node_def, "to_NHWC")); } params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); return Status::OK(); @@ -3541,7 +3599,7 @@ Status ConvertPool(OpConverterParams* params) { h_index = 1; w_index = 2; TF_RETURN_IF_ERROR(params->converter->TransposeTensor( - tensor, {0, 3, 1, 2}, StrCat(node_def.name(), "_to_NCHW"), &tensor)); + tensor, {0, 3, 1, 2}, &tensor, node_def, "to_NCHW")); } const auto tf_stride = attrs.get>("strides"); @@ -3575,6 +3633,7 @@ Status ConvertPool(OpConverterParams* params) { *tensor, nvinfer1::DimsHW(padding[0].first, padding[1].first), nvinfer1::DimsHW(padding[0].second, padding[1].second)); TFTRT_RETURN_ERROR_IF_NULLPTR(pad_layer, node_def.name()); + SetLayerName(pad_layer, node_def, "pad"); params->converter->MarkQuantizationRangesAsInferrable( tensor, pad_layer->getOutput(0)); padding = {{0, 0}, {0, 0}}; @@ -3604,13 +3663,12 @@ Status ConvertPool(OpConverterParams* params) { #else layer->setPadding(nvinfer1::DimsHW{padding[0].first, padding[1].first}); #endif - layer->setName(node_def.name().c_str()); + SetLayerName(layer, node_def, "pooling"); nvinfer1::ITensor* output_tensor = layer->getOutput(0); if (data_format == "NHWC") { TF_RETURN_IF_ERROR(params->converter->TransposeTensor( - output_tensor, {0, 2, 3, 1}, StrCat(node_def.name(), "_to_NHWC"), - &output_tensor)); + output_tensor, {0, 2, 3, 1}, &output_tensor, node_def, "to_NHWC")); } params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); return Status::OK(); @@ -3633,6 +3691,7 @@ Status ConvertLeakyRelu(OpConverterParams* params) { params->converter->network()->addActivation( *inputs.at(0).tensor(), nvinfer1::ActivationType::kLEAKY_RELU); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + SetLayerName(layer, node_def, "activation"); layer->setAlpha(alpha); params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0))); return Status::OK(); @@ -3655,12 +3714,14 @@ Status ConvertLeakyRelu(OpConverterParams* params) { params->converter->network()->addElementWise( *tensor, *const_alpha_tensor, nvinfer1::ElementWiseOperation::kPROD); TFTRT_RETURN_ERROR_IF_NULLPTR(mul_layer, node_def.name()); + SetLayerName(mul_layer, node_def, "mul"); // max(x, alpha * x) nvinfer1::IElementWiseLayer* max_layer = params->converter->network()->addElementWise( *tensor, *mul_layer->getOutput(0), nvinfer1::ElementWiseOperation::kMAX); TFTRT_RETURN_ERROR_IF_NULLPTR(max_layer, node_def.name()); + SetLayerName(mul_layer, node_def, "max"); nvinfer1::ITensor* output_tensor = max_layer->getOutput(0); params->converter->MarkQuantizationRangesAsInferrable( output_tensor, mul_layer->getOutput(0)); @@ -3705,6 +3766,7 @@ Status ConvertClipByValue(OpConverterParams* params) { layer->setAlpha(clip_value_min); layer->setBeta(clip_value_max); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + SetLayerName(layer, node_def, "activation"); nvinfer1::ITensor* output_tensor = layer->getOutput(0); params->converter->ProvideQuantizationRange(output_tensor, clip_value_min, clip_value_max); @@ -3748,7 +3810,7 @@ Status ConvertActivation(OpConverterParams* params) { params->converter->network()->addActivation(*inputs.at(0).tensor(), op_pair->second); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); - layer->setName(node_def.name().c_str()); + SetLayerName(layer, node_def, "activation"); // Set parameters. #if IS_TRT_VERSION_GE(5, 1, 2, 0) if (node_def.op() == "Elu") { @@ -3852,7 +3914,7 @@ Status ConvertRelu6(OpConverterParams* params) { TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); layer->setAlpha(0.0f); layer->setBeta(6.0f); - layer->setName(node_def.name().c_str()); + SetLayerName(layer, node_def, "activation"); nvinfer1::ITensor* output_tensor = layer->getOutput(0); params->converter->ProvideQuantizationRange(output_tensor, 0.0f, 6.0f); params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); @@ -3867,6 +3929,7 @@ Status ConvertRelu6(OpConverterParams* params) { params->converter->network()->addActivation( *tensor, nvinfer1::ActivationType::kRELU); TFTRT_RETURN_ERROR_IF_NULLPTR(relu_layer, node_def.name()); + SetLayerName(relu_layer, node_def, "activation"); // Large range of relu is problematic during quantization in INT8 precision // mode. Setting dynamic range of relu = [0.f, 6.0f] helps with quantization. @@ -3888,6 +3951,7 @@ Status ConvertRelu6(OpConverterParams* params) { *relu_layer->getOutput(0), *const6_tensor, nvinfer1::ElementWiseOperation::kMIN); TFTRT_RETURN_ERROR_IF_NULLPTR(relu6_layer, node_def.name()); + SetLayerName(relu6_layer, node_def, "min"); nvinfer1::ITensor* output_tensor = relu6_layer->getOutput(0); params->converter->ProvideQuantizationRange(output_tensor, 0.0f, 6.0f); @@ -3932,6 +3996,7 @@ Status ConvertBiasAddInt8WithoutCalibration(OpConverterParams* params) { nvinfer1::IShuffleLayer* shuffle_layer = params->converter->network()->addShuffle(*tensor); TFTRT_RETURN_ERROR_IF_NULLPTR(shuffle_layer, node_def.name()); + SetLayerName(shuffle_layer, node_def, "shuffle", /*op_instance=*/0); params->converter->MarkQuantizationRangesAsInferrable( tensor, shuffle_layer->getOutput(0)); @@ -3963,6 +4028,7 @@ Status ConvertBiasAddInt8WithoutCalibration(OpConverterParams* params) { *tensor, mode, weights.GetTrtWeights(), empty_weights.GetTrtWeights(), empty_weights.GetTrtWeights()); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + SetLayerName(layer, node_def, "scale"); nvinfer1::ITensor* output_tensor = layer->getOutput(0); @@ -3971,6 +4037,7 @@ Status ConvertBiasAddInt8WithoutCalibration(OpConverterParams* params) { nvinfer1::IShuffleLayer* shuffle_layer = params->converter->network()->addShuffle(*output_tensor); TFTRT_RETURN_ERROR_IF_NULLPTR(shuffle_layer, node_def.name()); + SetLayerName(shuffle_layer, node_def, "shuffle", /*op_instance=*/1); // NOTE: for same reason as mentioned above we need to apply the reshape // unconditionally. nvinfer1::Dims reshape_dims = original_dims; @@ -4055,13 +4122,16 @@ Status ConvertBiasAdd(OpConverterParams* params) { // Convert input to a TRT tensor nvinfer1::ITensor* input_tensor{nullptr}; TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( - inputs.at(0), input_shape, params->validation_only, &input_tensor)); + inputs.at(0), input_shape, params->validation_only, &input_tensor, + node_def, + /*op_instance=*/0)); // Finally, reshape bias. Since the bias is usually a constant, this will // normally happen at conversion-time. nvinfer1::ITensor* bias_tensor{nullptr}; TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( - inputs.at(1), bias_shape, params->validation_only, &bias_tensor)); + inputs.at(1), bias_shape, params->validation_only, &bias_tensor, node_def, + /*op_instance=*/1)); VLOG(2) << "Bias shape adjusted to " << DebugString(bias_shape); if (params->validation_only) return Status::OK(); @@ -4070,6 +4140,7 @@ Status ConvertBiasAdd(OpConverterParams* params) { params->converter->network()->addElementWise( *input_tensor, *bias_tensor, nvinfer1::ElementWiseOperation::kSUM); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + SetLayerName(layer, node_def, "sum"); nvinfer1::ITensor* output_tensor = layer->getOutput(0); params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); @@ -4298,16 +4369,18 @@ Status ConvertBinary(OpConverterParams* params) { nvinfer1::ITensor* tensor_r = nullptr; // This will also convert constants to tensors, and set quantization ranges. TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( - operand_l, broadcasted_dims_l, params->validation_only, &tensor_l)); + operand_l, broadcasted_dims_l, params->validation_only, &tensor_l, + node_def, /*op_instance=*/0)); TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( - operand_r, broadcasted_dims_r, params->validation_only, &tensor_r)); + operand_r, broadcasted_dims_r, params->validation_only, &tensor_r, + node_def, /*op_instance=*/1)); if (params->validation_only) return Status::OK(); // Add ElementWise layer. nvinfer1::ILayer* layer = params->converter->network()->addElementWise( *tensor_l, *tensor_r, op_pair->second); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); - layer->setName(node_def.name().c_str()); + SetLayerName(layer, node_def); nvinfer1::ITensor* trt_tensor = layer->getOutput(0); #if IS_TRT_VERSION_GE(5, 1, 0, 0) @@ -4315,6 +4388,7 @@ Status ConvertBinary(OpConverterParams* params) { layer = params->converter->network()->addUnary( *trt_tensor, nvinfer1::UnaryOperation::kFLOOR); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + SetLayerName(layer, node_def, "floor"); trt_tensor = layer->getOutput(0); } #endif @@ -4353,10 +4427,12 @@ Status ConvertRsqrt(OpConverterParams* params) { nvinfer1::IUnaryLayer* sqrt_layer = params->converter->network()->addUnary( *tensor, nvinfer1::UnaryOperation::kSQRT); TFTRT_RETURN_ERROR_IF_NULLPTR(sqrt_layer, node_def.name()); + SetLayerName(sqrt_layer, node_def, "sqrt"); // Recip nvinfer1::IUnaryLayer* recip_layer = params->converter->network()->addUnary( *sqrt_layer->getOutput(0), nvinfer1::UnaryOperation::kRECIP); TFTRT_RETURN_ERROR_IF_NULLPTR(recip_layer, node_def.name()); + SetLayerName(recip_layer, node_def, "recip"); params->outputs->push_back(TRT_TensorOrWeights(recip_layer->getOutput(0))); return Status::OK(); } @@ -4408,7 +4484,7 @@ Status ConvertUnary(OpConverterParams* params) { nvinfer1::IUnaryLayer* layer = params->converter->network()->addUnary(*tensor, op_pair->second); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); - layer->setName(node_def.name().c_str()); + SetLayerName(layer, node_def); nvinfer1::ITensor* output_tensor = layer->getOutput(0); // Set quantization ranges. @@ -4453,6 +4529,7 @@ Status ConvertSquare(OpConverterParams* params) { *inputs.at(0).tensor(), *const2_tensor, nvinfer1::ElementWiseOperation::kPOW); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + SetLayerName(layer, node_def); nvinfer1::ITensor* output_tensor = layer->getOutput(0); params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); @@ -4511,6 +4588,7 @@ Status ConvertReduce(OpConverterParams* params) { nvinfer1::ILayer* layer = params->converter->network()->addReduce( *tensor, reduce_operation, axes, keep_dims); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + SetLayerName(layer, node_def); params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0))); return Status::OK(); @@ -4585,21 +4663,25 @@ Status ConvertPack(OpConverterParams* params) { nvinfer1::Dims expanded_dims; TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(tensor_dims, &expanded_dims)); std::vector expanded_tensors; + int input_index = 0; for (const TRT_TensorOrWeights& input : inputs) { nvinfer1::ITensor* expanded_tensor = nullptr; if (input.is_tensor() && !params->use_implicit_batch && !HasStaticShape(dims)) { if (!params->validation_only) { TF_RETURN_IF_ERROR(params->converter->DynamicExpandDims( - input.tensor(), dims, trt_axis, params, &expanded_tensor)); + input.tensor(), dims, trt_axis, params, &expanded_tensor, + input_index)); } } else { TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( - input, expanded_dims, params->validation_only, &expanded_tensor)); + input, expanded_dims, params->validation_only, &expanded_tensor, + node_def, input_index)); } if (!params->validation_only) { expanded_tensors.push_back(expanded_tensor); } + input_index++; } if (params->validation_only) return Status::OK(); @@ -4615,6 +4697,7 @@ Status ConvertPack(OpConverterParams* params) { const_cast(expanded_tensors.data()), expanded_tensors.size()); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + SetLayerName(layer, node_def, "concat"); // Note that trt_axis stays the same even after expanding tensors at the axis. layer->setAxis(trt_axis); params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0))); @@ -4696,7 +4779,7 @@ Status ConvertPad(OpConverterParams* params) { if (pad_index[0] == 1) { legit_pad = false; TF_RETURN_IF_ERROR(params->converter->TransposeTensor( - tensor, {0, 3, 2, 1}, StrCat(node_def.name(), "_to_pad"), &tensor)); + tensor, {0, 3, 2, 1}, &tensor, node_def, "to_pad")); permuted_pad_index[0] = 3; } @@ -4714,13 +4797,13 @@ Status ConvertPad(OpConverterParams* params) { nvinfer1::IPaddingLayer* layer = params->converter->network()->addPadding( *tensor, pre_padding, post_padding); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + SetLayerName(layer, node_def); nvinfer1::ITensor* output_tensor = layer->getOutput(0); params->converter->MarkQuantizationRangesAsInferrable(tensor, output_tensor); if (!legit_pad) { TF_RETURN_IF_ERROR(params->converter->TransposeTensor( - output_tensor, {0, 3, 2, 1}, StrCat(node_def.name(), "_from_pad"), - &output_tensor)); + output_tensor, {0, 3, 2, 1}, &output_tensor, node_def, "from_pad")); } params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); @@ -4780,7 +4863,7 @@ Status ConvertSplitHelper(OpConverterParams* params, for (int i = 0; i < num_splits; ++i) { begin[trt_axis + 1] = i * split_size_on_axis; TF_RETURN_IF_ERROR(ConvertStridedSliceHelper( - params, input, begin, size, stride, final_shape_for_unpack_ptr)); + params, input, begin, size, stride, final_shape_for_unpack_ptr, i)); } return Status::OK(); } @@ -4854,6 +4937,7 @@ Status ConvertCast(OpConverterParams* params) { nvinfer1::ITensor* input = params->inputs.at(0).tensor(); nvinfer1::IIdentityLayer* layer = params->converter->network()->addIdentity(*input); + SetLayerName(layer, node_def); layer->setPrecision(nvinfer1::DataType::kFLOAT); if (layer->getOutput(0)->getType() != nvinfer1::DataType::kFLOAT) { @@ -4911,6 +4995,7 @@ Status ConvertConcat(OpConverterParams* params) { const_cast(input_tensors.data()), input_tensors.size()); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + SetLayerName(layer, node_def); layer->setAxis(trt_axis); params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0))); return Status::OK(); @@ -5057,7 +5142,7 @@ Status ConvertFusedBatchNorm(OpConverterParams* params) { combined_scale_weights.GetTrtWeights(), dummy_power_weights.GetTrtWeights()); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); - layer->setName(node_def.name().c_str()); + SetLayerName(layer, node_def); nvinfer1::ITensor* output_tensor = layer->getOutput(0); params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); return Status::OK(); @@ -5137,6 +5222,7 @@ Status ConvertGather(OpConverterParams* params) { nvinfer1::IGatherLayer* layer = params->converter->network()->addGather( *params_tensor, *indices_input.tensor(), trt_axis); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + SetLayerName(layer, node_def); nvinfer1::ITensor* output_tensor = layer->getOutput(0); nvinfer1::Dims trt_gather_output_dims = output_tensor->getDimensions(); @@ -5163,7 +5249,7 @@ Status ConvertGather(OpConverterParams* params) { TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( TRT_TensorOrWeights(output_tensor), trt_gather_output_dims, - /*validation_only=*/false, &output_tensor)); + /*validation_only=*/false, &output_tensor, node_def)); } params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); @@ -5173,7 +5259,7 @@ Status ConvertGather(OpConverterParams* params) { Status ConvertFullyConnectedHelper(OpConverterParams* params, nvinfer1::ITensor* tensor_a, TRT_ShapedWeights weights_b, - bool transpose_b, const string& node_name) { + bool transpose_b, const NodeDef& node_def) { // Reshape input to 3D - this will be a no-op unless using int8 precision. auto input_dim = tensor_a->getDimensions(); while (input_dim.nbDims < 3) { @@ -5181,7 +5267,7 @@ Status ConvertFullyConnectedHelper(OpConverterParams* params, } TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( TRT_TensorOrWeights(tensor_a), input_dim, /*validation_only=*/false, - &tensor_a)); + &tensor_a, node_def, /*op_instance=*/0)); // FC layer will transpose weights, so we need to pre-transpose. TRT_ShapedWeights weights(weights_b.TrtDType()); @@ -5197,7 +5283,8 @@ Status ConvertFullyConnectedHelper(OpConverterParams* params, params->converter->network()->addFullyConnected( *tensor_a, noutput, weights.GetTrtWeights(), biases.GetTrtWeights()); - TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_name); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + SetLayerName(layer, node_def); nvinfer1::ITensor* output_tensor = layer->getOutput(0); // Reshape output to 1D - this will be a no-op unless using int8 precision. @@ -5205,7 +5292,7 @@ Status ConvertFullyConnectedHelper(OpConverterParams* params, output_dim.nbDims = 1; TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( TRT_TensorOrWeights(output_tensor), output_dim, /*validation_only=*/false, - &output_tensor)); + &output_tensor, node_def, /*op_instance=*/1)); params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); return Status::OK(); @@ -5214,7 +5301,7 @@ Status ConvertFullyConnectedHelper(OpConverterParams* params, Status ConvertMatMulHelper(OpConverterParams* params, TRT_TensorOrWeights input_a, TRT_TensorOrWeights input_b, bool transpose_a, - bool transpose_b, string node_name) { + bool transpose_b, const NodeDef& node_def) { // TODO: ReorderCKtoKC is currently not general enough to transpose weights // that are not 2D. if ((transpose_a && input_a.is_weights() && @@ -5252,7 +5339,7 @@ Status ConvertMatMulHelper(OpConverterParams* params, if (should_use_fc || (can_use_fc && params->converter->precision_mode() == TrtPrecisionMode::INT8)) { return ConvertFullyConnectedHelper( - params, input_a.tensor(), input_b.weights(), transpose_b, node_name); + params, input_a.tensor(), input_b.weights(), transpose_b, node_def); } const auto get_matrix_op = [](nvinfer1::ITensor* in, @@ -5293,7 +5380,8 @@ Status ConvertMatMulHelper(OpConverterParams* params, *tensor_a, get_matrix_op(tensor_a, transpose_a), *tensor_b, get_matrix_op(tensor_b, transpose_b)); - TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_name); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + SetLayerName(layer, node_def); nvinfer1::ITensor* output_tensor = layer->getOutput(0); params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); return Status::OK(); @@ -5316,7 +5404,7 @@ Status ConvertMatMul(OpConverterParams* params) { bool transpose_b = attrs.get("transpose_b"); return ConvertMatMulHelper(params, inputs.at(0), inputs.at(1), transpose_a, - transpose_b, node_def.name()); + transpose_b, node_def); } Status ConvertBatchMatMul(OpConverterParams* params) { @@ -5379,14 +5467,16 @@ Status ConvertBatchMatMul(OpConverterParams* params) { nvinfer1::ITensor* tensor_l = nullptr; nvinfer1::ITensor* tensor_r = nullptr; TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( - inputs.at(0), broadcasted_dims_l, params->validation_only, &tensor_l)); + inputs.at(0), broadcasted_dims_l, params->validation_only, &tensor_l, + node_def)); TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( - inputs.at(1), broadcasted_dims_r, params->validation_only, &tensor_r)); + inputs.at(1), broadcasted_dims_r, params->validation_only, &tensor_r, + node_def)); if (params->validation_only) return Status::OK(); return ConvertMatMulHelper(params, TRT_TensorOrWeights(tensor_l), TRT_TensorOrWeights(tensor_r), transpose_a, - transpose_b, node_def.name()); + transpose_b, node_def); } Status ConvertSoftmax(OpConverterParams* params) { @@ -5408,6 +5498,7 @@ Status ConvertSoftmax(OpConverterParams* params) { nvinfer1::ISoftMaxLayer* layer = params->converter->network()->addSoftMax(*tensor); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + SetLayerName(layer, node_def); // Tensorflow SoftMax assumes applying softmax on the last dimension. layer->setAxes(1 << (num_trt_dims - 1)); @@ -5452,6 +5543,7 @@ Status ConvertArgMinMax(OpConverterParams* params) { nvinfer1::ITopKLayer* layer = params->converter->network()->addTopK( *inputs.at(0).tensor(), topk_op, 1, reduce_axes); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + SetLayerName(layer, node_def, "topk"); nvinfer1::ITensor* output_indices_tensor = layer->getOutput(1); // Squeeze on axis. @@ -5462,7 +5554,7 @@ Status ConvertArgMinMax(OpConverterParams* params) { nvinfer1::ITensor* output_tensor = nullptr; TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( TRT_TensorOrWeights(output_indices_tensor), new_dims, - /*validation_only=*/false, &output_tensor)); + /*validation_only=*/false, &output_tensor, node_def)); params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); return Status::OK(); @@ -5508,6 +5600,7 @@ Status ConvertTopK(OpConverterParams* params) { nvinfer1::ITopKLayer* layer = params->converter->network()->addTopK(*tensor, op, k, reduce_axes); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + SetLayerName(layer, node_def); nvinfer1::ITensor* output_value_tensor = layer->getOutput(0); nvinfer1::ITensor* output_indices_tensor = layer->getOutput(1); @@ -5583,6 +5676,7 @@ Status ConvertDepthSpaceShuffle(OpConverterParams* params) { nvinfer1::IShuffleLayer* first_shuffle = params->converter->network()->addShuffle(*inputs.at(0).tensor()); TFTRT_RETURN_ERROR_IF_NULLPTR(first_shuffle, node_def.name()); + SetLayerName(first_shuffle, node_def, "shuffle", /*op_instance=*/0); if (data_format == "NHWC") { first_shuffle->setFirstTranspose({2, 0, 1}); } @@ -5592,6 +5686,7 @@ Status ConvertDepthSpaceShuffle(OpConverterParams* params) { nvinfer1::IShuffleLayer* second_shuffle = params->converter->network()->addShuffle(*first_shuffle->getOutput(0)); TFTRT_RETURN_ERROR_IF_NULLPTR(second_shuffle, node_def.name()); + SetLayerName(second_shuffle, node_def, "shuffle", /*op_instance=*/1); second_shuffle->setReshapeDimensions(second_shuffle_shape); if (data_format == "NHWC") { second_shuffle->setSecondTranspose({1, 2, 0}); @@ -5619,9 +5714,11 @@ Status ConvertSquaredDifference(OpConverterParams* params) { nvinfer1::ITensor* tensor_l = nullptr; nvinfer1::ITensor* tensor_r = nullptr; TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( - inputs.at(0), broadcasted_dims_l, params->validation_only, &tensor_l)); + inputs.at(0), broadcasted_dims_l, params->validation_only, &tensor_l, + node_def)); TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( - inputs.at(1), broadcasted_dims_r, params->validation_only, &tensor_r)); + inputs.at(1), broadcasted_dims_r, params->validation_only, &tensor_r, + node_def)); if (params->validation_only) return Status::OK(); // Subtract x - y. @@ -5629,12 +5726,15 @@ Status ConvertSquaredDifference(OpConverterParams* params) { params->converter->network()->addElementWise( *tensor_l, *tensor_r, nvinfer1::ElementWiseOperation::kSUB); TFTRT_RETURN_ERROR_IF_NULLPTR(sub, node_def.name()); + SetLayerName(sub, node_def, "sub"); + // Multiply (x - y) * (x - y). nvinfer1::IElementWiseLayer* mul = params->converter->network()->addElementWise( *sub->getOutput(0), *sub->getOutput(0), nvinfer1::ElementWiseOperation::kPROD); TFTRT_RETURN_ERROR_IF_NULLPTR(mul, node_def.name()); + SetLayerName(mul, node_def, "mul"); params->outputs->push_back(TRT_TensorOrWeights(mul->getOutput(0))); return Status::OK(); @@ -5772,6 +5872,7 @@ Status ConvertCombinedNMS(OpConverterParams* params) { nvinfer1::IPluginV2Layer* layer = params->converter->network()->addPluginV2( &plugin_inputs[0], static_cast(plugin_inputs.size()), *plugin); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + SetLayerName(layer, node_def, "plugin"); // Set plugin outputs nvinfer1::ITensor* output_nmsed_boxes = layer->getOutput(1); @@ -5785,8 +5886,8 @@ Status ConvertCombinedNMS(OpConverterParams* params) { nvinfer1::ITensor* output_nmsed_scores = nullptr; nvinfer1::ITensor* output_nmsed_classes = nullptr; - auto shrink_last_dim = [params](nvinfer1::ITensor* in_tensor, - nvinfer1::ITensor** out_tensor) { + auto shrink_last_dim = [&](int output_index, nvinfer1::ITensor** out_tensor) { + nvinfer1::ITensor* in_tensor = layer->getOutput(output_index); nvinfer1::Dims dims = in_tensor->getDimensions(); if (dims.d[dims.nbDims - 1] != 1) { return errors::Internal("Expect last dims to be 1, for tensor ", @@ -5795,15 +5896,12 @@ Status ConvertCombinedNMS(OpConverterParams* params) { --dims.nbDims; TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( TRT_TensorOrWeights(in_tensor), dims, - /*validation_only=*/false, out_tensor)); + /*validation_only=*/false, out_tensor, node_def, output_index)); return Status::OK(); }; - TF_RETURN_IF_ERROR( - shrink_last_dim(layer->getOutput(2), &output_nmsed_scores)); - TF_RETURN_IF_ERROR( - shrink_last_dim(layer->getOutput(3), &output_nmsed_classes)); - TF_RETURN_IF_ERROR( - shrink_last_dim(layer->getOutput(0), &output_num_detections)); + TF_RETURN_IF_ERROR(shrink_last_dim(2, &output_nmsed_scores)); + TF_RETURN_IF_ERROR(shrink_last_dim(3, &output_nmsed_classes)); + TF_RETURN_IF_ERROR(shrink_last_dim(0, &output_num_detections)); #endif // IS_TRT_VERSION_GE(6, 0, 0, 0) params->outputs->push_back(TRT_TensorOrWeights(output_nmsed_boxes)); @@ -5845,6 +5943,12 @@ Status ConvertResize(OpConverterParams* params) { // Verify resize mode. Initialize resize mode if supported. nvinfer1::ResizeMode resize_mode; if (node_def.op() == "ResizeBilinear") { +#if IS_TRT_VERSION_GE(7, 1, 0, 0) + if (!align_corners) { + return errors::InvalidArgument( + "Cannot Convert Bilinear Resize when align_corners=False"); + } +#endif resize_mode = nvinfer1::ResizeMode::kLINEAR; } else if (node_def.op() == "ResizeNearestNeighbor") { resize_mode = nvinfer1::ResizeMode::kNEAREST; @@ -5858,7 +5962,7 @@ Status ConvertResize(OpConverterParams* params) { // Transpose tensor from NHWC to NCHW format. TF_RETURN_IF_ERROR(params->converter->TransposeTensor( - tensor, {0, 3, 1, 2}, StrCat(node_def.name(), "_to_NCHW"), &tensor)); + tensor, {0, 3, 1, 2}, &tensor, node_def, "to_NCHW")); // Calculate output dimensions. // Given input dimensions [N, C, H, W] and output size [H_out, W_out], @@ -5875,6 +5979,7 @@ Status ConvertResize(OpConverterParams* params) { nvinfer1::IResizeLayer* layer = params->converter->network()->addResize(*tensor); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + SetLayerName(layer, node_def); // Set layer parameters. layer->setResizeMode(resize_mode); @@ -5885,7 +5990,7 @@ Status ConvertResize(OpConverterParams* params) { nvinfer1::ITensor* output = layer->getOutput(0); TF_RETURN_IF_ERROR(params->converter->TransposeTensor( - output, {0, 2, 3, 1}, StrCat(node_def.name(), "_to_NHWC"), &output)); + output, {0, 2, 3, 1}, &output, node_def, "to_NHWC")); params->outputs->push_back(TRT_TensorOrWeights(output)); // Success return Status::OK(); @@ -5934,6 +6039,7 @@ Status ConvertAddN(OpConverterParams* params) { nvinfer1::ILayer* layer = params->converter->network()->addElementWise( *lhs, *rhs, nvinfer1::ElementWiseOperation::kSUM); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + SetLayerName(layer, node_def, std::to_string(i)); lhs = layer->getOutput(0); } params->outputs->push_back(TRT_TensorOrWeights(lhs)); @@ -6056,6 +6162,8 @@ Status ConvertGraphDefToEngine( VLOG(1) << "Starting to convert TensorFlow ops to TensorRT layers"; std::vector output_tensors; + int num_layers = converter->network()->getNbLayers(); + absl::flat_hash_set layer_names; // Graph nodes are already topologically sorted during construction for (const auto& node_def : gdef.node()) { const string& node_name = node_def.name(); @@ -6134,6 +6242,25 @@ Status ConvertGraphDefToEngine( } else { TF_RETURN_IF_ERROR(converter->ConvertNode(node_def)); } + + // To support TF-TRT profiling, we ensure each ILayer has a non-empty name. + // BuildCudaEngine returns an error if there is any ILayer name collision. + // We want to report the error here before BuildCudaEngine in a more + // meaningful way. + int new_num_layers = converter->network()->getNbLayers(); + for (int i = num_layers; i < new_num_layers; i++) { + auto layer = converter->network()->getLayer(i); + if (layer->getName() == nullptr || + !layer_names.insert(layer->getName()).second) { + std::string error_message = + absl::StrCat("Converting node ", node_name, ", op=", node_def.op(), + layer->getName() ? "create a layer with name collision" + : "create a layer without a name"); + LOG_WARNING_WITH_PREFIX << error_message; + return errors::Internal(error_message); + } + } + num_layers = new_num_layers; } TF_RETURN_IF_ERROR(converter->RenameAndMarkOutputTensors(output_tensors)); if (convert_successfully) *convert_successfully = true; diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h index a621735fad1..4a84793e254 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/types/optional.h" #include "tensorflow/compiler/tf2tensorrt/convert/utils.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h" @@ -515,14 +516,18 @@ class Converter { // Transpose 'input_tensor' with given permutation 'order_with_batch_dim' to // 'output_tensor'. The permutation 'order_with_batch_dim' contains the batch - // dimension which should always be 0. + // dimension which should always be 0. If this is for adding a transpose layer + // to support the conversion of 'node_def', callers need to provide a + // non-empty 'sub_op_name' appended to the name of 'node_def' to avoid layer + // name conflicts. Status TransposeTensor(nvinfer1::ITensor* input_tensor, const std::vector& order_with_batch_dim, - absl::string_view name, - nvinfer1::ITensor** output_tensor); + nvinfer1::ITensor** output_tensor, + const NodeDef& node_def, + absl::string_view sub_op_name = ""); - // Converts 'input' into 'tensor' with shape specified by 'dims' (which - // doesn't contain the batch dimension). + // Converts 'input' of 'node_def' into 'tensor' with shape specified by 'dims' + // (which doesn't contain the batch dimension). // // If validation_only is true, it doesn't do the conversion but only do some // minimum validation for the eligibility of the conversion, and *tensor will @@ -530,7 +535,9 @@ class Converter { Status PrepareTensorForShape(const TRT_TensorOrWeights& input, const nvinfer1::Dims& dims, const bool validation_only, - nvinfer1::ITensor** tensor); + nvinfer1::ITensor** tensor, + const NodeDef& node_def, + absl::optional op_instance = absl::nullopt); // Reshapes a dynamic shape tensor by removing or adding dimensions of size 1, // and/or permuting the dimensions. The new shape is derived from the shape of @@ -575,12 +582,14 @@ class Converter { Status DynamicReshape(nvinfer1::ITensor* input, std::vector> slices, OpConverterParams* params, nvinfer1::ITensor** output, - std::vector size_for_added_dims = {}); + std::vector size_for_added_dims = {}, + absl::optional op_instance = absl::nullopt); // Inserts a singleton dimension at axis for a dynamic shape tensor. Status DynamicExpandDims(nvinfer1::ITensor* input, const nvinfer1::Dims& dims, int axis, OpConverterParams* params, - nvinfer1::ITensor** output); + nvinfer1::ITensor** output, + absl::optional op_instance = absl::nullopt); // Helper function to add a squeeze op to the network. // @@ -667,6 +676,10 @@ class Converter { // acceptable by TRT. int batch_size_ = -1; + // Assign a ID to each constant layer we create, so that we can assign a + // unique name to the layer. + int next_constant_layer_id_ = 0; + friend class ConverterTest; friend class OpConverterTest; }; diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc index 72348c3cede..86e6f0dd345 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -203,6 +204,23 @@ void ExpectTrtDimsEqualsArray(const std::vector& lhs, << " actual: " << DebugString(rhs); } +void ExpectTrtLayerNames(absl::Span names, + nvinfer1::INetworkDefinition* network) { + EXPECT_EQ(network->getNbLayers(), names.size()); + + for (int i = 0; i < network->getNbLayers(); i++) { + auto layer = network->getLayer(i); + EXPECT_EQ(layer->getName(), names[i]); + } +} + +void VerifyTrtLayerNameNotEmpty(nvinfer1::INetworkDefinition* network) { + for (int i = 0; i < network->getNbLayers(); i++) { + auto layer = network->getLayer(i); + EXPECT_NE(layer->getName(), nullptr); + } +} + Matcher> ArrayFloatNear(const std::vector& values, float max_abs_error = 1e-5, bool nan_sensitive = false) { @@ -803,6 +821,8 @@ TEST_F(ConverterTest, ConvertNode) { TF_EXPECT_OK(GetTensorOrWeights("my_op:1", &actual_output_2)); EXPECT_EQ(&output_tensors[1], actual_output_2.tensor()); EXPECT_EQ(125, actual_output_2.tensor()->getDimensions().d[0]); + + VerifyTrtLayerNameNotEmpty(converter_->network()); } TEST_F(ConverterTest, AddAndGetInputs) { @@ -832,6 +852,8 @@ TEST_F(ConverterTest, AddAndGetInputs) { ExpectTrtDimsEqualsArray({1}, inputs[0].tensor()->getDimensions()); ExpectTrtDimsEqualsArray({2, 3}, inputs[2].tensor()->getDimensions()); ExpectTrtDimsEqualsArray({5, 3}, inputs[3].tensor()->getDimensions()); + + VerifyTrtLayerNameNotEmpty(converter_->network()); } TEST_F(ConverterTest, RenameAndMarkOutputTensors) { @@ -880,30 +902,33 @@ TEST_F(ConverterTest, RenameAndMarkOutputTensors) { } EXPECT_EQ("my_output", string(output_tensors[0]->getName())); EXPECT_EQ("my_output_1", string(output_tensors[1]->getName())); + + VerifyTrtLayerNameNotEmpty(converter_->network()); } TEST_F(ConverterTest, TransposeTensor) { nvinfer1::ITensor* input_tensor = converter_->network()->addInput( "", nvinfer1::DataType::kFLOAT, GetTestDims({2, 3, 5})); nvinfer1::ITensor* output_tensor = nullptr; - + NodeDef dummy_node_def = MakeNodeDef("dummy_op", "DummyOp", {}); // Rank doesn't match. ExpectStatus( - converter_->TransposeTensor(input_tensor, {0, 1}, "Bad perm", - &output_tensor), + converter_->TransposeTensor(input_tensor, {0, 1}, &output_tensor, + dummy_node_def, "sub1"), error::INVALID_ARGUMENT, "Rank of perm for transpose does not match with that of the input"); // Transpose at batch dimension. - ExpectStatus(converter_->TransposeTensor(input_tensor, {1, 0, 2, 3}, - "Batch perm", &output_tensor), - error::UNIMPLEMENTED, - "Transpose at batch dimension is not supported."); + ExpectStatus( + converter_->TransposeTensor(input_tensor, {1, 0, 2, 3}, &output_tensor, + dummy_node_def, "sub2"), + error::UNIMPLEMENTED, "Transpose at batch dimension is not supported."); // OK. - TF_EXPECT_OK(converter_->TransposeTensor(input_tensor, {0, 3, 1, 2}, "OK", - &output_tensor)); + TF_EXPECT_OK(converter_->TransposeTensor( + input_tensor, {0, 3, 1, 2}, &output_tensor, dummy_node_def, "sub3")); ExpectTrtDimsEqualsArray({5, 2, 3}, output_tensor->getDimensions()); + ExpectTrtLayerNames({"dummy_op-sub3"}, converter_->network()); } void TestPrepareTensorForShape( @@ -922,9 +947,11 @@ void TestPrepareTensorForShape( } nvinfer1::ITensor* output_tensor = nullptr; + NodeDef dummy_node_def = MakeNodeDef("dummy_op", "DummyOp", {}); for (bool validation_only : {false, true}) { const Status status = converter->PrepareTensorForShape( - input, GetTestDims(reshape_dims), validation_only, &output_tensor); + input, GetTestDims(reshape_dims), validation_only, &output_tensor, + dummy_node_def); if (expected_code == error::OK) { TF_EXPECT_OK(status); if (validation_only) { @@ -978,6 +1005,8 @@ TEST_F(ConverterTest, PrepareTensorForShape) { /*input_is_tensor=*/false, converter_.get(), weight_store_, error::INVALID_ARGUMENT, "Shape is not fully defined"); + + VerifyTrtLayerNameNotEmpty(converter_->network()); } TEST_F(ConverterTest, MaybeUpdateBatchSize) { @@ -1051,6 +1080,8 @@ TEST_F(ConverterTest, ProvideQuantizationRange) { // Symmetric range converter_->ProvideQuantizationRange(&fake_tensor, -6.123f, 6.123f); EXPECT_EQ(6.123f, quantization_ranges()[&fake_tensor]); + + VerifyTrtLayerNameNotEmpty(converter_->network()); } TEST_F(ConverterTest, MaybeApplyQuantizationRanges) { @@ -1077,6 +1108,8 @@ TEST_F(ConverterTest, MaybeApplyQuantizationRanges) { EXPECT_EQ(infer_3.getDynamicRange(), 5.0f); EXPECT_EQ(not_infer.getDynamicRange(), 100.0f); #endif + + VerifyTrtLayerNameNotEmpty(int8_converter->network()); } TEST_F(ConverterTest, PropagateQuantizationRanges) { @@ -1099,6 +1132,8 @@ TEST_F(ConverterTest, PropagateQuantizationRanges) { EXPECT_EQ(5.0f, ranges[&infer[i]]); } EXPECT_EQ(ranges.count(¬_infer), 0); + + VerifyTrtLayerNameNotEmpty(converter_->network()); } TEST_F(ConverterTest, GetTrtBroadcastShape) { @@ -1202,6 +1237,8 @@ TEST_F(ConverterTest, GetTrtBroadcastShape) { "(tensor #dims 4 vs broadcast #dims 5)"); symmetric_test({2, 3}, {7, 5}, kIsTensor, kIsTensor, {}, {}, error::INVALID_ARGUMENT, "Infeasible broadcast scheme"); + + VerifyTrtLayerNameNotEmpty(converter_->network()); } TEST_F(ConverterTest, CreateConstantLayer) { @@ -1216,6 +1253,8 @@ TEST_F(ConverterTest, CreateConstantLayer) { << DebugString(tensor->getType()); ExpectTrtDimsEqualsArray({3, 10}, tensor->getDimensions()); } + + VerifyTrtLayerNameNotEmpty(converter_->network()); } class ConvertGraphDefToEngineTest : public ::testing::Test { @@ -1575,6 +1614,9 @@ class OpConverterTest : public ::testing::Test { const char* expected_msg_substr = nullptr) { ExpectStatus(converter_->ConvertNode(node->def()), expected_code, expected_msg_substr); + if (expected_code == error::OK) { + VerifyTrtLayerNameNotEmpty(converter_->network()); + } } // Helper method to run both validation and conversion, when the expected @@ -1709,12 +1751,12 @@ class ParameterizedOpConverterTestBase std::tuple> { public: ParameterizedOpConverterTestBase() - : trt_mode(std::get<0>(GetParam())), - tf_type(std::get<1>(GetParam())), - converter_precision(std::get<2>(GetParam())) {} + : trt_mode_(std::get<0>(GetParam())), + tf_type_(std::get<1>(GetParam())), + converter_precision_(std::get<2>(GetParam())) {} void Reset() { - OpConverterTest::Reset(converter_precision, trt_mode); + OpConverterTest::Reset(converter_precision_, trt_mode_); input_data_.clear(); } @@ -1750,7 +1792,7 @@ class ParameterizedOpConverterTestBase if (!partial_input_shape_dims.empty()) { partial_shape = partial_input_shape_dims; } else { - if (trt_mode == TrtTestMode::kDynamicShape) { + if (trt_mode_ == TrtTestMode::kDynamicShape) { // In dynamic shape mode we make all dims unknown. partial_shape = std::vector(dims.size(), -1); } else { @@ -1776,7 +1818,7 @@ class ParameterizedOpConverterTestBase void AddTestTensor(const string& name, const std::vector& dims, const std::vector& values = {}, const std::vector& partial_input_shape_dims = {}) { - AddTestTensor(name, dims, tf_type, values, partial_input_shape_dims); + AddTestTensor(name, dims, tf_type_, values, partial_input_shape_dims); } // Builds and runs the converted network. Checks output tensor shape. Tests @@ -1796,7 +1838,7 @@ class ParameterizedOpConverterTestBase TensorShapeUtils::MakeShape(expected_output_dims[i], &shape)); string out_name = (n_output == 1) ? name : StrCat(name, ":", i); DataType out_tf_type = - out_tf_types.size() > i ? out_tf_types[i] : tf_type; + out_tf_types.size() > i ? out_tf_types[i] : tf_type_; InputOutputData data{ out_name, ConstructTensor(shape.num_elements(), 0, out_tf_type)}; output_data.push_back(data); @@ -1840,9 +1882,9 @@ class ParameterizedOpConverterTestBase } protected: - const TrtTestMode trt_mode; - const DataType tf_type; - const TrtPrecisionMode converter_precision; + const TrtTestMode trt_mode_; + const DataType tf_type_; + const TrtPrecisionMode converter_precision_; DataVec input_data_; }; @@ -2075,7 +2117,7 @@ TEST_P(OpConverterTest1, ConvertFusedBatchNorm) { 37.342354, 41.013527, 30.9738, 34.469433, 45.018955, 48.59309, 59.369415, 63.04059}; for (auto get_node_def : get_node_def_vec) { - NodeDef tmp_node_def = get_node_def(tf_type, "NCHW", true, 0); + NodeDef tmp_node_def = get_node_def(tf_type_, "NCHW", true, 0); std::string op_name = tmp_node_def.op(); std::vector test_param{ {"NHWC", 0, false, 0, @@ -2097,7 +2139,7 @@ TEST_P(OpConverterTest1, ConvertFusedBatchNorm) { errors::Unimplemented(StrCat("The input \"variance\" for ", op_name, " must be a constant, at my_batchnorm"))}, {"NCHW", 0, false, 0.01}}; // The last one is the only test that runs. - if (trt_mode == TrtTestMode::kDynamicShape) { + if (trt_mode_ == TrtTestMode::kDynamicShape) { test_param.push_back( {"NCHW", 0, false, 0.01, errors::InvalidArgument( @@ -2107,7 +2149,7 @@ TEST_P(OpConverterTest1, ConvertFusedBatchNorm) { for (auto p : test_param) { Reset(); NodeDef node_def = - get_node_def(tf_type, p.data_format, p.is_training, p.epsilon); + get_node_def(tf_type_, p.data_format, p.is_training, p.epsilon); for (int i = 0; i < node_input.size(); i++) { if (i == 0 || i == p.tensor_input_idx) { // The first input (x) is always added as a tensor, and it hase shape @@ -2126,7 +2168,7 @@ TEST_P(OpConverterTest1, ConvertFusedBatchNorm) { // the first arg is a tensor. TODO(tfeher) Check if one can relax this // restriction. Status expected_status = - (i != 0 && trt_mode == TrtTestMode::kImplicitBatch) + (i != 0 && trt_mode_ == TrtTestMode::kImplicitBatch) ? errors::InvalidArgument( StrCat("Batch size doesn't match for tensor ", node_input[i].name, @@ -2134,19 +2176,19 @@ TEST_P(OpConverterTest1, ConvertFusedBatchNorm) { "converter batch size: 3 vs 2")) : Status::OK(); std::vector partial_input_shape; - if (i == 0 && trt_mode == TrtTestMode::kDynamicShape && + if (i == 0 && trt_mode_ == TrtTestMode::kDynamicShape && !p.keep_channel_unknown) { // keep channel dim static (known) partial_input_shape.resize(4, -1); partial_input_shape[1] = node_input[i].dims[1]; } - AddTestTensor(node_input[i].name, node_input[i].dims, tf_type, + AddTestTensor(node_input[i].name, node_input[i].dims, tf_type_, node_input[i].val, partial_input_shape, expected_status); } else { AddTestWeights(node_input[i].name, node_input[i].dims, - node_input[i].val, tf_type); + node_input[i].val, tf_type_); } } TestOpConverter("my_batchnorm", node_def, node_input[0].dims, @@ -2154,12 +2196,12 @@ TEST_P(OpConverterTest1, ConvertFusedBatchNorm) { ArrayFloatNear(expected_output)); } } -} // namespace convert +} TEST_P(OpConverterTest1, ConvertTranspose) { // Get the NodeDef for Transpose. Scope s = Scope::NewRootScope(); - auto input = ops::Placeholder(s.WithOpName("input"), tf_type); + auto input = ops::Placeholder(s.WithOpName("input"), tf_type_); auto weights = ops::Placeholder(s.WithOpName("weights"), DT_INT32); auto transpose = ops::Transpose(s.WithOpName("my_transpose"), input, weights); const NodeDef& node_def = transpose.operation.node()->def(); @@ -2187,13 +2229,13 @@ TEST_P(OpConverterTest1, ConvertTranspose) { {}, {3, 2, 1, 1}, {3, 2, 1, 0}, - (trt_mode == TrtTestMode::kImplicitBatch) + (trt_mode_ == TrtTestMode::kImplicitBatch) ? Status(error::UNIMPLEMENTED, "Transpose at batch dimension is not supported") : Status::OK()}, TestParamBase{{1, 1, 2, 3}, {}, {1, 3, 1, 2}, {0, 3, 1, 2}}, }; - if (trt_mode == TrtTestMode::kDynamicShape) { + if (trt_mode_ == TrtTestMode::kDynamicShape) { // Dynamic shape tests where some shapes are known test_params.push_back(TestParamBase{ {1, 1, 2, 3}, {-1, 1, 2, -1}, {1, 3, 1, 2}, {0, 3, 1, 2}}); @@ -2317,19 +2359,22 @@ TEST_F(OpConverterTest, ConvertReshape) { TEST_P(OpConverterTest1, ConvertShape) { // Get the NodeDef for Shape op. Scope s = Scope::NewRootScope(); - auto input = ops::Placeholder(s.WithOpName("input"), tf_type); + auto input = ops::Placeholder(s.WithOpName("input"), tf_type_); auto shape = ops::Shape(s.WithOpName("my_shape"), input); const NodeDef& node_def = shape.operation.node()->def(); Status conversion_status = - (trt_mode == TrtTestMode::kImplicitBatch) + (trt_mode_ == TrtTestMode::kImplicitBatch) ? errors::Unimplemented( "Shape is only supported for explicit batch mode.") : Status::OK(); std::vector test_params = { - TestParamBase{{1, 2, 3}, {}, {3}, {}, conversion_status}, - // Add input as weight (we use non empty param ({1}) to trigger this). - TestParamBase{{1, 2, 3}, {}, {3}, {1}, conversion_status}, +// TODO(b/166274212): Enable the test parameter for TensorRT 7.1.3. +#if !IS_TRT_VERSION_GE(7, 1, 3, 0) + TestParamBase{{1, 2, 3}, {}, {3}, {}, conversion_status}, +#endif + // Add input as weight (we use non empty param ({1}) to trigger this). + TestParamBase{{1, 2, 3}, {}, {3}, {1}, conversion_status}, }; auto input_is_weight = [](const TestParamBase p) { return !p.param.empty(); }; @@ -2343,7 +2388,7 @@ TEST_P(OpConverterTest1, ConvertShape) { // we use for the unit test have no actual input tensor when it is converted // to a TensorRT network. int n_elements = 0; - if (input_is_weight(p) || trt_mode != TrtTestMode::kExplicitBatch) { + if (input_is_weight(p) || trt_mode_ != TrtTestMode::kExplicitBatch) { // Calculate the number of elements for adding input data. n_elements = std::accumulate(p.input_dims.begin(), p.input_dims.end(), 1, std::multiplies()); @@ -2352,7 +2397,7 @@ TEST_P(OpConverterTest1, ConvertShape) { if (!input_is_weight(p)) { AddTestTensor("input", p.input_dims, input_val); } else { - AddTestWeights("input", p.input_dims, input_val, tf_type); + AddTestWeights("input", p.input_dims, input_val, tf_type_); } TestOpConverter("my_shape", node_def, p.expected_output_dims, p.status, p.runtime_status, ElementsAreArray(p.input_dims), @@ -2617,7 +2662,7 @@ TEST_P(OpConverterTest2, ConvertBiasAdd) { for (const string& data_format : {"NHWC", "NCHW"}) { for (const int trt_input_rank : {1, 2, 3, 4}) { Reset(); - NodeDef node_def = get_biasadd_nodedef(data_format, tf_type); + NodeDef node_def = get_biasadd_nodedef(data_format, tf_type_); // Add input, dims_array will be like {2, 1, ..., 1, 3} std::vector dims_array(trt_input_rank + 1, 1); @@ -2639,7 +2684,7 @@ TEST_P(OpConverterTest2, ConvertBiasAdd) { for (int i = 0; i < channel_size; ++i) { bias[i] = i + 1; // bias will be {1, 2, 3, ...} } - AddTestWeights("weights", {channel_size}, bias, tf_type); + AddTestWeights("weights", {channel_size}, bias, tf_type_); // Build and run the engine. std::vector output_data; @@ -2675,7 +2720,7 @@ NodeDef GetBinaryOpNodeDef(DataType dtype) { TEST_P(OpConverterTest2, ConvertBinary) { { AttrValue dtype; - dtype.set_type(tf_type); + dtype.set_type(tf_type_); // Both inputs are weights. Reset(); NodeDef node_def = @@ -2720,19 +2765,19 @@ TEST_P(OpConverterTest2, ConvertBinary) { if (!op_test_info.count(op_name)) { FAIL() << "Binary op test map does not contain op " << op_name; } - NodeDef node_def = op_test_info[op_name].first(tf_type); + NodeDef node_def = op_test_info[op_name].first(tf_type_); std::vector input_names; std::vector> input_dims; std::vector> input_values; if (operand_1_is_tensor) { AddTestTensor("input1", {2, 1, 2}, {3, 6, 3, 6}); } else { - AddTestWeights("input1", {1, 2}, std::vector{3, 6}, tf_type); + AddTestWeights("input1", {1, 2}, std::vector{3, 6}, tf_type_); } if (operand_2_is_tensor) { AddTestTensor("input2", {2, 2, 1}, {2, 3, 2, 3}); } else { - AddTestWeights("input2", {2, 1}, std::vector{2, 3}, tf_type); + AddTestWeights("input2", {2, 1}, std::vector{2, 3}, tf_type_); } TestOpConverter("my_binary", node_def, {2, 2, 2}, Status::OK(), Status::OK(), @@ -2939,10 +2984,10 @@ TEST_P(OpConverterTest2, ConvertSquare) { // Input is weights, should fail. Reset(); Scope s = Scope::NewRootScope(); - auto input = ops::Placeholder(s.WithOpName("input"), tf_type); + auto input = ops::Placeholder(s.WithOpName("input"), tf_type_); auto square = ops::Square(s.WithOpName("my_square"), input); NodeDef node_def = square.operation.node()->def(); - AddTestWeights("input", {1, 2, 3}, {1, 2, 3, 4, -5, 6}, tf_type); + AddTestWeights("input", {1, 2, 3}, {1, 2, 3, 4, -5, 6}, tf_type_); RunValidationAndConversion( node_def, error::UNIMPLEMENTED, "The input \"x\" for Square must be a tensor, at my_square"); @@ -2951,7 +2996,7 @@ TEST_P(OpConverterTest2, ConvertSquare) { Reset(); Scope s = Scope::NewRootScope(); - auto input = ops::Placeholder(s.WithOpName("input"), tf_type); + auto input = ops::Placeholder(s.WithOpName("input"), tf_type_); auto square = ops::Square(s.WithOpName("my_square"), input); NodeDef node_def = square.operation.node()->def(); @@ -2964,7 +3009,7 @@ TEST_P(OpConverterTest2, ConvertSquare) { inputs[i] = value; expected_outputs[i] = value * value; } - AddTestTensor("input", {1, 1, 20}, tf_type, inputs); + AddTestTensor("input", {1, 1, 20}, tf_type_, inputs); TestOpConverter("my_square", node_def, {1, 1, 20}, Status::OK(), Status::OK(), ArrayFloatNear(expected_outputs, 0)); @@ -3091,7 +3136,7 @@ TEST_P(OpConverterTest1, ConvertActivation) { { // Input is weights, should fail. Reset(); - const NodeDef& node_def = CreateUnaryOp(tf_type); + const NodeDef& node_def = CreateUnaryOp(tf_type_); AddTestWeights("input", {1, 2, 3}, {-3, -2, -1, 0, 1, 2}); RunValidationAndConversion( node_def, error::UNIMPLEMENTED, @@ -3148,7 +3193,7 @@ TEST_P(OpConverterTest1, ConvertActivation) { FAIL() << "Activation op test map does not contain op " << op_name; } Reset(); - NodeDef node_def = op_map[op_name].first(tf_type); + NodeDef node_def = op_map[op_name].first(tf_type_); const std::vector input = {-100, -2, -1, 0, 1, 88}; AddTestTensor("input", p.input_dims, input); @@ -3176,7 +3221,7 @@ TEST_P(OpConverterTest1, ConvertActivation) { TEST_P(OpConverterTest1, ConvertExpandDims) { // Get the NodeDef for ExpandDims. Scope s = Scope::NewRootScope(); - auto input = ops::Placeholder(s.WithOpName("input"), tf_type); + auto input = ops::Placeholder(s.WithOpName("input"), tf_type_); auto weights = ops::Placeholder(s.WithOpName("weights"), DT_INT32); auto expanddims = ops::ExpandDims(s.WithOpName("my_expanddims"), input, weights); @@ -3204,7 +3249,7 @@ TEST_P(OpConverterTest1, ConvertExpandDims) { {}, {1, 1, 1, 2, 3}, {0}, - trt_mode == TrtTestMode::kImplicitBatch + trt_mode_ == TrtTestMode::kImplicitBatch ? Status(error::UNIMPLEMENTED, "TensorRT does not allow manipulation of the " "batch dimension, at my_expanddims") @@ -3213,7 +3258,7 @@ TEST_P(OpConverterTest1, ConvertExpandDims) { {}, {1, 1, 1, 2, 3}, {-5}, - trt_mode == TrtTestMode::kImplicitBatch + trt_mode_ == TrtTestMode::kImplicitBatch ? Status(error::UNIMPLEMENTED, "TensorRT does not allow manipulation of the " "batch dimension, at my_expanddims") @@ -3251,7 +3296,7 @@ TEST_P(OpConverterTest1, ConvertExpandDims) { } TEST_P(OpConverterTest1, ConvertSqueeze) { - const bool use_implicit_batch = (trt_mode == TrtTestMode::kImplicitBatch); + const bool use_implicit_batch = (trt_mode_ == TrtTestMode::kImplicitBatch); // Get the NodeDef for Squeeze. auto get_squeeze_nodedef = [](std::vector axes, DataType tf_type) -> NodeDef { @@ -3274,7 +3319,7 @@ TEST_P(OpConverterTest1, ConvertSqueeze) { {}, // input partial dims {2, 3}, // expected output dims {}, // axis - trt_mode == TrtTestMode::kExplicitBatch + trt_mode_ == TrtTestMode::kExplicitBatch ? Status::OK() : Status{error::UNIMPLEMENTED, "Squeeze is not implemented for empty squeeze_dims, at " @@ -3333,7 +3378,7 @@ TEST_P(OpConverterTest1, ConvertSqueeze) { "Dimension 2 with size 2 cannot be squeezed because it must be " "size 1, at my_squeeze"}}; - if (trt_mode == TrtTestMode::kDynamicShape) { + if (trt_mode_ == TrtTestMode::kDynamicShape) { // In this test we try to squeeze axis=2 which has size > 1. In dynamic // shape mode the converter sees only -1, so it cannot catch this error. squeeze_non_singleton.status = Status::OK(); // conversion status @@ -3348,7 +3393,7 @@ TEST_P(OpConverterTest1, ConvertSqueeze) { for (TestParamBase p : test_params) { SCOPED_TRACE(p); Reset(); - NodeDef node_def = get_squeeze_nodedef(p.param, tf_type); + NodeDef node_def = get_squeeze_nodedef(p.param, tf_type_); AddTestTensor("input", p.input_dims, {1, 2, 3, 4, 5, 6}, p.partial_input_dims); TestOpConverter("my_squeeze", node_def, p.expected_output_dims, p.status, @@ -4103,14 +4148,14 @@ TEST_F(OpConverterTest, ConvertSlice) { TEST_P(OpConverterTest1, ConvertConv2D) { // Get nodedef for Conv2D layer. - DataType tf_type_loc = tf_type; + DataType tf_type = tf_type_; auto get_conv2d_nodedef = - [tf_type_loc](std::vector strides = {1, 1, 1, 1}, - string padding = "SAME", string data_format = "NCHW", - std::vector dilations = {1, 1, 1, 1}) -> NodeDef { + [tf_type](std::vector strides = {1, 1, 1, 1}, + string padding = "SAME", string data_format = "NCHW", + std::vector dilations = {1, 1, 1, 1}) -> NodeDef { Scope s = Scope::NewRootScope(); - auto input = ops::Placeholder(s.WithOpName("input"), tf_type_loc); - auto filter = ops::Placeholder(s.WithOpName("weights"), tf_type_loc); + auto input = ops::Placeholder(s.WithOpName("input"), tf_type); + auto filter = ops::Placeholder(s.WithOpName("weights"), tf_type); ops::Conv2D::Attrs attrs = ops::Conv2D::Attrs().DataFormat(data_format).Dilations(dilations); auto conv2d = ops::Conv2D(s.WithOpName("my_conv2d"), input, filter, strides, @@ -4203,12 +4248,12 @@ TEST_P(OpConverterTest1, ConvertConv2D) { node_def, error::UNIMPLEMENTED, "Stride must be 1 for batch and channel dimensions, at my_conv2d"); } - if (trt_mode == TrtTestMode::kDynamicShape) { + if (trt_mode_ == TrtTestMode::kDynamicShape) { Reset(); NodeDef node_def = get_conv2d_nodedef(); // Channel dim unknown, should fail. AddTestTensorWithTFDims("input", {-1, -1, -1, -1}, - TfDataTypeToTrt(tf_type)); + TfDataTypeToTrt(tf_type_)); AddTestWeights("weights", {1, 2, 1, 1}, {-1, 1}); RunValidationAndConversion( node_def, error::INVALID_ARGUMENT, @@ -4230,8 +4275,6 @@ TEST_P(OpConverterTest1, ConvertConv2D) { // Ok. std::vector ok_params = { -// TODO(b/162447069): Enable the test parameters for TRT 7.1.3.x. -#if !IS_TRT_VERSION_GE(7, 1, 3, 0) // Basic TestParams{/*input_dims=*/{1, 1, 2, 3}, /*input=*/{0, 1, 2, 3, 3, 4}, @@ -4243,9 +4286,6 @@ TEST_P(OpConverterTest1, ConvertConv2D) { /*dilations=*/{1, 1, 1, 1}, /*expected_output_dims=*/{1, 1, 2, 2}, /*expected_output=*/{1, 1, 0, 1}}, -#endif -// TODO(b/162448349): Enable the test parameters for TRT 7.1.3.x. -#if !IS_TRT_VERSION_GE(7, 1, 3, 0) // SAME padding (Asymmetric) TestParams{/*input_dims=*/{1, 1, 2, 3}, /*input=*/{0, 1, 2, 3, 3, 4}, @@ -4268,9 +4308,6 @@ TEST_P(OpConverterTest1, ConvertConv2D) { /*dilations=*/{1, 1, 1, 1}, /*expected_output_dims=*/{1, 1, 2, 3}, /*expected_output=*/{1, 2, -1, 3, 1, -3}}, -#endif -// TODO(b/162447069): Enable the test parameters for TRT 7.1.3.x. -#if !IS_TRT_VERSION_GE(7, 1, 3, 0) // NHWC TestParams{/*input_dims=*/{1, 2, 3, 1}, /*input=*/{0, 1, 2, 3, 3, 4}, @@ -4304,7 +4341,6 @@ TEST_P(OpConverterTest1, ConvertConv2D) { /*dilations=*/{1, 1, 1, 1}, /*expected_output_dims=*/{1, 1, 2, 2}, /*expected_output=*/{1, 0, 1, 3}}, -#endif }; for (int i = 0; i < ok_params.size(); i++) { @@ -4313,15 +4349,15 @@ TEST_P(OpConverterTest1, ConvertConv2D) { get_conv2d_nodedef(ok_params[i].strides, ok_params[i].padding, ok_params[i].data_format, ok_params[i].dilations); std::vector partial_input_shape; - if (trt_mode == TrtTestMode::kDynamicShape) { + if (trt_mode_ == TrtTestMode::kDynamicShape) { // The channel dim cannot have unknown size, fix that. partial_input_shape.resize(ok_params[i].input_dims.size(), -1); int channel_id = (ok_params[i].data_format == "NCHW") ? 1 : 3; partial_input_shape[channel_id] = ok_params[i].input_dims[channel_id]; } - AddTestTensor("input", ok_params[i].input_dims, tf_type, ok_params[i].input, - partial_input_shape); + AddTestTensor("input", ok_params[i].input_dims, tf_type_, + ok_params[i].input, partial_input_shape); AddTestWeights("weights", ok_params[i].filter_dims, ok_params[i].filter); @@ -4848,7 +4884,7 @@ TEST_P(OpConverterTest1, ConvertPool) { for (int nDim : test_nDims) { // Input is weights, should fail. Reset(); - NodeDef node_def = get_pool_nodedef(tf_type, nDim); + NodeDef node_def = get_pool_nodedef(tf_type_, nDim); AddTestWeights("input", {1, 1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}); RunValidationAndConversion(node_def, error::UNIMPLEMENTED, @@ -4957,7 +4993,7 @@ TEST_P(OpConverterTest1, ConvertPool) { for (bool is_max_pooling : {true, false}) { Reset(); NodeDef node_def = - get_pool_nodedef(tf_type, nDim, ksize, strides, p.padding, + get_pool_nodedef(tf_type_, nDim, ksize, strides, p.padding, data_format, is_max_pooling); AddTestTensor("input", input_dims, input); TestOpConverter("my_pool", node_def, expected_output_dims, Status::OK(), @@ -5019,7 +5055,7 @@ TEST_F(OpConverterTest, ConvertTopK) { TEST_P(OpConverterTest3, ConvertGather) { // Get the NodeDef for GatherV2. Scope s = Scope::NewRootScope(); - auto params = ops::Placeholder(s.WithOpName("params"), tf_type); + auto params = ops::Placeholder(s.WithOpName("params"), tf_type_); auto indices = ops::Placeholder(s.WithOpName("indices"), DT_INT32); auto axis = ops::Placeholder(s.WithOpName("axis"), DT_INT32); auto gather = ops::GatherV2(s.WithOpName("my_gather"), params, indices, axis); @@ -5027,7 +5063,7 @@ TEST_P(OpConverterTest3, ConvertGather) { { // Axis is a tensor, should fail. Reset(); - AddTestTensor("params", {1, 1, 2, 3}, tf_type, {}); + AddTestTensor("params", {1, 1, 2, 3}, tf_type_, {}); AddTestTensor("indices", {1, 2}, DT_INT32, {}); AddTestTensor("axis", {1}, DT_INT32, {}); RunValidationAndConversion( @@ -5072,7 +5108,7 @@ TEST_P(OpConverterTest3, ConvertGather) { /*expected_output_shape=*/{2, 1, 1, 3}, /*expected_output=*/{4, 5, 6, 1, 2, 3}, /*params_is_tensor=*/true, - trt_mode == TrtTestMode::kImplicitBatch + trt_mode_ == TrtTestMode::kImplicitBatch ? Status{error::UNIMPLEMENTED, "TensorRT does not allow manipulation of the" " batch dimension, at my_gather"} @@ -5085,7 +5121,7 @@ TEST_P(OpConverterTest3, ConvertGather) { /*expected_output_shape=*/{2, 1, 2, 1}, /*expected_output=*/{3, 1, 6, 4}, /*params_is_tensor=*/true, - trt_mode == TrtTestMode::kImplicitBatch + trt_mode_ == TrtTestMode::kImplicitBatch ? Status{error::UNIMPLEMENTED, "Indices must have a batch size of 1 when params" " is a tensor."} @@ -5099,7 +5135,7 @@ TEST_P(OpConverterTest3, ConvertGather) { /*expected_output_shape=*/{2, 1, 2}, /*expected_output=*/{2, 3, 5, 6}, /*params_is_tensor=*/false, - trt_mode == TrtTestMode::kImplicitBatch + trt_mode_ == TrtTestMode::kImplicitBatch ? Status{error::UNIMPLEMENTED, "The input axis must be zero when params is a" " weight."} @@ -5112,13 +5148,13 @@ TEST_P(OpConverterTest3, ConvertGather) { /*expected_output_shape=*/{2}, /*expected_output=*/{2, 4}, /*params_is_tensor=*/true, - trt_mode == TrtTestMode::kImplicitBatch // conversion_status + trt_mode_ == TrtTestMode::kImplicitBatch // conversion_status ? Status{error::UNIMPLEMENTED, "TensorRT does not allow manipulation of the " "batch dimension, at my_gather"} : Status::OK(), - Status::OK(), // runtime_status - trt_mode == TrtTestMode::kImplicitBatch // add_index_status + Status::OK(), // runtime_status + trt_mode_ == TrtTestMode::kImplicitBatch // add_index_status ? Status{error::INVALID_ARGUMENT, "Batch size doesn't match for tensor indices: " "Provided batch size does not match converter " @@ -5233,7 +5269,7 @@ TEST_P(OpConverterTest3, ConvertGather) { if (p.params_is_tensor) { AddTestTensor("params", p.params_shape, params_input); } else { - AddTestWeights("params", p.params_shape, params_input, tf_type); + AddTestWeights("params", p.params_shape, params_input, tf_type_); } AddTestTensor("indices", p.indices_shape, DT_INT32, p.indices, {}, p.add_index_status); @@ -5273,7 +5309,7 @@ TEST_P(OpConverterTest1, ConvertReduce) { { // Input is weights, should fail. Reset(); - const NodeDef node_def = CreateReduceOp(tf_type, false); + const NodeDef node_def = CreateReduceOp(tf_type_, false); AddTestWeights("input", {1, 2, 3}, {-3, -2, -1, 0, 1, 2}); AddTestWeights("axis", {1}, {1}); RunValidationAndConversion( @@ -5283,7 +5319,7 @@ TEST_P(OpConverterTest1, ConvertReduce) { { // Axis is weights, should fail. Reset(); - const NodeDef node_def = CreateReduceOp(tf_type, false); + const NodeDef node_def = CreateReduceOp(tf_type_, false); AddTestTensor("input", {1, 2, 3}, {-3, -2, -1, 0, 1, 2}); AddTestTensor("axis", {1}, DT_INT32, {1}); RunValidationAndConversion( @@ -5343,7 +5379,7 @@ TEST_P(OpConverterTest1, ConvertReduce) { for (auto p : params) { SCOPED_TRACE(StrCat(op.name, keep_dims ? "keep_dims" : "")); Reset(); - NodeDef node_def = op.get_node(tf_type, keep_dims); + NodeDef node_def = op.get_node(tf_type_, keep_dims); AddTestTensor("input", p.input_dims, p.input_values); AddTestWeights("axis", {static_cast(p.axis.size())}, @@ -5363,7 +5399,7 @@ TEST_P(OpConverterTest1, ConvertReduce) { int ax_positive = ax >= 0 ? ax : ax + rank; // Zero marks elements that we will remove later. expected_output_dims[ax_positive] = keep_dims ? 1 : 0; - if (trt_mode == TrtTestMode::kImplicitBatch && + if (trt_mode_ == TrtTestMode::kImplicitBatch && (ax == 0 || ax == -rank)) { p.conversion_status = errors::Unimplemented( "TensorRT does not allow manipulation of the batch " @@ -5399,7 +5435,7 @@ TEST_P(OpConverterTest1, ConvertUnary) { { // Input is weights, should fail. Reset(); - const NodeDef node_def = CreateUnaryOp(tf_type); + const NodeDef node_def = CreateUnaryOp(tf_type_); AddTestWeights("input", {1, 2, 3}, {-3, -2, -1, 0, 1, 2}); RunValidationAndConversion( node_def, error::UNIMPLEMENTED, @@ -5455,7 +5491,7 @@ TEST_P(OpConverterTest1, ConvertUnary) { if (!op_map.count(op_name)) { FAIL() << "Unary op test map does not contain op " << op_name; } - NodeDef node_def = op_map[op_name].first(tf_type); + NodeDef node_def = op_map[op_name].first(tf_type_); // TODO(bixia): we assume this test is only instantiated for DT_FLOAT for // now. Need to find a better way to express input and output types. @@ -5463,7 +5499,7 @@ TEST_P(OpConverterTest1, ConvertUnary) { // TODO(tfeher): improve tests by defining an expected output data type and // check that. Currently only the shape and values of the output are // checked. - DataType input_tf_type = op_name == "Cast" ? DT_HALF : tf_type; + DataType input_tf_type = op_name == "Cast" ? DT_HALF : tf_type_; std::vector input_values{-0.9f, 0.6f, 0.0f, -3.5f, 100.0f, 2.9f}; AddTestTensor("input", p.input_dims, input_tf_type, input_values); @@ -6030,7 +6066,7 @@ TEST_P(OpConverterTest2, ConvertPack) { /*axis=*/1, /*expected_output_dims=*/{1, 2, 2, 3}, /*expected_output=*/InitTestVector(12), - trt_mode == TrtTestMode::kImplicitBatch + trt_mode_ == TrtTestMode::kImplicitBatch ? Status{error::UNIMPLEMENTED, "The input \"values_1\" for Pack must be a tensor, at " "my_pack"} @@ -6056,7 +6092,7 @@ TEST_P(OpConverterTest2, ConvertPack) { /*axis=*/-4, /*expected_output_dims=*/{2, 1, 2, 3}, /*expected_output=*/InitTestVector(12), - trt_mode == TrtTestMode::kImplicitBatch + trt_mode_ == TrtTestMode::kImplicitBatch ? Status{error::UNIMPLEMENTED, "TensorRT does not allow manipulation of the batch " "dimension, at my_pack"} @@ -6116,7 +6152,7 @@ TEST_P(OpConverterTest2, ConvertPack) { }, }; // Inputs have inconsistent shapes, should fail. - if (trt_mode != TrtTestMode::kDynamicShape) { + if (trt_mode_ != TrtTestMode::kDynamicShape) { params.push_back(TestParams{ /*input_shapes=*/{{1, 2, 3}, {1, 3, 2}}, /*partial_input_shapes=*/{{}, {}}, @@ -6136,7 +6172,7 @@ TEST_P(OpConverterTest2, ConvertPack) { // TODO(tfeher) Add dynamic shapes test once TRT handles shape error // decently } - if (trt_mode == TrtTestMode::kDynamicShape) { + if (trt_mode_ == TrtTestMode::kDynamicShape) { // Test with mixed dynamic / static shape input tensors params.push_back( TestParams{/*input_shapes=*/{{1, 2, 3}, {1, 2, 3}}, @@ -6152,14 +6188,14 @@ TEST_P(OpConverterTest2, ConvertPack) { const int num_inputs = p.input_shapes.size(); EXPECT_EQ(num_inputs, p.input_values.size()); - NodeDef node_def = GetPackNodeDef(tf_type, num_inputs, p.axis); + NodeDef node_def = GetPackNodeDef(tf_type_, num_inputs, p.axis); // Create inputs. for (int j = 0; j < num_inputs; ++j) { if (j == 1 && p.input_1_is_weight) { AddTestWeights(StrCat("values_", j), p.input_shapes[j], - p.input_values[j], tf_type); + p.input_values[j], tf_type_); } else { - AddTestTensor(StrCat("values_", j), p.input_shapes[j], tf_type, + AddTestTensor(StrCat("values_", j), p.input_shapes[j], tf_type_, p.input_values[j], p.partial_input_shapes[j]); } } @@ -6687,7 +6723,7 @@ TEST_P(OpConverterTest2, ConvertSquaredDifference) { { // Input is a weight, should fail. Reset(); - NodeDef node_def = GetSquaredDifferenceNodeDef(tf_type); + NodeDef node_def = GetSquaredDifferenceNodeDef(tf_type_); AddTestWeights("x", {1, 2, 3}, {1, 2, 3, 4, 5, 6}); AddTestTensor("y", {1, 1, 2, 3}); RunValidationAndConversion(node_def, error::UNIMPLEMENTED, @@ -6714,7 +6750,7 @@ TEST_P(OpConverterTest2, ConvertSquaredDifference) { /*value_y=*/std::vector(7 * 5, 0), /*expected_output_dims=*/{1, 1, 2, 3}, /*expected_output=*/common_input, - trt_mode == TrtTestMode::kDynamicShape + trt_mode_ == TrtTestMode::kDynamicShape ? Status::OK() : errors::InvalidArgument("Infeasible broadcast scheme"), errors::Internal( @@ -6740,7 +6776,7 @@ TEST_P(OpConverterTest2, ConvertSquaredDifference) { for (auto p : params) { Reset(); - NodeDef node_def = GetSquaredDifferenceNodeDef(tf_type); + NodeDef node_def = GetSquaredDifferenceNodeDef(tf_type_); AddTestTensor("x", p.dims_x, p.value_x); AddTestTensor("y", p.dims_y, p.value_y); TestOpConverter("my_squared_diff", node_def, p.expected_output_dims, @@ -6776,9 +6812,7 @@ template void TestConvertResize(OpConverterTest* test) { typedef typename EnumToDataType::Type CType; - std::vector> params{ -// TODO(b/162442839): Enable the test parameters for TRT 7.1.3.x. -#if !IS_TRT_VERSION_GE(7, 1, 3, 0) + std::vector> params { { /*input_dims=*/{1, 2, 1}, // H, W, C /*output_resize_dims=*/{2, 3}, // H_out, W_out @@ -6790,7 +6824,6 @@ void TestConvertResize(OpConverterTest* test) { /*expected_bilinear_output_values=*/ CastTestVector({2.0f, 0.f, -1.0f, 2.0f, 0.f, -1.0f}), }, -#endif { /*input_dims=*/{1, 2, 1}, // H, W, C /*output_resize_dims=*/{2, 3}, // H_out, W_out @@ -6804,6 +6837,13 @@ void TestConvertResize(OpConverterTest* test) { } }; +// This use case is not supported as of TRT version 7.1 +#if IS_TRT_VERSION_GE(7, 1, 0, 0) + if (std::is_same::value) { + params.erase(params.begin()); + } +#endif + for (int i = 0; i < params.size(); ++i) { test->Reset(); // Create resize node. @@ -6846,7 +6886,7 @@ TEST_F(OpConverterTest, ConvertResize) { // First input is weight, should fail. Reset(); NodeDef node_def = - MakeResizeNodeDef("my_resize", DT_FLOAT, false); + MakeResizeNodeDef("my_resize", DT_FLOAT, true); AddTestWeights("input", {1, 2}, {1, 2}); AddTestWeights("size", {1, 2}, {1, 2}); RunValidationAndConversion( @@ -6858,7 +6898,7 @@ TEST_F(OpConverterTest, ConvertResize) { // output dimension is a tensor, should fail. Reset(); NodeDef node_def = - MakeResizeNodeDef("my_resize", DT_FLOAT, false); + MakeResizeNodeDef("my_resize", DT_FLOAT, true); AddTestTensor("input", {1, 2}); AddTestTensor("size", {1, 2}); RunValidationAndConversion( diff --git a/tensorflow/compiler/tf2tensorrt/segment/segment.cc b/tensorflow/compiler/tf2tensorrt/segment/segment.cc index 1337a733f91..84f25d355ae 100644 --- a/tensorflow/compiler/tf2tensorrt/segment/segment.cc +++ b/tensorflow/compiler/tf2tensorrt/segment/segment.cc @@ -478,7 +478,7 @@ absl::Span GetInputsToDeterminateBatchSize( "Add", "AddV2", "Mul", - "Sub" + "Sub", "Div", "FloorDiv", "RealDiv", @@ -646,10 +646,12 @@ ClusterBatchSize GetClusterBatchSizeForNode( return cluster_batch_size; } + // As shape inference cannot provide any useful information about the batch + // size, we keep it as missing. if (!graph_properties || !graph_properties->HasInputProperties(node->name())) { VLOG(3) << "doesn't have input property"; - return cluster_batch_size.SetBatchSizeValue(-1); + return cluster_batch_size; } const std::vector& input_properties = @@ -660,7 +662,8 @@ ClusterBatchSize GetClusterBatchSizeForNode( const TensorShapeProto* leading_shape = optional_leading_shape.value(); DCHECK(!leading_shape->unknown_rank() && leading_shape->dim_size() >= 2); - return cluster_batch_size.SetBatchSizeValue(leading_shape->dim(0).size()); + VLOG(3) << "has batch size " << leading_shape->dim(0).size(); + return cluster_batch_size.SetBatchSize(leading_shape->dim(0).size()); } void AddSegmentForNode(const grappler::GraphProperties* graph_properties, @@ -668,12 +671,28 @@ void AddSegmentForNode(const grappler::GraphProperties* graph_properties, SimpleNode* node, const DeviceNameUtils::ParsedName& device_name, bool use_implicit_batch) { - segments->emplace_back( - node, + ClusterProperty property( GetClusterBatchSizeForNode(graph_properties, node == nullptr ? nullptr : node->tf_node(), use_implicit_batch), device_name); + segments->emplace_back(node, std::move(property)); +} + +bool OpBatchSizeExceedMaximumBatchSize( + const grappler::GraphProperties* graph_properties, const Node* node, + bool use_implicit_batch, absl::optional maximum_batch_size) { + ClusterBatchSize cluster_batch_size = + GetClusterBatchSizeForNode(graph_properties, node, use_implicit_batch); + // If the batch size is dynamic, then the negative dynamic batch size + // identifier shall never be larger than the positive max batch size. + if (cluster_batch_size.HasBatchSize() && maximum_batch_size.has_value() && + cluster_batch_size.GetBatchSize() > maximum_batch_size.value()) { + VLOG(2) << "OP batch size " << cluster_batch_size.GetBatchSize() + << " max_batch_size " << maximum_batch_size.value(); + return true; + } + return false; } } // namespace @@ -690,6 +709,10 @@ Status SegmentGraph(const Graph* tf_graph, "Explicit batch mode should allow dynamic non-batch dimensions"); } + if (options.use_implicit_batch && !options.maximum_batch_size.has_value()) { + return errors::Internal("Implicit batch mode requires maximum_batch_size"); + } + if (!options.allow_dynamic_non_batch_dim && !graph_properties) { return errors::Internal( "Need graph propertities to disallow dynamic non-batch dimensions"); @@ -768,6 +791,14 @@ Status SegmentGraph(const Graph* tf_graph, << "(Op type: " << node->tf_node()->type_string() << "), " << "(Op name: " << node->name() << ")"; exclude_node("Denylisted with the env var TF_TRT_OP_DENYLIST"); + } else if (OpBatchSizeExceedMaximumBatchSize( + graph_properties, node->tf_node(), + options.use_implicit_batch, options.maximum_batch_size)) { + LOG_WARNING_WITH_PREFIX + << "Implicit batch mode requires OP batch size not larger than " + << "the converter maximum batch size: " + << "(Op name: " << node->name() << ")"; + exclude_node("OP batch size too large"); } else { VLOG(2) << "Accepted as a TF-TRT candidate, " << "(Op type: " << node->tf_node()->type_string() << "), " @@ -819,9 +850,9 @@ Status SegmentGraph(const Graph* tf_graph, // step until no output edges can be further contracted. This is because // contracting an output edge may unblock new edges for contracting. ClusterBatchSize expected_batch_size = - node_segments[node->id()].BatchSize(); + node_segments[node->id()].Property().BatchSize(); DeviceNameUtils::ParsedName expected_device_name = - node_segments[node->id()].DeviceName(); + node_segments[node->id()].Property().DeviceName(); VLOG(3) << "batch size " << expected_batch_size; while (true) { std::set contract_edges; @@ -842,7 +873,7 @@ Status SegmentGraph(const Graph* tf_graph, continue; } // Out node must have compatible batch size. - ClusterBatchSize out_batch_size = out_cluster->BatchSize(); + ClusterBatchSize out_batch_size = out_cluster->Property().BatchSize(); ClusterBatchSize merged_batch_size = expected_batch_size; if (!merged_batch_size.MergeIfCompatible(out_batch_size)) { VLOG(3) << "... ... incompatible batch sizes " @@ -852,7 +883,7 @@ Status SegmentGraph(const Graph* tf_graph, } const DeviceNameUtils::ParsedName& out_device_name = - out_cluster->DeviceName(); + out_cluster->Property().DeviceName(); absl::optional merged_device_name = MergeIfCompatible(expected_device_name, out_device_name); if (!merged_device_name.has_value()) { @@ -898,11 +929,13 @@ Status SegmentGraph(const Graph* tf_graph, graph->RemoveEdge(r); } } - if (expected_batch_size != node_segments[node->id()].BatchSize()) { + if (expected_batch_size != + node_segments[node->id()].Property().BatchSize()) { return errors::Internal( "expected batch size is not the same as the actual batch size"); } - if (expected_device_name != node_segments[node->id()].DeviceName()) { + if (expected_device_name != + node_segments[node->id()].Property().DeviceName()) { return errors::Internal( "expected device name is not the same as the actual device name"); } diff --git a/tensorflow/compiler/tf2tensorrt/segment/segment.h b/tensorflow/compiler/tf2tensorrt/segment/segment.h index 3f79983cfd2..bab6e089fa4 100644 --- a/tensorflow/compiler/tf2tensorrt/segment/segment.h +++ b/tensorflow/compiler/tf2tensorrt/segment/segment.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/types/optional.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/grappler/costs/graph_properties.h" @@ -38,6 +39,9 @@ struct SegmentOptions { // Segment must contain at least this many nodes. int minimum_segment_size = 2; bool use_implicit_batch = true; + // The maximum batch size used to build the engines in the graph, when + // use_implicit_batch is true. + absl::optional maximum_batch_size = absl::nullopt; // When use_implicit_batch is false or when we are building dynamic engines, // we allow dynamic non-batch dimensions. bool allow_dynamic_non_batch_dim = false; diff --git a/tensorflow/compiler/tf2tensorrt/segment/segment_test.cc b/tensorflow/compiler/tf2tensorrt/segment/segment_test.cc index bf277328fe7..ee406c9743f 100644 --- a/tensorflow/compiler/tf2tensorrt/segment/segment_test.cc +++ b/tensorflow/compiler/tf2tensorrt/segment/segment_test.cc @@ -108,8 +108,9 @@ class SegmentTest : public ::testing::Test { segment_options_.allow_dynamic_non_batch_dim = true; } - void EnableImplicitBatchModeForStaticEngine() { + void EnableImplicitBatchModeForStaticEngine(int maximum_batch_size = 1000) { segment_options_.use_implicit_batch = true; + segment_options_.maximum_batch_size = maximum_batch_size; segment_options_.allow_dynamic_non_batch_dim = false; } @@ -487,7 +488,11 @@ TEST_F(SegmentTest, TwoChainsDiffBatchSizes) { const std::set all_nodes = {"const-scalar", "output-0", "output-1"}; EnableImplicitBatchModeForStaticEngine(); RunTest(&g, &static_graph_properties, all_nodes, all_nodes, all_nodes, - {{"output-0", "const-scalar"}}); + /*expected_segments=*/{{"output-0", "const-scalar"}}); + + EnableImplicitBatchModeForStaticEngine(1); + RunTest(&g, &static_graph_properties, all_nodes, all_nodes, all_nodes, + /*expected_segments=*/{}); } TEST_F(SegmentTest, SameRankImplicitBroadcastingStaticBatchSize) { diff --git a/tensorflow/compiler/tf2tensorrt/segment/union_find.cc b/tensorflow/compiler/tf2tensorrt/segment/union_find.cc new file mode 100644 index 00000000000..9aa7783b637 --- /dev/null +++ b/tensorflow/compiler/tf2tensorrt/segment/union_find.cc @@ -0,0 +1,128 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2tensorrt/segment/union_find.h" + +#include "absl/strings/str_format.h" +#include "tensorflow/compiler/tf2tensorrt/convert/utils.h" +#include "tensorflow/core/lib/core/errors.h" + +#if GOOGLE_CUDA && GOOGLE_TENSORRT + +namespace tensorflow { +namespace tensorrt { +namespace segment { + +namespace { +template +inline bool CheckIfCompatible(const absl::optional& a, + const absl::optional& b) { + if (a.has_value() && b.has_value()) { + return *a == *b; + } + return true; +} + +template +inline bool UnifyValues(absl::optional& a, absl::optional& b) { + if (a.has_value()) { + b = a; + } else { + a = b; + } + return true; +} + +template +inline absl::optional MergeCompatible(const absl::optional& a, + const absl::optional& b) { + DCHECK(CheckIfCompatible(a, b)); + return a.has_value() ? a : b; +} + +} // namespace + +ClusterBatchSize::ClusterBatchSize() : batch_size_(absl::nullopt) {} + +bool ClusterBatchSize::operator==(const ClusterBatchSize& other) { + return batch_size_ == other.batch_size_; +} + +ClusterBatchSize& ClusterBatchSize::SetBatchSize(int batch_size) { + SetBatchSize(static_cast>(batch_size)); + return *this; +} + +ClusterBatchSize& ClusterBatchSize::SetBatchSize( + const absl::optional& batch_size) { + batch_size_ = MergeCompatible(batch_size_, batch_size); + return *this; +} + +bool ClusterBatchSize::HasBatchSize() const { return batch_size_.has_value(); } + +int ClusterBatchSize::GetBatchSize() const { + DCHECK(HasBatchSize()); + return batch_size_.value(); +} + +bool ClusterBatchSize::MergeIfCompatible(const ClusterBatchSize& other) { + if (!CheckIfCompatible(batch_size_, other.batch_size_)) { + return false; + } + SetBatchSize(other.batch_size_); + return true; +} + +string ClusterBatchSize::ToString() const { + string s; + absl::StrAppendFormat(&s, "batch_size=("); + if (HasBatchSize()) { + absl::StrAppendFormat(&s, "%d", GetBatchSize()); + } else { + absl::StrAppendFormat(&s, "?"); + } + absl::StrAppend(&s, ")"); + return s; +} + +ClusterProperty::ClusterProperty(const ClusterBatchSize& batch_size, + const DeviceNameUtils::ParsedName& device_name) + : batch_size_(batch_size), device_name_(device_name) {} + +Status ClusterProperty::Merge(const ClusterProperty& other) { + ClusterBatchSize merged_batch_size(batch_size_); + if (!merged_batch_size.MergeIfCompatible(other.batch_size_)) { + return errors::Internal( + "trying to merge clusters with incompatible batch sizes."); + } + + absl::optional merged_device_name = + MergeIfCompatible(device_name_, other.device_name_); + if (!merged_device_name.has_value()) { + return errors::Internal( + "trying to merge clusters with incompatible device assignment."); + } + + batch_size_ = std::move(merged_batch_size); + device_name_ = std::move(merged_device_name.value()); + return Status::OK(); +} + +} // namespace segment +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT diff --git a/tensorflow/compiler/tf2tensorrt/segment/union_find.h b/tensorflow/compiler/tf2tensorrt/segment/union_find.h index b91f5771ce5..c72ea1f7553 100644 --- a/tensorflow/compiler/tf2tensorrt/segment/union_find.h +++ b/tensorflow/compiler/tf2tensorrt/segment/union_find.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2TENSORRT_SEGMENT_UNION_FIND_H_ #define TENSORFLOW_COMPILER_TF2TENSORRT_SEGMENT_UNION_FIND_H_ -#include "absl/strings/str_format.h" #include "absl/types/optional.h" +#include "tensorflow/core/util/device_name_utils.h" #if GOOGLE_CUDA && GOOGLE_TENSORRT @@ -28,110 +28,67 @@ namespace segment { // ClusterBatchSize is a data structure to record the batch size we have seen // for a cluster during segmentation. // -// When constructing clusters for implicit batch mode, we support the -// with both dynamic batch size and static batch size. We restrict nodes inside -// a cluster to either have dynamic batch size or have the same value for static -// batch size. For this reason, we use a field has_dynamic_batch_value_ to keep -// track of whether the cluster has any node with dynamic batch size. We use -// field static_batch_value_ to keep track of whether the cluster has any node -// with static batch size and what the value of the static batch size, if any. -// Examples: +// With the help of shape inference, all the dynamic batch sizes are converted +// to a negative integer number. +// If the number is -1, then nothing is known about the dynamic batch size. +// Ideally, we should not put nodes with -1 batch size into the same cluster, +// as they will likely have different batch sizes at runtime. However, we +// currently treat -1 as an equivalent class for simple implementation. We may +// need to revise this if it causes performance issues. +// If the number is strictly less than -1, then it represents a equivalent +// class. It is infered that all the nodes with the same equivalent class +// (strictly less than -1) shall have the same batch size at runtime. +// +// When constructing clusters for implicit batch mode, we support both +// dynamic batch sizes and static batch sizes. As all the nodes inside the same +// cluster shall have the same batch size at runtime, we restrict nodes inside a +// cluster to either have the same dynamic batch size equivalent class or the +// same static batch size value. +// // cluster: a = a1[1,3] + a1[1,3] -// ClusterBatchSize: has_dynamic_batch_size_ = false -// static_batch_value_ = {has value, 1} +// ClusterBatchSize: batch_size_ = 1 // // cluster: b = b1[-1,3] + b2[-1, 3] -// ClusterBatchSize: has_dynamic_batch_size_ = true -// static_batch_value_ = {has no value} +// ClusterBatchSize: batch_size_ = -1 // -// cluster: a = a1[1,3] + a1[1,3]; b = b1[-1,3] + b2[-1, 3] -// ClusterBatchSize: has_dynamic_batch_size_ = true -// static_batch_value_ = {has value, 1} +// cluster: c = c1[-2,3] + c2[-2, 3] +// ClusterBatchSize: batch_size_ = -2 // // When constructing cluster for explicit batch mode, all ClusterBatchSize is // irrelevant. // -// -absl::optional static_batch_value_; + class ClusterBatchSize { public: - ClusterBatchSize() - : has_dynamic_batch_value_(false), static_batch_value_(absl::nullopt) {} + ClusterBatchSize(); - bool operator==(const ClusterBatchSize& b) { - return HasDynamicBatchValue() == b.HasDynamicBatchValue() && - static_batch_value_ == b.static_batch_value_; - } + bool operator==(const ClusterBatchSize& other); + bool operator!=(const ClusterBatchSize& other) { return !(*this == other); } - bool operator!=(const ClusterBatchSize& b) { return !(*this == b); } + // Sets the batch size assuming that the object doesn't have a batch size yet: + // A non-negative input representing a static batch size value. + // A negative input representing a dynamic batch size equivalent class. + ClusterBatchSize& SetBatchSize(int batch_size); + bool HasBatchSize() const; + int GetBatchSize() const; - int GetStaticBatchValue() const { - DCHECK(HasStaticBatchValue()); - return static_batch_value_.value(); - } + // Merge `other` into the current ClusterBatchSize if the two are not + // conflicting. Two ClusterBatchSizes are conflicting iff they both have a + // value and their values are different. + bool MergeIfCompatible(const ClusterBatchSize& other); - // Sets the batch size value assuming that the object doesn't have a batch - // size value yet: - // a non-negative input value representing a known batch size. - // a negative input value representing a dynamic batch size. - ClusterBatchSize SetBatchSizeValue(int value) { - if (value < 0) { - has_dynamic_batch_value_ = true; - return *this; - } - static_batch_value_ = value; - return *this; - } - - bool MergeIfCompatible(const ClusterBatchSize& b) { - bool is_compatible = MergeIfCompatible(b.static_batch_value_); - if (!is_compatible) return false; - - if (!HasDynamicBatchValue() && b.HasDynamicBatchValue()) { - has_dynamic_batch_value_ = true; - } - - return true; - } - - // Returns a string for the batch size value. If the object has a static - // batch size value, return a string for the value. If the object has a - // dynamic size value, return -1. Otherwise, returns -2 to represent that - // a batch size hasn't been set yet. - string ToString() const { - string s; - absl::StrAppendFormat(&s, "batch_size=(%d,%d,", HasDynamicBatchValue(), - HasStaticBatchValue()); - if (HasStaticBatchValue()) { - absl::StrAppendFormat(&s, "%d", GetStaticBatchValue()); - } - absl::StrAppend(&s, ")"); - return s; - } + // Returns a string for the batch size. + // If the object has a static batch size, return a string representing a + // non-negative integer. + // If the object has a dynamic batch size, return a string representing a + // negative integer as an equivalent class. + // If the object doesn't have a batch size yet, return a "?" symbol string. + std::string ToString() const; private: - bool HasStaticBatchValue() const { return static_batch_value_.has_value(); } - bool HasDynamicBatchValue() const { return has_dynamic_batch_value_; } + ClusterBatchSize& SetBatchSize(const absl::optional& batch_size); - private: - bool MergeIfCompatible(const absl::optional& b) { - bool is_compatible = !HasStaticBatchValue() || !b.has_value() || - GetStaticBatchValue() == b.value(); - if (!is_compatible) { - return false; - } - if (!HasStaticBatchValue() && b.has_value()) { - static_batch_value_ = b; - } - return true; - } - - private: - // To track whether the cluster has any node with dynamic batch size. - bool has_dynamic_batch_value_; - // To track whether the cluster has any node with static batch size, and the - // unique value for static batch size. - absl::optional static_batch_value_; + absl::optional batch_size_; }; inline std::ostream& operator<<(std::ostream& os, @@ -139,89 +96,89 @@ inline std::ostream& operator<<(std::ostream& os, return os << batch_size.ToString(); } -// Represents a disjoint set of copyable values with type T. We use this data -// structure to construct clusters for TRTEngineOp. As such, this data structure -// has a field to record the batch size for the current cluster and merges the -// corresponding batch sizes when merging two clusters. Most of the methods in -// this class are side-effecting as they also compress the path from the object -// to the parent of its containing set. -template -class UnionFind { +// Represents the accumulated properties of a cluster during segmentation, +// including information about batch size and device assignment. Clusters shall +// have compatible properties in order to be merged together. +class ClusterProperty { public: - UnionFind() : size_(1), parent_(nullptr) {} - UnionFind(const T& v, ClusterBatchSize batch_size, - const DeviceNameUtils::ParsedName& device_name) - : size_(1), - cluster_batch_size_(batch_size), - cluster_device_name_(device_name), - parent_(nullptr), - value_(v) {} - - // Returns the number of elements in the cluster and compresses the path from - // this object to the root of the cluster. - int Size() { return FindRoot()->size_; } + ClusterProperty() {} + ClusterProperty(const ClusterBatchSize& batch_size, + const DeviceNameUtils::ParsedName& device_name); // Returns the batch size of the cluster and compresses the path from this // object to the root object. - ClusterBatchSize BatchSize() { return FindRoot()->cluster_batch_size_; } + const ClusterBatchSize& BatchSize() const { return batch_size_; } // Returns the device name of the cluster and compresses the path from this // object to the root object. - const DeviceNameUtils::ParsedName& DeviceName() { - return FindRoot()->cluster_device_name_; - } + const DeviceNameUtils::ParsedName& DeviceName() const { return device_name_; } - // Merges this cluster with 'other'. This cluster's size_ is updated to - // the size of the merged cluster; the size_ of 'other' becomes inaccessible - // as only the size_ of the root object is accessible. - Status Merge(UnionFind* other); - - // Retrieves the value for the root of the cluster. - T& ParentValue() { return FindRoot()->value_; } - - // Returns the value for the object. - T& Value() { return value_; } + Status Merge(const ClusterProperty& other); private: - // Returns the root object for the cluster and compresses the path from this + ClusterBatchSize batch_size_; + DeviceNameUtils::ParsedName device_name_; +}; + +// Represents a disjoint set of copyable value with type T and accumulated +// property of the values with type P. Most of the methods in this class are +// side-effecting as they also compress the path from the object to the parent +// of its containing set. +template +class UnionFind { + public: + UnionFind() : size_(1), parent_(nullptr) {} + UnionFind(const T& v, const P& p) + : size_(1), parent_(nullptr), value_(v), property_(p) {} + UnionFind(const T& v, P&& p) + : size_(1), parent_(nullptr), value_(v), property_(p) {} + + // Returns the number of elements in the set and compresses the path from + // this object to the root of the set. + int Size() { return FindRoot()->size_; } + + // Returns the accumulated property of all the elements in the set and + // compresses the path from this object to the root of the set. + const P& Property() { return FindRoot()->property_; } + + // Merges this set with 'other'. This updates the size_ and property_ of the + // set. The size_ and property_ of 'other' becomes inaccessible as only the + // size_ and property_ of the root of the set is accessible. + Status Merge(UnionFind* other); + + // Retrieves the value for the root of the set. + const T& ParentValue() { return FindRoot()->value_; } + + // Returns the value for the object. + const T& Value() const { return value_; } + + private: + // Returns the root object for the set and compresses the path from this // object to the root object. UnionFind* FindRoot(); int size_; - ClusterBatchSize cluster_batch_size_; - DeviceNameUtils::ParsedName cluster_device_name_; UnionFind* parent_; T value_; + P property_; }; -template -Status UnionFind::Merge(UnionFind* other) { +template +Status UnionFind::Merge(UnionFind* other) { UnionFind* a = FindRoot(); UnionFind* b = other->FindRoot(); if (a == b) return Status::OK(); - ClusterBatchSize batch_size = a->cluster_batch_size_; - if (!batch_size.MergeIfCompatible(other->cluster_batch_size_)) { - return errors::Internal( - "trying to merge clusters with incompatible batch sizes."); - } - - absl::optional device_name = - MergeIfCompatible(a->cluster_device_name_, other->cluster_device_name_); - if (!device_name.has_value()) { - return errors::Internal( - "trying to merge clusters with incompatible device assignment."); - } - - a->cluster_batch_size_ = batch_size; - a->cluster_device_name_ = *device_name; + P merged_property(a->property_); + TF_RETURN_IF_ERROR(merged_property.Merge(b->property_)); b->parent_ = a; a->size_ += b->size_; + a->property_ = std::move(merged_property); return Status::OK(); } -template -UnionFind* UnionFind::FindRoot() { +template +UnionFind* UnionFind::FindRoot() { if (!parent_) return this; // Path compression: update intermediate nodes to point to the root of the // equivalence class. diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index e9bcbcc6d83..5641339e7ef 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -1,4 +1,5 @@ -load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test", "tf_cuda_cc_test", "tf_openmp_copts") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") +load("//tensorflow:tensorflow.bzl", "if_libtpu", "tf_cc_binary", "tf_cc_test", "tf_copts", "tf_cuda_cc_test", "tf_openmp_copts") load( "//tensorflow/core/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured", @@ -7,10 +8,11 @@ load( "//tensorflow/core/platform:build_config.bzl", "tf_additional_tensor_coding_deps", "tf_proto_library", - "tf_proto_library_cc", ) load("//tensorflow/compiler/xla:xla.bzl", "xla_py_proto_library") -load("//tensorflow:tensorflow.bzl", "tf_portable_proto_library") + +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "filegroup") load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured") load("//tensorflow/compiler/xla/service/cpu:build_defs.bzl", "runtime_copts") @@ -40,6 +42,7 @@ package_group( "//tensorflow/...", "//tensorflow_models/...", "//third_party/mlperf/submissions/training/v0_7/models/...", + "//third_party/py/keras/...", ], ) @@ -78,19 +81,6 @@ tf_proto_library( visibility = ["//visibility:public"], ) -# A proto library that is minimal in size and dependencies for platforms like Android. -tf_portable_proto_library( - name = "portable_tf2xla_proto", - config_string = "allow_all:true", - header_outs = ["//tensorflow/compiler/tf2xla/tf2xla.proto.h"], - portable_deps = ["//tensorflow/core:portable_proto_lib"], - proto_deps = [ - ":tf2xla_proto", - "//tensorflow/core:protos_all", - ], - visibility = ["//visibility:public"], -) - xla_py_proto_library( name = "tf2xla_py", has_services = False, @@ -99,7 +89,7 @@ xla_py_proto_library( deps = [":tf2xla_proto"], ) -tf_proto_library_cc( +tf_proto_library( name = "host_compute_metadata_proto", srcs = ["host_compute_metadata.proto"], cc_api_version = 2, @@ -258,6 +248,7 @@ cc_library( "@com_google_absl//absl/synchronization", "//third_party/eigen3", "//tensorflow/core/framework:numeric_types", + "//tensorflow/core/platform:bfloat16", ] + tf_additional_tensor_coding_deps(), alwayslink = 1, ) @@ -303,14 +294,18 @@ cc_library( "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/service:cpu_plugin", "//tensorflow/compiler/xla/service:platform_util", - "//tensorflow/compiler/xla/service/cpu:buffer_info_util", - "//tensorflow/compiler/xla/service/cpu:cpu_executable", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/stream_executor:platform", - ], + ] + if_libtpu( + if_false = [ + "//tensorflow/compiler/xla/service:cpu_plugin", + "//tensorflow/compiler/xla/service/cpu:buffer_info_util", + "//tensorflow/compiler/xla/service/cpu:cpu_executable", + ], + if_true = [], + ), ) cc_library( @@ -334,6 +329,7 @@ cc_library( "xla_op_kernel.h", "xla_op_registry.h", ], + copts = tf_copts(), visibility = [":friends"], deps = [ ":common", @@ -349,10 +345,13 @@ cc_library( ":xla_helpers", ":xla_op_registry", ":xla_resource", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:span", + "@com_google_absl//absl/types:variant", "//tensorflow/compiler/jit:common", "//tensorflow/compiler/jit:flags", "//tensorflow/compiler/jit:shape_inference", - "//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes", "//tensorflow/compiler/xla:protobuf_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -370,11 +369,13 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:ops", "//tensorflow/core:protos_all_cc", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/types:span", - "@com_google_absl//absl/types:variant", - ], + ] + if_libtpu( + if_false = [ + "//tensorflow/compiler/mlir:array_container_utils", + "//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes", + ], + if_true = [], + ), alwayslink = 1, ) @@ -448,8 +449,8 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:session_options", - "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core/common_runtime:core_cpu_internal", + "//tensorflow/core/platform:stream_executor_no_cuda", ], alwayslink = 1, ) @@ -741,10 +742,10 @@ tf_cc_test( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", - "//tensorflow/core:tensor_testutil", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", + "//tensorflow/core/framework:tensor_testutil", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", ], @@ -818,9 +819,9 @@ cc_library( ":frontend_attributes_util", ":functionalize_control_flow_util", ":tf2xla_util", - "//tensorflow/compiler/jit:union_find", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:union_find", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", @@ -846,9 +847,9 @@ cc_library( ":functionalize_control_flow_util", ":functionalize_while", ":tf2xla_util", - "//tensorflow/compiler/jit:union_find", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:union_find", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", @@ -876,14 +877,20 @@ cc_library( cc_library( name = "mlir_bridge_pass_registration", - srcs = [ - "mlir_bridge_pass_registration.cc", - ], - deps = [ - ":mlir_bridge_pass", - "//tensorflow/compiler/mlir:mlir_graph_optimization_pass_registration", - "//tensorflow/core:core_cpu", - ], + srcs = if_libtpu( + if_false = [ + "mlir_bridge_pass_registration.cc", + ], + if_true = [], + ), + deps = if_libtpu( + if_false = [ + ":mlir_bridge_pass", + "//tensorflow/compiler/mlir:mlir_graph_optimization_pass_registration", + "//tensorflow/core:core_cpu", + ], + if_true = [], + ), alwayslink = 1, ) @@ -934,9 +941,9 @@ cc_library( ":functionalize_cond", ":functionalize_control_flow_util", ":tf2xla_util", - "//tensorflow/compiler/jit:union_find", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:union_find", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", diff --git a/tensorflow/compiler/tf2xla/cc/BUILD b/tensorflow/compiler/tf2xla/cc/BUILD index 973aafe1ad8..45a099baabe 100644 --- a/tensorflow/compiler/tf2xla/cc/BUILD +++ b/tensorflow/compiler/tf2xla/cc/BUILD @@ -1,3 +1,4 @@ +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_cc") package( diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.cc b/tensorflow/compiler/tf2xla/functionalize_cond.cc index 54abccb4cfc..452b102fade 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond.cc +++ b/tensorflow/compiler/tf2xla/functionalize_cond.cc @@ -25,9 +25,10 @@ limitations under the License. #include "absl/strings/match.h" #include "absl/strings/str_join.h" #include "absl/types/optional.h" -#include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/tf2xla/frontend_attributes_util.h" +#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" +#include "tensorflow/compiler/xla/union_find.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/shape_refiner.h" #include "tensorflow/core/framework/graph_to_functiondef.h" diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 596fa8e8e38..2a3e35e0ffd 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -23,12 +23,12 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/types/optional.h" -#include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/tf2xla/functionalize_cond.h" #include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h" #include "tensorflow/compiler/tf2xla/functionalize_while.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/union_find.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/common_runtime/graph_optimizer.h" diff --git a/tensorflow/compiler/tf2xla/functionalize_while.cc b/tensorflow/compiler/tf2xla/functionalize_while.cc index dce5efe5557..79412c4abc8 100644 --- a/tensorflow/compiler/tf2xla/functionalize_while.cc +++ b/tensorflow/compiler/tf2xla/functionalize_while.cc @@ -24,11 +24,11 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/match.h" #include "absl/types/optional.h" -#include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/tf2xla/frontend_attributes_util.h" #include "tensorflow/compiler/tf2xla/functionalize_cond.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/union_find.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def_builder.h" diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 26051c98cb7..7e1878682f2 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -1,3 +1,4 @@ +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow:tensorflow.bzl", "tf_kernel_library") package( @@ -108,6 +109,7 @@ tf_kernel_library( "stack_ops.cc", "stateful_random_ops.cc", "stateless_random_ops.cc", + "stateless_random_ops_v2.cc", "strided_slice_op.cc", "tensor_array_ops.cc", "tensor_list_ops.cc", @@ -187,6 +189,7 @@ tf_kernel_library( "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/kernels:stateful_random_ops_header", + "//tensorflow/core/kernels:stateless_random_ops_v2_header", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc index 88d7525e5d5..39f4beed0f4 100644 --- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc @@ -290,6 +290,21 @@ xla::XlaOp IgammacImpl(xla::XlaOp x, xla::XlaOp y, XLA_MAKE_BINARY(Igammac, IgammacImpl(lhs, rhs, broadcast_helper)); +xla::XlaOp PolygammaImpl(xla::XlaOp n, xla::XlaOp x, + const BCast& broadcast_helper) { + std::tie(n, x) = XlaBinaryOp::Broadcast(n, x, broadcast_helper); + return xla::Polygamma(n, x); +} + +XLA_MAKE_BINARY(Polygamma, PolygammaImpl(lhs, rhs, broadcast_helper)); + +xla::XlaOp ZetaImpl(xla::XlaOp x, xla::XlaOp q, const BCast& broadcast_helper) { + std::tie(x, q) = XlaBinaryOp::Broadcast(x, q, broadcast_helper); + return xla::Zeta(x, q); +} + +XLA_MAKE_BINARY(Zeta, ZetaImpl(lhs, rhs, broadcast_helper)); + #undef XLA_MAKE_BINARY class ApproximateEqualOp : public XlaOpKernel { diff --git a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc index 7e8d3d7002a..b461aa43153 100644 --- a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc @@ -186,7 +186,7 @@ class StatelessCategoricalOp : public CategoricalOp { REGISTER_XLA_OP(Name("StatelessMultinomial") .CompileTimeConstantInput("num_samples") - .TypeConstraint("T", {DT_FLOAT, DT_BFLOAT16}) + .TypeConstraint("T", {DT_DOUBLE, DT_FLOAT, DT_BFLOAT16}) .TypeConstraint("Tseed", DT_INT32), StatelessCategoricalOp); diff --git a/tensorflow/compiler/tf2xla/kernels/data_format_ops.cc b/tensorflow/compiler/tf2xla/kernels/data_format_ops.cc index c1f60abc0d6..a62d15f7904 100644 --- a/tensorflow/compiler/tf2xla/kernels/data_format_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/data_format_ops.cc @@ -35,15 +35,19 @@ class DataFormatDimMapOp : public XlaOpKernel { OP_REQUIRES_OK(context, context->GetAttr("src_format", &src_format)); string dst_format; OP_REQUIRES_OK(context, context->GetAttr("dst_format", &dst_format)); - OP_REQUIRES(context, src_format.size() == 4, - errors::InvalidArgument(absl::StrCat( - "Source format must of length 4, received src_format = ", - src_format))); + OP_REQUIRES(context, src_format.size() == 4 or src_format.size() == 5, + errors::InvalidArgument( + absl::StrCat("Source format must of length 4 or 5, " + "received src_format = ", + src_format))); OP_REQUIRES( - context, dst_format.size() == 4, + context, dst_format.size() == 4 or dst_format.size() == 5, errors::InvalidArgument(absl::StrCat( - "Destination format must of length 4, received dst_format = ", + "Destination format must of length 4 or 5, received dst_format = ", dst_format))); + for (int i = 0; i < src_format.size(); ++i) { + dst_idx_.push_back(-1); + } for (int i = 0; i < src_format.size(); ++i) { for (int j = 0; j < dst_format.size(); ++j) { if (dst_format[j] == src_format[i]) { @@ -61,9 +65,10 @@ class DataFormatDimMapOp : public XlaOpKernel { auto builder = context->builder(); xla::XlaOp dst_indices = xla::ConstantR1(builder, absl::Span(dst_idx_)); - xla::XlaOp four = xla::ConstantR0(builder, 4); + const int dims = dst_idx_.size(); + xla::XlaOp rank = xla::ConstantR0(builder, dims); xla::XlaOp src_indices = - (xla::ConvertElementType(context->Input(0), xla::S32) + four) % four; + (xla::ConvertElementType(context->Input(0), xla::S32) + rank) % rank; xla::XlaOp output = xla::TorchIndexSelect(dst_indices, src_indices, /*dim=*/0); context->SetOutput( @@ -71,7 +76,7 @@ class DataFormatDimMapOp : public XlaOpKernel { } private: - std::array dst_idx_ = {{-1, -1, -1, -1}}; + std::vector dst_idx_; TF_DISALLOW_COPY_AND_ASSIGN(DataFormatDimMapOp); }; diff --git a/tensorflow/compiler/tf2xla/kernels/dequantize_op.cc b/tensorflow/compiler/tf2xla/kernels/dequantize_op.cc index 7ac38369eb4..ad94c1383f8 100644 --- a/tensorflow/compiler/tf2xla/kernels/dequantize_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/dequantize_op.cc @@ -63,36 +63,27 @@ class DequantizeOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { DataType input_type = ctx->input_type(0); - double minrange, maxrange; - - OP_REQUIRES_OK(ctx, ctx->ConstantInputAsFloatScalar(1, &minrange)); - OP_REQUIRES_OK(ctx, ctx->ConstantInputAsFloatScalar(2, &maxrange)); - - float min_range = static_cast(minrange); - float max_range = static_cast(maxrange); - float full_range, half_range; + xla::XlaOp input = ctx->Input(0); + xla::XlaOp output = xla::ConvertElementType(input, xla::F32); + xla::XlaOp min_range = xla::ConvertElementType(ctx->Input(1), xla::F32); + xla::XlaOp max_range = xla::ConvertElementType(ctx->Input(2), xla::F32); + xla::XlaOp full_range; + xla::XlaOp half_range; if (input_type == DT_QINT8) { - full_range = get_fullrange(); - half_range = (full_range + 1.0f) / 2.0f; + full_range = ScalarLike(output, get_fullrange()); + half_range = + (full_range + ScalarLike(output, 1.0f)) / ScalarLike(output, 2.0f); } else { OP_REQUIRES(ctx, input_type == DT_QUINT8, errors::InvalidArgument( "Only support DT_QINT8 or DT_QUINT8, got ", input_type)); - full_range = get_fullrange(); - half_range = 0.0f; + full_range = ScalarLike(output, get_fullrange()); + half_range = ScalarLike(output, 0.0f); } - float scale_factor = (max_range - min_range) / full_range; + xla::XlaOp scale = (max_range - min_range) / full_range; - xla::XlaOp input = ctx->Input(0); - xla::XlaOp output; - - output = xla::ConvertElementType(input, xla::F32); - - auto scale = ScalarLike(output, scale_factor); - auto halfrange = ScalarLike(output, half_range); - output = xla::Add(xla::Mul(xla::Add(output, halfrange), scale), - ScalarLike(output, min_range)); + output = xla::Add(xla::Mul(xla::Add(output, half_range), scale), min_range); if (dtype_ == DT_BFLOAT16) { output = xla::ConvertElementType(output, xla::BF16); diff --git a/tensorflow/compiler/tf2xla/kernels/qr_op.cc b/tensorflow/compiler/tf2xla/kernels/qr_op.cc index 66ec40a946b..7aebb76071f 100644 --- a/tensorflow/compiler/tf2xla/kernels/qr_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/qr_op.cc @@ -41,7 +41,7 @@ class QROp : public XlaOpKernel { bool full_matrices_; }; -REGISTER_XLA_OP(Name("Qr").TypeConstraint("T", kFloatTypes), QROp); +REGISTER_XLA_OP(Name("Qr"), QROp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc index ce4a46b45c8..1b470bf58df 100644 --- a/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc @@ -182,6 +182,32 @@ class TensorScatterAddOp : public XlaOpKernel { } }; +class TensorScatterMaxOp : public XlaOpKernel { + public: + explicit TensorScatterMaxOp(OpKernelConstruction* context) + : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + CompileTensorScatter(context, + [](xla::XlaOp x, xla::XlaOp y, xla::XlaBuilder*) { + return xla::Max(x, y); + }); + } +}; + +class TensorScatterMinOp : public XlaOpKernel { + public: + explicit TensorScatterMinOp(OpKernelConstruction* context) + : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + CompileTensorScatter(context, + [](xla::XlaOp x, xla::XlaOp y, xla::XlaBuilder*) { + return xla::Min(x, y); + }); + } +}; + class TensorScatterSubOp : public XlaOpKernel { public: explicit TensorScatterSubOp(OpKernelConstruction* context) @@ -207,6 +233,8 @@ class TensorScatterUpdateOp : public XlaOpKernel { }; REGISTER_XLA_OP(Name("TensorScatterAdd"), TensorScatterAddOp); +REGISTER_XLA_OP(Name("TensorScatterMax"), TensorScatterMaxOp); +REGISTER_XLA_OP(Name("TensorScatterMin"), TensorScatterMinOp); REGISTER_XLA_OP(Name("TensorScatterSub"), TensorScatterSubOp); REGISTER_XLA_OP(Name("TensorScatterUpdate"), TensorScatterUpdateOp); diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc index 85917af6a65..75faa2eac81 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc @@ -15,6 +15,7 @@ limitations under the License. // XLA-specific Shape Ops. +#include "absl/strings/str_format.h" #include "tensorflow/compiler/tf2xla/kernels/shape_util.h" #include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h" #include "tensorflow/compiler/tf2xla/shape_util.h" @@ -24,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" @@ -65,6 +67,47 @@ class ShapeOp : public XlaOpKernel { REGISTER_XLA_OP(Name("Shape").CompilationOnly().IsMetadataOp(), ShapeOp); +class XlaSetBoundOp : public XlaOpKernel { + public: + explicit XlaSetBoundOp(OpKernelConstruction* context) + : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape input_shape = ctx->InputShape("input"); + const TensorShape bound_shape = ctx->InputShape("bound"); + + OP_REQUIRES( + ctx, + ctx->InputType("bound") == DT_INT32 && + ctx->InputType("input") == DT_INT32, + errors::InvalidArgument( + "XlaSetBound can only set bound for int32 scalar value: got", + input_shape.DebugString())); + + OP_REQUIRES( + ctx, input_shape.dims() == 0, + errors::InvalidArgument("XlaSetBound should only be used to set a " + "bound to the an int32 scalar value: got", + input_shape.DebugString())); + + OP_REQUIRES( + ctx, bound_shape.dims() == 0, + errors::InvalidArgument("XlaSetBound should only be used to set a " + "bound to the an int32 scalar value: got", + bound_shape.DebugString())); + int64 bound; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar("bound", &bound)); + + xla::XlaOp result = xla::CustomCall( + ctx->builder(), "SetBound", {ctx->Input("input")}, + ctx->InputXlaShape("input").ValueOrDie(), absl::StrFormat("%d", bound)); + ctx->SetOutput(0, result); + } +}; + +REGISTER_XLA_OP(Name("XlaSetBound").CompileTimeConstantInput("bound"), + XlaSetBoundOp); + class ShapeNOp : public XlaOpKernel { public: explicit ShapeNOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { diff --git a/tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc index 46d4b70606e..a46cceddced 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/rng_alg.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/lib/math/math_util.h" @@ -180,7 +181,7 @@ Status CompileImpl( } xla::Literal alg_literal; TF_RETURN_IF_ERROR(ctx->ConstantInput(alg_input_idx, &alg_literal)); - auto alg = alg_literal.Get({}); + Algorithm alg = Algorithm(alg_literal.Get({})); if (!(alg == RNG_ALG_THREEFRY || alg == RNG_ALG_PHILOX)) { return errors::InvalidArgument("Unsupported algorithm id: ", alg); } @@ -407,5 +408,80 @@ REGISTER_XLA_OP(Name("StatefulUniformFullInt") {DT_INT32, DT_UINT32, DT_INT64, DT_UINT64}), StatefulUniformFullIntOp); +xla::XlaOp IncreaseCounter(Algorithm const& alg, xla::XlaOp counter, + xla::XlaOp delta) { + // Multiplying 256 to be consistent with the CPU/GPU kernels + delta = delta * ConstantR0WithType(delta.builder(), xla::U64, 256); + if (alg == RNG_ALG_PHILOX) { + return xla::PhiloxIncreaseCounter(counter, delta); + } else { + return counter + delta; + } +} + +xla::XlaOp PadRight(xla::XlaOp a, int n) { + return xla::Pad(a, xla::ScalarLike(a, 0), + xla::MakeEdgePaddingConfig({{0, n}})); +} + +template +class RngSkipOp : public XlaOpKernel { + public: + explicit RngSkipOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + const int state_input_idx = 0; + const int alg_input_idx = 1; + const int delta_input_idx = 2; + xla::XlaOp var; + TensorShape var_shape; + OP_REQUIRES_OK(ctx, + ctx->ReadVariableInput(state_input_idx, STATE_ELEMENT_DTYPE, + &var_shape, &var)); + xla::Literal alg_literal; + OP_REQUIRES_OK(ctx, ctx->ConstantInput(alg_input_idx, &alg_literal)); + Algorithm alg = Algorithm(alg_literal.Get({})); + OP_REQUIRES(ctx, alg == RNG_ALG_THREEFRY || alg == RNG_ALG_PHILOX, + errors::InvalidArgument("Unsupported algorithm id: ", alg)); + OP_REQUIRES_OK(ctx, CheckStateShape(alg, var_shape)); + if (read_old_value) { + auto counter_size = GetCounterSize(alg); + xla::XlaOp output = var; + if (RNG_MAX_COUNTER_SIZE > counter_size) { + // Because the size of `var` depends on the algorithm while we want the + // output to have a fixed size (to help shape inference), we fix the + // output size to be the maximal state size among algorithms, and right- + // pad it with zeros if var's size is smaller than that. + output = PadRight(output, RNG_MAX_COUNTER_SIZE - counter_size); + } + ctx->SetOutput(0, output); + } + xla::XlaOp counter; + xla::XlaOp key; + std::tie(counter, key) = StateAndKeyFromVariable(alg, var); + xla::XlaOp delta = ctx->Input(delta_input_idx); + delta = BitcastConvertType(delta, xla::U64); + auto new_counter = IncreaseCounter(alg, counter, delta); + var = StateAndKeyToVariable(alg, new_counter, key); + xla::PrimitiveType state_element_type; + OP_REQUIRES_OK( + ctx, DataTypeToPrimitiveType(STATE_ELEMENT_DTYPE, &state_element_type)); + var = BitcastConvertType(var, state_element_type); + OP_REQUIRES_OK( + ctx, ctx->AssignVariable(state_input_idx, STATE_ELEMENT_DTYPE, var)); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(RngSkipOp); +}; + +REGISTER_XLA_OP(Name("RngSkip").CompileTimeConstantInput("algorithm"), + RngSkipOp<>); + +using RngReadAndSkipOp = RngSkipOp; + +REGISTER_XLA_OP(Name("RngReadAndSkip").CompileTimeConstantInput("alg"), + RngReadAndSkipOp); + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc index 13c3dbe489e..e606812bc4e 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc @@ -111,6 +111,8 @@ xla::XlaOp StatelessRngUniform(absl::string_view device_type_string, } } +namespace { + xla::XlaOp StatelessRngUniformFullInt(absl::string_view device_type_string, xla::XlaOp seeds, const xla::Shape& shape) { @@ -140,8 +142,6 @@ xla::XlaOp StatelessRngUniformFullInt(absl::string_view device_type_string, } } -namespace { - class StatelessRandomUniformOp : public XlaOpKernel { public: explicit StatelessRandomUniformOp(OpKernelConstruction* ctx) diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops_v2.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops_v2.cc new file mode 100644 index 00000000000..e46fec3c576 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops_v2.cc @@ -0,0 +1,485 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/stateless_random_ops_v2.h" + +#include + +#include "tensorflow/compiler/tf2xla/kernels/random_ops_util.h" +#include "tensorflow/compiler/tf2xla/lib/random.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/lib/prng.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/rng_alg.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/lib/math/math_util.h" + +namespace tensorflow { + +namespace { + +inline xla::RandomAlgorithm AlgorithmToRandomAlgorithm(Algorithm const& alg) { + if (alg == RNG_ALG_PHILOX) { + return xla::RandomAlgorithm::RNG_PHILOX; + } + return xla::RandomAlgorithm::RNG_THREE_FRY; +} + +inline Algorithm RandomAlgorithmToAlgorithm(xla::RandomAlgorithm const& alg) { + if (alg == xla::RandomAlgorithm::RNG_PHILOX) { + return RNG_ALG_PHILOX; + } + return RNG_ALG_THREEFRY; +} + +xla::XlaOp GetCounter(xla::RandomAlgorithm const& alg, xla::XlaOp state) { + Algorithm alg_ = RandomAlgorithmToAlgorithm(alg); + return xla::Slice(state, {RNG_KEY_SIZE}, + {RNG_KEY_SIZE + GetCounterSize(alg_)}, {1}); +} + +xla::RngOutput BitGenerator(xla::RandomAlgorithm const& alg, xla::XlaOp key, + xla::XlaOp counter, const xla::Shape& shape) { + key = BitcastConvertType(key, xla::U64); + counter = BitcastConvertType(counter, xla::U64); + xla::XlaOp state = xla::ConcatInDim(key.builder(), {key, counter}, 0); + xla::XlaOp result = xla::RngBitGenerator(alg, state, shape); + auto new_counter = GetCounter(alg, xla::GetTupleElement(result, 0)); + new_counter = BitcastConvertType(new_counter, xla::S64); + return xla::RngOutput{/*value=*/xla::GetTupleElement(result, 1), + /*state=*/new_counter}; +} + +std::tuple GetKeyCounterAlg( + absl::string_view device_type_string, xla::XlaOp key) { + // The Philox algorithm may cause performance regression on other devices. + // Turn on the Philox algorithm for the CPU and GPU backends only. + if (device_type_string == DEVICE_GPU_XLA_JIT || + device_type_string == DEVICE_CPU_XLA_JIT) { + auto counter_key = xla::ScramblePhiloxKey(key); + return std::make_tuple(counter_key.second, counter_key.first, + RNG_ALG_PHILOX); + } else { + auto counter_shape = + xla::ShapeUtil::MakeShape(xla::U64, {RNG_MAX_COUNTER_SIZE}); + auto counter = xla::Zeros(key.builder(), counter_shape); + return std::make_tuple(key, counter, RNG_ALG_THREEFRY); + } +} + +} // namespace + +xla::RngOutput StatelessRngUniformV2(xla::RandomAlgorithm const& alg, + xla::XlaOp key, xla::XlaOp counter, + const xla::Shape& shape, xla::XlaOp minval, + xla::XlaOp maxval) { + xla::XlaBuilder* builder = key.builder(); + xla::PrimitiveType type = shape.element_type(); + using std::placeholders::_1; + using std::placeholders::_2; + using std::placeholders::_3; + auto generator = std::bind(BitGenerator, alg, _1, _2, _3); + switch (type) { + case xla::F32: + case xla::F64: + return xla::UniformFloatingPointDistribution(key, counter, generator, + minval, maxval, shape); + case xla::S32: + case xla::S64: + case xla::U32: + case xla::U64: + return UniformIntDistribution(key, counter, generator, minval, maxval, + shape); + break; + default: + return {builder->ReportError(xla::Unimplemented( + "Types other than F32, S32, S64, U32 and U64 are not " + "implemented by " + "StatelessRngUniformV2; got %s", + xla::primitive_util::LowercasePrimitiveTypeName(type))), + counter}; + } +} + +namespace { + +xla::RngOutput StatelessRngUniformFullInt(xla::RandomAlgorithm const& alg, + xla::XlaOp key, xla::XlaOp counter, + const xla::Shape& shape) { + xla::XlaBuilder* builder = key.builder(); + + xla::PrimitiveType type = shape.element_type(); + xla::RngOutput output = BitGenerator(alg, key, counter, shape); + switch (type) { + case xla::U32: + case xla::U64: + return output; + case xla::S32: + case xla::S64: + return xla::RngOutput{BitcastConvertType(output.value, type), + output.state}; + default: + return { + builder->ReportError(xla::Unimplemented( + "Types other than U32, S32, U64 and S64 are not implemented by " + "StatelessRngUniformFullInt; got: %s", + xla::primitive_util::LowercasePrimitiveTypeName(type))), + output.state}; + } +} + +Status GetAlgorithm(XlaOpKernelContext* ctx, int alg_input_idx, + xla::RandomAlgorithm* alg) { + auto alg_shape = ctx->InputShape(alg_input_idx); + if (alg_shape.dims() != 0) { + return errors::InvalidArgument("algorithm must be of shape [], not ", + alg_shape.DebugString()); + } + xla::Literal alg_literal; + TF_RETURN_IF_ERROR(ctx->ConstantInput(alg_input_idx, &alg_literal)); + auto alg_ = Algorithm(alg_literal.Get({})); + *alg = AlgorithmToRandomAlgorithm(alg_); + return Status::OK(); +} + +xla::XlaOp MaybeSliceCounter(xla::RandomAlgorithm const& alg, + TensorShape const& counter_shape, + xla::XlaOp counter) { + auto input_counter_size = counter_shape.dim_size(0); + auto real_counter_size = GetCounterSize(RandomAlgorithmToAlgorithm(alg)); + if (input_counter_size > real_counter_size) { + counter = xla::Slice(counter, {0}, {real_counter_size}, {1}); + } + return counter; +} + +class StatelessRandomUniformOp : public XlaOpKernel { + public: + explicit StatelessRandomUniformOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaBuilder* builder = ctx->builder(); + + TensorShape shape; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape)); + + const int key_input_idx = 1; + const int counter_input_idx = 2; + const int alg_input_idx = 3; + xla::XlaOp key = ctx->Input(key_input_idx); + xla::XlaOp counter = ctx->Input(counter_input_idx); + + xla::RandomAlgorithm alg; + OP_REQUIRES_OK(ctx, GetAlgorithm(ctx, alg_input_idx, &alg)); + + auto counter_shape = ctx->InputShape(counter_input_idx); + OP_REQUIRES_OK(ctx, CheckKeyCounterShape(RandomAlgorithmToAlgorithm(alg), + ctx->InputShape(key_input_idx), + counter_shape)); + + DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT; + xla::Shape xla_shape; + OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(rng_dtype, shape, &xla_shape)); + xla::PrimitiveType rng_primitive_type = xla_shape.element_type(); + + counter = MaybeSliceCounter(alg, counter_shape, counter); + + auto result = StatelessRngUniformV2( + alg, key, counter, xla_shape, + xla::ConstantR0WithType(builder, rng_primitive_type, 0.0), + xla::ConstantR0WithType(builder, rng_primitive_type, 1.0)); + auto uniform = MaybeConvertF32ToBF16(result.value, dtype_); + ctx->SetOutput(0, uniform); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomUniformOp); +}; + +REGISTER_XLA_OP(Name("StatelessRandomUniformV2") + .CompileTimeConstantInput("shape") + .CompileTimeConstantInput("alg") + .TypeConstraint("dtype", + {DT_DOUBLE, DT_FLOAT, DT_BFLOAT16}), + StatelessRandomUniformOp); + +class StatelessRandomUniformIntOp : public XlaOpKernel { + public: + explicit StatelessRandomUniformIntOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + TensorShape shape; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape)); + + const int key_input_idx = 1; + const int counter_input_idx = 2; + const int alg_input_idx = 3; + xla::XlaOp key = ctx->Input(key_input_idx); + xla::XlaOp counter = ctx->Input(counter_input_idx); + + xla::RandomAlgorithm alg; + OP_REQUIRES_OK(ctx, GetAlgorithm(ctx, alg_input_idx, &alg)); + + auto counter_shape = ctx->InputShape(counter_input_idx); + OP_REQUIRES_OK(ctx, CheckKeyCounterShape(RandomAlgorithmToAlgorithm(alg), + ctx->InputShape(key_input_idx), + counter_shape)); + + const int minval_input_idx = 4; + const int maxval_input_idx = 5; + TensorShape minval_shape = ctx->InputShape(minval_input_idx); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(minval_shape), + errors::InvalidArgument("minval must be scalar, got shape ", + minval_shape.DebugString())); + TensorShape maxval_shape = ctx->InputShape(maxval_input_idx); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(maxval_shape), + errors::InvalidArgument("maxval must be scalar, got shape ", + maxval_shape.DebugString())); + + xla::XlaOp minval = ctx->Input(minval_input_idx); + xla::XlaOp maxval = ctx->Input(maxval_input_idx); + + xla::Shape xla_shape; + OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape, &xla_shape)); + + counter = MaybeSliceCounter(alg, counter_shape, counter); + auto result = + StatelessRngUniformV2(alg, key, counter, xla_shape, minval, maxval); + ctx->SetOutput(0, result.value); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomUniformIntOp); +}; + +REGISTER_XLA_OP(Name("StatelessRandomUniformIntV2") + .CompileTimeConstantInput("shape") + .CompileTimeConstantInput("alg") + .TypeConstraint("dtype", + {DT_INT32, DT_INT64, DT_UINT32, DT_UINT64}), + StatelessRandomUniformIntOp); + +class StatelessRandomUniformFullIntOp : public XlaOpKernel { + public: + explicit StatelessRandomUniformFullIntOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + TensorShape shape; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape)); + + const int key_input_idx = 1; + const int counter_input_idx = 2; + const int alg_input_idx = 3; + xla::XlaOp key = ctx->Input(key_input_idx); + xla::XlaOp counter = ctx->Input(counter_input_idx); + + xla::RandomAlgorithm alg; + OP_REQUIRES_OK(ctx, GetAlgorithm(ctx, alg_input_idx, &alg)); + + auto counter_shape = ctx->InputShape(counter_input_idx); + OP_REQUIRES_OK(ctx, CheckKeyCounterShape(RandomAlgorithmToAlgorithm(alg), + ctx->InputShape(key_input_idx), + counter_shape)); + + xla::Shape xla_shape; + OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape, &xla_shape)); + + counter = MaybeSliceCounter(alg, counter_shape, counter); + auto result = StatelessRngUniformFullInt(alg, key, counter, xla_shape); + ctx->SetOutput(0, result.value); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomUniformFullIntOp); +}; + +REGISTER_XLA_OP(Name("StatelessRandomUniformFullIntV2") + .CompileTimeConstantInput("shape") + .CompileTimeConstantInput("alg") + .TypeConstraint("dtype", + {DT_INT32, DT_INT64, DT_UINT32, DT_UINT64}), + StatelessRandomUniformFullIntOp); + +class StatelessRandomNormalOp : public XlaOpKernel { + public: + explicit StatelessRandomNormalOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + TensorShape shape; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape)); + + const int key_input_idx = 1; + const int counter_input_idx = 2; + const int alg_input_idx = 3; + xla::XlaOp key = ctx->Input(key_input_idx); + xla::XlaOp counter = ctx->Input(counter_input_idx); + + xla::RandomAlgorithm alg; + OP_REQUIRES_OK(ctx, GetAlgorithm(ctx, alg_input_idx, &alg)); + + auto counter_shape = ctx->InputShape(counter_input_idx); + OP_REQUIRES_OK(ctx, CheckKeyCounterShape(RandomAlgorithmToAlgorithm(alg), + ctx->InputShape(key_input_idx), + counter_shape)); + + DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT; + + xla::Shape xla_shape; + OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(rng_dtype, shape, &xla_shape)); + + using std::placeholders::_1; + using std::placeholders::_2; + using std::placeholders::_3; + auto generator = std::bind(BitGenerator, alg, _1, _2, _3); + counter = MaybeSliceCounter(alg, counter_shape, counter); + auto result = xla::NormalFloatingPointDistribution(key, counter, generator, + xla_shape); + auto normal = MaybeConvertF32ToBF16(result.value, dtype_); + ctx->SetOutput(0, normal); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomNormalOp); +}; + +REGISTER_XLA_OP(Name("StatelessRandomNormalV2") + .CompileTimeConstantInput("shape") + .CompileTimeConstantInput("alg") + .TypeConstraint("dtype", + {DT_DOUBLE, DT_FLOAT, DT_BFLOAT16}), + StatelessRandomNormalOp); + +class StatelessTruncatedNormalOp : public XlaOpKernel { + public: + explicit StatelessTruncatedNormalOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + TensorShape shape; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape)); + + const int key_input_idx = 1; + const int counter_input_idx = 2; + const int alg_input_idx = 3; + xla::XlaOp key = ctx->Input(key_input_idx); + xla::XlaOp counter = ctx->Input(counter_input_idx); + + xla::RandomAlgorithm alg; + OP_REQUIRES_OK(ctx, GetAlgorithm(ctx, alg_input_idx, &alg)); + + auto counter_shape = ctx->InputShape(counter_input_idx); + OP_REQUIRES_OK(ctx, CheckKeyCounterShape(RandomAlgorithmToAlgorithm(alg), + ctx->InputShape(key_input_idx), + counter_shape)); + + xla::XlaBuilder* builder = ctx->builder(); + + DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT; + xla::Shape xla_shape; + OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(rng_dtype, shape, &xla_shape)); + + counter = MaybeSliceCounter(alg, counter_shape, counter); + auto result = StatelessRngUniformV2( + alg, key, counter, xla_shape, + xla::MinPositiveNormalValue(builder, xla_shape.element_type()), + xla::One(builder, xla_shape.element_type())); + xla::XlaOp truncated_normal = TruncatedNormal(result.value); + truncated_normal = MaybeConvertF32ToBF16(truncated_normal, dtype_); + ctx->SetOutput(0, truncated_normal); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(StatelessTruncatedNormalOp); +}; + +REGISTER_XLA_OP(Name("StatelessTruncatedNormalV2") + .CompileTimeConstantInput("shape") + .CompileTimeConstantInput("alg") + .TypeConstraint("dtype", + {DT_DOUBLE, DT_FLOAT, DT_BFLOAT16}), + StatelessTruncatedNormalOp); + +class GetKeyCounterAlgOp : public XlaOpKernel { + public: + explicit GetKeyCounterAlgOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx), + device_type_string_(ctx->device_type().type_string()) {} + + void Compile(XlaOpKernelContext* ctx) override { + TensorShape seed_shape = ctx->InputShape(0); + OP_REQUIRES(ctx, seed_shape == TensorShape({2}), + errors::InvalidArgument("seed must have shape [2], not ", + seed_shape.DebugString())); + xla::XlaOp seed = ctx->Input(0); + + xla::XlaBuilder* builder = seed.builder(); + xla::XlaOp seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {}); + xla::XlaOp seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {}); + xla::XlaOp key = ConvertElementType(seed0, xla::U64) | + ShiftLeft(ConvertElementType(seed1, xla::U64), + ConstantR0WithType(builder, xla::U64, 32)); + auto key_counter_alg = GetKeyCounterAlg(device_type_string_, key); + key = std::get<0>(key_counter_alg); + auto counter = std::get<1>(key_counter_alg); + auto alg = std::get<2>(key_counter_alg); + key = xla::Reshape(key, {RNG_KEY_SIZE}); + ctx->SetOutput(0, key); + ctx->SetOutput(1, counter); + ctx->SetOutput(2, ConstantR0(builder, static_cast(alg))); + } + + private: + string device_type_string_; + + TF_DISALLOW_COPY_AND_ASSIGN(GetKeyCounterAlgOp); +}; + +REGISTER_XLA_OP(Name("StatelessRandomGetKeyCounterAlg"), GetKeyCounterAlgOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc index 268317d84fc..943d92982cb 100644 --- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -26,11 +26,13 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/ops_util.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/mem.h" @@ -290,6 +292,83 @@ class StridedSliceGradOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, ctx->GetAttr("Index", &index_type_)); } + // When the begin / end is unknown, compile the gradient into dynamic update + // slice into a broadcasted 0s. + // + // Broadcasted 0 + // +----------------------+ + // | +----+ | + // |<-begin->|grad|<-end->| <== Dynamic update grad into 0s. + // | +----+ | + // +----------------------+ + void CompileAsDynamicUpdateSlice(XlaOpKernelContext* ctx, + const TensorShape& input_shape, + const xla::Literal& strides_literal) { + bool dummy = false; + Tensor strides_tensor; + PartialTensorShape processing_shape, final_shape; + absl::InlinedVector begin; + absl::InlinedVector end; + absl::InlinedVector strides; + + absl::InlinedVector output_to_sparse_mapping; + absl::InlinedVector output_to_processing_mapping; + + OP_REQUIRES_OK(ctx, LiteralToHostTensor(strides_literal, index_type_, + &strides_tensor)); + OP_REQUIRES_OK( + ctx, ValidateStridedSliceOp( + nullptr, nullptr, strides_tensor, input_shape, begin_mask_, + end_mask_, ellipsis_mask_, new_axis_mask_, shrink_axis_mask_, + &processing_shape, &final_shape, &dummy, &dummy, &dummy, + &begin, &end, &strides, &output_to_sparse_mapping, + &output_to_processing_mapping)); + for (int64 i = 0; i < processing_shape.dims(); ++i) { + OP_REQUIRES( + ctx, strides[i] == 1, + errors::InvalidArgument("Strides in strided slice grad have to be " + "one when inputs are not constant.")); + } + + auto zero = XlaHelpers::Zero(ctx->builder(), ctx->expected_output_dtype(0)); + zero = xla::Broadcast(zero, input_shape.dim_sizes()); + xla::XlaOp grad = ctx->Input(4); + xla::Shape grad_shape = ctx->InputXlaShape(4).ValueOrDie(); + // Undo any new/shrink axes. + VLOG(1) << "xla grad shape" << grad_shape; + VLOG(1) << "input_shape" << input_shape.DebugString(); + std::vector begins(processing_shape.dims(), + xla::Zero(ctx->builder(), xla::S32)); + for (int64 i = 0; i < grad_shape.rank(); ++i) { + // Use grad shape, which is known, to update unknown processing shape. + // Grad shape is the output of the ValidateStridedSliceOp function in + // forward pass, thus we use output_to_processing_mapping. + if (output_to_processing_mapping[i] != -1) { + processing_shape.set_dim(output_to_processing_mapping[i], + grad_shape.dimensions(i)); + } + + // Similarly, use output_to_sparse_mapping to find out corresponding + // begin dim of the output, as indices for dynamic update slice. + int64 begin_dim = output_to_sparse_mapping[i]; + if (begin_dim != -1) { + auto begin_index = + xla::Slice(ctx->Input(1), {begin_dim}, {begin_dim + 1}, {1}); + auto begin_index_scalar = xla::Reshape( + xla::ShapeUtil::MakeScalarShape(xla::S32), begin_index); + begins[output_to_sparse_mapping[i]] = begin_index_scalar; + } + } + VLOG(1) << "processing_shape" << processing_shape.DebugString(); + TensorShape full_processing_shape; + OP_REQUIRES(ctx, processing_shape.AsTensorShape(&full_processing_shape), + errors::InvalidArgument( + "Processing shape ", processing_shape.DebugString(), + " can't be fully inferred from grad shape")); + grad = xla::Reshape(grad, full_processing_shape.dim_sizes()); + grad = xla::DynamicUpdateSlice(zero, grad, begins); + ctx->SetOutput(0, grad); + } void Compile(XlaOpKernelContext* ctx) override { TensorShape processing_shape, final_shape; absl::InlinedVector begin; @@ -298,12 +377,15 @@ class StridedSliceGradOp : public XlaOpKernel { TensorShape input_shape; OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &input_shape)); - xla::Literal begin_literal, end_literal, strides_literal; - OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &begin_literal)); - OP_REQUIRES_OK(ctx, ctx->ConstantInput(2, &end_literal)); - OP_REQUIRES_OK(ctx, ctx->ConstantInput(3, &strides_literal)); + bool begin_is_constant = ctx->ConstantInput(1, &begin_literal).ok(); + bool end_is_constant = ctx->ConstantInput(2, &end_literal).ok(); + OP_REQUIRES_OK(ctx, ctx->ConstantInput(3, &strides_literal)); + if (!(begin_is_constant && end_is_constant)) { + CompileAsDynamicUpdateSlice(ctx, input_shape, strides_literal); + return; + } Tensor begin_tensor, end_tensor, strides_tensor; OP_REQUIRES_OK( ctx, LiteralToHostTensor(begin_literal, index_type_, &begin_tensor)); @@ -446,7 +528,12 @@ class StridedSliceAssignOp : public XlaOpKernel { TensorShape lhs_shape; xla::XlaOp lhs; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &lhs_shape, &lhs)); + if (ctx->input_type(0) == DT_RESOURCE) { + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &lhs_shape, &lhs)); + } else { + lhs_shape = ctx->InputShape(0); + lhs = ctx->Input(0); + } const TensorShape rhs_shape = ctx->InputShape(4); @@ -504,7 +591,11 @@ class StridedSliceAssignOp : public XlaOpKernel { lhs = xla::DynamicUpdateSlice(lhs, rhs, slice_begin); - OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, lhs)); + if (ctx->input_type(0) == DT_RESOURCE) { + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, lhs)); + } else { + ctx->SetOutput(0, lhs); + } } private: @@ -520,5 +611,11 @@ REGISTER_XLA_OP(Name("ResourceStridedSliceAssign") .CompileTimeConstantInput("strides"), StridedSliceAssignOp); +REGISTER_XLA_OP(Name("TensorStridedSliceUpdate") + .CompileTimeConstantInput("begin") + .CompileTimeConstantInput("end") + .CompileTimeConstantInput("strides"), + StridedSliceAssignOp); + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/xla_reduce_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_reduce_op.cc index 8b481d55a80..555905ebe6b 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_reduce_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_reduce_op.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/platform/errors.h" namespace tensorflow { namespace { @@ -38,16 +39,28 @@ class XlaReduceOp : public XlaOpKernel { context, dims_set.size() == dimensions_to_reduce_.size(), errors::InvalidArgument("Duplicate dimension in dimensions_to_reduce " "argument to XlaReduce")); + if (context->HasAttr("N")) { // variadic reduce + use_tuples_ = true; + OP_REQUIRES_OK(context, context->GetAttr("N", &n_)); + } else { + use_tuples_ = false; + n_ = 1; + } } void Compile(XlaOpKernelContext* context) override { - const TensorShape input_shape = context->InputShape("input"); - const TensorShape init_value_shape = context->InputShape("init_value"); + OP_REQUIRES(context, n_ * 2 == context->num_inputs(), + errors::InvalidArgument("Expected ", n_ * 2, " inputs but got ", + context->num_inputs())); + + const TensorShape input_shape = context->InputShape(0); + const TensorShape init_value_shape = context->InputShape(n_); const DataType dtype = context->input_type(0); const int rank = input_shape.dims(); OP_REQUIRES(context, TensorShapeUtils::IsScalar(init_value_shape), - errors::InvalidArgument("init_value must be a scalar")); + errors::InvalidArgument("init_value must be a scalar but got ", + init_value_shape.DebugString())); auto dim_in_range = [rank](int64 dim) { return dim >= 0 && dim < rank; }; OP_REQUIRES(context, @@ -67,35 +80,58 @@ class XlaReduceOp : public XlaOpKernel { compile_options.always_return_tuple = false; compile_options.is_entry_computation = false; XlaCompiler::CompilationResult reducer; - OP_REQUIRES_OK(context, context->compiler()->CompileFunction( - compile_options, *reducer_, - {reducer_arg, reducer_arg}, &reducer)); + OP_REQUIRES_OK( + context, + context->compiler()->CompileFunction( + compile_options, *reducer_, + std::vector(n_ * 2, reducer_arg), &reducer)); - xla::Shape scalar_shape; - OP_REQUIRES_OK(context, - TensorShapeToXLAShape(dtype, TensorShape(), &scalar_shape)); + xla::Shape expected_shape; + OP_REQUIRES_OK( + context, TensorShapeToXLAShape(dtype, TensorShape(), &expected_shape)); + if (use_tuples_) { + expected_shape = xla::ShapeUtil::MakeTupleShape( + std::vector(n_, expected_shape)); + } OP_REQUIRES( context, - xla::ShapeUtil::Compatible(reducer.xla_output_shape, scalar_shape), + xla::ShapeUtil::Compatible(reducer.xla_output_shape, expected_shape), errors::InvalidArgument( "Invalid output shape of XlaReduce reducer. Expected ", - xla::ShapeUtil::HumanString(scalar_shape), " got ", + xla::ShapeUtil::HumanString(expected_shape), " got ", xla::ShapeUtil::HumanString(reducer.xla_output_shape))); + std::vector inputs; + std::vector inits; + inputs.reserve(n_); + inits.reserve(n_); + for (int i = 0; i < n_; i++) { + inputs.emplace_back(context->Input(i)); + inits.emplace_back(context->Input(n_ + i)); + } xla::XlaOp output = - xla::Reduce(context->Input("input"), context->Input("init_value"), - *reducer.computation, dimensions_to_reduce_); - context->SetOutput(0, output); + xla::Reduce(context->builder(), inputs, inits, *reducer.computation, + dimensions_to_reduce_); + if (use_tuples_) { + for (int i = 0; i < n_; i++) { + context->SetOutput(i, xla::GetTupleElement(output, i)); + } + } else { + context->SetOutput(0, output); + } } private: const NameAttrList* reducer_; std::vector dimensions_to_reduce_; + bool use_tuples_; + int n_; TF_DISALLOW_COPY_AND_ASSIGN(XlaReduceOp); }; REGISTER_XLA_OP(Name("XlaReduce"), XlaReduceOp); +REGISTER_XLA_OP(Name("XlaVariadicReduce"), XlaReduceOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD index 531679d3905..703f6c2eb72 100644 --- a/tensorflow/compiler/tf2xla/lib/BUILD +++ b/tensorflow/compiler/tf2xla/lib/BUILD @@ -1,5 +1,8 @@ # Utilities for building XLA computations. +load("//tensorflow:tensorflow.bzl", "filegroup") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") + package( default_visibility = ["//tensorflow/compiler/tf2xla:friends"], licenses = ["notice"], # Apache 2.0 diff --git a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc index eefef26dc24..b46429ef0d1 100644 --- a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc +++ b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc @@ -38,7 +38,7 @@ auto* mlir_bridge_gauge_v2 = monitoring::Gauge::New( // encapsulated graph to a particular device. Status MlirBridgePass::Run(const ConfigProto& config_proto, mlir::ModuleOp module) { - if (!config_proto.experimental().enable_mlir_bridge()) { + if (!IsEnabled(config_proto)) { VLOG(0) << "Skipping MLIR TPU Bridge, session flag not enabled"; mlir_bridge_gauge_v2->GetCell()->Set(false); return Status::OK(); diff --git a/tensorflow/compiler/tf2xla/mlir_bridge_pass.h b/tensorflow/compiler/tf2xla/mlir_bridge_pass.h index f7541e634d4..bbddeb6a967 100644 --- a/tensorflow/compiler/tf2xla/mlir_bridge_pass.h +++ b/tensorflow/compiler/tf2xla/mlir_bridge_pass.h @@ -30,7 +30,10 @@ class MlirBridgePass : public MlirOptimizationPass { llvm::StringRef name() const override { return "bridge"; } bool IsEnabled(const ConfigProto& config_proto) const override { - return config_proto.experimental().enable_mlir_bridge(); + return config_proto.experimental().enable_mlir_bridge() || + tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge == + tensorflow::ConfigProto::Experimental:: + MLIR_BRIDGE_ROLLOUT_ENABLED; } // This should be used as a thin mapper around mlir::ModulePass::runOnModule @@ -47,7 +50,9 @@ class MlirBridgeV1CompatPass : public MlirV1CompatOptimizationPass { bool IsEnabled(const ConfigProto& config_proto) const override { return config_proto.experimental().enable_mlir_bridge() || - tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge; + GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge == + tensorflow::ConfigProto::Experimental:: + MLIR_BRIDGE_ROLLOUT_ENABLED; } // This should be used as a thin mapper around mlir::ModulePass::runOnModule diff --git a/tensorflow/compiler/tf2xla/mlir_tf2xla.cc b/tensorflow/compiler/tf2xla/mlir_tf2xla.cc index db1a6929934..ac4d1f28803 100644 --- a/tensorflow/compiler/tf2xla/mlir_tf2xla.cc +++ b/tensorflow/compiler/tf2xla/mlir_tf2xla.cc @@ -90,18 +90,6 @@ Status ConvertOutputInfo(const tf2xla::Config& config, return ParseOutputArrayInfo(array_names, &specs->outputs); } -static void RegisterDialects() { - static bool init_once = []() { - mlir::registerDialect(); - mlir::registerDialect(); - mlir::registerDialect(); - mlir::registerDialect(); - mlir::registerDialect(); - return true; - }(); - (void)init_once; -} - } // namespace Status ConvertGraphDefToXlaViaMlir( @@ -150,9 +138,7 @@ Status ConvertGraphDefToXlaViaMlir( } } - RegisterDialects(); mlir::MLIRContext context; - context.loadAllGloballyRegisteredDialects(); TF_ASSIGN_OR_RETURN( mlir::OwningModuleRef module, ConvertGraphdefToMlir(pruned_graph_def, debug_info, specs, &context)); @@ -175,7 +161,7 @@ Status ConvertGraphDefToXlaViaMlir( return ConvertMLIRToXlaComputation(*module, /*device_type=*/"XLA_CPU_JIT", computation, /*use_tuple_args=*/false, - /*always_return_tuple=*/true); + /*return_tuple=*/true); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/ops/BUILD b/tensorflow/compiler/tf2xla/ops/BUILD index b116a09dd02..50ff1ea5d16 100644 --- a/tensorflow/compiler/tf2xla/ops/BUILD +++ b/tensorflow/compiler/tf2xla/ops/BUILD @@ -1,3 +1,4 @@ +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load( "//tensorflow:tensorflow.bzl", "tf_custom_op_library", diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc index f4b9e9654d2..471cc029a59 100644 --- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc @@ -291,6 +291,16 @@ dimension_numbers: a serialized xla::DotDimensionNumbers proto. precision_config: a serialized xla::PrecisionConfig proto. )doc"); +REGISTER_OP("XlaSetBound") + .Input("input: int32") + .Input("bound: int32") + .Output("output: int32") + .SetShapeFn(shape_inference::UnknownShape) + .Doc( + R"doc(Set a bound for the given input value as a hint to Xla compiler, + returns the same value. +)doc"); + REGISTER_OP("XlaDynamicSlice") .Input("input: T") .Input("start_indices: Tindices") @@ -465,6 +475,60 @@ reducer: a reducer function to apply dimensions_to_reduce: dimension numbers over which to reduce )doc"); +REGISTER_OP("XlaVariadicReduce") + .Input("input: N * T") + .Input("init_value: N * T") + .Attr("N: int >= 1") + .Attr("T: numbertype") + .Attr("dimensions_to_reduce: list(int)") + .Attr("reducer: func") + .Output("output: N * T") + .SetShapeFn([](shape_inference::InferenceContext* c) { + int n; + TF_RETURN_IF_ERROR(c->GetAttr("N", &n)); + for (int i = 0; i < n; i++) { + for (int j = 0; j < n; j++) { + c->MergeInput(i, c->input(j)); + } + } + if (c->RankKnown(c->input(0))) { + int rank = c->Rank(c->input(0)); + std::vector dimensions_to_reduce; + TF_RETURN_IF_ERROR( + c->GetAttr("dimensions_to_reduce", &dimensions_to_reduce)); + std::set dims_set(dimensions_to_reduce.begin(), + dimensions_to_reduce.end()); + auto dim_in_range = [rank](int64 dim) { + return dim >= 0 && dim < rank; + }; + const int dimensions_to_reduce_size = dimensions_to_reduce.size(); + if (rank < dimensions_to_reduce_size || + dims_set.size() != dimensions_to_reduce.size() || + !absl::c_all_of(dimensions_to_reduce, dim_in_range)) { + return errors::InvalidArgument( + "Invalid dimensions_to_reduce argument to XlaVariadicReduce"); + } + for (int i = 0; i < n; i++) { + c->set_output( + i, c->UnknownShapeOfRank(rank - dimensions_to_reduce.size())); + } + } else { + for (int i = 0; i < n; i++) { + c->set_output(i, c->input(i)); + } + } + return Status::OK(); + }) + .Doc(R"doc( +Wraps the variadic XLA Reduce operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#variadic_reduce. + +input: the input tensor(s) +init_value: scalar initial value(s) for the reduction +reducer: a reducer function to apply +dimensions_to_reduce: dimension numbers over which to reduce +)doc"); + REGISTER_OP("XlaReduceWindow") .Input("input: T") .Input("init_value: T") @@ -728,7 +792,7 @@ REGISTER_OP("XlaGather") .Input("slice_sizes: Tindices") .Attr("dimension_numbers: string") .Attr("indices_are_sorted: bool") - .Attr("T: numbertype") + .Attr("T: {numbertype, bool}") .Attr("Tindices: {int32, int64}") .Output("output: T") .SetShapeFn(shape_inference::UnknownShape) @@ -749,10 +813,10 @@ REGISTER_OP("XlaScatter") .Attr("update_computation: func") .Attr("dimension_numbers: string") .Attr("indices_are_sorted: bool") - .Attr("T: numbertype") + .Attr("T: {numbertype, bool}") .Attr("Tindices: {int32, int64}") .Output("output: T") - .SetShapeFn(UnchangedRank) + .SetShapeFn(shape_inference::UnchangedShape) .Doc(R"doc( Wraps the XLA Scatter operator documented at https://www.tensorflow.org/xla/operation_semantics#scatter. diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py index 846dafa2570..2e5667bc02f 100644 --- a/tensorflow/compiler/tf2xla/python/xla.py +++ b/tensorflow/compiler/tf2xla/python/xla.py @@ -206,6 +206,8 @@ igamma = _broadcasting_binary_op(math_ops.igamma) igamma_grad_a = _broadcasting_binary_op(gen_math_ops.igamma_grad_a) random_gamma_grad = _broadcasting_binary_op(gen_random_ops.random_gamma_grad) igammac = _broadcasting_binary_op(math_ops.igammac) +polygamma = _broadcasting_binary_op(math_ops.polygamma) +zeta = _broadcasting_binary_op(math_ops.zeta) def _binary_op(fn): @@ -338,6 +340,7 @@ def random_uniform(minval, maxval, dims, name=None): recv = gen_xla_ops.xla_recv reduce = gen_xla_ops.xla_reduce +variadic_reduce = gen_xla_ops.xla_variadic_reduce def reduce_window(operand, @@ -387,6 +390,14 @@ def reduce_window(operand, replica_id = gen_xla_ops.xla_replica_id +# Set a static bound for the given input value as a hint to Xla compiler, +# returns the same value. +# Usage: +# def f(t, p): +# p = xla.set_bound(p, 3) # Tells xla the constraint that p <= 3. +# return t[:p] # xla knows the bound of the slice is 3. +set_bound = gen_xla_ops.xla_set_bound + def reshape(x, new_sizes, dimensions=None, name=None): if dimensions is not None: diff --git a/tensorflow/compiler/tf2xla/resource_operation_table.cc b/tensorflow/compiler/tf2xla/resource_operation_table.cc index 2db431c0413..860c3a40424 100644 --- a/tensorflow/compiler/tf2xla/resource_operation_table.cc +++ b/tensorflow/compiler/tf2xla/resource_operation_table.cc @@ -83,6 +83,8 @@ CreateResourceOpInfoMap() { add("ResourceScatterSub" , kReadWrite, kVariable); add("ResourceScatterUpdate" , kReadWrite, kVariable); add("ResourceStridedSliceAssign" , kReadWrite, kVariable); + add("RngReadAndSkip" , kReadWrite, kVariable); + add("RngSkip" , kReadWrite, kVariable); add("StatefulStandardNormalV2" , kReadWrite, kVariable); add("StatefulTruncatedNormal" , kReadWrite, kVariable); add("StatefulUniform" , kReadWrite, kVariable); diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index f8319cd446a..5c8cfdde9e4 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/shape_inference.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h" #include "tensorflow/compiler/tf2xla/graph_compiler.h" #include "tensorflow/compiler/tf2xla/rearrange_function_argument.h" #include "tensorflow/compiler/tf2xla/shape_util.h" @@ -57,6 +56,11 @@ limitations under the License. #include "tensorflow/core/protobuf/graph_debug_info.pb.h" #include "tensorflow/core/util/dump_graph.h" +#ifndef LIBTPU_ON_GCE +#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h" +#include "tensorflow/compiler/mlir/utils/array_container_utils.h" +#endif + namespace tensorflow { namespace { @@ -623,8 +627,28 @@ std::unique_ptr XlaCompiler::GetGraph(const FunctionBody* fbody) { graph_optimizer_options.inline_with_single_device_body_placer = true; graph_optimizer_options.ignore_noinline = is_inside_mustcompile; - optimizer.Optimize(flib_runtime_, flib_runtime_->env(), - /*device=*/nullptr, &graph, graph_optimizer_options); + { + GraphShapeInfo shape_info; + InferShapes(graph.get(), /*arg_shapes=*/{}, + flib_runtime_->GetFunctionLibraryDefinition(), &shape_info) + .IgnoreError(); + auto node_name_index = graph->BuildNodeNameIndex(); + std::unordered_map> shape_map; + for (const auto& node_shape_info : shape_info) { + const string& node_name = node_shape_info.first; + const std::vector& output_shapes = node_shape_info.second; + const auto& node_iter = node_name_index.find(node_name); + if (node_iter != node_name_index.end()) { + auto& partial_shapes = shape_map[node_name]; + for (const auto& inferred_shape : output_shapes) { + partial_shapes.push_back(inferred_shape.shape); + } + } + } + graph_optimizer_options.shape_map = &shape_map; + optimizer.Optimize(flib_runtime_, flib_runtime_->env(), + /*device=*/nullptr, &graph, graph_optimizer_options); + } // Run shape inference on the graph and optimize the graph again. GraphShapeInfo shape_info; @@ -729,18 +753,32 @@ Status XlaCompiler::CompileFunction( } VLOG(1) << "===================================================="; - if (GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge) { +#ifdef LIBTPU_ON_GCE + if (GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge == + ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED) { + VLOG(1) << "MLIR is not supported in this environment."; + } + TF_RETURN_IF_ERROR( + CompileGraph(options, function_id, std::move(graph), args, result)); +#else + if (GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge == + ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED) { VLOG(1) << "Using MLIR bridge"; GraphDebugInfo debug_info; + std::vector control_rets; + for (const auto* control_ret_node : fbody->control_ret_nodes) { + control_rets.push_back(control_ret_node->name()); + } TF_RETURN_IF_ERROR(CompileGraphToXlaHlo( - std::move(*graph), {args.data(), args.size()}, - options_.device_type.type_string(), options.use_tuple_arg, + std::move(*graph), mlir::SpanToArrayRef(args), + control_rets, options_.device_type.type_string(), options.use_tuple_arg, *options_.flib_def, debug_info, options_.shape_representation_fn, result)); } else { TF_RETURN_IF_ERROR( CompileGraph(options, function_id, std::move(graph), args, result)); } +#endif VLOG(1) << "===================================================="; cache_[{function_id, arg_vector}] = *result; @@ -1143,7 +1181,11 @@ Status ValidateGraph(const Graph* graph, return errors::InvalidArgument(absl::StrCat( "Detected unsupported operations when trying to compile graph ", name, " on ", device_type.type_string(), ": ", node->def().op(), " (", - s.error_message(), ")", FormatNodeForError(*node))); + s.error_message(), ")", FormatNodeForError(*node), + "One approach is to outside compile the unsupported ops to run on " + "CPUs by enabling soft placement " + "`tf.config.set_soft_device_placement(True)`." + " This has a potential performance penalty.")); } return Status::OK(); }; diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index b0d93cde846..762700eaea8 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -129,8 +129,6 @@ class XlaCompiler { // Resource updates are converted into input / output of xla. The two // buffers are aliased with other if this option is true. - // - // Currently only supports TPU. bool alias_resource_update = false; }; diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index e37f4659185..ac6d065e775 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -134,6 +134,13 @@ XlaOpRegistry::~XlaOpRegistry() = default; result.first->second.op_filter = op_filter; } +/* static */ bool XlaOpRegistry::IsCompilationDevice( + const string& device_name) { + XlaOpRegistry& registry = Instance(); + mutex_lock lock(registry.mutex_); + return registry.backends_.find(device_name) != registry.backends_.end(); +} + /* static */ bool XlaOpRegistry::GetCompilationDevice( const string& device_name, const DeviceRegistration** registration) { XlaOpRegistry& registry = Instance(); @@ -365,6 +372,19 @@ std::vector XlaOpRegistry::DeviceKernels( return ops; } +/*static*/ const std::unordered_set* +XlaOpRegistry::CompileTimeConstantInputArgNames(const string& op) { + XlaOpRegistry& registry = Instance(); + mutex_lock lock(registry.mutex_); + auto it = registry.ops_.find(op); + static auto empty_set = new std::unordered_set; + if (it == registry.ops_.end() || it->second.empty()) { + return empty_set; + } else { + return &it->second.front()->compile_time_constant_inputs; + } +} + /* static */ Status XlaOpRegistry::CompileTimeConstantInputs( const NodeDef& node_def, const OpKernel* op_kernel, const OpDef* op_def, std::vector* result) { @@ -385,21 +405,10 @@ std::vector XlaOpRegistry::DeviceKernels( compile_time_constant_inputs_from_attr.end())); compile_time_constant_inputs = &compile_time_constant_inputs_from_attr; } else { - const string& op = node_def.op(); - - XlaOpRegistry& registry = Instance(); - mutex_lock lock(registry.mutex_); - auto it = registry.ops_.find(op); - if (it == registry.ops_.end() || it->second.empty()) { + compile_time_constant_inputs = + CompileTimeConstantInputArgNames(node_def.op()); + if (compile_time_constant_inputs->empty()) { return Status::OK(); - } else { - // The test in IsCompatible ensures that if there are multiple matching - // registrations for this op name, they all have the same value of - // compile_time_constant_inputs, so only the first match is returned. - // - // TODO(sanjoy): This can probably be a std::vector. - compile_time_constant_inputs = - &it->second.front()->compile_time_constant_inputs; } } diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index af720fb4bb9..36657208a28 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -153,6 +153,10 @@ class XlaOpRegistry { static void RegisterCompilationDevice(const string& device_name, const DeviceRegistration& registration); + // Returns whether the device name is for the JIT device used exclusively for + // TF2XLA conversion. + static bool IsCompilationDevice(const string& device_name); + // Returns the JIT device name associated with 'device_name', setting // 'jit_device_name', 'requires_jit', and 'enabled_jit_by_default', if they // are not null. Returns false and leaves the outputs unchanged if no matching @@ -198,6 +202,11 @@ class XlaOpRegistry { /*op_def=*/nullptr, result); } + // Return names of arguments for a given op which are supposed to be + // constants. + static const std::unordered_set* + CompileTimeConstantInputArgNames(const string& op); + // Returns true if `op` is a "metadata" op, one that only looks at the shapes // of its operands and not their values. static bool IsMetadataOp(const string& op); diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 35fa6a617f0..831da22e033 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -1,8 +1,9 @@ +load("//tensorflow:tensorflow.bzl", "filegroup") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow:tensorflow.bzl", "cc_header_only_library", "tf_cc_test") load( "//tensorflow/core/platform:build_config.bzl", - "tf_proto_library_cc", - "tf_proto_library_py", + "tf_proto_library", ) package( @@ -55,20 +56,14 @@ filegroup( visibility = [":friends"], ) -tf_proto_library_cc( +tf_proto_library( name = "xla_data_proto", srcs = ["xla_data.proto"], cc_api_version = 2, visibility = ["//visibility:public"], ) -tf_proto_library_py( - name = "xla_data_proto", # bzl adds a _py suffix - srcs = ["xla_data.proto"], - visibility = ["//visibility:public"], -) - -tf_proto_library_cc( +tf_proto_library( name = "xla_proto", srcs = ["xla.proto"], cc_api_version = 2, @@ -79,16 +74,6 @@ tf_proto_library_cc( visibility = ["//visibility:public"], ) -tf_proto_library_py( - name = "xla_proto", # bzl adds a _py suffix - srcs = ["xla.proto"], - visibility = ["//visibility:public"], - deps = [ - ":xla_data_proto_py", - "//tensorflow/compiler/xla/service:hlo_proto_py", - ], -) - cc_library( name = "bit_cast", hdrs = ["bit_cast.h"], @@ -292,6 +277,7 @@ tf_cc_test( ":types", ":util", "//tensorflow/core:test_main", + "//tensorflow/core/platform:bfloat16", ], ) @@ -335,7 +321,7 @@ cc_library( ":xla_data_proto_cc", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "//tensorflow/core:regexp_internal", + "//tensorflow/core/platform:regexp", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", @@ -541,7 +527,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":types", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "@com_google_absl//absl/strings", ], ) @@ -678,8 +664,8 @@ cc_library( ":statusor", ":types", "//tensorflow/core:lib", - "//tensorflow/core:regexp_internal", "//tensorflow/core:test", + "//tensorflow/core/platform:regexp", "@com_google_absl//absl/strings", ], ) @@ -969,6 +955,11 @@ tf_cc_test( ], ) +cc_library( + name = "union_find", + hdrs = ["union_find.h"], +) + # ----------------------------------------------------------------------------- # This is a headers target that extra XLA devices can use to prevent circular dependencies. Devices that are compiled as separate shared objects can also use it to prevent linking of library code. diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index a51970bb168..409cf37762b 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -1,6 +1,8 @@ # Description: # XLA client libraries. +load("//tensorflow:tensorflow.bzl", "filegroup") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test") package( @@ -128,7 +130,7 @@ cc_library( "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/service:source_map_util", "//tensorflow/compiler/xla/service:stream_pool", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "//tensorflow/stream_executor:device_memory_allocator", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:span", @@ -148,7 +150,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/service:compile_only_service", "//tensorflow/compiler/xla/service:compiler", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "@com_google_absl//absl/memory", "@llvm-project//llvm:Support", ], @@ -172,7 +174,7 @@ cc_library( "//tensorflow/compiler/xla/service:local_service", "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "//tensorflow/stream_executor:device_memory_allocator", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:optional", @@ -255,6 +257,7 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:hlo_matchers", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index eb09e9c8867..92d222f32b2 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -1,5 +1,7 @@ # Common computation builders for XLA. +load("//tensorflow:tensorflow.bzl", "filegroup") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites", "xla_test") package( @@ -305,6 +307,20 @@ xla_test( "//tensorflow/compiler/xla/tests:test_macros_header", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", + "//tensorflow/core/platform:tensor_float_32_utils", + ], +) + +cc_library( + name = "lu_decomposition", + srcs = ["lu_decomposition.cc"], + hdrs = ["lu_decomposition.h"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/client:xla_builder", ], ) @@ -345,6 +361,9 @@ cc_library( hdrs = ["sorting.h"], deps = [ ":comparators", + ":constants", + ":loops", + ":slicing", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", @@ -576,6 +595,7 @@ cc_library( ":loops", ":math", ":matrix", + ":qr", ":slicing", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", diff --git a/tensorflow/compiler/xla/client/lib/logdet.cc b/tensorflow/compiler/xla/client/lib/logdet.cc index 8f37c393922..18cd0870f2a 100644 --- a/tensorflow/compiler/xla/client/lib/logdet.cc +++ b/tensorflow/compiler/xla/client/lib/logdet.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/loops.h" #include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/lib/matrix.h" +#include "tensorflow/compiler/xla/client/lib/qr.h" #include "tensorflow/compiler/xla/client/lib/slicing.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" @@ -33,13 +34,46 @@ limitations under the License. namespace xla { -// let G = root(A) be the Cholesky root of the matrix A -// log(det(A)) = 2*sum(log(vecdiag(G))) +// log(det(A)) = sum(log(vecdiag(QR(A).r))), since R is triangular and Q is +// orthonormal XlaOp LogDet(XlaOp a) { - XlaOp cholesky = Cholesky(a, /*bool lower=*/true); + return a.builder()->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape a_shape, a.builder()->GetShape(a)); + // Compute the number of Householder transformations required on 'a' by + // determining the number of rows in 'a' that are already triangular. The + // determinant of Q is -1 ^ (number of Householder transfomations) + auto rows = Iota(a.builder(), ShapeUtil::ChangeElementType(a_shape, S32), + a_shape.rank() - 2); + auto cols = Iota(a.builder(), ShapeUtil::ChangeElementType(a_shape, S32), + a_shape.rank() - 1); + auto in_lower_triangle = Lt(cols, rows); + auto is_zero = Eq(a, ScalarLike(a, 0)); + auto num_zeros_in_triangle_per_row = Einsum( + ConvertElementType(And(in_lower_triangle, is_zero), S32), "...a->..."); + TF_ASSIGN_OR_RETURN(auto row_shape, + a.builder()->GetShape(num_zeros_in_triangle_per_row)); + rows = Iota(a.builder(), row_shape, row_shape.rank() - 1); + auto num_triangle_rows = + Einsum(ConvertElementType(Eq(rows, num_zeros_in_triangle_per_row), S32), + "...a->..."); + auto num_rows = + ScalarLike(num_triangle_rows, a_shape.dimensions(a_shape.rank() - 2)); - return ScalarLike(a, 2) * - Einsum(Log(cholesky), "...aa->...", xla::PrecisionConfig::HIGHEST); + TF_ASSIGN_OR_RETURN(auto qr, QRDecomposition(a, true)); + // Get the and log of the determinant based on the values along the diagonal + // of R. + auto log_abs_det = Einsum(Log(Abs(qr.r)), "...aa->..."); + auto sign_diag = Reduce( + Sign(Einsum(qr.r, "...aa->...a")), + One(a.builder(), a_shape.element_type()), + CreateScalarMultiplyComputation(a_shape.element_type(), a.builder()), + {a_shape.rank() - 2}); + return sign_diag * log_abs_det * + Select(ConvertElementType(Rem(num_rows - num_triangle_rows, + ScalarLike(num_triangle_rows, 2)), + PRED), + ScalarLike(sign_diag, -1.0), ScalarLike(sign_diag, 1.0)); + }); } } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/logdet_test.cc b/tensorflow/compiler/xla/client/lib/logdet_test.cc index 54af41f77f6..319d819ed98 100644 --- a/tensorflow/compiler/xla/client/lib/logdet_test.cc +++ b/tensorflow/compiler/xla/client/lib/logdet_test.cc @@ -51,6 +51,26 @@ XLA_TEST_F(LogDetTest, Simple) { xla::ErrorSpec(1e-4)); } +XLA_TEST_F(LogDetTest, SimpleTriangle) { + xla::XlaBuilder builder(TestName()); + + xla::Array2D a_vals({ + {4, 6, 8, 10}, + {4, -39, 62, 73}, + {0, 0, -146, 166}, + {4, 6, 8, 320}, + }); + + float expected = -15.9131355f; + + xla::XlaOp a; + auto a_data = CreateR2Parameter(a_vals, 0, "a", &builder, &a); + xla::LogDet(a); + + ComputeAndCompareR0(&builder, expected, {a_data.get()}, + xla::ErrorSpec(1e-4)); +} + XLA_TEST_F(LogDetTest, SimpleBatched) { xla::XlaBuilder builder(TestName()); diff --git a/tensorflow/compiler/xla/client/lib/lu_decomposition.cc b/tensorflow/compiler/xla/client/lib/lu_decomposition.cc new file mode 100644 index 00000000000..2920b6f56b5 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/lu_decomposition.cc @@ -0,0 +1,57 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/lu_decomposition.h" + +#include + +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" + +namespace xla { + +LuDecompositionResult LuDecomposition(XlaOp a) { + XlaBuilder* builder = a.builder(); + XlaOp result = builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); + const int ndims = a_shape.rank(); + TF_RET_CHECK(ndims >= 2); + const int64 m = ShapeUtil::GetDimension(a_shape, -2); + const int64 n = ShapeUtil::GetDimension(a_shape, -1); + const int num_batch_dims = a_shape.dimensions().size() - 2; + const std::vector batch_dims( + a_shape.dimensions().begin(), + a_shape.dimensions().begin() + num_batch_dims); + + std::vector pivot_dims = batch_dims; + pivot_dims.push_back(std::min(m, n)); + std::vector perm_dims = batch_dims; + perm_dims.push_back(m); + Shape lu_shape = ShapeUtil::MakeTupleShape( + {a_shape, ShapeUtil::MakeShape(S32, pivot_dims), + ShapeUtil::MakeShape(S32, perm_dims)}); + // The TPU compiler has a rewrite pass that lowers an LuDecomposition + // CustomCall. + // TODO(phawkins): upgrade LU decomposition to a first-class HLO operator + // and implement it on other backends. + return CustomCall(a.builder(), "LuDecomposition", {a}, lu_shape); + }); + return LuDecompositionResult{GetTupleElement(result, 0), + GetTupleElement(result, 1), + GetTupleElement(result, 2)}; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/lu_decomposition.h b/tensorflow/compiler/xla/client/lib/lu_decomposition.h new file mode 100644 index 00000000000..3f5703510a3 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/lu_decomposition.h @@ -0,0 +1,61 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_LU_DECOMPOSITION_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_LU_DECOMPOSITION_H_ + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +// Computes the LU decomposition with partial pivoting of a batch of matrices. +// +// Given a (batched) matrix a with shape [..., m, n], computes the matrix +// decomposition A = P @ L @ U where P is a permutation matrix, L is a +// lower-triangular matrix with unit diagonal entries, and U is an +// upper-triangular matrix. +// +// L and U are returned as a single matrix [..., m, n] containing both L and U +// packed in the same array. The unit diagonal of L is not represented +// explicitly. +// +// The permutation matrix P is returned in two forms, both as `pivots`, which is +// an s32[..., min(m, n)] array that describes a sequence of row-swaps in the +// style of LAPACK's xGETRF API, and `permutation`, which is a s32[..., m] array +// which gives the permutation to apply to the rows. We return both +// representations because they are each useful for different purposes; `pivots` +// is useful for computing the sign of a determinant, whereas `permutation` can +// be used via a Gather operation to permute the rows of a matrix. +// +// This method is only implemented on TPU at the moment. +// TODO(b/168208200): the implementation only supports F32 arrays. Handle the +// complex case. +struct LuDecompositionResult { + // The LU decomposition, with both L and U packed into an array with shape + // [..., m, n]. + XlaOp lu; + // An array of shape s32[..., min(m, n)] containing the pivot rows. + XlaOp pivots; + // An array of shape s32[..., m], containing an another representation of the + // pivots as a permutation. + XlaOp permutation; +}; + +LuDecompositionResult LuDecomposition(XlaOp a); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_LU_DECOMPOSITION_H_ diff --git a/tensorflow/compiler/xla/client/lib/math.cc b/tensorflow/compiler/xla/client/lib/math.cc index cd9f88a74ce..76cc6f0159b 100644 --- a/tensorflow/compiler/xla/client/lib/math.cc +++ b/tensorflow/compiler/xla/client/lib/math.cc @@ -203,7 +203,7 @@ static XlaOp ErfcImpl32(XlaOp x) { // Precondition: abs(x) <= 1. Otherwise, use ErfcImpl. // // This follows Cephes's f32 implementation of erf. -static XlaOp ErfImpl32(XlaOp x) { +static XlaOp ErfImpl32Cephes(XlaOp x) { // Coefficients for by erf(f32), from Cephes. // // erf(x) = x P(x^2), 0 < x < 1 @@ -291,11 +291,31 @@ XlaOp Erfc(XlaOp x) { // (not surprising!), so upcast to f32 in this case. return DoWithUpcastToF32(x, {BF16, F16}, [](XlaOp x) { return Select(Gt(Abs(x), ScalarLike(x, 1)), ErfcImpl32(x), - ScalarLike(x, 1) - ErfImpl32(x)); + ScalarLike(x, 1) - ErfImpl32Cephes(x)); }); }); } +// Compute a polynomial approximation of the error function. +// This is the same approximation used by Eigen. +static XlaOp ErfImpl32(XlaOp x) { + static const std::array kAlpha{ + -2.72614225801306e-10f, 2.77068142495902e-08f, -2.10102402082508e-06f, + -5.69250639462346e-05f, -7.34990630326855e-04f, -2.95459980854025e-03f, + -1.60960333262415e-02f, + }; + + static const std::array kBeta{ + -1.45660718464996e-05f, -2.13374055278905e-04f, -1.68282697438203e-03f, + -7.37332916720468e-03f, -1.42647390514189e-02f, + }; + + x = Clamp(ScalarLike(x, -4.f), x, ScalarLike(x, 4.f)); + auto x2 = x * x; + return x * EvaluatePolynomial(x2, kAlpha) / + EvaluatePolynomial(x2, kBeta); +} + XlaOp Erf(XlaOp x) { auto& b = *x.builder(); return b.ReportErrorOrReturn([&]() -> StatusOr { @@ -310,10 +330,8 @@ XlaOp Erf(XlaOp x) { } // Erf(c)Impl don't have enough precision when run with bf16 intermediates // (not surprising!), so upcast to f32 in this case. - return DoWithUpcastToF32(x, {BF16, F16}, [](XlaOp x) { - return Select(Lt(Abs(x), ScalarLike(x, 1)), ErfImpl32(x), - ScalarLike(x, 1) - ErfcImpl32(x)); - }); + return DoWithUpcastToF32(x, {BF16, F16}, + [](XlaOp x) { return ErfImpl32(x); }); }); } @@ -1832,4 +1850,139 @@ XlaOp RegularizedIncompleteBeta(XlaOp a, XlaOp b, XlaOp x) { }); } +XlaOp Polygamma(XlaOp n, XlaOp x) { + auto& builder = *x.builder(); + auto doit = [](XlaOp n, XlaOp x, PrimitiveType type) -> XlaOp { + XlaOp n_plus_one = n + ScalarLike(n, 1.); + XlaOp sign = + (ScalarLike(n, 2.) * Rem(n, ScalarLike(n, 2.)) - ScalarLike(n, 1.)); + + const double nan = std::numeric_limits::quiet_NaN(); + + XlaOp output = Select(Eq(n, ScalarLike(n, 0.)), Digamma(x), + sign * Exp(Lgamma(n_plus_one)) * Zeta(n_plus_one, x)); + // Check that n is a natural number. + output = Select(Or(Ne(n, Floor(n)), Lt(n, ScalarLike(n, 0.))), + ScalarLike(n, nan), output); + return output; + }; + return builder.ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(auto n_shape, builder.GetShape(n)); + TF_ASSIGN_OR_RETURN(auto x_shape, builder.GetShape(x)); + if (n_shape != x_shape) { + return InvalidArgument( + "Arguments to Polygamma must have equal shapes and types; " + "got %s and %s", + n_shape.ToString(), x_shape.ToString()); + } + TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Zeta", x)); + bool needs_upcast = + n_shape.element_type() == F16 || x_shape.element_type() == BF16; + + if (needs_upcast) { + n = ConvertElementType(n, F32); + x = ConvertElementType(x, F32); + } + XlaOp result = doit(n, x, n_shape.element_type()); + if (needs_upcast) { + result = ConvertElementType(result, n_shape.element_type()); + } + return result; + }); +} + +XlaOp Zeta(XlaOp x, XlaOp q) { + auto& builder = *x.builder(); + auto doit = [&builder](XlaOp x, XlaOp q, PrimitiveType type) -> XlaOp { + // (2k) ! / B_{2k}, where B_{2k} are the Bernoulli numbers. + // These are ordered in reverse. + static const std::array kZetaCoeffs{ + -7.1661652561756670113e18, + 1.8152105401943546773e17, + -4.5979787224074726105e15, + 1.1646782814350067249e14, + -2.950130727918164224e12, + 7.47242496e10, + -1.8924375803183791606e9, + 47900160.0, + -1209600.0, + 30240.0, + -720.0, + 12.0, + }; + + // For speed we'll always use 9 iterations for the initial series estimate, + // and a 12 term expansion for the Euler-Maclaurin formula. + + XlaOp a = q; + XlaOp neg_power = ScalarLike(a, 0.); + XlaOp initial_sum = Pow(q, Neg(x)); + for (int i = 0; i < 9; ++i) { + a = a + ScalarLike(a, 1.); + neg_power = Pow(a, Neg(x)); + initial_sum = initial_sum + neg_power; + } + a = a + ScalarLike(a, 1.); + neg_power = Pow(a, Neg(x)); + XlaOp s = initial_sum + neg_power * a / (x - ScalarLike(a, 1.)); + XlaOp a_inverse_square = Reciprocal(Square(a)); + XlaOp horner_sum = ScalarLike(a, 0.); + XlaOp factor = ScalarLike(a, 1.); + // Use Horner's rule for this. + // Note this differs from Cephes which does a 'naive' polynomial evaluation. + // Using Horner's rule allows to avoid some NaN's and Infs from happening, + // resulting in more numerically stable code. + for (int i = 0; i < 11; ++i) { + factor = + (x - ScalarLike(x, 22 - 2 * i)) * (x - ScalarLike(x, 21 - 2 * i)); + horner_sum = factor * a_inverse_square * + (horner_sum + ScalarLike(a, 1. / kZetaCoeffs[i])); + } + s = s + neg_power * + (ScalarLike(neg_power, 0.5) + + x / a * (ScalarLike(a, 1. / kZetaCoeffs[11]) + horner_sum)); + + const double nan = std::numeric_limits::quiet_NaN(); + const double inf = std::numeric_limits::infinity(); + // Use the initial zeta sum without the correction term coming + // from Euler-Maclaurin if it is accurate enough. + XlaOp output = + Select(Lt(Abs(neg_power), Abs(initial_sum) * Epsilon(&builder, type)), + initial_sum, s); + // This is the harmonic series. + output = Select(Eq(x, ScalarLike(x, 1.)), ScalarLike(x, inf), output); + // Function is not defined for x < 1. + output = Select(Lt(x, ScalarLike(x, 1.)), ScalarLike(x, nan), output); + // If q <= 0, then when q is an integer or x is not an integer, this is + // NaN. + XlaOp domain_error = And(Le(q, ScalarLike(x, 0.)), Ne(x, Floor(x))); + XlaOp negative_integer_q = And(Le(q, ScalarLike(x, 0.)), Eq(q, Floor(q))); + output = Select(negative_integer_q, ScalarLike(x, inf), output); + output = Select(domain_error, ScalarLike(x, nan), output); + return output; + }; + return builder.ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(auto x_shape, builder.GetShape(x)); + TF_ASSIGN_OR_RETURN(auto q_shape, builder.GetShape(q)); + if (x_shape != q_shape) { + return InvalidArgument( + "Arguments to Zeta must have equal shapes and types; got %s and %s", + x_shape.ToString(), q_shape.ToString()); + } + TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Zeta", x)); + bool needs_upcast = + x_shape.element_type() == F16 || x_shape.element_type() == BF16; + + if (needs_upcast) { + x = ConvertElementType(x, F32); + q = ConvertElementType(q, F32); + } + XlaOp result = doit(x, q, x_shape.element_type()); + if (needs_upcast) { + result = ConvertElementType(result, x_shape.element_type()); + } + return result; + }); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/math.h b/tensorflow/compiler/xla/client/lib/math.h index f03348c0a57..e6b5ac992cc 100644 --- a/tensorflow/compiler/xla/client/lib/math.h +++ b/tensorflow/compiler/xla/client/lib/math.h @@ -72,6 +72,12 @@ XlaOp RandomGammaGrad(XlaOp a, XlaOp x); // Computes an approximation of the complementary incomplete gamma function. XlaOp Igammac(XlaOp a, XlaOp x); +// Computes the Polygamma of two arguments. +XlaOp Polygamma(XlaOp n, XlaOp x); + +// Computes the Riemann zeta function of two arguments. +XlaOp Zeta(XlaOp x, XlaOp q); + // Rounds the given number to even when the number is equidistant between two // integers. XlaOp RoundToEven(XlaOp x); diff --git a/tensorflow/compiler/xla/client/lib/matrix.cc b/tensorflow/compiler/xla/client/lib/matrix.cc index ec1cc7e0487..dbb73602801 100644 --- a/tensorflow/compiler/xla/client/lib/matrix.cc +++ b/tensorflow/compiler/xla/client/lib/matrix.cc @@ -395,7 +395,6 @@ xla::XlaOp Einsum(xla::XlaOp x, absl::Span x_config, xla::XlaOp y, } DotDimensionNumbers dnums; - std::vector lhs_outer_dims; auto is_batch_dim = [&](int64 d) { return x_map.contains(d) && y_map.contains(d) && output_map.contains(d); }; @@ -408,11 +407,13 @@ xla::XlaOp Einsum(xla::XlaOp x, absl::Span x_config, xla::XlaOp y, }; absl::InlinedVector rhs_outer_dims; + absl::InlinedVector lhs_outer_dims; absl::InlinedVector rhs_delete_dims; absl::InlinedVector lhs_delete_dims; for (int64 i = 0; i < x_rank; ++i) { auto dim_name = x_config[i]; const int64 rhs_dim = rhs_dimension_number(dim_name); + if (is_batch_dim(dim_name)) { if (x_shape.dimensions(i) == y_shape.dimensions(rhs_dim)) { dnums.add_lhs_batch_dimensions(i); @@ -448,30 +449,34 @@ xla::XlaOp Einsum(xla::XlaOp x, absl::Span x_config, xla::XlaOp y, } absl::c_sort(rhs_outer_dims); - absl::InlinedVector output_transpose_dims; - absl::InlinedVector output_reduce_dims; - auto output_dimension_number = [&](int64 d) { + + auto output_dimension_number = [&](int64 d) -> absl::optional { auto pos = absl::c_find(output_config, d); if (pos == output_config.end()) { - const int64 dim = - output_transpose_dims.size() + output_reduce_dims.size(); - output_reduce_dims.push_back(dim); - } else { - output_transpose_dims.push_back(pos - output_config.begin()); + return absl::nullopt; } + return pos - output_config.begin(); }; for (auto d : dnums.lhs_batch_dimensions()) { - output_dimension_number(x_config[d]); + output_transpose_dims.push_back(*output_dimension_number(x_config[d])); } for (auto d : lhs_outer_dims) { - output_dimension_number(x_config[d]); + if (auto output_dim = output_dimension_number(x_config[d])) { + output_transpose_dims.push_back(*output_dim); + continue; + } + lhs_delete_dims.push_back(d); } for (auto d : rhs_outer_dims) { - output_dimension_number(y_config[d]); + if (auto output_dim = output_dimension_number(y_config[d])) { + output_transpose_dims.push_back(*output_dim); + continue; + } + rhs_delete_dims.push_back(d); } const int64 transpose_rank = output_transpose_dims.size(); @@ -482,29 +487,31 @@ xla::XlaOp Einsum(xla::XlaOp x, absl::Span x_config, xla::XlaOp y, // Remove ones that where broadcasted from the x and the y shape and adjust // the dimension numbers that are more minor than those dimensions. + absl::c_sort(lhs_delete_dims); DeleteDimsFromContainer(lhs_delete_dims, &x_shape, dnums.mutable_lhs_batch_dimensions(), dnums.mutable_lhs_contracting_dimensions()); + + absl::c_sort(rhs_delete_dims); DeleteDimsFromContainer(rhs_delete_dims, &y_shape, dnums.mutable_rhs_batch_dimensions(), dnums.mutable_rhs_contracting_dimensions()); if (!lhs_delete_dims.empty()) { - x = Reshape(x, x_shape.dimensions()); + x = Reduce(x, ScalarLike(x, 0), + CreateScalarAddComputation(x_shape.element_type(), builder), + lhs_delete_dims); } if (!rhs_delete_dims.empty()) { - y = Reshape(y, y_shape.dimensions()); + y = Reduce(y, ScalarLike(y, 0), + CreateScalarAddComputation(y_shape.element_type(), builder), + rhs_delete_dims); } PrecisionConfig precision_proto; precision_proto.add_operand_precision(precision); precision_proto.add_operand_precision(precision); auto dot = DotGeneral(x, y, dnums, &precision_proto); - if (!output_reduce_dims.empty()) { - dot = Reduce(dot, ScalarLike(dot, 0), - CreateScalarAddComputation(x_shape.element_type(), builder), - output_reduce_dims); - } dot = Transpose(dot, transpose_dims); if (transpose_rank == output_rank) { return dot; diff --git a/tensorflow/compiler/xla/client/lib/prng.cc b/tensorflow/compiler/xla/client/lib/prng.cc index cc5639f1be1..60086773d18 100644 --- a/tensorflow/compiler/xla/client/lib/prng.cc +++ b/tensorflow/compiler/xla/client/lib/prng.cc @@ -487,6 +487,10 @@ std::pair BoxMullerTransform(XlaOp x0, XlaOp x1) { } // namespace +XlaOp PhiloxIncreaseCounter(XlaOp counter, XlaOp delta) { + return Uint128ToOp(Uint128AddUint64(Uint128FromOp(counter), delta)); +} + RngOutput ThreeFryBitGenerator(XlaOp key, XlaOp initial_state, const Shape& shape) { PrimitiveType type = shape.element_type(); diff --git a/tensorflow/compiler/xla/client/lib/prng.h b/tensorflow/compiler/xla/client/lib/prng.h index 107fd884de3..20ad223403d 100644 --- a/tensorflow/compiler/xla/client/lib/prng.h +++ b/tensorflow/compiler/xla/client/lib/prng.h @@ -89,6 +89,9 @@ RngOutput NormalFloatingPointDistribution(XlaOp key, XlaOp initial_state, xla::XlaOp ConcatScalars(xla::XlaBuilder* builder, absl::Span scalars); +// Increases Philox counter (an uint128) by a delta (an uint64). +xla::XlaOp PhiloxIncreaseCounter(xla::XlaOp counter, xla::XlaOp delta); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_PRNG_H_ diff --git a/tensorflow/compiler/xla/client/lib/qr.cc b/tensorflow/compiler/xla/client/lib/qr.cc index b2eecbac309..88a17f94a24 100644 --- a/tensorflow/compiler/xla/client/lib/qr.cc +++ b/tensorflow/compiler/xla/client/lib/qr.cc @@ -35,301 +35,7 @@ namespace xla { namespace { -std::vector ConcatVectors(absl::Span xs, - absl::Span ys) { - std::vector output(xs.size() + ys.size()); - std::copy(xs.begin(), xs.end(), output.begin()); - std::copy(ys.begin(), ys.end(), output.begin() + xs.size()); - return output; -} -// Computes a Householder reflection of the form: -// H = I - tau v v.T. -// such that -// H . ( x1 ) = ( x1 ) -// ( x2 ) = ( x2 ) -// ( ... ) = ( ... ) -// ( xk ) = ( beta ) -// ( ... ) ( 0 ) -// ( ... ) ( 0 ) -// Unlike the usual formulation, we allow the caller to supply 'k' rather than -// only providing the relevant part of 'x' to maintain XLA's static shape -// invariant. In addition, the implementation supports batching. -// Pseudo-code, without batching: -// alpha = x[k] -// x_copy = np.copy(x) -// x_copy[:k+1] = 0 -// xnorm = norm2(x_copy) -// if xnorm == 0: -// beta = alpha -// tau = 0 -// v = np.zeros_like(x) -// else: -// beta = - np.sign(alpha) * dlapy2(alpha, xnorm) -// tau = (beta - alpha) / beta -// v = x / (alpha - beta) -// v[k] = 1 -// return (v, tau, beta) -// TODO(phawkins): LAPACK's xLARFG implementation has code for handling -// overflows in the norm/beta calculations. Perhaps do the same here. -Status House(XlaOp x, XlaOp k, absl::Span batch_dims, - const int64 m, XlaOp* v, XlaOp* tau, XlaOp* beta) { - XlaBuilder* const builder = x.builder(); - TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x)); - const PrimitiveType type = x_shape.element_type(); - - std::vector batch_dim_ids(batch_dims.size()); - std::iota(batch_dim_ids.begin(), batch_dim_ids.end(), 0); - const int64 minor_dim = batch_dims.size(); - - XlaOp zero = ScalarLike(x, 0.0); - XlaOp one = ScalarLike(x, 1.0); - - // alpha = x[k] - XlaOp alpha = Reshape(DynamicSliceInMinorDims(x, {k}, {1}), batch_dims); - - // Compute x[k+1:] (padded with zeros in elements 0..k) - XlaOp iota = Iota(builder, S32, m); - XlaOp x_after_k = Mul(x, ConvertElementType(Gt(iota, k), type), - /*broadcast_dimensions=*/{minor_dim}); - - // sigma = np.dot(x[k+1:], x[k+1:]) - auto sigma = Reduce(x_after_k * x_after_k, zero, - CreateScalarAddComputation(type, builder), {minor_dim}); - // mu = np.sqrt(x[k]*x[k] + sigma) - auto mu = Sqrt(Square(alpha) + sigma); - - auto sigma_is_zero = Eq(sigma, zero); - - *beta = Select(sigma_is_zero, alpha, Select(Lt(alpha, zero), one, -one) * mu); - *tau = Select(sigma_is_zero, Broadcast(zero, batch_dims), - (*beta - alpha) / *beta); - auto divisor = - Select(sigma_is_zero, Broadcast(one, batch_dims), alpha - *beta); - - auto e_k = Broadcast(ConvertElementType(Eq(iota, k), type), - std::vector(batch_dims.size(), 1)); - - // Form v as [0, 0, ..., 1] ++ x[k+1:] / divisor - // If sigma is zero, x[k+1:] is zero, so use any non-zero divisor. - *v = e_k + Div(x_after_k, divisor, /*broadcast_dimensions=*/batch_dim_ids); - return Status::OK(); -} - -// Householder QR decomposition. Algorithm 5.2.1 from Golub and Van -// Loan "Matrix Computations", 4th Edition. This is an unblocked implementation -// used as an inner routine of the blocked implementation. -// Algorithm is adapted slightly so the shapes inside the loop are static, at -// the cost of some redundant computation. Since this is used as an inner block -// kernel, accumulates the Householder transformations (vs, taus) rather than -// the matrix q. -// Equivalent Python code, without batching: -// def qr(a): -// m = a.shape[0] -// n = a.shape[1] -// vs = np.zeros([m, n]) -// taus = np.zeros([n]) -// for j in xrange(min(m, n)): -// v, tau, beta = house(a[:, j], j) -// # Unusually, we apply the Householder transformation to the entirety of -// # a, wasting FLOPs to maintain the static shape invariant that XLA -// # requires. For columns that precede j this has no effect. -// a[:, :] -= tau * np.dot(v[:, np.newaxis], -// np.dot(v[np.newaxis, :], a[:, :])) -// # Form column j explicitly rather than relying on the precision of the -// # Householder update. -// a[j, j] = beta -// a[j+1:, j] = np.zeros([m - j - 1], dtype=a.dtype) -// vs[:, j] = v -// taus[j] = tau -// return (q, vs, taus) -struct QRBlockResult { - // The factored R value - XlaOp r; - - // Representation of the Householder matrices I - beta v v.T - XlaOp taus; // Shape: [..., n] - XlaOp vs; // Shape: [..., m, n] -}; -StatusOr QRBlock(XlaOp a, PrecisionConfig::Precision precision) { - XlaBuilder* builder = a.builder(); - TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); - const int num_dims = a_shape.rank(); - if (num_dims < 2) { - return InvalidArgument("Argument to QR must have rank >= 2; got shape %s", - a_shape.ToString()); - } - PrimitiveType type = a_shape.element_type(); - - const int64 m = ShapeUtil::GetDimension(a_shape, -2); - const int64 n = ShapeUtil::GetDimension(a_shape, -1); - - const int64 num_batch_dims = num_dims - 2; - std::vector batch_dims(num_batch_dims); - for (int i = 0; i < num_batch_dims; ++i) { - batch_dims[i] = ShapeUtil::GetDimension(a_shape, i); - } - - std::vector batch_dim_indices(num_batch_dims); - std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0); - - auto qr_body_fn = [&](XlaOp j, absl::Span values, - XlaBuilder* builder) -> StatusOr> { - auto a = values[0]; - auto vs = values[1]; - auto taus = values[2]; - - // v, beta = house(a[:, j], j) - auto x = DynamicSliceInMinorDims(a, {j}, {1}); - XlaOp v, tau, beta; - TF_RETURN_IF_ERROR(House(Collapse(x, {num_dims - 2, num_dims - 1}), j, - batch_dims, m, &v, &tau, &beta)); - - std::vector shape = batch_dims; - shape.push_back(1); - shape.push_back(m); - auto v_broadcast = Reshape(v, shape); - // a[:, :] -= tau * np.dot(v[:, np.newaxis], - // np.dot(v[np.newaxis, :], a[:, :])) - auto vva = BatchDot(v_broadcast, a, precision); - vva = BatchDot(v_broadcast, true, vva, false, precision); - a = a - Mul(tau, vva, - /*broadcast_dimensions=*/batch_dim_indices); - - // It is more precise to populate column 'k' explicitly, rather than - // computing it implicitly by applying the Householder transformation. - // a[k,k] = beta - // a[k+1:,k] = np.zeros([m-k-1], dtype=a.dtype) - auto iota = Reshape(Iota(a.builder(), S32, m), {m, 1}); - auto predecessor_mask = ConvertElementType(Lt(iota, j), type); - auto mask = Broadcast(ConvertElementType(Eq(iota, j), type), - std::vector(batch_dims.size(), 1)); - auto new_x = Mul(x, predecessor_mask, - /*broadcast_dimensions=*/{num_dims - 2, num_dims - 1}) + - Mul(beta, mask, /*broadcast_dimensions=*/batch_dim_indices); - // Update a[:,j] - std::vector dim_ids(num_dims); - std::iota(dim_ids.begin(), dim_ids.end(), 0); - new_x = BroadcastInDim(new_x, ConcatVectors(batch_dims, {m, n}), - /*broadcast_dimensions=*/dim_ids); - const int64 minor_dim = batch_dims.size(); - auto iota_mn = Iota( - builder, ShapeUtil::MakeShape(S32, ConcatVectors(batch_dims, {m, n})), - minor_dim + 1); - a = Select(Eq(iota_mn, j), new_x, a); - - // vs[:, j] = v - std::vector vs_broadcast_dims(batch_dims.size() + 1); - std::iota(vs_broadcast_dims.begin(), vs_broadcast_dims.end(), 0); - auto vs_zeros = ZerosLike(vs); - auto vs_update = Select( - Eq(iota_mn, j), - Add(vs_zeros, v, /*broadcast_dimensions=*/vs_broadcast_dims), vs_zeros); - vs = vs + vs_update; - - // taus[j] = tau - std::vector tau_broadcast_dims(batch_dims.size()); - std::iota(tau_broadcast_dims.begin(), tau_broadcast_dims.end(), 0); - - auto iota_n = - Iota(builder, ShapeUtil::MakeShape(S32, ConcatVectors(batch_dims, {n})), - minor_dim); - auto taus_zeros = ZerosLike(taus); - auto taus_update = Select( - Eq(iota_n, j), - Add(taus_zeros, tau, /*broadcast_dimensions=*/tau_broadcast_dims), - taus_zeros); - taus = taus + taus_update; - return std::vector{a, vs, taus}; - }; - - auto vs = Zeros( - builder, ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {m, n}))); - auto taus = Zeros(builder, - ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {n}))); - - TF_ASSIGN_OR_RETURN(auto values, ForEachIndex(std::min(m, n), S32, qr_body_fn, - {a, vs, taus}, "qr", builder)); - - QRBlockResult result; - result.r = values[0]; - result.vs = values[1]; - result.taus = values[2]; - return result; -} - -// Computes W and Y such that I-WY is equivalent to the sequence of Householder -// transformations given by vs and taus. -// Golub and van Loan, "Matrix Computations", algorithm 5.1.2. -// Y = np.zeros([m, n]) -// W = np.zeros([m, n]) -// Y[:, 0] = vs[:, 0] -// W[:, 0] = -taus[0] * vs[:, 0] -// for j in xrange(1, n): -// v = vs[:, j] -// z = -taus[j] * v - taus[j] * np.dot(W, np.dot(Y.T, v)) -// W[:, j] = z -// Y[:, j] = v -// return W -// There is no need to return Y since at termination of the loop it is equal to -// vs. -StatusOr ComputeWYRepresentation(PrimitiveType type, - absl::Span batch_dims, - XlaOp vs, XlaOp taus, int64 m, int64 n, - PrecisionConfig::Precision precision) { - std::vector batch_dim_indices(batch_dims.size()); - std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0); - int64 n_index = batch_dims.size() + 1; - - auto body_fn = [&](XlaOp j, absl::Span values, - XlaBuilder* builder) -> StatusOr> { - // w has shape [..., m, n] - auto w = values[0]; - const auto vs = values[1]; - const auto taus = values[2]; - - // Want j values in range [1, ... n). - j = j + ConstantR0(builder, 1); - // vs has shape [..., m, 1] - auto v = DynamicSliceInMinorDims(vs, {j}, {1}); - // beta has shape [..., 1] - auto beta = DynamicSliceInMinorDims(taus, {j}, {1}); - - auto iota_mn = Iota( - builder, ShapeUtil::MakeShape(S32, ConcatVectors(batch_dims, {m, n})), - n_index); - - // y has shape [..., m, n] - auto y = Select(Ge(iota_mn, j), ZerosLike(vs), vs); - - // yv has shape [..., n, 1] - auto yv = BatchDot(y, true, v, false, precision); - // wyv has shape [..., m, 1] - auto wyv = BatchDot(w, yv, precision); - - auto z = Mul( - -beta, v + wyv, - /*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {n_index})); - - w = DynamicUpdateSliceInMinorDims(w, z, {j}); - - return std::vector{w, vs, taus}; - }; - - XlaBuilder* builder = vs.builder(); - auto w = Zeros(builder, - ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {m, n}))); - auto v = SliceInMinorDims(vs, {0}, {1}); - auto beta = SliceInMinorDims(taus, {0}, {1}); - auto bv = - Mul(-beta, v, - /*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {n_index})); - w = UpdateSliceInMinorDims(w, bv, {0}); - - TF_ASSIGN_OR_RETURN(auto values, ForEachIndex(n - 1, S32, body_fn, - {w, vs, taus}, "wy", builder)); - return values[0]; -} } // namespace @@ -340,14 +46,12 @@ StatusOr ComputeWYRepresentation(PrimitiveType type, // q = np.eye(m) // for i in xrange(0, min(m, n), block_size): // k = min(block_size, min(m, n) - s) -// (a, vs, taus) = qr(a[i:, i:i+k]) -// y = vs -// w = ComputeWYRepresentation(vs, taus, m-i, k) -// a[i:, i+r:] += np.dot(y, np.dot(w.T, a[i:, i+k:])) -// q[:, i:] += np.dot(q[:, i:], np.dot(w, y.T)) +// (a, taus) = qr(a[i:, i:i+k]) +// y = np.eye(m, n) + np.tril(a, -1) +// t = CompactWYRepresentation(vs, taus, m-i, k) +// a[i:, i+k:] += (y @ t.T) @ (y.T @ a[i:, i+k:]) +// q[:, i:] += (q[:, i:] @ y) @ (y @ t.T).T // return (q, a) -// TODO(phawkins): consider using UT transformations (in the form I - V U V') -// rather than WY transformations. StatusOr QRDecomposition( XlaOp a, bool full_matrices, int64 block_size, PrecisionConfig::Precision precision) { @@ -358,8 +62,6 @@ StatusOr QRDecomposition( return InvalidArgument("Arguments to QR must have rank >= 2: got shape %s", a_shape.ToString()); } - PrimitiveType type = a_shape.element_type(); - const int64 m = ShapeUtil::GetDimension(a_shape, -2); const int64 n = ShapeUtil::GetDimension(a_shape, -1); const int64 p = std::min(m, n); @@ -369,53 +71,21 @@ StatusOr QRDecomposition( block_size); } - const int64 num_batch_dims = num_dims - 2; - std::vector batch_dims(num_batch_dims); - for (int i = 0; i < num_batch_dims; ++i) { - batch_dims[i] = ShapeUtil::GetDimension(a_shape, i); - } + Shape q_shape = a_shape; + q_shape.mutable_dimensions().back() = m; - auto q = Broadcast(IdentityMatrix(builder, type, m, m), batch_dims); - for (int64 i = 0; i < p; i += block_size) { - int64 k = std::min(block_size, p - i); - - auto a_block = SliceInMinorDims(a, {i, i}, {m, i + k}); - TF_ASSIGN_OR_RETURN(auto qr_block, QRBlock(a_block, precision)); - - a = UpdateSliceInMinorDims(a, qr_block.r, {i, i}); - - // Compute the I-WY block representation of a product of Householder - // matrices. - TF_ASSIGN_OR_RETURN( - auto w, ComputeWYRepresentation(type, batch_dims, qr_block.vs, - qr_block.taus, m - i, k, precision)); - auto y = qr_block.vs; - - // a[i:, i+k:] += np.dot(Y, np.dot(W.T, a[i:, i+k:])) - auto a_panel = SliceInMinorDims(a, {i, i + k}, {m, n}); - auto a_update = BatchDot(w, true, a_panel, false, precision); - a_update = BatchDot(y, a_update, precision); - a_panel = a_panel + a_update; - a = UpdateSliceInMinorDims(a, a_panel, {i, i + k}); - - // q[:, i:] += np.dot(np.dot(q[:, i:], W), Y.T)) - auto q_panel = SliceInMinorDims(q, {0, i}, {m, m}); - auto q_update = BatchDot(q_panel, w, precision); - q_update = BatchDot(q_update, false, y, true, precision); - q_panel = q_panel + q_update; - q = UpdateSliceInMinorDims(q, q_panel, {0, i}); - } - QRDecompositionResult result; + Shape qr_shape = ShapeUtil::MakeTupleShape({q_shape, a_shape}); + auto qr = CustomCall(a.builder(), "QrDecomposition", {a}, qr_shape); + auto q = GetTupleElement(qr, 0); + auto r = GetTupleElement(qr, 1); // full_matrices is false when only a partial result in needed. Slice to the // needed dimensions here. if (!full_matrices) { q = SliceInMinorDims(q, {0, 0}, {m, p}); - a = SliceInMinorDims(a, {0, 0}, {p, n}); + r = SliceInMinorDims(r, {0, 0}, {p, n}); } - result.q = q; - result.r = a; - return result; + return QRDecompositionResult{q, r}; } } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/qr_test.cc b/tensorflow/compiler/xla/client/lib/qr_test.cc index a61f243e126..f1d2e4ddb1c 100644 --- a/tensorflow/compiler/xla/client/lib/qr_test.cc +++ b/tensorflow/compiler/xla/client/lib/qr_test.cc @@ -27,12 +27,15 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/tensor_float_32_utils.h" namespace { using QrTest = xla::ClientLibraryTestBase; XLA_TEST_F(QrTest, Simple) { + // Test fails with TensorFloat-32 enabled + tensorflow::enable_tensor_float_32_execution(false); xla::XlaBuilder builder(TestName()); xla::Array2D a_vals({ @@ -61,6 +64,8 @@ XLA_TEST_F(QrTest, Simple) { } XLA_TEST_F(QrTest, ZeroDiagonal) { + // Test fails with TensorFloat-32 enabled + tensorflow::enable_tensor_float_32_execution(false); xla::XlaBuilder builder(TestName()); xla::Array2D a_vals({ @@ -88,6 +93,8 @@ XLA_TEST_F(QrTest, ZeroDiagonal) { } XLA_TEST_F(QrTest, SimpleBatched) { + // Test fails with TensorFloat-32 enabled + tensorflow::enable_tensor_float_32_execution(false); xla::XlaBuilder builder(TestName()); xla::Array3D a_vals({ diff --git a/tensorflow/compiler/xla/client/lib/slicing_test.cc b/tensorflow/compiler/xla/client/lib/slicing_test.cc index 8e2e713c45c..10e27285f02 100644 --- a/tensorflow/compiler/xla/client/lib/slicing_test.cc +++ b/tensorflow/compiler/xla/client/lib/slicing_test.cc @@ -206,10 +206,12 @@ XLA_TEST_F(SlicingTest, DoubleEmptyIndexSelect) { xla::XlaOp input, index; Literal l(ShapeUtil::MakeShape(F32, {0, 1, 2, 0})); Literal i(ShapeUtil::MakeShape(S32, {0})); - auto input_data = - CreateParameterAndTransferLiteral(0, l, "input", &builder, &input); - auto index_data = - CreateParameterAndTransferLiteral(1, i, "index", &builder, &index); + TF_ASSERT_OK_AND_ASSIGN( + auto input_data, + CreateParameterAndTransferLiteral(0, l, "input", &builder, &input)); + TF_ASSERT_OK_AND_ASSIGN( + auto index_data, + CreateParameterAndTransferLiteral(1, i, "index", &builder, &index)); TorchIndexSelect(input, index, 0); ComputeAndCompareLiteral(&builder, l, {input_data.get(), index_data.get()}); } @@ -219,8 +221,9 @@ XLA_TEST_F(SlicingTest, EmptyIndexSelectNonZero) { xla::XlaOp input, index; Literal l(ShapeUtil::MakeShape(F32, {0, 2})); - auto input_data = - CreateParameterAndTransferLiteral(0, l, "input", &builder, &input); + TF_ASSERT_OK_AND_ASSIGN( + auto input_data, + CreateParameterAndTransferLiteral(0, l, "input", &builder, &input)); auto index_data = CreateR1Parameter({0, 0, 0}, 1, "index", &builder, &index); TorchIndexSelect(input, index, 0); diff --git a/tensorflow/compiler/xla/client/lib/sorting.cc b/tensorflow/compiler/xla/client/lib/sorting.cc index 750237c2000..abb0054558f 100644 --- a/tensorflow/compiler/xla/client/lib/sorting.cc +++ b/tensorflow/compiler/xla/client/lib/sorting.cc @@ -16,6 +16,9 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/sorting.h" #include "tensorflow/compiler/xla/client/lib/comparators.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/loops.h" +#include "tensorflow/compiler/xla/client/lib/slicing.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" @@ -27,6 +30,20 @@ XlaOp TopK(XlaOp input, int64 k) { return builder->ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input)); int last_dim = input_shape.dimensions_size() - 1; + int64 last_dim_size = input_shape.dimensions(last_dim); + // TODO(b/148796364): tune these constants for better performance. + const int64 kPerPartitionSize = 8192; // 2^13 + const int64 kLastDimSizeThreshold = 524288; // 2^19 + const int64 kMinNumPartitions = 8; + const int64 kMinimalK = 1000; + if ((k >= kMinimalK) && (k < kPerPartitionSize) && + (kPerPartitionSize / k > 2) && last_dim_size >= kLastDimSizeThreshold) { + int64 num_partitions = + CeilOfRatio(last_dim_size - k, kPerPartitionSize - k); + if (num_partitions >= kMinNumPartitions) { + return TopKWithPartitions(input, k, num_partitions); + } + } Shape iota_shape = ShapeUtil::MakeShape(S32, AsInt64Slice(input_shape.dimensions())); @@ -80,30 +97,35 @@ XlaOp TopKWithPartitions(XlaOp input, int64 k, int64 num_partitions) { } } - XlaOp values, indices; - for (int64 partition = 0; partition < num_partitions; partition++) { - std::vector start_indices(input_shape.dimensions_size(), 0); - std::vector limit_indices(input_dims.begin(), input_dims.end()); - std::vector strides(input_shape.dimensions_size(), 1); - start_indices[last_dim] = partition * per_partition_size; - limit_indices[last_dim] = - std::min((partition + 1) * per_partition_size, last_dim_size); - // Slice value and indices for this partition.. - XlaOp sliced_input = Slice(input, start_indices, limit_indices, strides); + auto topk_body_fn = + [&](XlaOp partition, absl::Span values_and_indices, + XlaBuilder* builder) -> StatusOr> { + auto values = values_and_indices[0]; + auto indices = values_and_indices[1]; + auto input = values_and_indices[2]; + auto iota_s32 = values_and_indices[3]; + + // Slice value and indices for this partition. + XlaOp start = Mul(Add(partition, ConstantR0(builder, 1)), + ConstantR0(builder, per_partition_size)); + XlaOp sliced_input = + DynamicSliceInMinorDims(input, {start}, {per_partition_size}); XlaOp sliced_indices = - Slice(iota_s32, start_indices, limit_indices, strides); + DynamicSliceInMinorDims(iota_s32, {start}, {per_partition_size}); // Concat with previous results. - if (partition > 0) { - sliced_input = ConcatInDim(builder, {values, sliced_input}, last_dim); - sliced_indices = - ConcatInDim(builder, {indices, sliced_indices}, last_dim); - } + sliced_input = ConcatInDim(builder, {values, sliced_input}, last_dim); + sliced_indices = + ConcatInDim(builder, {indices, sliced_indices}, last_dim); // Sort this slice XlaOp sort_result = Sort({sliced_input, sliced_indices}, CreateScalarGtComputation({input_shape.element_type(), S32}, sliced_indices.builder()), - last_dim, /*is_stable=*/true); + last_dim, true); + + std::vector start_indices(input_shape.dimensions_size(), 0); + std::vector limit_indices(input_dims.begin(), input_dims.end()); + std::vector strides(input_shape.dimensions_size(), 1); // Slice topk. start_indices[last_dim] = 0; limit_indices[last_dim] = k; @@ -111,8 +133,42 @@ XlaOp TopKWithPartitions(XlaOp input, int64 k, int64 num_partitions) { limit_indices, strides); indices = Slice(GetTupleElement(sort_result, 1), start_indices, limit_indices, strides); - } - return Tuple(builder, {values, indices}); + return std::vector{values, indices, input, iota_s32}; + }; + + // Get the values and indices for the first topk so that they can + // be passed to the while loop. + std::vector start_indices(input_shape.dimensions_size(), 0); + std::vector limit_indices(input_dims.begin(), input_dims.end()); + std::vector strides(input_shape.dimensions_size(), 1); + start_indices[last_dim] = 0; + limit_indices[last_dim] = per_partition_size; + // Slice value and indices for the first partition. + XlaOp sliced_input = Slice(input, start_indices, limit_indices, strides); + XlaOp sliced_indices = + Slice(iota_s32, start_indices, limit_indices, strides); + // Sort this slice + XlaOp sort_result = + Sort({sliced_input, sliced_indices}, + CreateScalarGtComputation({input_shape.element_type(), S32}, + sliced_indices.builder()), + last_dim, /*is_stable=*/true); + + // Slice topk. + start_indices[last_dim] = 0; + limit_indices[last_dim] = k; + XlaOp values = Slice(GetTupleElement(sort_result, 0), start_indices, + limit_indices, strides); + XlaOp indices = Slice(GetTupleElement(sort_result, 1), start_indices, + limit_indices, strides); + + // Pass the result of the first TopK to the while loop and do + // num_partition - 1 iterations. + TF_ASSIGN_OR_RETURN(auto values_and_indices, + ForEachIndex(num_partitions - 1, S32, topk_body_fn, + {values, indices, input, iota_s32}, + "topk_with_partition", builder)); + return Tuple(builder, {values_and_indices[0], values_and_indices[1]}); }); } diff --git a/tensorflow/compiler/xla/client/lib/sorting_test.cc b/tensorflow/compiler/xla/client/lib/sorting_test.cc index e01f6faf59e..e820d5bfe6f 100644 --- a/tensorflow/compiler/xla/client/lib/sorting_test.cc +++ b/tensorflow/compiler/xla/client/lib/sorting_test.cc @@ -118,6 +118,19 @@ XLA_TEST_F(SortingTest, TopK3From8Values5Partitions) { ComputeAndCompareR1(&builder, {7.0, 6.0, 5.0}, {}); } +XLA_TEST_F(SortingTest, DISABLED_TopKLargeInput) { + XlaBuilder builder(TestName()); + Array input({2, 1000000}); + input.FillRandom(1.0f, 2.0f); + auto x = + CreateConstantFromLiteral(LiteralUtil::CreateFromArray(input), &builder); + Array2D expected_array(2, 1000); + expected_array.Fill(2.0f); + xla::GetTupleElement(xla::TopK(x, 1000), 0); + ErrorSpec error_spec(10.0f, 10.0f); + ComputeAndCompareR2(&builder, expected_array, {}, error_spec); +} + XLA_TEST_F(SortingTest, TopK3From8Indices5Partitions) { XlaBuilder builder(TestName()); auto x_rev = diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index 1389f548c5d..82a6128025f 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -267,9 +267,9 @@ StatusOr LocalExecutable::RunAsync( } static ShapedBuffer MaybeOwningShapeTreeToShapedBuffer( - Shape const& on_host_shape, const ShapeTree& tree, - se::Platform* platform, int device_ordinal) { - ShapedBuffer result(on_host_shape, tree.shape(), platform, device_ordinal); + const ShapeTree& tree, se::Platform* platform, + int device_ordinal) { + ShapedBuffer result(tree.shape(), platform, device_ordinal); auto it = tree.begin(); auto out_it = result.buffers().begin(); for (; it != tree.end(); ++it, ++out_it) { @@ -299,8 +299,8 @@ StatusOr LocalExecutable::RunAsync( shaped_buffer_ptrs.reserve(arguments.size()); for (size_t i = 0; i < arguments.size(); ++i) { shaped_buffers.push_back(MaybeOwningShapeTreeToShapedBuffer( - *argument_host_shapes[i], arguments[i].Buffers(), - backend_->platform(), stream->parent()->device_ordinal())); + arguments[i].Buffers(), backend_->platform(), + stream->parent()->device_ordinal())); shaped_buffer_ptrs.push_back(&shaped_buffers.back()); } diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 34d78f9d933..41212e69b2e 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/match.h" +#include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/types/span.h" @@ -42,6 +43,7 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/macros.h" namespace xla { @@ -117,14 +119,23 @@ HloComputationProto CreateReduceOr(int64 reducer_id, } return reducer; } + +bool InstrIsSetBound(const HloInstructionProto* instr_proto) { + HloOpcode opcode = StringToHloOpcode(instr_proto->opcode()).ValueOrDie(); + if (opcode == HloOpcode::kCustomCall && + instr_proto->custom_call_target() == "SetBound") { + return true; + } + return false; +} } // namespace namespace internal { -XlaOp XlaBuilderBuildFusion(XlaBuilder* builder, - absl::Span operands, - absl::string_view fusion_kind, - const XlaComputation& fused_computation) { +XlaOp XlaBuilderFriend::BuildFusion(XlaBuilder* builder, + absl::Span operands, + absl::string_view fusion_kind, + const XlaComputation& fused_computation) { return builder->ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; instr.set_fusion_kind(std::string(fusion_kind)); @@ -138,6 +149,21 @@ XlaOp XlaBuilderBuildFusion(XlaBuilder* builder, }); } +XlaOp XlaBuilderFriend::BuildBitcast(XlaBuilder* builder, XlaOp operand, + const Shape& shape) { + return builder->ReportErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + *instr.mutable_shape() = shape.ToProto(); + return builder->AddInstruction(std::move(instr), HloOpcode::kBitcast, + {operand}); + }); +} + +HloInstructionProto* XlaBuilderFriend::GetInstruction(XlaOp op) { + return &op.builder() + ->instructions_[op.builder()->handle_to_index_[op.handle_]]; +} + } // namespace internal XlaOp operator-(XlaOp x) { return Neg(x); } @@ -293,7 +319,6 @@ void XlaBuilder::IsConstantVisitor(const int64 op_handle, // GetDimensionSize is always considered constant in XLA -- If a dynamic // dimension is presented, -1 is returned. break; - // Non functional ops. case HloOpcode::kRng: case HloOpcode::kAllReduce: @@ -306,6 +331,11 @@ void XlaBuilder::IsConstantVisitor(const int64 op_handle, // cannot be constant. We cannot set is_functional=false in other similar // cases since we're already relying on IsConstant to return true. case HloOpcode::kCustomCall: + if (instr.custom_call_target() == "SetBound") { + // Set bound is considered constant -- the bound is used as the value. + break; + } + TF_FALLTHROUGH_INTENDED; case HloOpcode::kWhile: // TODO(b/32495713): We aren't checking the condition and body // computations themselves. @@ -661,8 +691,10 @@ XlaOp XlaBuilder::BinaryOpNoBroadcast(HloOpcode binop, const Shape& shape, StatusOr XlaBuilder::Compare(const Shape& shape, XlaOp lhs, XlaOp rhs, ComparisonDirection direction) { - return Compare(shape, lhs, rhs, direction, - Comparison::DefaultComparisonType(shape.element_type())); + TF_ASSIGN_OR_RETURN(auto operand_shape, GetShape(lhs)); + return Compare( + shape, lhs, rhs, direction, + Comparison::DefaultComparisonType(operand_shape.element_type())); } StatusOr XlaBuilder::Compare(const Shape& shape, XlaOp lhs, XlaOp rhs, @@ -1692,7 +1724,9 @@ XlaOp XlaBuilder::CustomCall( const string& call_target_name, absl::Span operands, const Shape& shape, const string& opaque, absl::optional> operand_shapes_with_layout, - bool has_side_effect) { + bool has_side_effect, + absl::Span>> + output_operand_aliasing) { return ReportErrorOrReturn([&]() -> StatusOr { if (absl::StartsWith(call_target_name, "$")) { return InvalidArgument( @@ -1724,7 +1758,8 @@ XlaOp XlaBuilder::CustomCall( } } return CustomCallInternal(call_target_name, operands, shape, opaque, - operand_shapes_with_layout, has_side_effect); + operand_shapes_with_layout, has_side_effect, + output_operand_aliasing); }); } @@ -1732,7 +1767,9 @@ StatusOr XlaBuilder::CustomCallInternal( const string& call_target_name, absl::Span operands, const Shape& shape, const string& opaque, absl::optional> operand_shapes_with_layout, - bool has_side_effect) { + bool has_side_effect, + absl::Span>> + output_operand_aliasing) { HloInstructionProto instr; *instr.mutable_shape() = shape.ToProto(); instr.set_custom_call_target(call_target_name); @@ -1744,6 +1781,16 @@ StatusOr XlaBuilder::CustomCallInternal( } } instr.set_custom_call_has_side_effect(has_side_effect); + for (const auto& pair : output_operand_aliasing) { + auto aliasing = instr.add_custom_call_output_operand_aliasing(); + aliasing->set_operand_index(pair.second.first); + for (int64 index : pair.second.second) { + aliasing->add_operand_shape_index(index); + } + for (int64 index : pair.first) { + aliasing->add_output_shape_index(index); + } + } return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands); } @@ -1751,7 +1798,9 @@ XlaOp XlaBuilder::CustomCall( const string& call_target_name, absl::Span operands, const XlaComputation& computation, const Shape& shape, const string& opaque, absl::optional> operand_shapes_with_layout, - bool has_side_effect) { + bool has_side_effect, + absl::Span>> + output_operand_aliasing) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; if (absl::StartsWith(call_target_name, "$")) { @@ -1789,6 +1838,16 @@ XlaOp XlaBuilder::CustomCall( } } AddCalledComputation(computation, &instr); + for (const auto& pair : output_operand_aliasing) { + auto aliasing = instr.add_custom_call_output_operand_aliasing(); + aliasing->set_operand_index(pair.second.first); + for (int64 index : pair.second.second) { + aliasing->add_operand_shape_index(index); + } + for (int64 index : pair.first) { + aliasing->add_output_shape_index(index); + } + } return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands); }); } @@ -3086,6 +3145,15 @@ StatusOr XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) { case HloOpcode::kConstant: SetInstructionAsConstant(new_instr, id, new_shape, false); break; + case HloOpcode::kCustomCall: + if (instr_proto->custom_call_target() == "SetBound") { + SetInstructionAsConstant(new_instr, id, new_shape, true); + break; + } else { + return InvalidArgument( + "Dynamic inferencing on custom call %s is not supported", + instr_proto->DebugString()); + } case HloOpcode::kParameter: SetInstructionAsConstant(new_instr, id, new_shape, true); break; @@ -3149,7 +3217,8 @@ StatusOr XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) { TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(instr_proto->opcode())); if (next_operand >= instr_proto->operand_ids_size() || - opcode == HloOpcode::kGetDimensionSize) { + opcode == HloOpcode::kGetDimensionSize || + InstrIsSetBound(instr_proto)) { // No more operands to process, process self. int64 new_id = ++global_id; VLOG(3) << "new_id: " << new_id << "instr: " << instr_proto->name(); @@ -3235,26 +3304,33 @@ StatusOr XlaBuilder::BuildConstantSubGraph( LookUpInstructionByHandle(handle)); if (instr_proto->opcode() == - HloOpcodeString(HloOpcode::kGetDimensionSize)) { - // At this point, BuildConstantSubGraph should never encounter a - // GetDimensionSize with a dynamic dimension. IsConstant check would have - // failed at the beginning of this function. - // - // Replace GetDimensionSize with a Constant representing the static bound - // of the shape. - int64 dimension = instr_proto->dimensions(0); - int64 operand_handle = instr_proto->operand_ids(0); - TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto, - LookUpInstructionByHandle(operand_handle)); + HloOpcodeString(HloOpcode::kGetDimensionSize) || + InstrIsSetBound(instr_proto)) { + int32 constant_value = -1; + if (instr_proto->opcode() == + HloOpcodeString(HloOpcode::kGetDimensionSize)) { + // At this point, BuildConstantSubGraph should never encounter a + // GetDimensionSize with a dynamic dimension. IsConstant check would + // have failed at the beginning of this function. + // + // Replace GetDimensionSize with a Constant representing the static + // bound of the shape. + int64 dimension = instr_proto->dimensions(0); + int64 operand_handle = instr_proto->operand_ids(0); + TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto, + LookUpInstructionByHandle(operand_handle)); - int32 constant_dimension_size = -1; - if (!(operand_proto->shape().is_dynamic_dimension(dimension) && - dynamic_dimension_is_minus_one)) { - constant_dimension_size = - static_cast(operand_proto->shape().dimensions(dimension)); + if (!(operand_proto->shape().is_dynamic_dimension(dimension) && + dynamic_dimension_is_minus_one)) { + constant_value = + static_cast(operand_proto->shape().dimensions(dimension)); + } + } else { + TF_RET_CHECK( + absl::SimpleAtoi(instr_proto->backend_config(), &constant_value)); } - Literal literal = LiteralUtil::CreateR0(constant_dimension_size); + Literal literal = LiteralUtil::CreateR0(constant_value); HloInstructionProto const_instr; *const_instr.mutable_shape() = literal.shape().ToProto(); @@ -3286,6 +3362,9 @@ StatusOr XlaBuilder::BuildConstantSubGraph( if (instr_src->opcode() == HloOpcodeString(HloOpcode::kGetDimensionSize)) { continue; } + if (InstrIsSetBound(instr_src)) { + continue; + } auto* instr = entry.add_instructions(); *instr = *instr_src; @@ -3826,31 +3905,39 @@ XlaOp Call(XlaBuilder* builder, const XlaComputation& computation, return builder->Call(computation, operands); } -XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name, - absl::Span operands, const Shape& shape, - const string& opaque, bool has_side_effect) { +XlaOp CustomCall( + XlaBuilder* builder, const string& call_target_name, + absl::Span operands, const Shape& shape, const string& opaque, + bool has_side_effect, + absl::Span>> + output_operand_aliasing) { return builder->CustomCall(call_target_name, operands, shape, opaque, /*operand_shapes_with_layout=*/absl::nullopt, - has_side_effect); + has_side_effect, output_operand_aliasing); } -XlaOp CustomCallWithComputation(XlaBuilder* builder, - const string& call_target_name, - absl::Span operands, - const XlaComputation& computation, - const Shape& shape, const string& opaque, - bool has_side_effect) { - return builder->CustomCall( - call_target_name, operands, computation, shape, opaque, - /*operand_shapes_with_layout=*/absl::nullopt, has_side_effect); +XlaOp CustomCallWithComputation( + XlaBuilder* builder, const string& call_target_name, + absl::Span operands, const XlaComputation& computation, + const Shape& shape, const string& opaque, bool has_side_effect, + absl::Span>> + output_operand_aliasing) { + return builder->CustomCall(call_target_name, operands, computation, shape, + opaque, + /*operand_shapes_with_layout=*/absl::nullopt, + has_side_effect, output_operand_aliasing); } -XlaOp CustomCallWithLayout(XlaBuilder* builder, const string& call_target_name, - absl::Span operands, const Shape& shape, - absl::Span operand_shapes_with_layout, - const string& opaque, bool has_side_effect) { +XlaOp CustomCallWithLayout( + XlaBuilder* builder, const string& call_target_name, + absl::Span operands, const Shape& shape, + absl::Span operand_shapes_with_layout, const string& opaque, + bool has_side_effect, + absl::Span>> + output_operand_aliasing) { return builder->CustomCall(call_target_name, operands, shape, opaque, - operand_shapes_with_layout, has_side_effect); + operand_shapes_with_layout, has_side_effect, + output_operand_aliasing); } XlaOp Complex(const XlaOp lhs, const XlaOp rhs, diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index f841a1a75a0..f736ae1d470 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -47,13 +47,21 @@ namespace xla { class XlaBuilder; class XlaOp; +class HloInstruction; namespace internal { -XlaOp XlaBuilderBuildFusion(XlaBuilder* builder, - absl::Span operands, - absl::string_view fusion_kind, - const XlaComputation& fused_computation); +struct XlaBuilderFriend { + static XlaOp BuildFusion(XlaBuilder* builder, + absl::Span operands, + absl::string_view fusion_kind, + const XlaComputation& fused_computation); + + static XlaOp BuildBitcast(XlaBuilder* builder, XlaOp operand, + const Shape& shape); + + static HloInstructionProto* GetInstruction(XlaOp op); +}; } // namespace internal @@ -107,6 +115,7 @@ class XlaOp { friend class XlaBuilder; friend class MlirHloBuilder; + friend struct internal::XlaBuilderFriend; // < 0 means "invalid handle". int64 handle_; @@ -164,6 +173,15 @@ class XlaBuilder { // OpMetadata attached until a call to ClearOpMetadata. void SetOpMetadata(OpMetadata metadata) { metadata_ = std::move(metadata); } + // Swaps the passed op metadata with the ones currently set. + // + // Returns the old op metadata. + OpMetadata SwapOpMetadata(OpMetadata metadata) { + OpMetadata old_metadata = std::move(metadata_); + metadata_ = std::move(metadata); + return old_metadata; + } + // Similar to SetOpMetadata, but only set the metadata for the next op. void SetOneShotOpMetadata(OpMetadata metadata) { metadata_ = std::move(metadata); @@ -584,7 +602,9 @@ class XlaBuilder { const string& call_target_name, absl::Span operands, const Shape& shape_with_layout, const string& opaque, absl::optional> operand_shapes_with_layout, - bool has_side_effect); + bool has_side_effect, + absl::Span>> + output_operand_aliasing); // Internal version of CustomCall without computation that doesn't do op // specific error handling and expects arguments to be legal. CustomCall @@ -593,14 +613,18 @@ class XlaBuilder { const string& call_target_name, absl::Span operands, const Shape& shape_with_layout, const string& opaque, absl::optional> operand_shapes_with_layout, - bool has_side_effect); + bool has_side_effect, + absl::Span>> + output_operand_aliasing); XlaOp CustomCall( const string& call_target_name, absl::Span operands, const XlaComputation& computation, const Shape& shape_with_layout, const string& opaque, absl::optional> operand_shapes_with_layout, - bool has_side_effect); + bool has_side_effect, + absl::Span>> + output_operand_aliasing); XlaOp Reduce(XlaOp operand, XlaOp init_value, const XlaComputation& computation, @@ -1049,18 +1073,25 @@ class XlaBuilder { const string& outfeed_config); friend XlaOp Call(XlaBuilder* builder, const XlaComputation& computation, absl::Span operands); - friend XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name, - absl::Span operands, const Shape& shape, - const string& opaque, bool has_side_effect); + friend XlaOp CustomCall( + XlaBuilder* builder, const string& call_target_name, + absl::Span operands, const Shape& shape, + const string& opaque, bool has_side_effect, + absl::Span>> + output_operand_aliasing); friend XlaOp CustomCallWithComputation( XlaBuilder* builder, const string& call_target_name, absl::Span operands, const XlaComputation& computation, - const Shape& shape, const string& opaque, bool has_side_effect); + const Shape& shape, const string& opaque, bool has_side_effect, + absl::Span>> + output_operand_aliasing); friend XlaOp CustomCallWithLayout( XlaBuilder* builder, const string& call_target_name, absl::Span operands, const Shape& shape_with_layout, absl::Span operand_shapes_with_layout, const string& opaque, - bool has_side_effect); + bool has_side_effect, + absl::Span>> + output_operand_aliasing); friend XlaOp Complex(XlaOp real, XlaOp imag, absl::Span broadcast_dimensions); friend XlaOp Conj(XlaOp operand); @@ -1284,9 +1315,7 @@ class XlaBuilder { return LookUpInstructionByHandleInternal(op.handle()); } - friend XlaOp internal::XlaBuilderBuildFusion( - XlaBuilder* builder, absl::Span operands, - absl::string_view fusion_kind, const XlaComputation& fused_computation); + friend struct internal::XlaBuilderFriend; }; // RAII-style object: sets the current sharding assignment in builder on @@ -1339,6 +1368,25 @@ class XlaScopedFrontendAttributesAssignment { TF_DISALLOW_COPY_AND_ASSIGN(XlaScopedFrontendAttributesAssignment); }; + +// RAII-style object: sets the current op metadata in builder on construction, +// and sets back to the previous assignment on destruction. +class XlaScopedOpMetadataAssignment { + public: + XlaScopedOpMetadataAssignment(xla::XlaBuilder* builder, OpMetadata metadata) + : builder_(builder) { + saved_ = builder_->SwapOpMetadata(metadata); + } + + ~XlaScopedOpMetadataAssignment() { builder_->SwapOpMetadata(saved_); } + + private: + xla::XlaBuilder* const builder_; + OpMetadata saved_; + + TF_DISALLOW_COPY_AND_ASSIGN(XlaScopedOpMetadataAssignment); +}; + // Free functions for building XlaOps. The intention is that these will // become the public API for building XlaOps rather than calling methods on // XlaBuilder directly. @@ -1777,30 +1825,39 @@ XlaOp Call(XlaBuilder* builder, const XlaComputation& computation, // backend, a call instruction is emitted which targets a symbol with the name // |call_target_name|. |call_target_name| and |opaque| can arbitrary strings, // but |call_target_name| should be short as it may be used in labels. |opaque| -// can encode arbitrarily large amounts of information. -XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name, - absl::Span operands, const Shape& shape, - const string& opaque = "", bool has_side_effect = false); +// can encode arbitrarily large amounts of information. |has_side_effect| +// specifies whether the instruction can have side effects. +// |output_operand_aliasing| specifies a list of output/operand buffer pairs +// that alias each other, where the output buffer is represented as a +// ShapeIndex, and the operand buffer is represented as the operand index and +// the ShapeIndex. +XlaOp CustomCall( + XlaBuilder* builder, const string& call_target_name, + absl::Span operands, const Shape& shape, + const string& opaque = "", bool has_side_effect = false, + absl::Span>> + output_operand_aliasing = {}); // Overload which constructs a custom call that applies an Xla computation. -XlaOp CustomCallWithComputation(XlaBuilder* builder, - const string& call_target_name, - absl::Span operands, - const XlaComputation& computation, - const Shape& shape, const string& opaque = "", - bool has_side_effect = false); +XlaOp CustomCallWithComputation( + XlaBuilder* builder, const string& call_target_name, + absl::Span operands, const XlaComputation& computation, + const Shape& shape, const string& opaque = "", bool has_side_effect = false, + absl::Span>> + output_operand_aliasing = {}); // Overload which constructs a custom call with fixed layouts. The operands will // have the layouts specified by |operand_shapes_with_layout| when provided to // external code, and the external code is expected to produce a result with the // layout specified by |shape_with_layout|. All shapes in |shape_with_layout| // and |operand_shapes_with_layout| must have layouts. -XlaOp CustomCallWithLayout(XlaBuilder* builder, const string& call_target_name, - absl::Span operands, - const Shape& shape_with_layout, - absl::Span operand_shapes_with_layout, - const string& opaque = "", - bool has_side_effect = false); +XlaOp CustomCallWithLayout( + XlaBuilder* builder, const string& call_target_name, + absl::Span operands, const Shape& shape_with_layout, + absl::Span operand_shapes_with_layout, + const string& opaque = "", bool has_side_effect = false, + absl::Span>> + output_operand_aliasing = {}); // The following methods enqueue element-wise binary arithmetic operations // onto the computation. The shapes of the operands have to match unless one diff --git a/tensorflow/compiler/xla/client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_builder_test.cc index 7011c946203..bfd13c8ddf5 100644 --- a/tensorflow/compiler/xla/client/xla_builder_test.cc +++ b/tensorflow/compiler/xla/client/xla_builder_test.cc @@ -19,6 +19,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/debug_options_flags.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -1203,5 +1205,16 @@ TEST_F(XlaBuilderTest, AddFrontendAttribute) { TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); ExpectInstructionsAttributesMatch(*module, expected); } + +TEST_F(XlaBuilderTest, ComparisonType) { + XlaBuilder b(TestName()); + (void)Le(ConstantR0(&b, 1), ConstantR0(&b, 2)); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + ASSERT_THAT(root, op::Compare(op::Constant(), op::Constant())); + EXPECT_EQ(Comparison::Type::kSigned, + DynCast(root)->type()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/g3doc/index.md b/tensorflow/compiler/xla/g3doc/index.md index 51d666fba9a..45abd9b4c92 100644 --- a/tensorflow/compiler/xla/g3doc/index.md +++ b/tensorflow/compiler/xla/g3doc/index.md @@ -121,8 +121,8 @@ example. ### AOT (Ahead-of-time) compilation for CPU with `tfcompile` -You can also use a standalone [`tfcompile`](./tfcompile) tool, -which converts TensorFlow graph into executable code (for x86-64 CPU only). +You can also use a standalone [`tfcompile`](./tfcompile.md) tool, which converts +TensorFlow graph into executable code (for x86-64 CPU only). ## Inspect compiled programs @@ -196,7 +196,7 @@ Apart from TensorFlow, XLA programs can be generated by: [XLA source](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/xla) on Github! - diff --git a/tensorflow/compiler/xla/g3doc/tutorials/autoclustering_xla.ipynb b/tensorflow/compiler/xla/g3doc/tutorials/autoclustering_xla.ipynb index c0160f2766c..d7799093583 100644 --- a/tensorflow/compiler/xla/g3doc/tutorials/autoclustering_xla.ipynb +++ b/tensorflow/compiler/xla/g3doc/tutorials/autoclustering_xla.ipynb @@ -169,7 +169,7 @@ " model.set_weights(initial_weights)\n", "\n", "warmup(model, x_train, y_train, x_test, y_test)\n", - "%time train_model(model, x_train, y_train, x_test, y_test)\n", + "train_model(model, x_train, y_train, x_test, y_test)\n", "\n", "scores = model.evaluate(x_test, y_test, verbose=1)\n", "print('Test loss:', scores[0])\n", diff --git a/tensorflow/compiler/xla/pjrt/BUILD b/tensorflow/compiler/xla/pjrt/BUILD index 5b3b75eb352..1ff96db8637 100644 --- a/tensorflow/compiler/xla/pjrt/BUILD +++ b/tensorflow/compiler/xla/pjrt/BUILD @@ -1,3 +1,4 @@ +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") @@ -25,7 +26,7 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor", + "//tensorflow/core/platform:stream_executor", "@com_google_absl//absl/memory", "@com_google_absl//absl/synchronization", ], @@ -108,7 +109,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor", + "//tensorflow/core/platform:stream_executor", "//tensorflow/stream_executor:event", "@com_google_absl//absl/memory", "@com_google_absl//absl/synchronization", @@ -140,11 +141,12 @@ cc_library( "//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_cost_analysis", "//tensorflow/compiler/xla/service:maybe_owning_device_memory", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options", - "//tensorflow/core:allocator", "//tensorflow/core:lib", + "//tensorflow/core/framework:allocator", "//tensorflow/core/profiler/lib:connected_traceme", "//tensorflow/core/profiler/lib:traceme", "//tensorflow/core/profiler/lib:traceme_encode", @@ -166,6 +168,47 @@ cc_library( ], ) +cc_library( + name = "tpu_client", + srcs = ["tpu_client.cc"], + hdrs = ["tpu_client.h"], + visibility = [ + "//learning/brain/research/jax:__subpackages__", + "//learning/deepmind/tensorflow/tensorfn:__subpackages__", + "//learning/pathways:__subpackages__", + "//tensorflow/compiler/xla:friends", + ], + deps = [ + ":local_device_state", + ":pjrt_client", + ":tracked_device_buffer", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/service:computation_placer", + "//tensorflow/compiler/xla/service:shaped_buffer", + "//tensorflow/core:lib", + "//tensorflow/core/tpu:tpu_executor_dlsym_initializer", + "//tensorflow/core/tpu:tpu_on_demand_compiler", + "//tensorflow/stream_executor:device_memory", + "//tensorflow/stream_executor:stream", + "//tensorflow/stream_executor/lib", + "//tensorflow/stream_executor/tpu:tpu_computation_placer", + "//tensorflow/stream_executor/tpu:tpu_executable_interface", + "//tensorflow/stream_executor/tpu:tpu_executor", + "//tensorflow/stream_executor/tpu:tpu_executor_interface", + "//tensorflow/stream_executor/tpu:tpu_platform_interface", + "//tensorflow/stream_executor/tpu:tpu_topology_external", + "//tensorflow/stream_executor/tpu:tpu_transfer_manager", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + ], +) + cc_library( name = "interpreter_device", srcs = ["interpreter_device.cc"], @@ -208,6 +251,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core/common_runtime:bfc_allocator", "//tensorflow/core/common_runtime/gpu:gpu_mem_allocator", + "//tensorflow/core:lib_internal", "//tensorflow/stream_executor:tf_allocator_adapter", ] + if_cuda(["@local_config_nccl//:nccl"]), ) diff --git a/tensorflow/compiler/xla/pjrt/cpu_device.cc b/tensorflow/compiler/xla/pjrt/cpu_device.cc index e2543bda7df..c571ef2a4df 100644 --- a/tensorflow/compiler/xla/pjrt/cpu_device.cc +++ b/tensorflow/compiler/xla/pjrt/cpu_device.cc @@ -28,7 +28,7 @@ CpuDevice::CpuDevice(int id, : PjRtDevice(id, std::move(local_device_state), kCpuPlatformName, /*device_kind=*/kCpuPlatformName) {} -StatusOr> GetCpuClient(bool asynchronous) { +StatusOr> GetCpuClient(bool asynchronous) { TF_ASSIGN_OR_RETURN(se::Platform * platform, PlatformUtil::GetPlatform("Host")); if (platform->VisibleDeviceCount() <= 0) { @@ -56,7 +56,7 @@ StatusOr> GetCpuClient(bool asynchronous) { devices.push_back(std::move(device)); } - return std::make_shared( + return std::make_unique( kCpuPlatformName, client, std::move(devices), /*host_id=*/0, /*allocator=*/nullptr, /*host_memory_allocator=*/nullptr, /*should_stage_host_to_device_transfers=*/false, diff --git a/tensorflow/compiler/xla/pjrt/cpu_device.h b/tensorflow/compiler/xla/pjrt/cpu_device.h index ad0079b1c4a..1036d8fedbb 100644 --- a/tensorflow/compiler/xla/pjrt/cpu_device.h +++ b/tensorflow/compiler/xla/pjrt/cpu_device.h @@ -28,7 +28,7 @@ class CpuDevice : public PjRtDevice { CpuDevice(int id, std::unique_ptr local_device_state); }; -StatusOr> GetCpuClient(bool asynchronous); +StatusOr> GetCpuClient(bool asynchronous); } // namespace xla diff --git a/tensorflow/compiler/xla/pjrt/distributed/BUILD b/tensorflow/compiler/xla/pjrt/distributed/BUILD index 175b4268dda..4cd6093dc48 100644 --- a/tensorflow/compiler/xla/pjrt/distributed/BUILD +++ b/tensorflow/compiler/xla/pjrt/distributed/BUILD @@ -1,4 +1,5 @@ -load("//tensorflow/core/platform:build_config.bzl", "tf_proto_library_cc") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") +load("//tensorflow/core/platform:build_config.bzl", "tf_proto_library") load("//tensorflow:tensorflow.bzl", "tf_grpc_cc_dependency") load("//tensorflow:tensorflow.bzl", "tf_cc_test") @@ -6,7 +7,7 @@ licenses(["notice"]) package(default_visibility = ["//tensorflow:internal"]) -tf_proto_library_cc( +tf_proto_library( name = "protocol_proto", srcs = ["protocol.proto"], has_services = 1, diff --git a/tensorflow/compiler/xla/pjrt/gpu_multistream_test.cc b/tensorflow/compiler/xla/pjrt/gpu_multistream_test.cc index 298c41c7f58..c56b41861b0 100644 --- a/tensorflow/compiler/xla/pjrt/gpu_multistream_test.cc +++ b/tensorflow/compiler/xla/pjrt/gpu_multistream_test.cc @@ -28,7 +28,7 @@ namespace { // computation wait for the inputs to be produced before executing. TEST(GpuMultiStream, Basics) { TF_ASSERT_OK_AND_ASSIGN( - std::shared_ptr client, + std::unique_ptr client, GetNvidiaGpuClient(/*asynchronous=*/true, GpuAllocatorConfig(), /*distributed_client=*/nullptr, /*node_id=*/0)); diff --git a/tensorflow/compiler/xla/pjrt/interpreter_device.cc b/tensorflow/compiler/xla/pjrt/interpreter_device.cc index c1149f2dbf9..376d8687892 100644 --- a/tensorflow/compiler/xla/pjrt/interpreter_device.cc +++ b/tensorflow/compiler/xla/pjrt/interpreter_device.cc @@ -28,7 +28,7 @@ InterpreterDevice::InterpreterDevice( : PjRtDevice(id, std::move(local_device_state), kInterpreterPlatformName, /*device_kind=*/kInterpreterPlatformName) {} -StatusOr> GetInterpreterClient() { +StatusOr> GetInterpreterClient() { TF_ASSIGN_OR_RETURN(se::Platform * platform, PlatformUtil::GetPlatform("Interpreter")); if (platform->VisibleDeviceCount() != 1) { @@ -50,7 +50,7 @@ StatusOr> GetInterpreterClient() { absl::make_unique(0, std::move(device_state)); devices.push_back(std::move(device)); - return std::make_shared( + return std::make_unique( kInterpreterPlatformName, client, std::move(devices), /*host_id=*/0, /*allocator=*/nullptr, /*host_memory_allocator=*/nullptr, /*should_stage_host_to_device_transfers=*/false, diff --git a/tensorflow/compiler/xla/pjrt/interpreter_device.h b/tensorflow/compiler/xla/pjrt/interpreter_device.h index cf732f70124..4038d8dbf11 100644 --- a/tensorflow/compiler/xla/pjrt/interpreter_device.h +++ b/tensorflow/compiler/xla/pjrt/interpreter_device.h @@ -29,7 +29,7 @@ class InterpreterDevice : public PjRtDevice { std::unique_ptr local_device_state); }; -StatusOr> GetInterpreterClient(); +StatusOr> GetInterpreterClient(); } // namespace xla diff --git a/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc b/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc index 512ff81ef6e..df92921c39d 100644 --- a/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc +++ b/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/common_runtime/gpu/gpu_host_allocator.h" #include "tensorflow/core/common_runtime/gpu/gpu_mem_allocator.h" +#include "tensorflow/core/util/env_var.h" #include "tensorflow/stream_executor/tf_allocator_adapter.h" namespace xla { @@ -89,12 +90,20 @@ StatusOr> CreateBFCAllocator( CHECK_GT(local_devices.size(), 0); const se::Platform* platform = local_devices.front()->executor()->platform(); std::vector allocators; + bool enable_unified_memory; + Status status = tensorflow::ReadBoolFromEnvVar("TF_FORCE_UNIFIED_MEMORY", + false, &enable_unified_memory); + if (!status.ok()) { + LOG(ERROR) << "Unable to read TF_FORCE_UNIFIED_MEMORY: " + << status.error_message(); + } + for (auto& local_device : local_devices) { se::StreamExecutor* executor = local_device->executor(); int device_ordinal = executor->device_ordinal(); auto sub_allocator = absl::make_unique( executor, tensorflow::PlatformGpuId(device_ordinal), - /*use_unified_memory=*/false, + /*use_unified_memory=*/enable_unified_memory, /*alloc_visitors=*/std::vector(), /*free_visitors=*/std::vector()); @@ -104,7 +113,10 @@ StatusOr> CreateBFCAllocator( return Unavailable("Failed to query available memory from device %i", device_ordinal); } - size_t allocator_memory = free_memory * memory_fraction; + // To allow full GPU memory to be visible to the BFC allocator if using + // unified memory. + size_t allocator_memory = + enable_unified_memory ? total_memory : free_memory * memory_fraction; if (preallocate) { LOG(INFO) << "XLA backend allocating " << allocator_memory << " bytes on device " << device_ordinal @@ -289,7 +301,7 @@ GpuDevice::GpuDevice(int id, : PjRtDevice(id, std::move(local_device_state), kGpuPlatformName, std::move(device_kind), node_id) {} -StatusOr> GetNvidiaGpuClient( +StatusOr> GetNvidiaGpuClient( bool asynchronous, const GpuAllocatorConfig& allocator_config, std::shared_ptr distributed_client, int node_id) { TF_ASSIGN_OR_RETURN(LocalClient * xla_client, GetGpuXlaClient()); @@ -312,13 +324,12 @@ StatusOr> GetNvidiaGpuClient( devices = BuildLocalDevices(std::move(local_device_states)); } - std::shared_ptr pyclient = std::make_shared( + return std::unique_ptr(std::make_unique( "gpu", xla_client, std::move(devices), /*node_id=*/node_id, std::move(allocator), std::move(host_memory_allocator), /*should_stage_host_to_device_transfers=*/true, - /*gpu_run_options=*/std::move(gpu_run_options)); - return pyclient; + /*gpu_run_options=*/std::move(gpu_run_options))); } } // namespace xla diff --git a/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.h b/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.h index 4f22a169bd8..f480a37429a 100644 --- a/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.h +++ b/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.h @@ -53,7 +53,7 @@ struct GpuAllocatorConfig { // distributed_client may be nullptr in non-distributed settings. // distributed_client should not be Open()ed before calling this function. -StatusOr> GetNvidiaGpuClient( +StatusOr> GetNvidiaGpuClient( bool asynchronous, const GpuAllocatorConfig& allocator_config, std::shared_ptr distributed_client, int node_id); diff --git a/tensorflow/compiler/xla/pjrt/pjrt_client.cc b/tensorflow/compiler/xla/pjrt/pjrt_client.cc index 099c7729679..02ae37b71db 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_client.cc +++ b/tensorflow/compiler/xla/pjrt/pjrt_client.cc @@ -90,6 +90,7 @@ limitations under the License. #include "tensorflow/compiler/xla/pjrt/local_device_state.h" #include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h" #include "tensorflow/compiler/xla/service/executable.h" +#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h" #include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" @@ -282,6 +283,11 @@ StatusOr> PjRtClient::GetParametersThatMustBeDonated( return parameters_to_donate; } +std::unique_ptr PjRtClient::GetHloCostAnalysis() { + return absl::make_unique( + client_->backend().compiler()->ShapeSizeBytesFunction()); +} + namespace { // Ensures that it is safe to deallocate any buffers that have been enqueued in @@ -894,6 +900,7 @@ void PjRtBuffer::WaitForOutstandingDonationHold() { StatusOr> PjRtBuffer::Release( bool wait_for_operations_to_complete) { + tensorflow::profiler::TraceMe trace_me("PjRtBuffer::Release"); std::shared_ptr device_buffer; TrackedDeviceBuffer::StreamAndEventContainer events; { @@ -1257,6 +1264,14 @@ StatusOr> PjRtBuffer::CopyToDevice( "CopyToDevice cannot accept the same source and destination devices"); } + // Copying across PjRtClients involves a copy through the host. + if (dst_device->client() != client_) { + TF_ASSIGN_OR_RETURN(std::shared_ptr literal, ToLiteral()); + return FromHostBuffer(literal->untyped_data(), literal->shape(), + HostBufferSemantics::kZeroCopy, nullptr, + dst_device->client(), dst_device); + } + TF_ASSIGN_OR_RETURN(LocalDeviceState * dst_local_device, dst_device->GetLocalDeviceState()); LocalDeviceState* transfer_local_device = diff --git a/tensorflow/compiler/xla/pjrt/pjrt_client.h b/tensorflow/compiler/xla/pjrt/pjrt_client.h index 1bed959e3e6..cb4ef9da85b 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_client.h +++ b/tensorflow/compiler/xla/pjrt/pjrt_client.h @@ -195,6 +195,9 @@ class PjRtClient { return absl::optional(); } + // Returns a backend-specific HLO cost analysis visitor. + virtual std::unique_ptr GetHloCostAnalysis(); + protected: friend class PjRtBuffer; virtual void EnqueueCrossHostReceive( @@ -560,8 +563,10 @@ class PjRtBuffer { return GetBufferWithHold(ScopedHold::kExternalReference); } - // Copies the buffer to device `dst_device`. Returns an error if the buffer is - // already on dst_device. + // Copies the buffer to device `dst_device`, performing a d2d transfer when + // `dst_device` is sharing the same Client, and performing a d2h and h2d copy + // if `dst_device` lives on a different Client. + // Returns an error if the buffer is already on dst_device. StatusOr> CopyToDevice(PjRtDevice* dst_device); // Copies the buffer to the remote device encoded in serialized_descriptor. @@ -695,7 +700,7 @@ struct ExecuteOptions { int32 launch_id = 0; // If non-null, an opaque context passed to an execution that may be used to // supply additional arguments to a derived class of PjRtExecutable. - ExecuteContext* context = nullptr; + const ExecuteContext* context = nullptr; }; // Represents a compiled computation that can be executed given handles to diff --git a/tensorflow/compiler/xla/pjrt/tpu_client.cc b/tensorflow/compiler/xla/pjrt/tpu_client.cc new file mode 100644 index 00000000000..b2af6e79980 --- /dev/null +++ b/tensorflow/compiler/xla/pjrt/tpu_client.cc @@ -0,0 +1,247 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/pjrt/tpu_client.h" + +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/pjrt/local_device_state.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" +#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h" +#include "tensorflow/compiler/xla/service/shaped_buffer.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/platform/casts.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/stream_executor/device_memory.h" +#include "tensorflow/stream_executor/lib/statusor.h" +#include "tensorflow/stream_executor/stream.h" +#include "tensorflow/stream_executor/tpu/tpu_computation_placer.h" +#include "tensorflow/stream_executor/tpu/tpu_executable_interface.h" +#include "tensorflow/stream_executor/tpu/tpu_executor_interface.h" +#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h" +#include "tensorflow/stream_executor/tpu/tpu_stream.h" + +namespace tf_tpu = tensorflow::tpu; + +namespace xla { +namespace { + +class TpuDeviceState : public LocalDeviceState { + public: + TpuDeviceState(se::StreamExecutor* executor, LocalClient* client, + bool asynchronous); + + Status ThenMemcpyDeviceToDevice(se::Stream* transfer_stream, + se::Stream* dst_stream, + se::DeviceMemoryBase src_buffer, + se::DeviceMemoryBase dst_buffer) override; +}; + +TpuDeviceState::TpuDeviceState(se::StreamExecutor* executor, + LocalClient* client, bool asynchronous) + : LocalDeviceState(executor, client, LocalDeviceState::kAsynchronous, + asynchronous, + /*allow_event_reuse=*/false) {} + +Status TpuDeviceState::ThenMemcpyDeviceToDevice( + se::Stream* transfer_stream, se::Stream* dst_stream, + se::DeviceMemoryBase src_buffer, se::DeviceMemoryBase dst_buffer) { + auto* transfer_tpu_stream = tensorflow::down_cast( + transfer_stream->implementation()); + tf_tpu::TpuTopologyExternal topology = + tf_tpu::TpuPlatformInterface::GetRegisteredPlatform()->topology(); + // TODO(b/157179600): use device-to-device transfers when implemented instead + // of copying via host. + if (topology.version() == kTpuV4) { + LOG(WARNING) + << "device-to-device transfers not yet implemented, copying via host"; + auto* dst_tpu_stream = + tensorflow::down_cast(dst_stream->implementation()); + TF_RET_CHECK(src_buffer.size() == dst_buffer.size()); + auto host_tmp = std::make_unique(src_buffer.size()); + TF_RETURN_IF_ERROR(transfer_tpu_stream->EnqueueTransferDeviceToHost( + src_buffer, host_tmp.get(), src_buffer.size())); + dst_stream->ThenWaitFor(transfer_stream); + TF_RETURN_IF_ERROR(dst_tpu_stream->EnqueueTransferHostToDevice( + dst_buffer, host_tmp.get(), dst_buffer.size())); + transfer_stream->ThenWaitFor(dst_stream); + char* tmp = host_tmp.release(); + dst_stream->ThenDoHostCallback([tmp] { delete[] tmp; }); + } else { + TF_RETURN_IF_ERROR(transfer_tpu_stream->EnqueueOnTpuDeviceSendRecvLocal( + src_buffer, dst_buffer)); + } + return Status::OK(); +} + +class PjRtTpuClient : public PjRtClient { + public: + PjRtTpuClient(LocalClient* client, + std::vector> devices, int host_id, + tf_tpu::TpuPlatformInterface* tpu_platform); + + StatusOr GetDefaultDeviceAssignment( + int num_replicas, int num_partitions) const override; + + bool EnqueueD2DTransfersOnSrcStream() const override { + return tpu_platform_->topology().version() == kTpuV4; + } + + StatusOr> ExecutableFingerprint( + const PjRtExecutable& executable) const override; + + private: + tf_tpu::TpuPlatformInterface* tpu_platform_; +}; + +PjRtTpuClient::PjRtTpuClient(LocalClient* client, + std::vector> devices, + int host_id, + tf_tpu::TpuPlatformInterface* tpu_platform) + : PjRtClient("tpu", client, std::move(devices), host_id, + /*allocator=*/nullptr, + /*host_memory_allocator=*/nullptr, + /*should_stage_host_to_device_transfers=*/false, + /*gpu_run_options=*/nullptr), + tpu_platform_(tpu_platform) {} + +StatusOr PjRtTpuClient::GetDefaultDeviceAssignment( + int num_replicas, int num_partitions) const { + tf_tpu::TpuPlatformInterface* platform = + tf_tpu::TpuPlatformInterface::GetRegisteredPlatform(); + tf_tpu::TpuHostLocationExternal host = platform->GetTpuHostLocation(); + int num_local_devices = host.Cores(kTensorCore).size(); + if (num_replicas * num_partitions <= num_local_devices) { + return tf_tpu::TpuComputationPlacer::AssignLocalDevices(host, num_replicas, + num_partitions); + } + // Fallback to default global device assignment if we can't run locally. + return PjRtClient::GetDefaultDeviceAssignment(num_replicas, num_partitions); +} + +StatusOr> PjRtTpuClient::ExecutableFingerprint( + const PjRtExecutable& executable) const { + if (executable.client() != this) { + return InvalidArgument( + "Passed executable from different client (platform '%s') to " + "PjRtTpuClient::ExecutableFingerprint", + executable.client()->platform_name()); + } + if (executable.executables().size() > 1) { + LOG(INFO) << "ExecutableFingerprint not fully implemented for MPMD " + "executables, fingerprint may not be unique."; + } + xla::TpuExecutableInterface* tpu_executable = + tensorflow::down_cast( + executable.executables()[0]->executable()); + return absl::optional(tpu_executable->fingerprint()); +} + +StatusOr>> GetTpuDevices( + LocalClient* client, + std::vector> local_device_states) { + std::vector> devices; + tf_tpu::TpuTopologyExternal topology = + tf_tpu::TpuPlatformInterface::GetRegisteredPlatform()->topology(); + + std::map core_id_to_device_ordinal; + for (int i = 0; i < client->device_count(); ++i) { + se::StreamExecutor* executor = + client->backend().stream_executor(i).ValueOrDie(); + tf_tpu::TpuExecutorInterface* tpu_executor = + tensorflow::down_cast( + executor->implementation()); + core_id_to_device_ordinal[tpu_executor->GetCoreLocationExternal().Id()] = i; + } + + for (const tf_tpu::TpuCoreLocationExternal& core : + topology.cores(TpuCoreTypeEnum::kTensorCore)) { + auto it = core_id_to_device_ordinal.find(core.Id()); + int device_ordinal = + (it != core_id_to_device_ordinal.end()) ? it->second : -1; + int host_id = topology.IdForHost(core.host_coordinates()); + const tf_tpu::TpuDimensionsExternal coords = core.chip_coordinates(); + std::array coords_array = {coords.x, coords.y, coords.z}; + std::unique_ptr local_device_state; + if (device_ordinal >= 0) { + local_device_state = std::move(local_device_states[device_ordinal]); + } + auto device = absl::make_unique( + core, std::move(local_device_state), host_id, coords_array, + std::string(tf_tpu::TpuVersionEnumToString(topology.version()))); + devices.push_back(std::move(device)); + } + return devices; +} + +} // namespace + +StatusOr> GetTpuClient( + bool asynchronous, absl::Duration init_retry_timeout) { + tf_tpu::TpuPlatformInterface* platform = + tf_tpu::TpuPlatformInterface::GetRegisteredPlatform( + /*initialize_platform=*/true, /*num_tries=*/1); + if (platform == nullptr) { + return InvalidArgument("TpuPlatform is not available."); + } + // NOTE: We retry in a loop since some pod failures are transient (e.g. some + // RPCs may timeout waiting for other hosts to come up, but will succeed + // at a later point if retried). + auto start = absl::Now(); + // TODO(b/165870356): TpuPlatform::Initialized() always returns true! + auto status = platform->Initialize({}); + while (!platform->Initialized()) { + status = platform->Initialize({}); + if (!status.ok()) { + LOG(ERROR) << "Platform initialization failed: " << status; + if ((absl::Now() - start) >= init_retry_timeout) { + return status; + } + } + } + if (platform->VisibleDeviceCount() <= 0) { + return InvalidArgument("No TPU devices found."); + } + LocalClientOptions options; + options.set_platform(platform); + TF_ASSIGN_OR_RETURN(LocalClient * client, + ClientLibrary::GetOrCreateLocalClient(options)); + + std::vector> local_device_states; + local_device_states.reserve(client->device_count()); + for (int i = 0; i < client->device_count(); ++i) { + se::StreamExecutor* executor = + client->backend().stream_executor(i).ValueOrDie(); + local_device_states.push_back( + absl::make_unique(executor, client, asynchronous)); + } + + TF_ASSIGN_OR_RETURN(auto devices, + GetTpuDevices(client, std::move(local_device_states))); + int host_id = platform->GetTpuHostLocation().Id(); + + return std::shared_ptr(absl::make_unique( + client, std::move(devices), host_id, platform)); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/pjrt/tpu_client.h b/tensorflow/compiler/xla/pjrt/tpu_client.h new file mode 100644 index 00000000000..1a458c1480b --- /dev/null +++ b/tensorflow/compiler/xla/pjrt/tpu_client.h @@ -0,0 +1,60 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PJRT_TPU_CLIENT_H_ +#define TENSORFLOW_COMPILER_XLA_PJRT_TPU_CLIENT_H_ + +#include +#include +#include + +#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/stream_executor/tpu/tpu_topology.h" + +namespace xla { + +class PjRtTpuDevice : public PjRtDevice { + public: + PjRtTpuDevice(const tensorflow::tpu::TpuCoreLocationExternal core, + std::unique_ptr local_device_state, + int host_id, const std::array& coords, + std::string device_kind) + : PjRtDevice(core.Id(), std::move(local_device_state), + /*platform_name=*/"tpu", std::move(device_kind), host_id), + core_(core), + coords_(coords) {} + + const std::array& coords() const { return coords_; } + int core_on_chip() const { return core_.index(); } + const tensorflow::tpu::TpuCoreLocationExternal core() const { return core_; } + + std::string DebugString() const override { + return absl::StrFormat("TPU_%i(host=%i,(%i,%i,%i,%i))", id(), host_id(), + coords_[0], coords_[1], coords_[2], core_.index()); + } + + private: + const tensorflow::tpu::TpuCoreLocationExternal core_; + const std::array coords_; +}; + +StatusOr> GetTpuClient( + bool asynchronous, + absl::Duration init_retry_timeout = absl::ZeroDuration()); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_PJRT_TPU_CLIENT_H_ diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index 046fadb405b..2db43727fbd 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -1,3 +1,4 @@ +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow/core/platform:build_config.bzl", "pyx_library") load("//tensorflow/compiler/xla:xla.bzl", "xla_py_test_deps") load("//tensorflow:tensorflow.bzl", "tf_cc_test") @@ -6,7 +7,10 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow:tensorflow.bzl", "pybind_extension") package( - default_visibility = ["//tensorflow:internal"], + default_visibility = [ + "//learning/pathways/data_parallel/jax:__subpackages__", + "//tensorflow:internal", + ], licenses = ["notice"], # Apache 2.0 ) @@ -24,6 +28,18 @@ pyx_library( srcs = ["custom_call_for_test.pyx"], ) +py_test( + name = "xla_client_backend_independent_test", + srcs = ["xla_client_backend_independent_test.py"], + python_version = "PY3", + tags = ["no_oss"], # TODO(phawkins): This test passes, but requires --config=monolithic. + deps = [ + ":xla_client", + ":xla_extension", + "@absl_py//absl/testing:absltest", + ] + xla_py_test_deps(), +) + py_library( name = "xla_client_test", testonly = 1, @@ -264,6 +280,7 @@ cc_library( "//tensorflow/core/platform:status", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:optional", "@pybind11", ], @@ -284,6 +301,7 @@ cc_library( "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:comparators", + "//tensorflow/compiler/xla/client/lib:lu_decomposition", "//tensorflow/compiler/xla/client/lib:math", "//tensorflow/compiler/xla/client/lib:qr", "//tensorflow/compiler/xla/client/lib:self_adjoint_eig", @@ -428,6 +446,7 @@ pybind_extension( "//tensorflow/compiler/xla/pjrt:interpreter_device", "//tensorflow/compiler/xla/pjrt:nvidia_gpu_device", "//tensorflow/compiler/xla/pjrt:pjrt_client", + "//tensorflow/compiler/xla/pjrt:tpu_client", "//tensorflow/compiler/xla/pjrt:tracked_device_buffer", "//tensorflow/compiler/xla/pjrt/distributed", "//tensorflow/compiler/xla/pjrt/distributed:client", diff --git a/tensorflow/compiler/xla/python/dlpack.cc b/tensorflow/compiler/xla/python/dlpack.cc index 974816407ee..67afa25d23e 100644 --- a/tensorflow/compiler/xla/python/dlpack.cc +++ b/tensorflow/compiler/xla/python/dlpack.cc @@ -321,7 +321,8 @@ StatusOr> DLPackManagedTensorToBuffer( DLDataTypeToPrimitiveType(dlmt->dl_tensor.dtype)); std::vector minor_to_major; - if (dlmt->dl_tensor.strides && !absl::c_find(dimensions, 0)) { + if (dlmt->dl_tensor.strides && + absl::c_find(dimensions, 0) == dimensions.end()) { absl::Span strides( reinterpret_cast(dlmt->dl_tensor.strides), dlmt->dl_tensor.ndim); diff --git a/tensorflow/compiler/xla/python/jax_jit.cc b/tensorflow/compiler/xla/python/jax_jit.cc index 6594125d493..944b4c20a8a 100644 --- a/tensorflow/compiler/xla/python/jax_jit.cc +++ b/tensorflow/compiler/xla/python/jax_jit.cc @@ -28,11 +28,12 @@ limitations under the License. #include #include -#include #include +#include #include "absl/container/flat_hash_map.h" #include "absl/container/inlined_vector.h" +#include "absl/synchronization/notification.h" #include "absl/types/optional.h" #include "pybind11/cast.h" #include "pybind11/numpy.h" @@ -90,9 +91,7 @@ struct ArgSignature { template H AbslHashValue(H h, const ArgSignature& s) { h = H::combine(std::move(h), s.dtype); - if (!s.shape.empty()) { - h = H::combine_contiguous(std::move(h), &s.shape.front(), s.shape.size()); - } + h = H::combine_contiguous(std::move(h), s.shape.data(), s.shape.size()); return h; } @@ -123,17 +122,25 @@ struct CallSignature { std::vector static_args; // A PyTreeDef for each positional dynamic (i.e. not static) argument. std::vector dynamic_positional_args_treedef; - // Keyword arguments. Sorted by the interned keyword pointers. + // Keyword arguments. Sorted by the keyword name. std::vector keyword_args; // Shape and dtype for both the dynamic positional arguments and the keyword - // arguments (sorted by interned keyword pointers). + // arguments (sorted by keyword name). std::vector dynamic_args_signatures; + PjRtDevice* device; bool operator==(const CallSignature& other) const { - return std::tie(dynamic_positional_args_treedef, static_args, keyword_args, - dynamic_args_signatures) == - std::tie(other.dynamic_positional_args_treedef, other.static_args, - other.keyword_args, other.dynamic_args_signatures); + return std::tie(dynamic_positional_args_treedef, keyword_args, + dynamic_args_signatures, device) == + std::tie(other.dynamic_positional_args_treedef, + other.keyword_args, other.dynamic_args_signatures, + other.device) && + // `==` on py:objects is the Python `is`. We need equal. + std::equal(static_args.begin(), static_args.end(), + other.static_args.begin(), other.static_args.end(), + [](const py::object& a, const py::object& b) { + return a.equal(b); + }); } bool operator!=(const CallSignature& other) const { return !(*this == other); @@ -175,12 +182,13 @@ H AbslHashValue(H h, const CallSignature& s) { // TODO(jblespiau): We should either ban non-hashable objects from jit or we // should hash them by object identity. h = H::combine_contiguous(std::move(h), - &s.dynamic_positional_args_treedef.front(), + s.dynamic_positional_args_treedef.data(), s.dynamic_positional_args_treedef.size()); - h = H::combine_contiguous(std::move(h), &s.keyword_args.front(), + h = H::combine_contiguous(std::move(h), s.keyword_args.data(), s.keyword_args.size()); - h = H::combine_contiguous(std::move(h), &s.dynamic_args_signatures.front(), + h = H::combine_contiguous(std::move(h), s.dynamic_args_signatures.data(), s.dynamic_args_signatures.size()); + h = H::combine(std::move(h), s.device); return h; } @@ -188,7 +196,7 @@ std::string CallSignature::DebugString() const { std::vector static_args_str; static_args_str.reserve(static_args.size()); for (auto& static_arg : static_args) { - static_args_str.emplace_back(py::cast(static_arg.str())); + static_args_str.emplace_back(py::cast(py::str(static_arg))); } std::vector signature_str; @@ -222,27 +230,35 @@ std::string CallSignature::DebugString() const { struct CacheEntry { std::shared_ptr executable; - xla::PjRtDevice* device; PyTreeDef out_pytree_def; - // These are the objects required to create a `DeviceArray` object. - // We use Python types within the vector because this is what we will be - // returning to Python. No need to convert back and forth. - // We need py::object to maintain the objects alive. - std::vector out_avals; - std::vector out_lazy_exprs; + // Callables (one for each output) to call on each output to get the Python + // object (usually a DeviceArray) that we should return. + // TODO(jblespiau): The goal of the C++ codepath being to be fast, thus, we + // should not call into Python. It will be trivial to fix this when + // omnistaging is the only option & when DeviceArray and PyBuffer are merged). + std::vector handlers; + + // Ensures a single thread performs the compilation for a given executable. + // + // The first thread (holding the GIL) will create the CacheEntry associated to + // a signature and if the object has been insterted already, other threads + // will wait for the notification. + absl::Notification compilation_complete; + absl::optional compilation_error = absl::nullopt; + // Trivial computation will fallback to Python. + // Running a jax(pmap) will also fallback to Python. + bool fall_back_to_python = false; }; // A `CompiledFunction` is associated to a `jax.jit(f)` and takes care of the // bookkeeping of the different signatures used and the dispatch of calls to -// the correct underlying `PyExecutable`. -// TODO(jblespiau): This class is thread-unsafe. Note that using a mutex for the -// full `Call` will lead to a deadlock because it goes back to Python which will -// release the GIL. +// the correct underlying `PyExecutable`. This class is thread-safe. class CompiledFunction { public: - CompiledFunction(py::function fun, py::function cache_miss_fun, - py::function python_f_jitted, bool jax_enable_x64, - bool jax_disable_jit, std::vector static_argnums); + CompiledFunction(py::function fun, py::function cache_miss, + py::function get_device, py::function get_jax_enable_x64, + py::function get_jax_disable_jit, + std::vector static_argnums); ~CompiledFunction(); // This function will: @@ -259,28 +275,22 @@ class CompiledFunction { return inspect->attr("signature")(fun_); } + int cache_size() const { return executables_.size(); } + private: - CacheEntry& GetCacheEntry(const py::args& args, const py::kwargs& kwargs, + // Returns nullptr if not present in the cache. + CacheEntry* GetCacheEntryIfPresent(const CallSignature& signature); + // Should never return nullptr. + CacheEntry* AddCacheEntry(const py::args& args, const py::kwargs& kwargs, const CallSignature& signature, - absl::optional cache_miss_return); - CacheEntry& SetAndReturnCacheEntry( - const py::args& args, const py::kwargs& kwargs, - const CallSignature& signature, - absl::optional cache_miss_return = absl::nullopt); - bool JitIsDisabled() { return GetDisableJit() || jax_disable_jit_; } + py::object out_and_fastpath_data); + bool JitIsDisabled() { return GetDisableJit() || jax_disable_jit_.value(); } + + bool always_fallback_to_python_ = false; const py::function fun_; // The Python function to jit. - // The Python function in charge of returning a `xla::PyExecutable` from - // the arguments passed to `jitted_f`. - const py::function cache_miss_fun_; - // A function to call as fallback. This is the result of calling the Python - // `jax.jit`. - // TODO(jblespiau): Delete this when the C++ codepath supports all features. - const py::function python_f_jitted_; - - // The value of the Python flag when the object was created. - const bool jax_enable_x64_; - const bool jax_disable_jit_; + // See JAX _cpp_jit in api.py for documentation. + const py::function cache_miss_; // We need to know the static arguments to remove them from the arguments // passed to the underlying PyExecutable. In sorted order. @@ -292,21 +302,43 @@ class CompiledFunction { // `CompiledFunction` is being instantiated from Python, the clients are not // yet available (done after GoogleInit). They will be during the first call // to `Call`. - std::shared_ptr pyclient_ = nullptr; + // A function taking no arguments and returning the default device and whether + // jax.jit has been committed to it. + const py::function get_jax_enable_x64_; + const py::function get_jax_disable_jit_; + const py::function get_device_; + + // The writing of the following is protected by the mutex. + absl::Mutex mu_; + // The value of the Python flag. The value will be computed only during the + // first object call, because GoogleInit must have been executed. + absl::optional jax_enable_x64_ = absl::nullopt; + absl::optional jax_disable_jit_ = absl::nullopt; + + // The logic if the following: + // - if `device` or `backend` are not specified to `jax.jit`, we will use + // the input sticky buffer device, or `default_device_` if there is no + // such sticky buffer. + // - When one of `device` or `backend` is specified, this will determine + // the `default_device_` which will be used as the targeted device. In + // which case, we will always copy input buffers to this device. + std::shared_ptr default_pyclient_ = nullptr; + xla::ClientAndPtr default_pydevice_; xla::PjRtDevice* default_device_ = nullptr; + bool is_committed_; }; -CompiledFunction::CompiledFunction(py::function fun, - py::function cache_miss_fun, - py::function python_f_jitted, - bool jax_enable_x64, bool jax_disable_jit, +CompiledFunction::CompiledFunction(py::function fun, py::function cache_miss, + py::function get_device, + py::function get_jax_enable_x64, + py::function get_jax_disable_jit, std::vector static_argnums) : fun_(std::move(fun)), - cache_miss_fun_(std::move(cache_miss_fun)), - python_f_jitted_(std::move(python_f_jitted)), - jax_enable_x64_(jax_enable_x64), - jax_disable_jit_(jax_disable_jit), - static_argnums_(std::move(static_argnums)) { + cache_miss_(std::move(cache_miss)), + static_argnums_(std::move(static_argnums)), + get_jax_enable_x64_(get_jax_enable_x64), + get_jax_disable_jit_(get_jax_disable_jit), + get_device_(std::move(get_device)) { std::sort(static_argnums_.begin(), static_argnums_.end()); } @@ -318,6 +350,34 @@ CompiledFunction::~CompiledFunction() { namespace { +// The equivalent of the Python jax/lazy.py::is_trivial: +// return (type(lexpr.input) is ArrayVar and +// lexpr.dims == tuple(range(len(lexpr.shape)))) +// +// Expects *only* instances of `DeviceArray`. +bool HasTrivialLazyExpr(py::handle device_array) { + static const auto* lazy_module = + new py::module(py::module::import("jax.lazy")); + + auto lexpr = py::getattr(device_array, "_lazy_expr"); + auto input = py::getattr(lexpr, "input"); + if (!input.get_type().is(lazy_module->attr("ArrayVar"))) { + return false; + } + py::tuple dims = py::cast(lexpr.attr("dims")); + py::tuple shape = py::cast(lexpr.attr("shape")); + + for (int i = 0; i < shape.size(); ++i) { + if (dims[i].is_none()) { + return false; + } + if (py::cast(dims[i]) != i) { + return false; + } + } + return true; +} + // The resulting information of the parsing and conversion of the arguments. struct ParsedArgumentsAsBuffers { // The call signature will be filled during 2 steps: @@ -370,8 +430,10 @@ void FlattenArguments(const py::args& args, const py::kwargs& py_kwargs, // Keyword arguments. std::vector> kwargs(py_kwargs.begin(), py_kwargs.end()); - // We first intern the keys, then sort them (by pointer) and then create - // the signatures. + // We first intern the keys, then sort them (by name, as in the Python path) + // (see also PyTreeDef::Flatten) and then create the signatures. + // TODO(jblespiau): We should be able to sort the keys by interned-key + // pointers, but this requires the Python compilation to do the same. arguments.signature.keyword_args.resize(kwargs.size()); for (size_t i = 0; i < kwargs.size(); ++i) { // Intern the key if not already interned. @@ -388,7 +450,7 @@ void FlattenArguments(const py::args& args, const py::kwargs& py_kwargs, std::sort(kwargs.begin(), kwargs.end(), [](const std::pair& a, const std::pair& b) { - return a.first.ptr() < b.first.ptr(); + return a.first < b.first; }); for (size_t i = 0; i < kwargs.size(); ++i) { arguments.signature.keyword_args[i].key = kwargs[i].first; @@ -457,7 +519,7 @@ StatusOr> ScalarToBuffer( "%s", absl::StrCat( "Not supported: The C++ jax jit execution path, only accepts " "DeviceArray, Numpy arrays, or Python scalars. Got type ", - py::cast(scalar.get_type().str()))); + py::cast(py::str(scalar.get_type())))); } const py::dtype* DtypeTo32BitDtype(const py::dtype& dtype) { @@ -470,28 +532,37 @@ const py::dtype* DtypeTo32BitDtype(const py::dtype& dtype) { static const auto* complex64_dt = new py::dtype("complex64"); static const auto* complex128_dt = new py::dtype("complex128"); - if (dtype == *int64_dt) { + if (dtype.equal(*int64_dt)) { return int32_dt; } - if (dtype == *float64_dt) { + if (dtype.equal(*float64_dt)) { return float32_dt; } - if (dtype == *uint64_dt) { + if (dtype.equal(*uint64_dt)) { return uint32_dt; } - if (dtype == *complex128_dt) { + if (dtype.equal(*complex128_dt)) { return complex64_dt; } return nullptr; } +bool IsFloat0(py::array arg) { + static const auto* dtypes_module = + new py::module(py::module::import("jax.dtypes")); + static const auto* float0_dtype = + new py::handle(dtypes_module->attr("float0")); + return float0_dtype->is(arg.attr("dtype")); +} + // Converts flattened arguments contained in ParsedArgumentsAsBuffers in // place. If arguments are `DeviceArray`, they must all be on the same `Device`. // -// Returns `OkStatus()` on success. +// Returns `OkStatus()` on success. Returning an error should lead to calling +// the Python fallback. Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient, - xla::PjRtDevice* default_device, + xla::PjRtDevice* default_device, bool is_committed, ParsedArgumentsAsBuffers& arguments) { std::vector& arg_buffers = arguments.arg_buffers; auto& keep_alive = arguments.keep_alive; @@ -505,44 +576,49 @@ Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient, const auto& device_array = xla_module->attr("DeviceArray"); static const auto* numpy_module = new py::module(py::module::import("numpy")); - const auto& array = numpy_module->attr("array"); + const auto& np_array = numpy_module->attr("array"); - // TODO(phawkins): consider device stickiness. - // We first check whether any `DeviceArray` is present and whether they are - // attached to any specific device. See also + // When the jitted function is not committed, we first check whether any + // sticky `DeviceArray` is present and on which device they live. See also: // https://github.com/google/jax/pull/1884 // https://github.com/google/jax/pull/1916 for the rationale why the // computation follows the data locality. // It's also similar to PyTorch's behavior. xla::PjRtDevice* data_device = nullptr; - for (py::handle arg : arguments.flat_dynamic_args) { - if (py::isinstance(arg, device_array)) { - xla::PyBuffer* buffer; - try { - // This can fail, e.g. when device_buffer is a `DeviceConstant`. - buffer = py::cast(arg.attr("device_buffer")); - } catch (const py::cast_error& e) { - return InvalidArgument( - "%s", - absl::StrCat("[jaxjit] Unsupported subclass of `DeviceArray`: " - "`device_buffer` field is of type ", - py::cast( - arg.attr("device_buffer").get_type().str()), - " while a `PyBuffer` was expected." + if (is_committed) { + data_device = default_device; + } else { + for (py::handle arg : arguments.flat_dynamic_args) { + // We specically only deal with DeviceArray (not ShardedDeviceArray). + // (Can happen in jit(pmap), e.g. "test_jit_nested_donate_ignored"). + if (arg.get_type().is(device_array)) { + xla::PyBuffer* buffer; + if (arg.attr("_device").is_none()) { // Skip non-sticky devices. + continue; + } + try { + // This can fail, e.g. when device_buffer is a `DeviceConstant`. + buffer = py::cast(arg.attr("device_buffer")); + } catch (const py::cast_error& e) { + return InvalidArgument( + "%s", + absl::StrCat("[jaxjit] Unsupported subclass of `DeviceArray`: " + "`device_buffer` field is of type ", + py::cast( + arg.attr("device_buffer").get_type().str()), + " while a `PyBuffer` was expected." - )); - } - xla::PjRtDevice* device = buffer->buffer()->device(); - if (data_device && (device != data_device)) { - return InvalidArgument( - "%s", - absl::StrCat( - "Arguments to a jit-compiled function must be colocated on the " - "same device. Arguments were found to be on the two following " - "different devices: ", - device->DebugString(), " and ", data_device->DebugString())); - } else { - data_device = device; + )); + } + xla::PjRtDevice* device = buffer->buffer()->device(); + if (data_device && (device != data_device)) { + throw std::invalid_argument(absl::StrCat( + "primitive arguments must be colocated on the same device (" + "C++ jax.jit). Arguments are on devices: ", + device->DebugString(), " and ", data_device->DebugString())); + } else { + data_device = device; + } } } } @@ -550,16 +626,31 @@ Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient, // No `DeviceArray` were found default to `default_device`. data_device = default_device; } + CHECK(data_device); + arguments.signature.device = data_device; xla::PjRtClient* pjrt_client = data_device->client(); for (py::handle arg : arguments.flat_dynamic_args) { - // We do not support here d2d transparent transfers. - // We assumes all the `DeviceArray` are already on the correct and shared - // device. - if (py::isinstance(arg, device_array)) { - xla::PyBuffer* buffer = - py::cast(arg.attr("device_buffer")); - arg_buffers.push_back(buffer->buffer()); + if (arg.get_type().is(device_array)) { + if (!HasTrivialLazyExpr(arg)) { + return InvalidArgument( + "Non-trivial lazy expression not supported in C++. " + "Falling back to Python."); + } + + PyBuffer* buffer = py::cast(arg.attr("device_buffer")); + if (buffer->device().contents == data_device) { + arg_buffers.push_back(buffer->buffer()); + } else { + // source and target platforms are the same, but different device. + // Perform a device-to-device copy. + // buffers from different XLA backends are passed through the host. + std::unique_ptr copied_buffer = + ValueOrThrow(buffer->buffer()->CopyToDevice(data_device)); + arg_buffers.push_back(copied_buffer.get()); + keep_alive.emplace_back(std::move(copied_buffer)); + } + ArgSignature sig; sig.dtype = buffer->shape().element_type(); sig.shape.assign(buffer->shape().dimensions().begin(), @@ -570,12 +661,17 @@ Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient, // TODO(jblespiau): Can we improve this call? Do we need the underlying // GlobalPyRefManager() and co? py::array numpy_array = py::cast(arg); + if (IsFloat0(numpy_array)) { + return InvalidArgument( + "float0 numpy arrays not supported in C++. " + "It will fallback to Python."); + } // If jax_enable_x64 is not set, we need to coerce 32 bits types. // Note that this is calling back to Python! if (!jax_enable_x64) { const py::dtype* to_dtype = DtypeTo32BitDtype(numpy_array.dtype()); if (to_dtype) { - numpy_array = array(numpy_array, to_dtype); + numpy_array = np_array(numpy_array, *to_dtype); } } std::unique_ptr buffer = @@ -587,6 +683,7 @@ Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient, ArgSignature sig; sig.dtype = buffer->shape().element_type(); + sig.weak_type = false; sig.shape.assign(buffer->shape().dimensions().begin(), buffer->shape().dimensions().end()); arguments.signature.dynamic_args_signatures.push_back(sig); @@ -612,121 +709,149 @@ Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient, } // namespace -CacheEntry& CompiledFunction::GetCacheEntry( - const py::args& args, const py::kwargs& kwargs, - const CallSignature& signature, - absl::optional cache_miss_return) { +CacheEntry* CompiledFunction::GetCacheEntryIfPresent( + const CallSignature& signature) { auto found_iterator = executables_.find(signature); if (found_iterator != executables_.end()) { // Cache hit! - return *(found_iterator->second); + if (!found_iterator->second->compilation_complete.HasBeenNotified()) { + py::gil_scoped_release gil_release; + found_iterator->second->compilation_complete.WaitForNotification(); + } + if (found_iterator->second->compilation_error) { + throw std::invalid_argument( + found_iterator->second->compilation_error.value().error_message()); + } + return found_iterator->second.get(); } - return SetAndReturnCacheEntry(args, kwargs, signature, cache_miss_return); + return nullptr; } -CacheEntry& CompiledFunction::SetAndReturnCacheEntry( - const py::args& args, const py::kwargs& kwargs, - const CallSignature& signature, - absl::optional cache_miss_return) { + +CacheEntry* CompiledFunction::AddCacheEntry(const py::args& args, + const py::kwargs& kwargs, + const CallSignature& signature, + py::object out_and_fastpath_data) { // We need to insert the element. auto result = executables_.emplace(signature, std::make_unique()); auto it = result.first; - + CacheEntry* cache_entry = it->second.get(); // CallSignatures in the cache own their keyword argument reference. result.first->first.IncRef(); - // Cache miss? Call the Python cache miss function. - py::tuple executable_and_pytree; - if (cache_miss_return) { - executable_and_pytree = cache_miss_return.value(); - } else { - executable_and_pytree = cache_miss_fun_(*args, **kwargs); + py::tuple tuple = py::cast(out_and_fastpath_data); + CHECK_EQ(tuple.size(), 2); + if (tuple[1].is_none()) { + cache_entry->fall_back_to_python = true; + cache_entry->compilation_complete.Notify(); + return cache_entry; } - if (executable_and_pytree.size() != 4) { - throw std::runtime_error( - "AssertionError: The cache miss function should return 4 " - "arguments."); + + py::tuple executable_handlers_out_tree = py::cast(tuple[1]); + CHECK_EQ(executable_handlers_out_tree.size(), 3); + + auto executable = py::cast>( + executable_handlers_out_tree[0]); + std::vector handlers; + for (const auto& handler : + py::cast(executable_handlers_out_tree[1])) { + handlers.push_back(py::cast(handler)); } - it->second->executable = py::cast>( - std::move(executable_and_pytree[0])); + auto out_tree = py::cast(executable_handlers_out_tree[2]); + + cache_entry->executable = std::move(executable); int num_devices = - it->second->executable->pjrt_executable().local_devices().size(); - if (num_devices != 1) { - throw std::runtime_error(absl::StrCat( - "Running on more than a single device is not currently supported." - "The underlying PjRtExecutable has ", - num_devices)); - } - it->second->device = - it->second->executable->pjrt_executable().local_devices()[0]; - it->second->out_pytree_def = py::cast(executable_and_pytree[1]); + cache_entry->executable->pjrt_executable().local_devices().size(); + // The presence of jit(pmap) is detected from Python. + CHECK_EQ(num_devices, 1); - py::list shaped_arrays = - py::reinterpret_borrow(executable_and_pytree[2]); - py::list lazy_expressions = - py::reinterpret_borrow(executable_and_pytree[3]); + cache_entry->handlers = std::move(handlers); + cache_entry->out_pytree_def = std::move(out_tree); - it->second->out_avals.reserve(shaped_arrays.size()); - it->second->out_lazy_exprs.reserve(lazy_expressions.size()); - - int num_outputs = shaped_arrays.size(); - for (int i = 0; i < num_outputs; ++i) { - py::object shaped_array = - py::reinterpret_borrow(shaped_arrays[i]); - py::object lazy_expr = - py::reinterpret_borrow(lazy_expressions[i]); - - it->second->out_avals.push_back(shaped_array); - it->second->out_lazy_exprs.push_back(lazy_expr); - } - - return *(it->second); + cache_entry->compilation_complete.Notify(); + return cache_entry; } py::object CompiledFunction::Call(py::args args, py::kwargs kwargs) { + if (always_fallback_to_python_) { + return py::cast(cache_miss_(*args, **kwargs))[0]; + } + // Delayed values are retrieved on the first call to `Call`. + if (!default_device_) { + // As we are calling Python code, that may release the GIL, we first hold + // mu_ before holding the GIL. + py::gil_scoped_release gil_release; + { + absl::MutexLock lock1(&mu_); + py::gil_scoped_acquire gil_aquire; + + jax_enable_x64_ = py::cast(get_jax_enable_x64_()); + jax_disable_jit_ = py::cast(get_jax_disable_jit_()); + if (!default_device_) { + py::object device_and_is_committed = get_device_(); + try { + default_pydevice_ = py::cast>( + device_and_is_committed.attr("default_device")); + } catch (const py::cast_error& e) { + // Pathways and Cloud TPU 2VM runtime. + always_fallback_to_python_ = true; + return py::cast(cache_miss_(*args, **kwargs))[0]; + } + default_pyclient_ = default_pydevice_.client; + default_device_ = default_pydevice_.contents; + if (!default_device_) { // UPTC + always_fallback_to_python_ = true; + return py::cast(cache_miss_(*args, **kwargs))[0]; + } + is_committed_ = + py::cast(device_and_is_committed.attr("committed_to_device")); + } + } + } + CHECK(default_device_); if (JitIsDisabled()) { return fun_(*args, **kwargs); } ParsedArgumentsAsBuffers arguments; FlattenArguments(args, kwargs, static_argnums_, arguments); - absl::optional cache_miss_result = absl::nullopt; - if (!default_device_) { - cache_miss_result = cache_miss_fun_(*args, **kwargs); - auto executable = py::cast>( - cache_miss_result.value()[0]); - - pyclient_ = executable->client(); - default_device_ = executable->LocalDevices()[0].contents; - } - - // The C++ jit do not support Tracers arguments yet. The Python-based jit - // function will be called if any of the dynamic arguments is unsupported. - if (!ConvertArgsToBuffers(jax_enable_x64_, *pyclient_, default_device_, - arguments) + // The C++ jit do not support Tracers arguments inputs yet. The Python-based + // jit function will be called if any of the dynamic arguments is unsupported. + if (!ConvertArgsToBuffers(jax_enable_x64_.value(), *default_pyclient_, + default_device_, is_committed_, arguments) .ok()) { - return python_f_jitted_(*args, **kwargs); + return py::cast(cache_miss_(*args, **kwargs))[0]; } - CacheEntry& cache_entry = - GetCacheEntry(args, kwargs, arguments.signature, cache_miss_result); + CacheEntry* cache_entry = GetCacheEntryIfPresent(arguments.signature); + if (!cache_entry) { + py::object out_and_fastpath_data = cache_miss_(*args, **kwargs); + cache_entry = GetCacheEntryIfPresent(arguments.signature); + if (!cache_entry) { + cache_entry = AddCacheEntry(args, kwargs, arguments.signature, + out_and_fastpath_data); + } + CHECK(cache_entry); + if (cache_entry->fall_back_to_python) { + return py::cast(out_and_fastpath_data)[0]; + } + // As we have already computed the results, we can return it. + // It's even *required* e.g. if there are donated arguments, because + // otherwise the buffer which has been donated already will be invalid. + return py::cast(out_and_fastpath_data)[0]; + } + CHECK(cache_entry); + if (cache_entry->fall_back_to_python) { + return py::cast(cache_miss_(*args, **kwargs))[0]; + } std::vector> outputs = - ValueOrThrow(cache_entry.executable->PjRtExecute(arguments.arg_buffers)); - - static const auto* xla_module = - new py::module(py::module::import("jax.interpreters.xla")); - const auto& device_array = xla_module->attr("DeviceArray"); - - const std::vector& out_avals = cache_entry.out_avals; - const std::vector& out_lazy_exprs = cache_entry.out_lazy_exprs; + ValueOrThrow(cache_entry->executable->PjRtExecute(arguments.arg_buffers)); + const std::vector& handlers = cache_entry->handlers; py::list flat_device_arrays; for (int i = 0; i < outputs.size(); ++i) { - flat_device_arrays.append(device_array( - /*aval=*/out_avals[i], /*device=*/outputs[i]->device(), - /*lazy_expr=*/out_lazy_exprs[i], - /*device_buffer=*/std::move(outputs[i]))); + flat_device_arrays.append(handlers[i](std::move(outputs[i]))); } - return cache_entry.out_pytree_def.Unflatten(flat_device_arrays); + return cache_entry->out_pytree_def.Unflatten(flat_device_arrays); } } // namespace @@ -743,17 +868,28 @@ void BuildJaxjitSubmodule(pybind11::module& m) { jitlib.def("get_disable_jit", &GetDisableJit); jitlib.def( "jit", - [](py::function fun, py::function cache_miss_fun, - py::function fallback_on_unsupported_argument, bool jax_enable_x64, - bool jax_disable_jit, + [](py::function fun, py::function cache_miss, py::function get_device, + py::function get_jax_enable_x64, py::function get_jax_disable_jit, std::vector static_argnums) -> std::unique_ptr { return std::make_unique( - std::move(fun), std::move(cache_miss_fun), - std::move(fallback_on_unsupported_argument), jax_enable_x64, - jax_disable_jit, std::move(static_argnums)); + std::move(fun), std::move(cache_miss), std::move(get_device), + std::move(get_jax_enable_x64), std::move(get_jax_disable_jit), + std::move(static_argnums)); }); // Only for testing purposes + cfun.def("_cache_size", &CompiledFunction::cache_size); + jitlib.def("_DtypeTo32BitDtype", [](const py::object obj) -> py::object { + py::dtype dtype = py::dtype::from_args(obj); + const py::dtype* res = DtypeTo32BitDtype(dtype); + if (res) { + return *res; + } else { + return py::none(); + } + }); + jitlib.def("_is_float0", &IsFloat0); + jitlib.def("_is_trivial", &HasTrivialLazyExpr); jitlib.def("_ScalarToBuffer", [](py::handle scalar, bool jax_enable_x64, std::shared_ptr client) { xla::PjRtClient* pjrt_client = client->pjrt_client(); diff --git a/tensorflow/compiler/xla/python/ops.cc b/tensorflow/compiler/xla/python/ops.cc index 3ac4709b160..04e68f9a563 100644 --- a/tensorflow/compiler/xla/python/ops.cc +++ b/tensorflow/compiler/xla/python/ops.cc @@ -23,6 +23,7 @@ limitations under the License. #include "pybind11/attr.h" #include "pybind11/pybind11.h" #include "tensorflow/compiler/xla/client/lib/comparators.h" +#include "tensorflow/compiler/xla/client/lib/lu_decomposition.h" #include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/lib/qr.h" #include "tensorflow/compiler/xla/client/lib/self_adjoint_eig.h" @@ -186,6 +187,13 @@ void BuildOpsSubmodule(py::module* m) { return std::make_pair(qr.q, qr.r); }, py::arg("operand"), py::arg("full_matrices")); + ops.def( + "LU", + [](XlaOp a) -> StatusOr> { + LuDecompositionResult lu = LuDecomposition(a); + return std::make_tuple(lu.lu, lu.pivots, lu.permutation); + }, + py::arg("operand")); ops.def( "Eigh", [](XlaOp a, bool lower, int64 max_iter, @@ -283,6 +291,7 @@ void BuildOpsSubmodule(py::module* m) { ops.def("RandomGammaGrad", &RandomGammaGrad, py::arg("a"), py::arg("x")); ops.def("RegularizedIncompleteBeta", &RegularizedIncompleteBeta, py::arg("a"), py::arg("b"), py::arg("x")); + ops.def("Zeta", &Zeta, py::arg("x"), py::arg("q")); #define BINARY_OP(op) \ ops.def( \ diff --git a/tensorflow/compiler/xla/python/py_client.cc b/tensorflow/compiler/xla/python/py_client.cc index 6df11322564..07b915c640c 100644 --- a/tensorflow/compiler/xla/python/py_client.cc +++ b/tensorflow/compiler/xla/python/py_client.cc @@ -30,6 +30,8 @@ namespace xla { namespace py = pybind11; namespace pprof = tensorflow::tfprof::pprof; +PyClient::PyClient(std::unique_ptr pjrt_client) + : pjrt_client_(std::move(pjrt_client)) {} PyClient::PyClient(std::shared_ptr pjrt_client) : pjrt_client_(std::move(pjrt_client)) {} diff --git a/tensorflow/compiler/xla/python/py_client.h b/tensorflow/compiler/xla/python/py_client.h index f12a4ae4f0a..08249722d6c 100644 --- a/tensorflow/compiler/xla/python/py_client.h +++ b/tensorflow/compiler/xla/python/py_client.h @@ -88,6 +88,7 @@ ClientAndPtr WrapWithClient(std::shared_ptr client, T* contents) { // We use a wrapper class to add Python-specific functionality. class PyClient : public std::enable_shared_from_this { public: + explicit PyClient(std::unique_ptr pjrt_client); explicit PyClient(std::shared_ptr pjrt_client); PjRtClient* pjrt_client() const { return pjrt_client_.get(); } diff --git a/tensorflow/compiler/xla/python/tpu_driver/BUILD b/tensorflow/compiler/xla/python/tpu_driver/BUILD index 4725becdedf..bda1db6a466 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/BUILD +++ b/tensorflow/compiler/xla/python/tpu_driver/BUILD @@ -1,24 +1,24 @@ -load("//tensorflow/core/platform:build_config.bzl", "tf_proto_library_cc") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") +load("//tensorflow/core/platform:build_config.bzl", "tf_proto_library") load("//tensorflow:tensorflow.bzl", "tf_grpc_cc_dependency") load( "//tensorflow/compiler/xla/python/tpu_driver:platform/external/tools.bzl", "external_deps", "go_grpc_library", - "go_proto_library", ) licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//visibility:public"]) -tf_proto_library_cc( +tf_proto_library( name = "tpu_driver_proto", srcs = ["tpu_driver.proto"], cc_api_version = 2, protodeps = [], ) -tf_proto_library_cc( +tf_proto_library( name = "tpu_service_proto", srcs = ["tpu_service.proto"], has_services = 1, @@ -77,6 +77,7 @@ cc_library( cc_library( name = "direct_tpu_driver", srcs = ["direct_tpu_driver.cc"], + compatible_with = [], deps = [ ":tpu_driver", "@com_google_absl//absl/strings:str_format", @@ -115,10 +116,22 @@ cc_library( alwayslink = 1, ) -go_proto_library( - name = "tpu_service_go_proto", - compatible_with = ["//buildenv/target:gce"], - deps = [":tpu_service_proto"], +cc_library( + name = "pod_tpu_driver", + srcs = ["pod_tpu_driver.cc"], + deps = [ + ":grpc_tpu_driver", + ":tpu_driver", + ":tpu_driver_proto_cc", + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/container:flat_hash_set", + "//tensorflow/compiler/xla/pjrt:semaphore", + "//tensorflow/compiler/xla/pjrt:worker_thread", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + tf_grpc_cc_dependency(), + ] + external_deps(), + alwayslink = 1, ) go_grpc_library( diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/BUILD b/tensorflow/compiler/xla/python/tpu_driver/client/BUILD index c460cc36f08..9d98d0cf654 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/BUILD +++ b/tensorflow/compiler/xla/python/tpu_driver/client/BUILD @@ -1,3 +1,9 @@ +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") + +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "filegroup") + +# buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "pybind_extension") package( @@ -11,6 +17,7 @@ cc_library( hdrs = [ "tpu_client.h", ], + compatible_with = [], deps = [ "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -24,11 +31,12 @@ cc_library( "//tensorflow/compiler/xla/python/tpu_driver", "//tensorflow/compiler/xla/python/tpu_driver:direct_tpu_driver", "//tensorflow/compiler/xla/python/tpu_driver:grpc_tpu_driver", + "//tensorflow/compiler/xla/python/tpu_driver:pod_tpu_driver", "//tensorflow/compiler/xla/python/tpu_driver:recording_tpu_driver", "//tensorflow/compiler/xla/python/tpu_driver:tpu_driver_proto_cc", "//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/compiler/xla/service:shaped_buffer", - "//tensorflow/core:allocator", + "//tensorflow/core/framework:allocator", "//tensorflow/core/platform:env", "//tensorflow/core/profiler/lib:traceme", "@com_google_absl//absl/memory", diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc index e4fb2cdfd41..0602d096aaa 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc @@ -588,7 +588,7 @@ PyTpuExecutable::ExecuteResult PyTpuExecutable::ExecuteHelper( static const absl::Duration kWarnExecutionDelay = absl::Seconds(10); // Delay before terminating a stalled execute call. -static const absl::Duration kMaxExecutionDelay = absl::Seconds(120); +static const absl::Duration kMaxExecutionDelay = absl::Minutes(60); Status WaitForExecuteEvent(tpu_driver::Event* event) { absl::optional opt_status; diff --git a/tensorflow/compiler/xla/python/tpu_driver/platform/external/tools.bzl b/tensorflow/compiler/xla/python/tpu_driver/platform/external/tools.bzl index 99b07b6c787..12a4390d317 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/platform/external/tools.bzl +++ b/tensorflow/compiler/xla/python/tpu_driver/platform/external/tools.bzl @@ -16,10 +16,6 @@ Build dependencies and utilities for the TPU driver interface. """ -def go_proto_library(**kwargs): - # A dummy macro placeholder for compatibility reason. - pass - def go_grpc_library(**kwargs): # A dummy macro placeholder for compatibility reason. pass diff --git a/tensorflow/compiler/xla/python/tpu_driver/pod_tpu_driver.cc b/tensorflow/compiler/xla/python/tpu_driver/pod_tpu_driver.cc new file mode 100644 index 00000000000..a5a6cbabb82 --- /dev/null +++ b/tensorflow/compiler/xla/python/tpu_driver/pod_tpu_driver.cc @@ -0,0 +1,977 @@ +// Copyright 2020 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "absl/container/btree_map.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/str_split.h" +#include "absl/synchronization/mutex.h" +#include "tensorflow/compiler/xla/pjrt/semaphore.h" +#include "tensorflow/compiler/xla/pjrt/worker_thread.h" +#include "tensorflow/compiler/xla/python/tpu_driver/grpc_tpu_driver.h" +#include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h" +#include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.pb.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/protobuf/error_codes.pb.h" + +namespace tpu_driver { +namespace { + +#define CHECK_EXISTS_OR_RETURN(container, target_op_id, operation_id) \ + { \ + auto p = CheckHandleExists(container, target_op_id, operation_id); \ + if (p != nullptr) return p; \ + } + +using xla::Status; +using xla::WorkerThread; + +const char kPodTpuDriverPrefix[] = "grpc+pod://"; + +class PodTpuDriver; + +class PodEvent : public Event { + public: + explicit PodEvent(PodTpuDriver* driver, int64_t operation_id) + : driver_(driver), operation_id_(operation_id) {} + int64_t operation_id() const { return operation_id_; } + + xla::Status Await() override; + + absl::optional AwaitWithTimeout( + absl::Duration duration) override; + + void AddCallback(std::function callback) override; + + private: + PodTpuDriver* driver_; + const int64_t operation_id_; +}; + +class ErrorEvent : public PodEvent { + public: + explicit ErrorEvent(PodTpuDriver* driver, int64_t operation_id, Status status) + : PodEvent(driver, operation_id) { + status_ = status; + } + + xla::Status Await() override { return status_; } + absl::optional AwaitWithTimeout( + absl::Duration duration) override { + return status_; + } + void AddCallback(std::function callback) override { + callback(status_); + } + + private: + Status status_; +}; + +class CombinedEvent : public PodEvent { + public: + explicit CombinedEvent(PodTpuDriver* driver, int64_t operation_id, + std::vector> events) + : PodEvent(driver, operation_id), events_(events) { + for (auto& event : events_) { + event->AddCallback([this](Status s) { IncrementAndCheckComplete(s); }); + } + } + + xla::Status Await() override { + for (auto& event : events_) { + TF_RETURN_IF_ERROR(event->Await()); + } + return Status::OK(); + } + + absl::optional AwaitWithTimeout( + absl::Duration duration) override { + for (auto& event : events_) { + auto start_time = absl::Now(); + auto status = event->AwaitWithTimeout(duration); + duration -= absl::Now() - start_time; + if (status == absl::nullopt) { + return absl::nullopt; + } else { + TF_RETURN_IF_ERROR(status.value()); + } + } + return Status::OK(); + } + + void AddCallback(std::function callback) + TF_LOCKS_EXCLUDED(mu_) override { + bool all_events_completed = false; + { + absl::MutexLock l(&mu_); + all_events_completed = events_completed_ == events_.size(); + } + if (all_events_completed) { + callback(event_status_); + } else { + absl::MutexLock l(&mu_); + callbacks_.push_back(std::move(callback)); + } + } + + private: + void IncrementAndCheckComplete(Status s) TF_LOCKS_EXCLUDED(mu_) { + std::vector> callbacks; + { + absl::MutexLock l(&mu_); + + event_status_ = s; + events_completed_++; + if (events_completed_ == events_.size()) { + // Copy callbacks to a temporary to be invoked outside the mutex. + callbacks.assign(callbacks_.begin(), callbacks_.end()); + callbacks_.clear(); + } else { + return; + } + } + + for (const auto& callback : callbacks) { + callback(event_status_); + } + } + + absl::Mutex mu_; + std::vector> events_; + std::vector> callbacks_ ABSL_GUARDED_BY(mu_); + int64_t events_completed_ ABSL_GUARDED_BY(mu_) = 0; + Status event_status_; +}; + +class PodBufferHandle : public BufferHandle { + public: + explicit PodBufferHandle(PodTpuDriver* driver, int64_t operation_id, + int64_t size_in_bytes, + absl::optional shape, + int64_t core_id) + : driver_(driver), + operation_id_(operation_id), + size_in_bytes_(size_in_bytes), + shape_(shape), + event_(std::make_shared(driver_, operation_id_)), + core_id_(core_id) {} + + std::shared_ptr OnReady() override { return event_; } + int64_t size_in_bytes() override { return size_in_bytes_; } + absl::optional shape() override { return shape_; } + + int64_t operation_id() const { return operation_id_; } + int64_t core_id() const { return core_id_; } + + private: + PodTpuDriver* driver_; + const int64_t operation_id_; + const int64_t size_in_bytes_; + const absl::optional shape_; + std::shared_ptr event_; + const int64_t core_id_; +}; + +class PodCompiledProgramHandle : public CompiledProgramHandle { + public: + explicit PodCompiledProgramHandle(PodTpuDriver* driver, int64_t operation_id) + : driver_(driver), + operation_id_(operation_id), + event_(std::make_shared(driver_, operation_id_)) {} + + std::shared_ptr OnReady() override { return event_; } + + xla::Status program_shape(xla::ProgramShapeProto* program_shape) override; + + int64_t operation_id() const { return operation_id_; } + + private: + PodTpuDriver* driver_; + const int64_t operation_id_; + std::shared_ptr event_; +}; + +class PodLoadedProgramHandle : public LoadedProgramHandle { + public: + explicit PodLoadedProgramHandle(PodTpuDriver* driver, int64_t operation_id, + int64_t core_id) + : driver_(driver), + operation_id_(operation_id), + core_id_(core_id), + event_(std::make_shared(driver_, operation_id_)) {} + + std::shared_ptr OnReady() override { return event_; } + + int64_t operation_id() const { return operation_id_; } + int64_t core_id() const { return core_id_; } + + private: + PodTpuDriver* driver_; + const int64_t operation_id_; + const int64_t core_id_; + std::shared_ptr event_; +}; + +struct EventInFlight { + EventInFlight() + : underlying_event(nullptr), + create_fn(nullptr), + incomplete_deps(), + callbacks() {} + + std::shared_ptr underlying_event; + std::function(void)> create_fn; + + absl::flat_hash_set incomplete_deps; + std::vector> callbacks; +}; + +class PodTpuDriver : public TpuDriver { + public: + explicit PodTpuDriver(const TpuDriverConfig& config, + std::shared_ptr<::grpc::ChannelCredentials> creds) + : config_(config), + creds_(creds), + event_thread_(tensorflow::Env::Default(), "grpc_pod_event_thread") { + std::vector workers = absl::StrSplit( + absl::StripPrefix(config.worker(), kPodTpuDriverPrefix), ','); + + int worker_count = 0; + + // Flag for environments where local core # == all cores in TPU system #, + // which means that we are connecting to separate TPU systems or we are in + // a test environment. + bool in_local_core_environment = false; + + for (const auto& worker : workers) { + TpuDriverConfig worker_config(config_); + *(worker_config.mutable_worker()) = absl::StrCat("grpc://", worker); + auto tpu_driver = + CreateGrpcTpuDriver(worker_config, creds_).ConsumeValueOrDie(); + + SystemInfo driver_info; + tpu_driver->QuerySystemInfo(&driver_info); + + if (driver_info.core_count() == driver_info.local_core_size()) { + drivers_.insert({worker_count, std::move(tpu_driver)}); + in_local_core_environment = true; + } else { + drivers_.insert({driver_info.host_id(), std::move(tpu_driver)}); + } + + worker_count++; + } + + absl::flat_hash_set> processed_chips; + + for (int driver_num = 0; driver_num < workers.size(); ++driver_num) { + SystemInfo driver_info; + drivers_[driver_num]->QuerySystemInfo(&driver_info); + + for (const auto& tpu_chip : driver_info.tpu_chip()) { + std::tuple coord{tpu_chip.chip_coord().x(), + tpu_chip.chip_coord().y(), + tpu_chip.chip_coord().z()}; + // We only want to add chips that we have not seen before if we are in a + // TPU pod slice, or we are only seeing local cores (e.g. we are + // connected to individual TPUs or we are in a test environment). + if (!processed_chips.contains(coord) || + driver_info.core_count() == driver_info.local_core_size()) { + *(pod_info_.add_tpu_chip()) = tpu_chip; + processed_chips.insert(coord); + } + } + + *(pod_info_.mutable_cpu()) = driver_info.cpu(); + } + + // Process all the unique chips that we have seen. + int core_count = 0; + for (auto& tpu_chip : *pod_info_.mutable_tpu_chip()) { + for (auto& tpu_core : *tpu_chip.mutable_core()) { + int current_core = tpu_core.id(); + if (in_local_core_environment) { + current_core = core_count; + } + + core_to_driver_.insert( + {current_core, drivers_[tpu_chip.host_id()].get()}); + core_to_driver_id_.insert({current_core, tpu_chip.host_id()}); + core_to_driver_core_.insert({current_core, tpu_core.id()}); + + tpu_core.set_id(current_core); + tpu_core.set_core_on_host_index(current_core); + *(pod_info_.add_local_core()) = tpu_core; + + core_count++; + } + + // We are setting host_id to zero because we want this to look like one + // host with many cores from the perspective of tpu_client.cc. + tpu_chip.set_host_id(0); + } + + pod_info_.set_chip_count(pod_info_.tpu_chip_size()); + pod_info_.set_core_count(pod_info_.local_core_size()); + + // We want this to look like one host with many TPU chips/cores connected. + pod_info_.set_host_count(1); + pod_info_.set_host_id(0); + } + + ~PodTpuDriver() override { + // TODO(frankchn): Unload all handles, and wait for all events to finish. + } + + void QuerySystemInfo(SystemInfo* system_info) override { + *system_info = pod_info_; + } + + xla::Status Reset() override { + for (auto& driver : drivers_) { + TF_RETURN_IF_ERROR(driver.second->Reset()); + } + return xla::Status::OK(); + } + + std::unique_ptr Allocate( + int32_t core_id, MemoryRegion region, int64_t num_bytes, + absl::Span wait_for) override { + int64_t operation_id = GetOperationId(); + auto deps = GetDependencyOperationIds(wait_for); + + ScheduleRequest( + operation_id, + [this, core_id, region, num_bytes, + operation_id]() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + underlying_buffers_.insert( + {operation_id, + core_to_driver_[core_id]->Allocate(core_to_driver_core_[core_id], + region, num_bytes, {})}); + return underlying_buffers_[operation_id]->OnReady(); + }, + deps); + + return absl::make_unique(this, operation_id, num_bytes, + absl::nullopt, core_id); + } + + std::unique_ptr Allocate( + int32_t core_id, MemoryRegion region, const xla::ShapeProto& shape, + absl::Span wait_for) override { + int64_t operation_id = GetOperationId(); + auto deps = GetDependencyOperationIds(wait_for); + + ScheduleRequest( + operation_id, + [this, core_id, region, shape, + operation_id]() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + underlying_buffers_.insert( + {operation_id, + core_to_driver_[core_id]->Allocate(core_to_driver_core_[core_id], + region, shape, {})}); + return underlying_buffers_[operation_id]->OnReady(); + }, + deps); + + return absl::make_unique( + this, operation_id, ComputeBytesFromShape(shape), shape, core_id); + } + + std::unique_ptr AllocateTuple( + int32_t core_id, MemoryRegion region, + absl::Span children, + absl::Span wait_for) override { + int64_t operation_id = GetOperationId(); + auto deps = GetDependencyOperationIds(wait_for); + + std::vector children_ids; + for (int i = 0; i < children.size(); ++i) { + auto child_op_id = + static_cast(children[i])->operation_id(); + deps.insert(child_op_id); + children_ids.push_back(child_op_id); + } + + ScheduleRequest( + operation_id, + [this, core_id, region, children_ids, + operation_id]() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) + -> std::shared_ptr { + std::vector child_buffers; + child_buffers.reserve(children_ids.size()); + for (int i = 0; i < children_ids.size(); ++i) { + CHECK_EXISTS_OR_RETURN(underlying_buffers_, children_ids[i], + operation_id); + child_buffers.push_back(underlying_buffers_[children_ids[i]].get()); + } + + underlying_buffers_.insert( + {operation_id, + core_to_driver_[core_id]->AllocateTuple( + core_to_driver_core_[core_id], region, child_buffers, {})}); + return underlying_buffers_[operation_id]->OnReady(); + }, + deps); + + return absl::make_unique(this, operation_id, 0, + absl::nullopt, core_id); + } + + std::shared_ptr Deallocate( + std::unique_ptr handle, + absl::Span wait_for) override { + int64_t operation_id = GetOperationId(); + auto deps = GetDependencyOperationIds(wait_for); + deps.insert(static_cast(handle.get())->operation_id()); + + auto op_id = static_cast(handle.get())->operation_id(); + auto core_id = static_cast(handle.get())->core_id(); + + ScheduleRequest( + operation_id, + [this, operation_id, op_id, + core_id]() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) -> std::shared_ptr { + CHECK_EXISTS_OR_RETURN(underlying_buffers_, op_id, operation_id); + + auto buf_iter = underlying_buffers_.find(op_id); + auto underlying_hn = std::move(buf_iter->second); + underlying_buffers_.erase(buf_iter); + + return core_to_driver_[core_id]->Deallocate(std::move(underlying_hn), + {}); + }, + deps); + + return std::make_shared(this, operation_id); + } + + std::shared_ptr TransferToDevice( + const void* src, BufferHandle* dst, + absl::Span wait_for) override { + int64_t operation_id = GetOperationId(); + auto deps = GetDependencyOperationIds(wait_for); + deps.insert(static_cast(dst)->operation_id()); + + auto op_id = static_cast(dst)->operation_id(); + auto core_id = static_cast(dst)->core_id(); + + ScheduleRequest( + operation_id, + [this, src, operation_id, op_id, + core_id]() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) -> std::shared_ptr { + CHECK_EXISTS_OR_RETURN(underlying_buffers_, op_id, operation_id); + + auto buf_iter = underlying_buffers_.find(op_id); + return core_to_driver_[core_id]->TransferToDevice( + src, buf_iter->second.get(), {}); + }, + deps); + + return std::make_shared(this, operation_id); + } + + std::shared_ptr TransferFromDevice( + const BufferHandle* src, void* dst, + absl::Span wait_for) override { + int64_t operation_id = GetOperationId(); + auto deps = GetDependencyOperationIds(wait_for); + deps.insert(static_cast(src)->operation_id()); + + auto op_id = static_cast(src)->operation_id(); + auto core_id = static_cast(src)->core_id(); + + ScheduleRequest( + operation_id, + [this, dst, operation_id, op_id, + core_id]() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) -> std::shared_ptr { + CHECK_EXISTS_OR_RETURN(underlying_buffers_, op_id, operation_id); + auto buf_iter = underlying_buffers_.find(op_id); + return core_to_driver_[core_id]->TransferFromDevice( + buf_iter->second.get(), dst, {}); + }, + deps); + + return std::make_shared(this, operation_id); + } + + std::shared_ptr TransferFromDeviceToDevice( + const BufferHandle* src, BufferHandle* dst, + absl::Span wait_for) override { + auto src_core_id = static_cast(src)->core_id(); + auto dst_core_id = static_cast(dst)->core_id(); + + auto src_driver_id = core_to_driver_id_[src_core_id]; + auto dst_driver_id = core_to_driver_id_[dst_core_id]; + + if (src_driver_id == dst_driver_id) { + // They are in the same host, we can schedule it normally + int64_t operation_id = GetOperationId(); + auto deps = GetDependencyOperationIds(wait_for); + deps.insert(static_cast(src)->operation_id()); + deps.insert(static_cast(dst)->operation_id()); + + auto src_op_id = static_cast(src)->operation_id(); + auto dst_op_id = static_cast(dst)->operation_id(); + + ScheduleRequest( + operation_id, + [this, operation_id, src_op_id, dst_op_id, dst_core_id]() + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) -> std::shared_ptr { + CHECK_EXISTS_OR_RETURN(underlying_buffers_, src_op_id, + operation_id); + CHECK_EXISTS_OR_RETURN(underlying_buffers_, dst_op_id, + operation_id); + + auto src_iter = underlying_buffers_.find(src_op_id); + auto dst_iter = underlying_buffers_.find(dst_op_id); + return core_to_driver_[dst_core_id]->TransferFromDeviceToDevice( + src_iter->second.get(), dst_iter->second.get(), {}); + }, + deps); + return std::make_shared(this, operation_id); + } else { + // src and dst are on different hosts, we have to bounce through us. + auto dst_size = dst->size_in_bytes(); + char* host_buf = new char[dst_size]; + + auto src_event = TransferFromDevice(src, host_buf, wait_for); + auto dst_event = TransferToDevice(host_buf, dst, {src_event.get()}); + dst_event->AddCallback( + [src_event, host_buf](xla::Status status) { delete[] host_buf; }); + return dst_event; + } + } + + std::unique_ptr CompileProgram( + const xla::HloProto& source, int32_t num_replicas, + absl::Span wait_for) override { + int64_t operation_id = GetOperationId(); + auto deps = GetDependencyOperationIds(wait_for); + + ScheduleRequest( + operation_id, + [this, operation_id, source, + num_replicas]() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + auto cph_iterator = + underlying_cph_ + .insert( + {operation_id, + std::vector>()}) + .first; + + std::vector> collected_events; + for (int i = 0; i < drivers_.size(); ++i) { + auto current_cph = + drivers_[i]->CompileProgram(source, num_replicas, {}); + cph_iterator->second.push_back(std::move(current_cph)); + collected_events.push_back(cph_iterator->second[i]->OnReady()); + } + return std::make_shared(this, operation_id, + collected_events); + }, + deps); + + return absl::make_unique(this, operation_id); + } + + std::unique_ptr LoadProgram( + int32_t core_id, const CompiledProgramHandle* handle, + absl::Span wait_for) override { + int64_t operation_id = GetOperationId(); + auto deps = GetDependencyOperationIds(wait_for); + deps.insert( + static_cast(handle)->operation_id()); + auto cph_op_id = + static_cast(handle)->operation_id(); + + ScheduleRequest( + operation_id, + [this, operation_id, cph_op_id, + core_id]() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) -> std::shared_ptr { + CHECK_EXISTS_OR_RETURN(underlying_cph_, cph_op_id, operation_id); + auto cph_iter = underlying_cph_.find(cph_op_id); + + underlying_lph_.insert( + {operation_id, + core_to_driver_[core_id]->LoadProgram( + core_to_driver_core_[core_id], + cph_iter->second[core_to_driver_id_[core_id]].get(), {})}); + + return underlying_lph_[operation_id]->OnReady(); + }, + deps); + + return absl::make_unique(this, operation_id, + core_id); + } + + std::shared_ptr UnloadProgram( + std::unique_ptr handle, + absl::Span wait_for) override { + int64_t operation_id = GetOperationId(); + auto deps = GetDependencyOperationIds(wait_for); + deps.insert( + static_cast(handle.get())->operation_id()); + auto op_id = + static_cast(handle.get())->operation_id(); + auto core_id = + static_cast(handle.get())->core_id(); + + ScheduleRequest( + operation_id, + [this, operation_id, op_id, + core_id]() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) -> std::shared_ptr { + CHECK_EXISTS_OR_RETURN(underlying_lph_, op_id, operation_id); + auto lph_iter = underlying_lph_.find(op_id); + auto event = core_to_driver_[core_id]->UnloadProgram( + std::move(lph_iter->second), {}); + underlying_lph_.erase(lph_iter); + + return event; + }, + deps); + + return std::make_shared(this, operation_id); + } + + std::shared_ptr ExecuteProgram( + LoadedProgramHandle* program, absl::Span inputs, + absl::Span outputs, + const xla::DeviceAssignmentProto& device_assignment, + absl::Span wait_for) override { + int64_t operation_id = GetOperationId(); + + auto deps = GetDependencyOperationIds(wait_for); + deps.insert(static_cast(program)->operation_id()); + + auto op_id = static_cast(program)->operation_id(); + auto core_id = static_cast(program)->core_id(); + + std::vector input_op_ids; + std::vector output_op_ids; + + for (auto* input : inputs) { + auto input_dep = + static_cast(input)->operation_id(); + input_op_ids.push_back(input_dep); + deps.insert(input_dep); + } + for (auto* output : outputs) { + auto output_dep = + static_cast(output)->operation_id(); + output_op_ids.push_back(output_dep); + deps.insert(output_dep); + } + + ScheduleRequest( + operation_id, + [this, operation_id, core_id, op_id, input_op_ids, output_op_ids, + device_assignment]() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) + -> std::shared_ptr { + std::vector underlying_inputs; + std::vector underlying_outputs; + + underlying_inputs.reserve(input_op_ids.size()); + for (auto input_op_id : input_op_ids) { + CHECK_EXISTS_OR_RETURN(underlying_buffers_, input_op_id, + operation_id); + underlying_inputs.push_back(underlying_buffers_[input_op_id].get()); + } + underlying_outputs.reserve(output_op_ids.size()); + for (auto output_op_id : output_op_ids) { + CHECK_EXISTS_OR_RETURN(underlying_buffers_, output_op_id, + operation_id); + underlying_outputs.push_back( + underlying_buffers_[output_op_id].get()); + } + + CHECK_EXISTS_OR_RETURN(underlying_lph_, op_id, operation_id); + LoadedProgramHandle* handle = underlying_lph_[op_id].get(); + return core_to_driver_[core_id]->ExecuteProgram( + handle, underlying_inputs, underlying_outputs, device_assignment, + {}); + }, + deps); + + return std::make_shared(this, operation_id); + } + + std::unique_ptr GetLinearizer() override { + return drivers_[0]->GetLinearizer(); + } + + // Helper methods for Event scheduling + + absl::optional WaitForEvent(int64_t event_id, absl::Duration duration) + TF_LOCKS_EXCLUDED(mu_) { + std::shared_ptr underlying_event; + + { + absl::MutexLock l(&mu_); + auto event = events_.find(event_id); + + if (event == events_.end()) { + auto event_status = abnormal_event_status_.find(event_id); + if (event_status == abnormal_event_status_.end()) { + return Status::OK(); + } else { + return event_status->second; + } + } + + auto done = [this, event_id]() { + mu_.AssertHeld(); + // The event was either completed and erased from the map or we have + // an underlying event available to us. + return events_.count(event_id) == 0 || + (events_[event_id]->underlying_event != nullptr && + events_[event_id]->underlying_event.use_count() != 0); + }; + + auto status = mu_.AwaitWithTimeout(absl::Condition(&done), duration); + if (!status) { + return absl::nullopt; + } + + if (events_.count(event_id) > 0) { + underlying_event = events_[event_id]->underlying_event; + } else { + underlying_event = nullptr; + } + } + + // Wait for the underlying event without holding on to the event_lock_, or + // else incoming events will not be processed. + if (underlying_event != nullptr) { + return underlying_event->AwaitWithTimeout(duration); + } else { + absl::MutexLock l(&mu_); + auto event_status = abnormal_event_status_.find(event_id); + if (event_status == abnormal_event_status_.end()) { + return Status::OK(); + } else { + return event_status->second; + } + } + } + + void AddCallbackForEvent(int64_t event_id, std::function fn) + TF_LOCKS_EXCLUDED(mu_) { + absl::MutexLock l(&mu_); + auto event = events_.find(event_id); + + if (event == events_.end()) { + auto event_status = abnormal_event_status_.find(event_id); + if (event_status == abnormal_event_status_.end()) { + fn(Status::OK()); + } else { + fn(event_status->second); + } + } else { + if (event->second->underlying_event != nullptr && + event->second->underlying_event.use_count() != 0) { + event->second->underlying_event->AddCallback(fn); + } else { + event->second->callbacks.push_back(std::move(fn)); + } + } + } + + xla::Status GetCompiledProgramShape(int64_t op_id, + xla::ProgramShapeProto* program_shape) + TF_LOCKS_EXCLUDED(mu_) { + absl::MutexLock l(&mu_); + + auto done = [this, op_id]() { + mu_.AssertHeld(); + return underlying_cph_.contains(op_id); + }; + mu_.Await(absl::Condition(&done)); + + return underlying_cph_[op_id][0]->program_shape(program_shape); + } + + private: + const TpuDriverConfig& config_; + std::shared_ptr<::grpc::ChannelCredentials> creds_; + + absl::flat_hash_map> drivers_; + absl::flat_hash_map core_to_driver_id_; + absl::flat_hash_map core_to_driver_; + absl::flat_hash_map core_to_driver_core_; + SystemInfo pod_info_; + + absl::Mutex mu_; + + absl::flat_hash_map> + underlying_buffers_ ABSL_GUARDED_BY(mu_); + absl::flat_hash_map>> + underlying_cph_ ABSL_GUARDED_BY(mu_); + absl::flat_hash_map> + underlying_lph_ ABSL_GUARDED_BY(mu_); + + absl::btree_map> events_ + ABSL_GUARDED_BY(mu_); + absl::flat_hash_map abnormal_event_status_ + ABSL_GUARDED_BY(mu_); + + std::atomic operation_id_counter_{0}; + + WorkerThread event_thread_; + + int64_t GetOperationId() { return operation_id_counter_++; } + + absl::flat_hash_set GetDependencyOperationIds( + absl::Span wait_for) { + absl::flat_hash_set deps; + for (auto* event : wait_for) { + deps.insert(static_cast(event)->operation_id()); + } + return deps; + } + + // EventCompleted is executed on the event_thread_ worker thread. We want + // to propagate the fact that the event is completed to any subsequent events + // that might depend on this event. + void EventCompleted(int64_t event_id, Status status) TF_LOCKS_EXCLUDED(mu_) { + absl::MutexLock l(&mu_); + + absl::btree_map>::iterator + curr_event; + if (!status.ok()) abnormal_event_status_.insert({event_id, status}); + curr_event = events_.find(event_id); + + DCHECK(curr_event->second->callbacks.empty()); + DCHECK(curr_event->second->incomplete_deps.empty()); + + for (auto& event : events_) { + event.second->incomplete_deps.erase(event_id); + // The if statement conditions on both + // - all previous events have completed (incomplete_deps.empty()) + // - the op creating this event has not been called yet + // (event.second.create_fn != nullptr) + // We call the create_fn that creates the event and adds any relevant + // callbacks to the actual event, before setting create_fn to nullptr + // to indicate that it has already been called + if (event.second->incomplete_deps.empty() && + event.second->create_fn != nullptr) { + // We were the last unfilled dependency, all other dependencies are + // filled. We can now fire the create function. + event.second->underlying_event = event.second->create_fn(); + for (auto& fn : event.second->callbacks) { + event.second->underlying_event->AddCallback(std::move(fn)); + } + event.second->callbacks.clear(); + event.second->create_fn = nullptr; + } + } + + // We erase the current event to signal that it has finished. + events_.erase(curr_event); + } + + void ScheduleRequest(int64_t operation_id, + std::function(void)> fn, + const absl::flat_hash_set& deps) + TF_LOCKS_EXCLUDED(mu_) { + absl::MutexLock l(&mu_); + absl::btree_map>::iterator event; + absl::flat_hash_set incomplete_deps; + + event = events_.insert({operation_id, absl::make_unique()}) + .first; + for (const auto& dep : deps) { + if (events_.count(dep) > 0) incomplete_deps.insert(dep); + } + + if (incomplete_deps.empty()) { + // All dependencies have been fulfilled, we execute the request + // immediately and add a callback to inform our event fulfilled thread + // when it is done. + event->second->create_fn = nullptr; + event->second->underlying_event = fn(); + event->second->underlying_event->AddCallback( + [this, operation_id](Status status) { + event_thread_.Schedule([this, operation_id, status]() { + EventCompleted(operation_id, status); + }); + }); + } else { + // There are some dependencies that are not yet fulfilled. We attach + // the request to the event, and will execute it in the EventFulfilled + // worker thread when all its dependencies are fulfilled. + event->second->create_fn = std::move(fn); + event->second->incomplete_deps = std::move(incomplete_deps); + event->second->callbacks.push_back([this, operation_id](Status status) { + event_thread_.Schedule([this, operation_id, status]() { + EventCompleted(operation_id, status); + }); + }); + } + } + + template + std::shared_ptr CheckHandleExists( + absl::flat_hash_map& container, int64_t target_op_id, + int64_t operation_id) { + if (container.count(target_op_id) == 0) { + return std::make_shared( + this, operation_id, + tensorflow::errors::InvalidArgument("Handle ", target_op_id, + " does not exist.")); + } + return nullptr; + } +}; + +xla::Status PodEvent::Await() { + return driver_->WaitForEvent(operation_id_, absl::InfiniteDuration()).value(); +} + +absl::optional PodEvent::AwaitWithTimeout( + absl::Duration duration) { + return driver_->WaitForEvent(operation_id_, duration); +} + +void PodEvent::AddCallback(std::function callback) { + driver_->AddCallbackForEvent(operation_id_, std::move(callback)); +} + +xla::StatusOr> CreatePodTpuDriver( + const TpuDriverConfig& config, + std::shared_ptr<::grpc::ChannelCredentials> creds) { + return std::unique_ptr(new PodTpuDriver(config, creds)); +} + +xla::Status PodCompiledProgramHandle::program_shape( + xla::ProgramShapeProto* program_shape) { + return driver_->GetCompiledProgramShape(operation_id(), program_shape); +} + +} // namespace + +REGISTER_TPU_DRIVER(kPodTpuDriverPrefix, + [](const TpuDriverConfig& config) + -> xla::StatusOr> { + return CreatePodTpuDriver( + config, + ::grpc::InsecureChannelCredentials()); // NOLINT + }); + +} // namespace tpu_driver diff --git a/tensorflow/compiler/xla/python/tpu_driver/recording_tpu_driver.cc b/tensorflow/compiler/xla/python/tpu_driver/recording_tpu_driver.cc index da51380c104..49a19cf9e7a 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/recording_tpu_driver.cc +++ b/tensorflow/compiler/xla/python/tpu_driver/recording_tpu_driver.cc @@ -127,8 +127,11 @@ class RecordingLoadedProgramHandle : public LoadedProgramHandle { class RecordingTpuDriver : public TpuDriver { public: explicit RecordingTpuDriver(std::unique_ptr driver, - const std::string recording_path) - : driver_(std::move(driver)), recording_path_(recording_path) { + const std::string recording_path, + const bool flush) + : driver_(std::move(driver)), + recording_path_(recording_path), + flush_(flush) { auto file_status = tensorflow::Env::Default()->NewAppendableFile( recording_path_, &log_file_); if (!file_status.ok()) { @@ -466,6 +469,7 @@ class RecordingTpuDriver : public TpuDriver { private: std::unique_ptr driver_; const std::string recording_path_; + const bool flush_; std::unique_ptr log_file_; @@ -499,6 +503,22 @@ class RecordingTpuDriver : public TpuDriver { "corrupt. Error: " << data_status.ToString(); } + + if (flush_) { + auto flush_status = log_file_->Flush(); + if (!flush_status.ok()) { + LOG(WARNING) << "Unable to flush data to log file. File possibly " + "corrupt. Error: " + << flush_status.ToString(); + } + + auto sync_status = log_file_->Sync(); + if (!sync_status.ok()) { + LOG(WARNING) << "Unable to sync log file. File possibly " + "corrupt. Error: " + << sync_status.ToString(); + } + } } } @@ -521,6 +541,7 @@ xla::StatusOr> RegisterRecordingTpuDriver( std::string file; std::string worker; + bool flush = false; for (const auto& config : configs) { std::vector kv = @@ -531,6 +552,11 @@ xla::StatusOr> RegisterRecordingTpuDriver( if (kv[0] == "worker") { worker = kv[1]; } + if (kv[0] == "flush") { + if (kv[1] == "true" || kv[1] == "1") { + flush = true; + } + } } TpuDriverConfig worker_config; @@ -541,7 +567,7 @@ xla::StatusOr> RegisterRecordingTpuDriver( auto driver = driver_status.ConsumeValueOrDie(); return std::unique_ptr( - new RecordingTpuDriver(std::move(driver), file)); + new RecordingTpuDriver(std::move(driver), file, flush)); } // To record a sequence of operations, set the worker configuration string to diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc index d5977f4f0cf..f5c1c2d5fa8 100644 --- a/tensorflow/compiler/xla/python/xla.cc +++ b/tensorflow/compiler/xla/python/xla.cc @@ -42,6 +42,7 @@ limitations under the License. #include "tensorflow/compiler/xla/pjrt/interpreter_device.h" #include "tensorflow/compiler/xla/pjrt/nvidia_gpu_device.h" #include "tensorflow/compiler/xla/pjrt/pjrt_client.h" +#include "tensorflow/compiler/xla/pjrt/tpu_client.h" #include "tensorflow/compiler/xla/python/bfloat16.h" #include "tensorflow/compiler/xla/python/dlpack.h" #include "tensorflow/compiler/xla/python/jax_jit.h" @@ -75,7 +76,6 @@ namespace { namespace py = pybind11; - struct Uniquer { absl::Mutex mu; NameUniquer name_uniquer TF_GUARDED_BY(mu); @@ -171,13 +171,13 @@ class TraceMeWrapper : public tensorflow::profiler::TraceMeWrapper { void BuildProfilerSubmodule(py::module* m) { py::module profiler = m->def_submodule("profiler", "TensorFlow profiler integration"); - py::class_> + py::class_> profiler_server_class(profiler, "ProfilerServer"); profiler.def( "start_server", - [](int port) -> std::unique_ptr { - auto server = absl::make_unique(); + [](int port) -> std::unique_ptr { + auto server = absl::make_unique(); server->StartProfilerServer(port); return server; }, @@ -206,6 +206,23 @@ bool IsOptimizedBuild() { #endif // NDEBUG } +// Safe version of ShapeUtil::MakeShapeWithLayout that fails gracefully on +// invalid input. +StatusOr MakeShapeWithLayout( + PrimitiveType element_type, absl::Span dims, + absl::optional> minor_to_major) { + TF_ASSIGN_OR_RETURN(Shape shape, + ShapeUtil::MakeValidatedShape(element_type, dims)); + if (minor_to_major) { + *shape.mutable_layout() = LayoutUtil::MakeLayout(*minor_to_major); + TF_RETURN_IF_ERROR( + LayoutUtil::ValidateLayoutForShape(shape.layout(), shape)); + } else { + shape.clear_layout(); + } + return shape; +} + } // namespace PYBIND11_MODULE(xla_extension, m) { @@ -262,15 +279,13 @@ PYBIND11_MODULE(xla_extension, m) { .def_static( "array_shape", [](PrimitiveType type, py::object dims_seq, - absl::optional layout_seq) -> Shape { + absl::optional layout_seq) -> StatusOr { std::vector dims = IntSequenceToVector(dims_seq); if (layout_seq) { std::vector layout = IntSequenceToVector(*layout_seq); - return ShapeUtil::MakeShapeWithLayout(type, dims, layout); + return MakeShapeWithLayout(type, dims, layout); } else { - Shape shape = ShapeUtil::MakeShape(type, dims); - shape.clear_layout(); - return shape; + return MakeShapeWithLayout(type, dims, absl::nullopt); } }, "Constructs an array shape.", py::arg("type"), py::arg("dims"), @@ -278,16 +293,14 @@ PYBIND11_MODULE(xla_extension, m) { .def_static( "array_shape", [](py::dtype dtype, py::object dims_seq, - absl::optional layout_seq) -> Shape { + absl::optional layout_seq) -> StatusOr { PrimitiveType type = ValueOrThrow(DtypeToPrimitiveType(dtype)); std::vector dims = IntSequenceToVector(dims_seq); if (layout_seq) { std::vector layout = IntSequenceToVector(*layout_seq); - return ShapeUtil::MakeShapeWithLayout(type, dims, layout); + return MakeShapeWithLayout(type, dims, layout); } else { - Shape shape = ShapeUtil::MakeShape(type, dims); - shape.clear_layout(); - return shape; + return MakeShapeWithLayout(type, dims, absl::nullopt); } }, "Constructs an array shape.", py::arg("type"), py::arg("dims"), @@ -430,8 +443,13 @@ PYBIND11_MODULE(xla_extension, m) { }) .def_property( "device_assignment", - [](const CompileOptions& options) { - return options.executable_build_options.device_assignment(); + [](const CompileOptions& options) + -> absl::optional { + return options.executable_build_options.has_device_assignment() + ? absl::optional( + options.executable_build_options + .device_assignment()) + : absl::nullopt; }, [](CompileOptions& options, const DeviceAssignment& device_assignment) { @@ -466,32 +484,31 @@ PYBIND11_MODULE(xla_extension, m) { return local_device->client()->TransferToInfeedLocal( literal, local_device->device_ordinal()); }) - .def( - "transfer_from_outfeed", - [](const PjRtDevice& device, - const Shape& shape) -> StatusOr { - GlobalPyRefManager()->CollectGarbage(); - std::shared_ptr literal_shared; - { - py::gil_scoped_release gil_release; - TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, - device.GetLocalDeviceState()); - Shape shape_with_layout = shape; - ShapeUtil::ForEachMutableSubshape( - &shape_with_layout, [](Shape* subshape, const ShapeIndex&) { - if (!subshape->has_layout()) { - LayoutUtil::SetToDefaultLayout(subshape); - } - }); - TF_ASSIGN_OR_RETURN( - Literal literal, - local_device->client()->TransferFromOutfeedLocal( - shape_with_layout, local_device->device_ordinal())); + .def("transfer_from_outfeed", + [](const PjRtDevice& device, + const Shape& shape) -> StatusOr { + GlobalPyRefManager()->CollectGarbage(); + std::shared_ptr literal_shared; + { + py::gil_scoped_release gil_release; + TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, + device.GetLocalDeviceState()); + Shape shape_with_layout = shape; + ShapeUtil::ForEachMutableSubshape( + &shape_with_layout, [](Shape* subshape, const ShapeIndex&) { + if (!subshape->has_layout()) { + LayoutUtil::SetToDefaultLayout(subshape); + } + }); + TF_ASSIGN_OR_RETURN( + Literal literal, + local_device->client()->TransferFromOutfeedLocal( + shape_with_layout, local_device->device_ordinal())); - literal_shared = std::make_shared(std::move(literal)); - } - return LiteralToPython(std::move(literal_shared)); - }); + literal_shared = std::make_shared(std::move(literal)); + } + return LiteralToPython(std::move(literal_shared)); + }); py::class_>(m, "CpuDevice") .def("__repr__", [](const CpuDevice& device) { @@ -553,13 +570,13 @@ PYBIND11_MODULE(xla_extension, m) { m.def( "get_cpu_client", [](bool asynchronous) -> StatusOr> { - TF_ASSIGN_OR_RETURN(std::shared_ptr client, + TF_ASSIGN_OR_RETURN(std::unique_ptr client, GetCpuClient(asynchronous)); return std::make_shared(std::move(client)); }, py::arg("asynchronous") = true); m.def("get_interpreter_client", []() -> StatusOr> { - TF_ASSIGN_OR_RETURN(std::shared_ptr client, + TF_ASSIGN_OR_RETURN(std::unique_ptr client, GetInterpreterClient()); return std::make_shared(std::move(client)); }); @@ -569,7 +586,7 @@ PYBIND11_MODULE(xla_extension, m) { std::shared_ptr distributed_client, int node_id) -> StatusOr> { TF_ASSIGN_OR_RETURN( - std::shared_ptr client, + std::unique_ptr client, GetNvidiaGpuClient(asynchronous, allocator_config, std::move(distributed_client), node_id)); return std::make_shared(std::move(client)); @@ -577,6 +594,14 @@ PYBIND11_MODULE(xla_extension, m) { py::arg("asynchronous") = true, py::arg("allocator_config") = GpuAllocatorConfig(), py::arg("distributed_client") = nullptr, py::arg("node_id") = 0); + m.def( + "get_tpu_client", + [](bool asynchronous) -> StatusOr> { + TF_ASSIGN_OR_RETURN(std::shared_ptr client, + GetTpuClient(asynchronous)); + return std::make_shared(std::move(client)); + }, + py::arg("asynchronous") = true); py::class_(m, "Frame") .def_readonly("file_name", &Traceback::Frame::file_name) @@ -820,6 +845,14 @@ PYBIND11_MODULE(xla_extension, m) { hlo_module.config().debug_options(), RenderedGraphFormat::kDot); }); + m.def( + "hlo_module_cost_analysis", + [](PyClient* client, + const HloModule& module) -> StatusOr> { + auto analysis = client->pjrt_client()->GetHloCostAnalysis(); + TF_RETURN_IF_ERROR(module.entry_computation()->Accept(analysis.get())); + return analysis->properties(); + }); py::class_ xla_op_class(m, "XlaOp"); diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index 38c55c6fe5d..3de0ffcc2f8 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -90,11 +90,16 @@ def _gpu_backend_factory(distributed_client=None, node_id=0): node_id=node_id) +def _tpu_backend_factory(): + return _xla.get_tpu_client(asynchronous=True) + + # Backend factories, keyed by user-visible name, in increasing priority order. _local_backend_factories = collections.OrderedDict([ ('interpreter', _interpreter_backend_factory), ('cpu', _cpu_backend_factory), ('gpu', _gpu_backend_factory), + ('tpu', _tpu_backend_factory), ]) @@ -113,16 +118,17 @@ def _get_local_backends(): _local_backends = collections.OrderedDict() for name, factory in _local_backend_factories.items(): - logging.vlog(2, "Initializing backend '%s'" % name) + logging.vlog(1, "Initializing backend '%s'" % name) try: backend = factory() - except RuntimeError: + except RuntimeError as err: if name == 'cpu': # We always expect CPU to initialize successfully. raise else: # If the backend isn't built into the binary, or if it has no devices, # we expect a RuntimeError. + logging.vlog(1, "Error initializing backend '%s': %s" % (name, err)) continue _local_backends[name] = backend return _local_backends @@ -144,7 +150,8 @@ def get_local_backend(name=None): try: return backends[name] except KeyError: - raise RuntimeError('Unknown backend {}'.format(name)) + raise RuntimeError( + 'Unknown backend %s. Available: %s' % (name, list(backends.keys()))) return list(backends.values())[-1] @@ -191,8 +198,8 @@ XLA_ELEMENT_TYPE_TO_DTYPE = { PrimitiveType.F64: np.dtype('float64'), PrimitiveType.C64: np.dtype('complex64'), PrimitiveType.C128: np.dtype('complex128'), - PrimitiveType.TUPLE: np.dtype(np.object), - PrimitiveType.TOKEN: np.dtype(np.object), + PrimitiveType.TUPLE: np.dtype(np.object_), + PrimitiveType.TOKEN: np.dtype(np.object_), } # Note the conversion on the key. Numpy has a known issue wherein dtype hashing diff --git a/tensorflow/compiler/xla/python/xla_client_backend_independent_test.py b/tensorflow/compiler/xla/python/xla_client_backend_independent_test.py new file mode 100644 index 00000000000..180bb040cc4 --- /dev/null +++ b/tensorflow/compiler/xla/python/xla_client_backend_independent_test.py @@ -0,0 +1,147 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Backend-independent tests for the Python XLA client.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import unittest + +from absl.testing import absltest +import numpy as np + +from tensorflow.compiler.xla.python import xla_client + +# pylint: disable=g-import-not-at-top +try: + import portpicker +except ImportError: + portpicker = None +# pylint: enable=g-import-not-at-top + +ops = xla_client.ops + + +class ShapeTest(absltest.TestCase): + + def testInvalidShapes(self): + with self.assertRaisesRegex(RuntimeError, + "shape's dimensions must not be < 0.*"): + xla_client.Shape.array_shape(xla_client.PrimitiveType.F32, [-2, 4]) + + with self.assertRaisesRegex( + RuntimeError, "layout minor_to_major field contains 1 element.*"): + xla_client.Shape.array_shape(xla_client.PrimitiveType.F32, [2, 4], [3]) + + with self.assertRaisesRegex( + RuntimeError, "layout minor_to_major field has out-of-bounds value.*"): + xla_client.Shape.array_shape(xla_client.PrimitiveType.F32, [2, 4], + [1, -1]) + + +class ComputationPrinting(absltest.TestCase): + + def ExampleComputation(self): + builder = xla_client.XlaBuilder("acomputation") + p0 = ops.Parameter(builder, 0, xla_client.shape_from_pyval(np.float32(0))) + p1 = ops.Parameter(builder, 1, + xla_client.shape_from_pyval(np.zeros((4,), np.float32))) + x = ops.Mul(p0, p1) + ops.Add(x, x) + return builder.build() + + def testComputationToHloText(self): + computation = self.ExampleComputation() + hlo_text = computation.as_hlo_text() + self.assertTrue(hlo_text.startswith("HloModule acomputation")) + + def testComputationToHloGraph(self): + computation = self.ExampleComputation() + hlo_dot_graph = computation.as_hlo_dot_graph() + self.assertTrue(hlo_dot_graph.startswith("digraph ")) + + def testHloModuleToHloText(self): + computation = self.ExampleComputation() + hlo_text = computation.as_hlo_module().to_string() + self.assertTrue(hlo_text.startswith("HloModule acomputation")) + + def testHloModuleToHloGraph(self): + computation = self.ExampleComputation() + hlo_dot_graph = xla_client._xla.hlo_module_to_dot_graph( + computation.as_hlo_module()) + self.assertTrue(hlo_dot_graph.startswith("digraph ")) + + +class ComputationHashTest(absltest.TestCase): + + def testHash(self): + builder0 = xla_client.XlaBuilder("computation0") + p0 = ops.Parameter(builder0, 0, xla_client.shape_from_pyval(np.float32(0))) + p1 = ops.Parameter(builder0, 1, + xla_client.shape_from_pyval(np.zeros((4,), np.float32))) + ops.Mul(p0, p1) + computation0 = builder0.build() + + builder1 = xla_client.XlaBuilder("computation1") + p0 = ops.Parameter(builder1, 0, xla_client.shape_from_pyval(np.float32(0))) + p1 = ops.Parameter(builder1, 1, + xla_client.shape_from_pyval(np.zeros((4,), np.float32))) + ops.Mul(p0, p1) + computation1 = builder1.build() + + self.assertEqual(computation0.hash(), computation1.hash()) + + +class AliasTest(absltest.TestCase): + + def testSetUpAlias(self): + c = xla_client.XlaBuilder(self.id()) + p1 = ops.Parameter( + c, 0, + xla_client.shape_from_pyval(np.array( + 1.0, np.float32)).with_major_to_minor_layout_if_absent()) + p2 = ops.Parameter( + c, 1, + xla_client.shape_from_pyval(np.array( + 1.0, np.float32)).with_major_to_minor_layout_if_absent()) + out = ops.Add(p1, p2) + c.setup_alias([], 0, []) + c.build(out) + + +class ProfilerTest(absltest.TestCase): + + def testTraceMe(self): + # TODO(phawkins): These tests just check that the TraceMe context manager + # acts like a context manager and doesn't explode. Ideally we'd check that + # the profiler saw the traceme too. + with xla_client.profiler.TraceMe("test1"): + pass + with xla_client.profiler.TraceMe("test2", foo=123): + pass + with self.assertRaises(ValueError): + with xla_client.profiler.TraceMe("test3"): + raise ValueError("test") + + @unittest.skipIf(portpicker is None, "Test requires portpicker") + def testStartServer(self): + port = portpicker.pick_unused_port() + server = xla_client.profiler.start_server(port) + del server + + +if __name__ == "__main__": + absltest.main() diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index 49c57a27ac0..1f8befd79d3 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -1,4 +1,3 @@ -# Lint as: python3 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the Python extension-based XLA client.""" +"""Backend-dependent tests for the Python XLA client.""" from __future__ import absolute_import from __future__ import division @@ -37,12 +36,6 @@ try: except ImportError: custom_call_for_test = None -try: - import portpicker -except ImportError: - portpicker = None -# pylint: enable=g-import-not-at-top - bfloat16 = xla_client.bfloat16 ops = xla_client.ops @@ -105,7 +98,7 @@ def TestFactory(xla_backend, cloud_tpu=False): c, arguments=(), expected=None, - rtol=1e-7, + rtol=1e-4, atol=0): self._ExecuteAndAssertWith( functools.partial(np.testing.assert_allclose, rtol=rtol, atol=atol), @@ -142,27 +135,6 @@ def TestFactory(xla_backend, cloud_tpu=False): ops.Add(x, x) return builder.build() - def testComputationToHloText(self): - computation = self.ExampleComputation() - hlo_text = computation.as_hlo_text() - self.assertTrue(hlo_text.startswith("HloModule acomputation")) - - def testComputationToHloGraph(self): - computation = self.ExampleComputation() - hlo_dot_graph = computation.as_hlo_dot_graph() - self.assertTrue(hlo_dot_graph.startswith("digraph ")) - - def testHloModuleToHloText(self): - computation = self.ExampleComputation() - hlo_text = computation.as_hlo_module().to_string() - self.assertTrue(hlo_text.startswith("HloModule acomputation")) - - def testHloModuleToHloGraph(self): - computation = self.ExampleComputation() - hlo_dot_graph = xla_client._xla.hlo_module_to_dot_graph( - computation.as_hlo_module()) - self.assertTrue(hlo_dot_graph.startswith("digraph ")) - @unittest.skipIf(cloud_tpu, "not implemented") def testCompiledHloModuleToHloText(self): computation = self.ExampleComputation() @@ -173,31 +145,15 @@ def TestFactory(xla_backend, cloud_tpu=False): self.assertTrue(hlo_text.startswith("HloModule acomputation")) self.assertIn("fusion", hlo_text) + @unittest.skipIf(cloud_tpu, "not implemented") + def testFlopEstimate(self): + computation = self.ExampleComputation() + properties = xla_client._xla.hlo_module_cost_analysis( + self.backend, computation.as_hlo_module()) + self.assertEqual(properties["flops"], 8.0) + tests.append(ComputationPrinting) - class ComputationHashTest(absltest.TestCase): - - def testHash(self): - builder0 = xla_client.XlaBuilder("computation0") - p0 = ops.Parameter(builder0, 0, - xla_client.shape_from_pyval(np.float32(0))) - p1 = ops.Parameter( - builder0, 1, xla_client.shape_from_pyval(np.zeros((4,), np.float32))) - ops.Mul(p0, p1) - computation0 = builder0.build() - - builder1 = xla_client.XlaBuilder("computation1") - p0 = ops.Parameter(builder1, 0, - xla_client.shape_from_pyval(np.float32(0))) - p1 = ops.Parameter( - builder1, 1, xla_client.shape_from_pyval(np.zeros((4,), np.float32))) - ops.Mul(p0, p1) - computation1 = builder1.build() - - self.assertEqual(computation0.hash(), computation1.hash()) - - tests.append(ComputationHashTest) - class ComputationsWithConstantsTest(ComputationTest): """Tests focusing on Constant ops.""" @@ -1894,24 +1850,6 @@ def TestFactory(xla_backend, cloud_tpu=False): tests.append(SetShardingTest) - class AliasTest(ComputationTest): - - def testSetUpAlias(self): - c = self._NewComputation() - p1 = ops.Parameter( - c, 0, - xla_client.shape_from_pyval( - NumpyArrayF32(1.0)).with_major_to_minor_layout_if_absent()) - p2 = ops.Parameter( - c, 1, - xla_client.shape_from_pyval( - NumpyArrayF32(1.0)).with_major_to_minor_layout_if_absent()) - out = ops.Add(p1, p2) - c.setup_alias([], 0, []) - c = c.build(out) - - tests.append(AliasTest) - testcase_shapes = [ (), (1,), @@ -2015,28 +1953,6 @@ def TestFactory(xla_backend, cloud_tpu=False): tests.append(BufferProtocolTest) - class ProfilerTest(absltest.TestCase): - - def testTraceMe(self): - # TODO(phawkins): These tests just check that the TraceMe context manager - # acts like a context manager and doesn't explode. Ideally we'd check that - # the profiler saw the traceme too. - with xla_client.profiler.TraceMe("test1"): - pass - with xla_client.profiler.TraceMe("test2", foo=123): - pass - with self.assertRaises(ValueError): - with xla_client.profiler.TraceMe("test3"): - raise ValueError("test") - - @unittest.skipIf(portpicker is None, "Test requires portpicker") - def testStartServer(self): - port = portpicker.pick_unused_port() - server = xla_client.profiler.start_server(port) - del server - - tests.append(ProfilerTest) - class TracebackTest(absltest.TestCase): def setUp(self): diff --git a/tensorflow/compiler/xla/rpc/BUILD b/tensorflow/compiler/xla/rpc/BUILD index 6e345b06e43..15022d1a879 100644 --- a/tensorflow/compiler/xla/rpc/BUILD +++ b/tensorflow/compiler/xla/rpc/BUILD @@ -1,12 +1,14 @@ +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow:tensorflow.bzl", "tf_grpc_cc_dependency") load( "//tensorflow:tensorflow.bzl", + "if_libtpu", "tf_cc_binary", "tf_cc_test", ) load( "//tensorflow/core/platform:build_config.bzl", - "tf_proto_library_cc", + "tf_proto_library", ) load( "//tensorflow/compiler/xla:xla.bzl", @@ -18,7 +20,7 @@ package( licenses = ["notice"], # Apache 2.0 ) -tf_proto_library_cc( +tf_proto_library( name = "xla_service_proto", srcs = ["xla_service.proto"], has_services = 1, @@ -50,13 +52,15 @@ cc_library( srcs = ["grpc_service_main.cc"], deps = [ ":grpc_service", - "//tensorflow/compiler/xla/service:cpu_plugin", + "@com_google_absl//absl/strings:str_format", "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", - "@com_google_absl//absl/strings:str_format", tf_grpc_cc_dependency(), - ], + ] + if_libtpu( + if_false = ["//tensorflow/compiler/xla/service:cpu_plugin"], + if_true = [], + ), ) tf_cc_binary( diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index dd16bd32dd1..491d1d67877 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -1,13 +1,22 @@ # Description: # XLA service implementation. +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") load( "//tensorflow/core/platform:build_config.bzl", - "tf_proto_library_cc", - "tf_proto_library_py", + "tf_proto_library", ) load("//tensorflow:tensorflow.bzl", "tf_cc_test") + +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "filegroup") + +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "internal_hlo_deps") + +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "internal_cuda_deps") load( "//tensorflow/core/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured", @@ -30,7 +39,7 @@ package_group( packages = ["//learning/brain/experimental/tf_runtime/..."], ) -tf_proto_library_cc( +tf_proto_library( name = "hlo_proto", srcs = ["hlo.proto"], cc_api_version = 2, @@ -38,20 +47,13 @@ tf_proto_library_cc( visibility = ["//visibility:public"], ) -tf_proto_library_py( - name = "hlo_proto", # bzl adds a _py suffix only to the OSS target. - srcs = ["hlo.proto"], - visibility = ["//visibility:public"], - deps = ["//tensorflow/compiler/xla:xla_data_proto_py"], -) - -tf_proto_library_cc( +tf_proto_library( name = "hlo_profile_printer_data", srcs = ["hlo_profile_printer_data.proto"], cc_api_version = 2, ) -tf_proto_library_cc( +tf_proto_library( name = "hlo_execution_profile_data", srcs = ["hlo_execution_profile_data.proto"], cc_api_version = 2, @@ -83,6 +85,7 @@ cc_library( deps = [ ":bfloat16_support", ":hlo", + ":hlo_dataflow_analysis", ":hlo_pass", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:xla_data_proto_cc", @@ -201,7 +204,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_proto_cc", "//tensorflow/core:lib", - "//tensorflow/core:regexp_internal", + "//tensorflow/core/platform:regexp", "@com_google_absl//absl/strings", ], ) @@ -446,9 +449,9 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto_cc", - "//tensorflow/core:human_readable_json", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core/platform:human_readable_json", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -477,6 +480,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/types:optional", ], ) @@ -876,7 +880,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "@com_google_absl//absl/strings", ], ) @@ -896,7 +900,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "//tensorflow/stream_executor:device_memory_allocator", "//third_party/eigen3", "@com_google_absl//absl/container:flat_hash_map", @@ -946,7 +950,7 @@ cc_library( "//tensorflow/compiler/xla:xla_proto_cc", "//tensorflow/core:lib", "//tensorflow/core:ptr_util", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "//tensorflow/stream_executor:device_memory_allocator", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -982,7 +986,7 @@ cc_library( "//tensorflow/compiler/xla/client:executable_build_options", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "//tensorflow/stream_executor:device_memory_allocator", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -1010,7 +1014,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "@com_google_absl//absl/strings", ], ) @@ -1021,7 +1025,7 @@ cc_library( ":service", "//tensorflow/compiler/xla/service/cpu:cpu_compiler", "//tensorflow/compiler/xla/service/cpu:cpu_transfer_manager", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", ], ) @@ -1054,14 +1058,14 @@ cc_library( ":service", "//tensorflow/compiler/xla/service/gpu:gpu_compiler", "//tensorflow/compiler/xla/service/gpu:gpu_transfer_manager", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", ] + if_cuda_is_configured([ "//tensorflow/compiler/xla/service/gpu:nvptx_compiler", "//tensorflow/core/platform/default/build_config:stream_executor_cuda", ]) + if_rocm_is_configured([ "//tensorflow/compiler/xla/service/gpu:amdgpu_compiler", "//tensorflow/core/platform/default/build_config:stream_executor_rocm", - ]), + ]) + internal_cuda_deps(), ) cc_library( @@ -1069,10 +1073,10 @@ cc_library( deps = [ ":service", "//tensorflow/compiler/xla/service/gpu:gpu_transfer_manager", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", ] + if_cuda_is_configured([ "//tensorflow/compiler/xla/service/mlir_gpu:mlir_compiler_impl", - ]), + ]) + internal_cuda_deps(), ) cc_library( @@ -1082,7 +1086,7 @@ cc_library( "//tensorflow/compiler/xla/service/interpreter:compiler", "//tensorflow/compiler/xla/service/interpreter:interpreter_transfer_manager", "//tensorflow/compiler/xla/service/interpreter:platform", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", ], ) @@ -1100,7 +1104,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "//tensorflow/stream_executor:device_memory_allocator", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", @@ -1122,8 +1126,8 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:ptr_util", - "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", + "//tensorflow/core/platform:stream_executor_no_cuda", "//tensorflow/stream_executor:device_memory_allocator", "@com_google_absl//absl/memory", ], @@ -1146,6 +1150,10 @@ cc_library( ":maybe_owning_device_memory", ":shaped_buffer", ":stream_pool", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@com_google_absl//absl/types:variant", "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:shape_tree", @@ -1156,15 +1164,11 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "//tensorflow/stream_executor", "//tensorflow/stream_executor:device_description", "//tensorflow/stream_executor:device_memory_allocator", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/types:span", - "@com_google_absl//absl/types:variant", - ], + ] + internal_hlo_deps(), ) cc_library( @@ -1184,7 +1188,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "@com_google_absl//absl/types:span", ], ) @@ -1217,7 +1221,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "//tensorflow/stream_executor:device_memory", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -1258,7 +1262,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "@com_google_absl//absl/memory", ], ) @@ -1684,6 +1688,7 @@ cc_library( hdrs = ["multi_output_fusion.h"], deps = [ ":hlo", + ":hlo_dataflow_analysis", ":hlo_dce", ":hlo_pass", ":hlo_reachability", @@ -1703,7 +1708,6 @@ cc_library( srcs = ["hlo_creation_utils.cc"], hdrs = [ "hlo_creation_utils.h", - "//tensorflow/compiler/xla:literal_util", ], deps = [ ":hlo", @@ -1938,6 +1942,29 @@ cc_library( ], ) +cc_library( + name = "qr_expander", + srcs = ["qr_expander.cc"], + hdrs = ["qr_expander.h"], + deps = [ + ":op_expander_pass", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/lib:constants", + "//tensorflow/compiler/xla/client/lib:loops", + "//tensorflow/compiler/xla/client/lib:math", + "//tensorflow/compiler/xla/client/lib:matrix", + "//tensorflow/compiler/xla/client/lib:slicing", + "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + ], +) + cc_library( name = "convolution_4d_expander", srcs = ["convolution_4d_expander.cc"], @@ -2340,10 +2367,13 @@ cc_library( ":call_inliner", ":hlo", ":hlo_casting_utils", + ":hlo_cse", ":hlo_dce", ":hlo_pass", ":hlo_pass_pipeline", + ":hlo_verifier", ":tuple_simplifier", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -2412,6 +2442,42 @@ tf_cc_test( ], ) +cc_library( + name = "space_to_batch_converter", + srcs = ["space_to_batch_converter.cc"], + hdrs = ["space_to_batch_converter.h"], + deps = [ + ":hlo", + ":hlo_creation_utils", + ":hlo_pass", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + ], +) + +tf_cc_test( + name = "space_to_batch_converter_test", + size = "small", + srcs = ["space_to_batch_converter_test.cc"], + deps = [ + ":hlo", + ":hlo_matchers", + ":space_to_batch_converter", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep + ], +) + cc_library( name = "while_loop_analysis", srcs = ["while_loop_analysis.cc"], @@ -2789,7 +2855,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", @@ -2824,7 +2890,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", ], alwayslink = True, # Contains per-platform transfer manager registration ) @@ -2887,7 +2953,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/memory", ], @@ -3425,6 +3491,8 @@ cc_library( hdrs = ["memory_space_assignment_utils.h"], deps = [ ":heap_simulator", + ":hlo", + ":hlo_casting_utils", ], ) @@ -4156,7 +4224,7 @@ cc_library( "//tensorflow/compiler/xla:xla_proto_cc", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "//tensorflow/core:regexp_internal", + "//tensorflow/core/platform:regexp", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -4259,7 +4327,7 @@ cc_library( deps = [ "//tensorflow/compiler/xla:types", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "@com_google_absl//absl/memory", ], ) @@ -4271,7 +4339,7 @@ tf_cc_test( ":stream_pool", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", ], ) @@ -4328,7 +4396,7 @@ cc_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "//third_party/eigen3", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:span", @@ -4677,7 +4745,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/core:lib", - "//tensorflow/core:regexp_internal", + "//tensorflow/core/platform:regexp", "@com_google_absl//absl/base", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 214cbfa93a7..76b0236fcdd 100755 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -913,7 +913,7 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) { (Match(lhs, m::Multiply(m::Op(&c), m::Op(&a))) && Match(rhs, m::MultiplyAnyOrder(m::Op().Is(c), m::Op(&b))))) && (ShapeUtil::ElementIsIntegral(add->shape()) || - IsAllFpConstantPowerOf2(c))) { + options_.enable_floats_are_real() || IsAllFpConstantPowerOf2(c))) { return ReplaceWithNewInstruction( add, HloInstruction::CreateBinary( add->shape(), HloOpcode::kMultiply, @@ -1300,7 +1300,15 @@ Status AlgebraicSimplifierVisitor::HandleConcatenate( auto replacement = computation_->AddInstruction(concatenate->CloneWithNewOperands( concatenate->shape(), new_operands)); - ReplaceInstructionIfSameShape(concatenate, replacement); + + // Recurse to handle multiple disjoint sequence of inputs. The + // logic above merge only 1 sequential series of + // inputs. Otherwise, it can lead to the FixPass optimization + // hitting its threshold. + if (ReplaceInstructionIfSameShape(concatenate, replacement)) { + return HandleConcatenate(replacement); + } + return Status::OK(); } } @@ -2702,6 +2710,17 @@ Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) { return Status::OK(); } + { + HloInstruction* abs_operand; + if (lhs == rhs && Match(lhs, m::Abs(m::Op(&abs_operand))) && + !ShapeUtil::ElementIsComplex(abs_operand->shape())) { + TF_RETURN_IF_ERROR(multiply->ReplaceOperandWith(0, abs_operand)); + TF_RETURN_IF_ERROR(multiply->ReplaceOperandWith(1, abs_operand)); + changed_ = true; + return Status::OK(); + } + } + { HloInstruction *convert_operand, *operand; // Mul(Convert(Pred), operand) => select(pred, operand, 0) @@ -3303,6 +3322,9 @@ Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) { // padding with a pad with non-negative padding followed by a slice. bool all_zero = true; bool has_negative = false; + // Used to possibly split off the unchanged padding dimensions. + std::vector padding_dimensions; + int64 dimension_index = 0; for (auto& padding_dimension : pad->padding_config().dimensions()) { if (padding_dimension.edge_padding_low() < 0 || padding_dimension.edge_padding_high() < 0) { @@ -3311,12 +3333,93 @@ Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) { if (padding_dimension.edge_padding_low() != 0 || padding_dimension.edge_padding_high() != 0) { all_zero = false; + padding_dimensions.push_back(dimension_index); + } else if (padding_dimension.interior_padding()) { + padding_dimensions.push_back(dimension_index); } + dimension_index++; } if (all_zero) { - ReplaceInstructionIfSameShape(pad, pad->mutable_operand(0)); - return Status::OK(); + if (ReplaceInstructionIfSameShape(pad, pad->mutable_operand(0))) { + return Status::OK(); + } + } + + // The context of this optimization can be found at b/163617402 + // It tries to capture the case of pad(broadcast(x)), where + // x->shape().dimensions(), or broadcast(x)->dimensions(), is + // a subset of the padded dimensions in pad->config(), + // and the padded dimensions in pad->config() is in turn a strict + // subset of broadcast->shape().dimensions(). The combined op can be + // rewritten to broadcast2(pad(broadcast1(x))), where broadcast1 extends + // x with dimensions that need to be padded, and broadcast2 extends + // the result of padding to full dimensions. + // TODO(qyi): for future extensions: The condition for broadcast(x) + // ->dimensions() to be a subset of padded dimensions in pad->config() + // does not have to be strictly required, but it makes the calculation + // for optimization easier, so it is required by the current implementation. + // Only the second condition between the padded dimensions and the + // dimensions of the final shape have to be enforced for the optimization + // to make sense. If needed to remove the first constraint, the shape + // calculations across the implementation need to be re-adjusted. + auto pad_dims = padding_dimensions.size(); + if (pad_dims < dimension_index && + pad->operand(0)->opcode() == HloOpcode::kBroadcast && + pad->operand(0)->user_count() == 1 && + pad->operand(0)->operand(0)->shape().rank() <= pad_dims) { + // Check broadcast operand dimensions is a subset of pading_dimensions. + // If not, skip the optimization. + bool opt_is_valid = true; + std::vector broadcast_dimensions; + HloBroadcastInstruction* broadcast = + static_cast(pad->mutable_operand(0)); + for (auto broadcast_index : broadcast->dimensions()) { + bool found = false; + for (int i = 0; i < pad_dims; ++i) { + if (broadcast_index == padding_dimensions[i]) { + broadcast_dimensions.push_back(i); + found = true; + break; + } + } + if (!found) { + opt_is_valid = false; + break; + } + } + if (opt_is_valid) { + auto pad_shape = pad->shape(); + auto broadcast_shape = broadcast->shape(); + auto pad_shape1 = pad_shape; + auto broadcast_shape1 = broadcast_shape; + PaddingConfig pad_config; + for (int i = padding_dimensions.size() - 1; i >= 0; --i) { + int64 j = padding_dimensions[i]; + while (--dimension_index > j) { + broadcast_shape1.DeleteDimension(dimension_index); + pad_shape1.DeleteDimension(dimension_index); + } + } + while (--dimension_index >= 0) { + broadcast_shape1.DeleteDimension(dimension_index); + pad_shape1.DeleteDimension(dimension_index); + } + for (auto dimension_to_pad : padding_dimensions) { + auto dimension = pad_config.add_dimensions(); + *dimension = pad->padding_config().dimensions(dimension_to_pad); + } + *broadcast->mutable_shape() = broadcast_shape1; + *broadcast->mutable_dimensions() = broadcast_dimensions; + simplifier_->UpdateLayout(broadcast->mutable_shape()); + auto pad2 = + computation_->AddInstruction(pad->CloneWithNewShape(pad_shape1)); + *pad2->mutable_padding_config() = pad_config; + simplifier_->UpdateLayout(pad2->mutable_shape()); + auto broadcast2 = computation_->AddInstruction( + HloInstruction::CreateBroadcast(pad_shape, pad2, padding_dimensions)); + return ReplaceInstruction(pad, broadcast2); + } } if (has_negative) { @@ -3351,7 +3454,8 @@ Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) { pad->shape(), nonzero_pad->mutable_shape())); simplifier_->UpdateLayout(nonzero_pad->mutable_shape()); - // Second, construct the slice instruction to perform the negative padding. + // Second, construct the slice instruction to perform the negative + // padding. std::vector start_indices; std::vector end_indices; std::vector strides; @@ -4012,8 +4116,10 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) { new_limits[i] -= low; } if (slice_in_padding) { - return ReplaceInstruction( - slice, MakeBroadcastHlo(pad->mutable_operand(1), {}, slice->shape())); + HloInstruction* broadcast = + MakeBroadcastHlo(pad->mutable_operand(1), {}, slice->shape()); + *(broadcast->mutable_shape()) = slice->shape(); + return ReplaceInstruction(slice, broadcast); } if (slice_undoes_pad && ReplaceInstructionIfSameShape(slice, pad_operand)) { return Status::OK(); @@ -4022,6 +4128,7 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) { TF_ASSIGN_OR_RETURN(HloInstruction * new_slice, MakeSliceHlo(pad_operand, new_starts, new_limits, slice->slice_strides())); + *(new_slice->mutable_shape()) = slice->shape(); return ReplaceInstruction(slice, new_slice); } } @@ -4085,9 +4192,18 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) { VLOG(3) << "Sink broadcast through slice"; VLOG(3) << "Original slice: " << slice->ToString(); VLOG(3) << "Original broadcast: " << broadcast->ToString(); - TF_ASSIGN_OR_RETURN(auto new_slice, - MakeSliceHlo(broadcast_operand, new_slice_starts, - new_slice_limits, new_slice_strides)); + auto new_slice_shape = broadcast_operand->shape(); + for (int64 i = 0; i < broadcast_operand->shape().rank(); ++i) { + int64 size_i = (new_slice_limits[i] - new_slice_starts[i] + + new_slice_strides[i] - 1) / + new_slice_strides[i]; + new_slice_shape.set_dimensions(i, size_i); + } + simplifier_->UpdateLayout(&new_slice_shape); + HloComputation* computation = broadcast_operand->parent(); + auto new_slice = computation->AddInstruction(HloInstruction::CreateSlice( + new_slice_shape, broadcast_operand, new_slice_starts, new_slice_limits, + new_slice_strides)); auto new_broadcast = HloInstruction::CreateBroadcast( slice->shape(), new_slice, broadcast->dimensions()); VLOG(3) << "New slice: " << slice->ToString(); @@ -4187,9 +4303,15 @@ Status AlgebraicSimplifierVisitor::HandleDynamicSlice( VLOG(3) << "Original broadcast: " << operand->ToString(); HloInstruction* new_dynamic_slice = broadcast_operand; if (!new_slice_sizes.empty()) { - TF_ASSIGN_OR_RETURN( - new_dynamic_slice, - MakeDynamicSliceHlo(broadcast_operand, new_indices, new_slice_sizes)); + auto new_ds_shape = broadcast_operand->shape(); + for (int64 i = 0; i < broadcast_operand->shape().rank(); ++i) { + new_ds_shape.set_dimensions(i, new_slice_sizes[i]); + } + simplifier_->UpdateLayout(&new_ds_shape); + HloComputation* computation = broadcast_operand->parent(); + new_dynamic_slice = + computation->AddInstruction(HloInstruction::CreateDynamicSlice( + new_ds_shape, broadcast_operand, new_indices, new_slice_sizes)); } auto new_broadcast = HloInstruction::CreateBroadcast( dynamic_slice->shape(), new_dynamic_slice, operand->dimensions()); @@ -5167,10 +5289,10 @@ StatusOr AlgebraicSimplifierVisitor::SwapConvOperands( if (!reverse_dimensions.empty()) { TF_ASSIGN_OR_RETURN(kernel, MakeReverseHlo(kernel, reverse_dimensions)); } - TF_ASSIGN_OR_RETURN( - HloInstruction * new_convolution, - MakeConvolveHlo(kernel, input, /*feature_group_count=*/1, swapped_window, - swapped_dnums, precision_config)); + TF_ASSIGN_OR_RETURN(HloInstruction * new_convolution, + MakeConvolveHlo(kernel, input, /*feature_group_count=*/1, + /*batch_group_count=*/1, swapped_window, + swapped_dnums, precision_config)); convolution->SetupDerivedInstruction(new_convolution); TF_RETURN_IF_ERROR(ReplaceInstruction(convolution, new_convolution)); diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h index 9f2a3404116..cabecec4eb8 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.h +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h @@ -97,6 +97,14 @@ class AlgebraicSimplifierOptions { return enable_scalar_multiply_reduction_; } + // Also the algebraic simplifer to treat floating point values like real + // numbers. + void set_enable_floats_are_real(bool enable_floats_are_real) { + enable_floats_are_real_ = enable_floats_are_real; + } + + bool enable_floats_are_real() const { return enable_floats_are_real_; } + // If enable_window_reduce_replacement is true, the kReduceWindow instruction // can be optimized by replacement with simpler operations. void set_enable_window_reduce_to_reduce_replacement( @@ -158,6 +166,7 @@ class AlgebraicSimplifierOptions { bool enable_conv_simplification_{true}; bool enable_conv_operand_swap_{true}; bool enable_scalar_multiply_reduction_{false}; + bool enable_floats_are_real_{false}; bool enable_window_reduce_to_reduce_replacement_{true}; bool enable_reduce_of_reshape_{true}; bool replace_transpose_with_bitcast_{true}; diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 70147f6ecad..c4f3ea4087b 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -117,6 +117,22 @@ TEST_F(AlgebraicSimplifierTest, FactorFpAddition) { m::ConstantScalar(0.125)))); } +// (Abs(A)) * (Abs(A)) => (A*A) +TEST_F(AlgebraicSimplifierTest, SquareOfAbs) { + const char* kModuleStr = R"( + HloModule m + test { + p = f32[] parameter(0) + a = f32[] abs(p) + ROOT z = f32[] multiply(a, a) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(0)))); +} + // (A*C1) * (B*C2) => (A*B)*(C1*C2) TEST_F(AlgebraicSimplifierTest, MultiplyChain) { const char* kModuleStr = R"( @@ -2319,7 +2335,7 @@ TEST_F(AlgebraicSimplifierTest, ConcatenateOfBroadcastBecomesPad) { TEST_F(AlgebraicSimplifierTest, SimplifyConcatenateOfSlices) { auto m = CreateNewVerifiedModule(); Shape r2f32 = ShapeUtil::MakeShape(F32, {100, 99}); - Shape concat_shape = ShapeUtil::MakeShape(F32, {50, 80}); + Shape concat_shape = ShapeUtil::MakeShape(F32, {50, 90}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r2f32, "param0")); @@ -2366,10 +2382,15 @@ TEST_F(AlgebraicSimplifierTest, SimplifyConcatenateOfSlices) { HloInstruction* slice7 = builder.AddInstruction(HloInstruction::CreateSlice( ShapeUtil::MakeShape(F32, {50, 10}), param1, /*start_indices=*/{50, 79}, /*limit_indices=*/{100, 89}, /*strides=*/{1, 1})); + // Can merge 'slice7' and 'slice8'. + HloInstruction* slice8 = builder.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {50, 10}), param1, /*start_indices=*/{50, 89}, + /*limit_indices=*/{100, 99}, /*strides=*/{1, 1})); builder.AddInstruction(HloInstruction::CreateConcatenate( concat_shape, - {slice0, slice1, slice2, slice3, slice4, slice5, slice6, slice7}, 1)); + {slice0, slice1, slice2, slice3, slice4, slice5, slice6, slice7, slice8}, + 1)); auto computation = m->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(default_options_); @@ -2384,6 +2405,12 @@ TEST_F(AlgebraicSimplifierTest, SimplifyConcatenateOfSlices) { ShapeUtil::Equal(computation->root_instruction()->operand(3)->shape(), ShapeUtil::MakeShape(F32, {50, 30}))); EXPECT_EQ(computation->root_instruction()->operand(3)->slice_starts(1), 40); + + // The operand 6 should be merge of 'slice7' and 'slice8', so its + // shape should have dimensions {50, 20} + EXPECT_TRUE( + ShapeUtil::Equal(computation->root_instruction()->operand(5)->shape(), + ShapeUtil::MakeShape(F32, {50, 20}))); } // Test that a simplification which changes layouts is not performed if layout @@ -6955,5 +6982,57 @@ TEST_F(AlgebraicSimplifierTest, UnaryVariadicReduce) { GmockMatch(m::Add(m::Parameter(0), m::Parameter(1)))); } +TEST_F(AlgebraicSimplifierTest, BroadcastAndPadReorder) { + const char* kModuleStr = R"( + HloModule m + test { + c1 = pred[] constant(true) + b2 = pred[32,1,768]{2,1,0} broadcast(pred[] c1), dimensions={} + c3 = pred[] constant(false) + ROOT p4 = pred[4096,1,768]{2,1,0} pad(pred[32,1,768]{2,1,0} b2, pred[] c3), padding=0_4064x0_0x0_0 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Broadcast( + m::Pad(m::Broadcast(m::Constant()), m::Constant())))); +} + +TEST_F(AlgebraicSimplifierTest, BroadcastAndPadReorderWithUse) { + const char* kModuleStr = R"( + HloModule m + test { + c1 = pred[] constant(true) + b2 = pred[1,768,32]{2,1,0} broadcast(pred[] c1), dimensions={} + c3 = pred[] constant(false) + p4 = pred[1,768,4096]{2,1,0} pad(pred[1,768,32]{2,1,0} b2, pred[] c3), padding=0_0x0_0x0_4064 + ROOT p5 = (pred[1,768,4096]{2,1,0}) tuple(pred[1,768,4096]{2,1,0} p4) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Tuple(m::Broadcast( + m::Pad(m::Broadcast(m::Constant()), m::Constant()))))); +} + +TEST_F(AlgebraicSimplifierTest, BroadcastAndPadReorderWithNonScalar) { + const char* kModuleStr = R"( + HloModule m + test { + c1 = pred[32] parameter(0) + b2 = pred[1,768,32]{2,1,0} broadcast(pred[32] c1), dimensions={2} + c3 = pred[] constant(false) + p4 = pred[1,768,4096]{2,1,0} pad(pred[1,768,32]{2,1,0} b2, pred[] c3), padding=0_0x0_0x0_4064 + ROOT p5 = (pred[1,768,4096]{2,1,0}) tuple(pred[1,768,4096]{2,1,0} p4) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Tuple(m::Broadcast( + m::Pad(m::Broadcast(m::Parameter()), m::Constant()))))); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/all_reduce_simplifier.cc b/tensorflow/compiler/xla/service/all_reduce_simplifier.cc index 541006f04d5..18a0fdc1a70 100644 --- a/tensorflow/compiler/xla/service/all_reduce_simplifier.cc +++ b/tensorflow/compiler/xla/service/all_reduce_simplifier.cc @@ -31,27 +31,7 @@ StatusOr AllReduceSimplifier::Run(HloModule* module) { TF_ASSIGN_OR_RETURN( auto replication, HloReplicationAnalysis::Run(module, /*cross_partition_spmd=*/false)); - std::vector all_reduces_to_replace; - for (auto computation : module->computations()) { - for (HloInstruction* inst : computation->MakeInstructionPostOrder()) { - if (!inst->shape().IsArray()) { - // We currently do not change tuple-shaped all-reduce. - // Until XLA will support Token fed AllReduce(), the PyTorch client code - // uses a fake data token (constant) which relies on this pass to not - // optimize out (being fed within a tuple input). - continue; - } - if (inst->IsCrossReplicaAllReduce() && - replication->HloInstructionIsReplicatedAt(inst->operand(0), {})) { - all_reduces_to_replace.push_back(inst); - } - } - } - - bool changed = false; - if (all_reduces_to_replace.empty()) { - return changed; - } + std::vector> all_reduces_to_replace; // Returns the size of a replica group if all groups have the same size, or -1 // if they have different sizes. @@ -71,7 +51,40 @@ StatusOr AllReduceSimplifier::Run(HloModule* module) { return replica_group_size; }; - for (auto all_reduce : all_reduces_to_replace) { + for (auto computation : module->computations()) { + for (HloInstruction* inst : computation->MakeInstructionPostOrder()) { + if (!inst->shape().IsArray()) { + // We currently do not change tuple-shaped all-reduce. + // Until XLA will support Token fed AllReduce(), the PyTorch client code + // uses a fake data token (constant) which relies on this pass to not + // optimize out (being fed within a tuple input). + continue; + } + if (!inst->IsCrossReplicaAllReduce()) { + continue; + } + int64 group_size = get_replica_group_size(inst); + if (group_size == -1) { + continue; + } + if (replication->HloInstructionIsReplicatedAt(inst->operand(0), {}) || + group_size == 1) { + all_reduces_to_replace.push_back({inst, group_size}); + } + } + } + + bool changed = false; + + for (auto all_reduce_and_group_size : all_reduces_to_replace) { + auto all_reduce = all_reduce_and_group_size.first; + const int64 replica_group_size = all_reduce_and_group_size.second; + if (replica_group_size == 1) { + TF_RETURN_IF_ERROR(all_reduce->parent()->ReplaceInstruction( + all_reduce, all_reduce->mutable_operand(0))); + changed = true; + continue; + } if (all_reduce->to_apply()->instruction_count() != 3 || all_reduce->to_apply()->num_parameters() != 2) { continue; @@ -79,10 +92,6 @@ StatusOr AllReduceSimplifier::Run(HloModule* module) { HloInstruction* replacement; switch (all_reduce->to_apply()->root_instruction()->opcode()) { case HloOpcode::kAdd: { - int64 replica_group_size = get_replica_group_size(all_reduce); - if (replica_group_size == -1) { - continue; - } // Create the multiplier: // broadcast(convert_to_matching_type(s32 group size)) auto multiplier = diff --git a/tensorflow/compiler/xla/service/all_reduce_simplifier_test.cc b/tensorflow/compiler/xla/service/all_reduce_simplifier_test.cc index 4914836b34a..1e938594cc3 100644 --- a/tensorflow/compiler/xla/service/all_reduce_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/all_reduce_simplifier_test.cc @@ -167,5 +167,30 @@ test { m::Parameter(0), m::AllReduce(m::Parameter(1))))); } +TEST_F(AllReduceSimplifierTest, TrivialSubgroupAllReduce) { + const char* kModuleStr = R"( +HloModule m + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add.2 = f32[] add(a, b) +} + + +test { + p0 = f32[8,16] parameter(0), parameter_replication={false} + ROOT all-reduce = f32[8,16] all-reduce(p0), + replica_groups={{0},{1},{2},{3},{4},{5},{6},{7}}, + to_apply=sum +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule( + kModuleStr, /*replica_count=*/8)); + AllReduceSimplifier simplifier(/*replica_count=*/8); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Parameter(0))); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/allocation_tracker.cc b/tensorflow/compiler/xla/service/allocation_tracker.cc index cc501161ce9..19927ae1576 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.cc +++ b/tensorflow/compiler/xla/service/allocation_tracker.cc @@ -143,13 +143,10 @@ StatusOr> AllocationTracker::DeconstructTuple( // We only need to care about replica id 0 here, since the GlobalDataHandle is // the same for all buffers across replicas. const ShapedBuffer* shaped_buffer = replicated_buffers[0]; - if (!shaped_buffer->on_host_shape().IsTuple()) { + if (!shaped_buffer->on_device_shape().IsTuple()) { return InvalidArgument("global data handle %d is not a tuple", data.handle()); } - // If the on-host representation is a tuple, then the on-device one should be - // as well. - TF_RET_CHECK(shaped_buffer->on_device_shape().IsTuple()); if (ShapeUtil::IsNestedTuple(shaped_buffer->on_device_shape())) { return Unimplemented("Deconstructing nested tuples is not implemented."); @@ -160,7 +157,6 @@ StatusOr> AllocationTracker::DeconstructTuple( i < ShapeUtil::TupleElementCount(shaped_buffer->on_device_shape()); ++i) { auto element_buffer = ShapedBuffer( - ShapeUtil::GetTupleElementShape(shaped_buffer->on_host_shape(), i), ShapeUtil::GetTupleElementShape(shaped_buffer->on_device_shape(), i), shaped_buffer->platform(), shaped_buffer->device_ordinal()); element_buffer.set_buffer(shaped_buffer->buffer(/*index=*/{i}), diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc index 23d2a9225a8..73210e6b3dc 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc @@ -18,6 +18,7 @@ limitations under the License. #include "absl/types/span.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -159,19 +160,20 @@ Status BFloat16ConversionFoldingVisitor::TryFoldBF16Conversions( Status BFloat16ConversionFoldingVisitor::DefaultAction(HloInstruction* hlo) { // Do not fold BF16 conversions for instructions related to tuples, entry and - // exit of a computation, fusion, convert, side-effecting instructions and - // control flow. - if (hlo->opcode() == HloOpcode::kTuple || // - hlo->opcode() == HloOpcode::kGetTupleElement || // - hlo->opcode() == HloOpcode::kConstant || // - hlo->opcode() == HloOpcode::kParameter || // - hlo->opcode() == HloOpcode::kFusion || // - hlo->opcode() == HloOpcode::kBitcastConvert || // - hlo->opcode() == HloOpcode::kConvert || // - hlo->opcode() == HloOpcode::kCall || // - hlo->opcode() == HloOpcode::kCustomCall || // - hlo->opcode() == HloOpcode::kWhile || // - hlo->opcode() == HloOpcode::kConditional || // + // exit of a computation, fusion, convert, side-effecting instructions, + // in-place operations and control flow. + if (hlo->opcode() == HloOpcode::kTuple || // + hlo->opcode() == HloOpcode::kGetTupleElement || // + hlo->opcode() == HloOpcode::kConstant || // + hlo->opcode() == HloOpcode::kParameter || // + hlo->opcode() == HloOpcode::kFusion || // + hlo->opcode() == HloOpcode::kBitcastConvert || // + hlo->opcode() == HloOpcode::kConvert || // + hlo->opcode() == HloOpcode::kCall || // + hlo->opcode() == HloOpcode::kCustomCall || // + hlo->opcode() == HloOpcode::kWhile || // + hlo->opcode() == HloOpcode::kConditional || // + HloDataflowAnalysis::IsInPlaceOperation(hlo->opcode()) || // hlo->HasSideEffectNoRecurse()) { return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc index a0fe0eaa1d9..f9e19493a86 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc @@ -598,6 +598,31 @@ bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper( type = F32; break; } + // In order to find aliases due to in-place operations, use + // GetInPlaceInputOutputPairs. Ideally, we'd use HloAliasAnalysis here, + // but this code works with HloModules that aren't ready yet to use + // HloAliasAnalysis (e.g., their computation graphs may not have been + // flattened yet). + for (const auto& operand_and_output_index : + HloDataflowAnalysis::GetInPlaceInputOutputPairs(hlo)) { + if (operand_and_output_index.second == index) { + const HloUse& operand = operand_and_output_index.first; + for (const auto* value : + dataflow_ + ->GetValueSet(hlo->operand(operand.operand_number), + operand.operand_index) + .values()) { + auto value_type = ValueTypeAfterChange(value); + if (value_type == BF16) { + continue; + } + CHECK_EQ(value_type, F32); + type = F32; + break; + } + } + } + // It's possible that a user has been changed from BF16 to F32 // during this final adjustment pass, so we need to check // AllUsersConsumeBF16() again. diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc index 02d79025f1b..9a898833373 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc @@ -1156,4 +1156,30 @@ ENTRY entry { EXPECT_FALSE(PropagatePrecision(module.get())); } +TEST_F(BFloat16PropagationTest, DynamicUpdateSlice) { + // This test is crafted so that the DUS has an f32 input (due to parameter) + // and bf16 output (due to dot). But we should enforce DUS operand 0 and + // output to get the same precision since it's an in-place operation. + const string module_str = R"( +HloModule Module + +ENTRY main { + param = f32[128,128] parameter(0) + constant.1 = f32[] constant(0) + broadcast.6 = f32[128,1] broadcast(constant.1), dimensions={} + constant.3 = s32[] constant(0) + dynamic-update-slice = f32[128,128] dynamic-update-slice(param, broadcast.6, constant.3, constant.3) + ROOT dot = f32[128,128] dot(dynamic-update-slice, dynamic-update-slice), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + EXPECT_FALSE(PropagatePrecision(module.get())); + + HloInstruction* dus = module->entry_computation()->GetInstructionWithName( + "dynamic-update-slice"); + EXPECT_FALSE(OutputsBF16(dus)); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index a0989d5765e..db34f054f35 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -1007,102 +1007,6 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation, return true; } // namespace xla -Status BufferAssigner::MergeInplaceOpBuffers(BufferAssignment* assignment) { - // Try allocate same buffer for dynamic update slice's operand and output. - - // If memory_space_assignment is run and there is information about a color in - // preset assignments, don't merge those buffers. We expect - // memory_space_assignment to have merged these buffers. If - // memory_space_assignment didn't merge these buffers and have assigned - // different offsets to the operand and the output buffer, merging the buffers - // can cause memory corruption if memory_space_assignment assigned a different - // buffer at the same offset. - absl::flat_hash_set excluded_colors; - if (preset_assignments_) { - for (const auto& color_and_info : - preset_assignments_->assignment_informations()) { - excluded_colors.insert(color_and_info.first); - } - } - - // TODO(yunxing): Moving this logic to alias analysis and add must-alias rule - // to operations that can be done in place. - for (HloComputation* computation : assignment->module().computations()) { - for (HloInstruction* instruction : computation->instructions()) { - if (!(instruction->opcode() == HloOpcode::kDynamicUpdateSlice || - (instruction->opcode() == HloOpcode::kFusion && - (instruction->fused_expression_root()->opcode() == - HloOpcode::kDynamicUpdateSlice)))) { - continue; - } - if (instruction->parent()->IsFusionComputation()) { - continue; - } - if (instruction->operand_count() == 0) { - continue; - } - - // The operand can't share the same buffer with the user based on dataflow - // analysis. - if (!assignment->dataflow_analysis().CanShareOperandBufferWithUser( - instruction->mutable_operand(0), {}, instruction, {})) { - continue; - } - HloBuffer& instruction_buffer = - assignment->alias_analysis().GetUniqueBufferAt(instruction, {}); - - HloBuffer& operand_buffer = - assignment->alias_analysis().GetUniqueBufferAt( - instruction->operand(0), {}); - - // The instruction or operand color is excluded because it was assigned by - // memory_space_assignment. - if (excluded_colors.contains(instruction_buffer.color()) || - excluded_colors.contains(operand_buffer.color())) { - continue; - } - - // Already have the same buffer. No need to merge those. - if (instruction_buffer.id() == operand_buffer.id()) { - continue; - } - - // Do not perform in-place dynamic update slice if the operand buffer is - // read-only. - if (HloBufferIsReadOnly(operand_buffer)) { - continue; - } - - bool interfere = false; - - for (const HloValue* instruction_value : instruction_buffer.values()) { - for (const HloValue* operand_value : operand_buffer.values()) { - if (assignment->hlo_ordering().MayInterfere( - *instruction_value, *operand_value, - assignment->dataflow_analysis())) { - interfere = true; - break; - } - } - } - if (interfere) { - continue; - } - if (assignment->alias_analysis().BufferLivesOut(instruction_buffer)) { - continue; - } - if (instruction_buffer.color() != operand_buffer.color()) { - continue; - } - VLOG(3) << "Merging inplace " << instruction_buffer << " and " - << operand_buffer; - assignment->alias_analysis().MergeBuffers(instruction_buffer, - operand_buffer); - } - } - return Status::OK(); -} - Status BufferAssigner::AssignSingleHloBuffer( const HloBuffer* hlo_buffer, bool is_thread_local, absl::flat_hash_map> BufferAssigner::CreateAssignment( VLOG(3) << "After coloring:"; XLA_VLOG_LINES(3, assignment->alias_analysis().dataflow_analysis().ToString()); - TF_RETURN_IF_ERROR(MergeInplaceOpBuffers(assignment.get())); std::vector thread_local_computations; std::vector global_computations; diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h index 60422965832..dfde46ca4b1 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.h +++ b/tensorflow/compiler/xla/service/buffer_assignment.h @@ -635,10 +635,6 @@ class BufferAssigner { absl::flat_hash_set* assigned_buffers, BufferAssignment* assignment); - // Promotes operations (DUS, scatter) to be done in place: If an operation can - // be done in place, merge its buffer with its operand buffer. - Status MergeInplaceOpBuffers(BufferAssignment* assignment); - // Assigns a single hlo buffer to an HLO allocation. Status AssignSingleHloBuffer( const HloBuffer* hlo_buffer, bool is_thread_local, diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index bc024f7144b..b49ca649f9a 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -1925,8 +1925,10 @@ ENTRY main { TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_text)); HloInstruction* parameter = m->entry_computation()->GetInstructionWithName("get-tuple-element.4"); - HloInstruction* dus = + HloInstruction* dus1 = m->entry_computation()->GetInstructionWithName("dynamic-update-slice.5"); + HloInstruction* dus2 = + m->entry_computation()->GetInstructionWithName("dynamic-update-slice.9"); auto buffers = RunBufferAssignment(m.get()); @@ -1934,8 +1936,10 @@ ENTRY main { const BufferAllocation& parameter_alloc = GetTopLevelAllocation(*buffers, parameter); - const BufferAllocation& dus_alloc = GetTopLevelAllocation(*buffers, dus); - EXPECT_NE(parameter_alloc, dus_alloc); + const BufferAllocation& dus1_alloc = GetTopLevelAllocation(*buffers, dus1); + EXPECT_EQ(parameter_alloc, dus1_alloc); + const BufferAllocation& dus2_alloc = GetTopLevelAllocation(*buffers, dus2); + EXPECT_EQ(parameter_alloc, dus2_alloc); } } diff --git a/tensorflow/compiler/xla/service/cholesky_expander.cc b/tensorflow/compiler/xla/service/cholesky_expander.cc index 20576cdc52d..4abfe1b018e 100644 --- a/tensorflow/compiler/xla/service/cholesky_expander.cc +++ b/tensorflow/compiler/xla/service/cholesky_expander.cc @@ -35,8 +35,6 @@ limitations under the License. namespace xla { -namespace { - // The Cholesky–Banachiewicz algorithm. See // https://en.wikipedia.org/wiki/Cholesky_decomposition#The_Cholesky–Banachiewicz_and_Cholesky–Crout_algorithms // for a description. @@ -54,78 +52,81 @@ namespace { // l = temp / l[..., j, j) * mask + l // return l // Returns a (result, error) pair. -std::pair CholeskyUnblocked( +StatusOr> CholeskyExpander::CholeskyUnblocked( XlaOp a, PrecisionConfig::Precision precision) { XlaBuilder* builder = a.builder(); - auto result = [&]() -> StatusOr> { - TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); - const int n_dims = a_shape.rank(); - const int64 n = ShapeUtil::GetDimension(a_shape, -1); - auto major_dims = AsInt64Slice(a_shape.dimensions()) - .subspan( - /*pos=*/0, - /*len=*/n_dims - 2); + TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); + const int ndims = a_shape.rank(); + const int64 n = ShapeUtil::GetDimension(a_shape, -1); + std::vector error_dims(a_shape.dimensions().begin(), + a_shape.dimensions().end()); + error_dims.back() = error_dims.at(ndims - 2) = 1; - auto matrix_dims = AsInt64Slice(a_shape.dimensions()) - .subspan( - /*pos=*/0, - /*len=*/n_dims); + auto major_dims = AsInt64Slice(a_shape.dimensions()) + .subspan( + /*pos=*/0, + /*len=*/ndims - 2); - XlaOp l = ZerosLike(a); + auto matrix_dims = AsInt64Slice(a_shape.dimensions()) + .subspan( + /*pos=*/0, + /*len=*/ndims); - // Construct the for loop body to iterate over rows. - auto body_fn = - [&](XlaOp i, absl::Span loop_vars, - XlaBuilder* body_builder) -> StatusOr> { - std::vector row_shape_dims(major_dims.begin(), major_dims.end()); - std::vector col_shape_dims(major_dims.begin(), major_dims.end()); - auto body_a = loop_vars[0]; - auto body_l = loop_vars[1]; - auto seen_error = loop_vars[2]; - auto iota_row = Iota(body_builder, ShapeUtil::MakeShape(S32, matrix_dims), - n_dims - 1); - auto iota_col = Iota(body_builder, ShapeUtil::MakeShape(S32, matrix_dims), - n_dims - 2); + XlaOp l = ZerosLike(a); - auto mask_pred = Ge(iota_col, iota_row); - mask_pred = And(mask_pred, Eq(iota_row, i)); - auto mask_zeros = - Zeros(body_builder, - ShapeUtil::MakeShape(a_shape.element_type(), matrix_dims)); - // L * L.T, This matrix has of a lot of multiplying with zero - // (namely, L[:, j:] = 0) and redundant computation, but it is faster - // than slice. - auto l_square = BatchDot(body_l, false, body_l, true, precision); + // Construct the for loop body to iterate over rows. + auto body_fn = [&](XlaOp i, absl::Span loop_vars, + XlaBuilder* body_builder) -> StatusOr> { + std::vector row_shape_dims(major_dims.begin(), major_dims.end()); + std::vector col_shape_dims(major_dims.begin(), major_dims.end()); + auto body_a = loop_vars[0]; + auto body_l = loop_vars[1]; + auto seen_error = loop_vars[2]; + auto iota_row = + Iota(body_builder, ShapeUtil::MakeShape(S32, matrix_dims), ndims - 1); + auto iota_col = + Iota(body_builder, ShapeUtil::MakeShape(S32, matrix_dims), ndims - 2); - // A - L*L.T - l_square = body_a - l_square; - auto l_ii = DynamicSliceInMinorDims(l_square, {i, i}, {1, 1}); + auto mask_pred = Ge(iota_col, iota_row); + mask_pred = And(mask_pred, Eq(iota_row, i)); + auto mask_zeros = + Zeros(body_builder, + ShapeUtil::MakeShape(a_shape.element_type(), matrix_dims)); + // L * L.T, This matrix has of a lot of multiplying with zero + // (namely, L[:, j:] = 0) and redundant computation, but it is faster + // than slice. + auto l_square = + BatchDot(body_l, false, MaybeConjugate(body_l, true), true, precision); + + // A - L*L.T + l_square = body_a - l_square; + auto l_ii = DynamicSliceInMinorDims(l_square, {i, i}, {1, 1}); + if (ShapeUtil::ElementIsComplex(a_shape)) { + auto sqrt = Sqrt(Real(l_ii)); + l_ii = Complex(sqrt, ZerosLike(sqrt)); + seen_error = Or(seen_error, IsNan(sqrt)); + } else { l_ii = Sqrt(l_ii); - // L = (A - L*L.T) / l_ii * mask + L - body_l = Select(mask_pred, l_square / l_ii, mask_zeros) + body_l; + seen_error = Or(seen_error, IsNan(l_ii)); + } + // L = (A - L*L.T) / l_ii * mask + L + body_l = Select(mask_pred, l_square / l_ii, mask_zeros) + body_l; - seen_error = - Or(seen_error, Any(Or(Le(l_ii, ZerosLike(l_ii)), IsNan(l_ii)))); + return std::vector{body_a, body_l, seen_error}; + }; - return std::vector{body_a, body_l, seen_error}; - }; + TF_ASSIGN_OR_RETURN( + auto cholesky_while, + ForEachIndex( + n, S32, body_fn, + {a, l, Zeros(builder, ShapeUtil::MakeShape(PRED, error_dims))}, + "unblocked", builder)); - TF_ASSIGN_OR_RETURN( - auto cholesky_while, - ForEachIndex(n, S32, body_fn, {a, l, ConstantR0(builder, false)}, - "unblocked", builder)); - - return std::make_pair(cholesky_while[1], cholesky_while[2]); - }(); - if (!result.ok()) { - XlaOp error = builder->ReportError(result.status()); - return {error, error}; - } - return result.ValueOrDie(); + return std::make_pair(cholesky_while[1], cholesky_while[2]); } -XlaOp BuildCholesky(XlaOp a, int64 block_size, - PrecisionConfig::Precision precision) { +XlaOp CholeskyExpander::BuildCholesky(XlaOp a, int64 block_size, + PrecisionConfig::Precision precision) { XlaBuilder* builder = a.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); @@ -143,64 +144,77 @@ XlaOp BuildCholesky(XlaOp a, int64 block_size, ShapeUtil::HumanString(a_shape)); } - if (primitive_util::IsComplexType(a_shape.element_type())) { - return Unimplemented( - "Complex types are not implemented in Cholesky; got shape %s", - ShapeUtil::HumanString(a_shape)); - } - if (block_size < 1) { return InvalidArgument( "block_size argument to Cholesky must be >= 1; got %d", block_size); } + std::vector error_dims(a_shape.dimensions().begin(), + a_shape.dimensions().end()); + error_dims.back() = error_dims.at(ndims - 2) = 1; + std::vector error_dim_indices(ndims); + absl::c_iota(error_dim_indices, 0); + // Blocked left-looking Cholesky factorization. // Algorithm 1 from // Haidar, Azzam, et al. "High-performance Cholesky factorization for // GPU-only execution." Proceedings of General Purpose GPUs. ACM, 2017. XlaOp l = ZerosLike(a); - XlaOp seen_error = ConstantR0(builder, false); + XlaOp seen_error = Zeros(builder, ShapeUtil::MakeShape(PRED, error_dims)); for (int64 i = 0; i < n; i += block_size) { int64 k = std::min(block_size, n - i); + auto panel = SliceInMinorDims(a, {i, i}, {n, i + k}); if (i > 0) { // TODO(phawkins): consider implementing SYRK for the diagonal part of // the panel. // a[i:, i:i+k] -= np.dot(l[i:, :i], np.transpose(l[i:i+k, :i])) auto lhs = SliceInMinorDims(l, {i, 0}, {n, i}); auto rhs = SliceInMinorDims(l, {i, 0}, {i + k, i}); - auto delta = BatchDot(lhs, false, rhs, true, precision); - auto before = SliceInMinorDims(a, {i, i}, {n, i + k}); - a = UpdateSliceInMinorDims(a, before - delta, {i, i}); + auto delta = + BatchDot(lhs, false, MaybeConjugate(rhs, true), true, precision); + panel = panel - delta; } // l[i:i+k, i:i+k] = cholesky_unblocked(a[i:i+k, i:i+k]) - auto x = SliceInMinorDims(a, {i, i}, {i + k, i + k}); + auto x = SliceInMinorDims(panel, {0, 0}, {k, k}); XlaOp factorized; + // TODO(b/167896062): A failure in one element of a batch shouldn't fail + // other elements. XlaOp factorized_error; - std::tie(factorized, factorized_error) = CholeskyUnblocked(x, precision); + if (k == 1) { + if (ShapeUtil::ElementIsComplex(a_shape)) { + auto sqrt = Sqrt(Real(x)); + factorized = Complex(sqrt, ZerosLike(sqrt)); + factorized_error = IsNan(sqrt); + } else { + factorized = Sqrt(x); + factorized_error = IsNan(factorized); + } + } else { + TF_ASSIGN_OR_RETURN(auto tile_output, CholeskyUnblocked(x, precision)); + std::tie(factorized, factorized_error) = tile_output; + } seen_error = Or(seen_error, factorized_error); l = UpdateSliceInMinorDims(l, factorized, {i, i}); if (i + k < n) { // l[i+k:, i:i+k] = // trsm_right_transpose(l[i:i+k, i:i+k], a[i+k:, i:i+k]) - auto panel = SliceInMinorDims(a, {i + k, i}, {n, i + k}); - auto update = - TriangularSolve(factorized, panel, - /*left_side=*/false, - /*lower=*/true, - /*unit_diagonal=*/false, - /*transpose_a=*/TriangularSolveOptions::TRANSPOSE); + auto update = TriangularSolve( + factorized, SliceInMinorDims(panel, {k, 0}, {n - i, k}), + /*left_side=*/false, + /*lower=*/true, + /*unit_diagonal=*/false, + /*transpose_a=*/TriangularSolveOptions::ADJOINT); l = UpdateSliceInMinorDims(l, update, {i + k, i}); } } - return Select(seen_error, - FullLike(l, std::numeric_limits::quiet_NaN()), l); + return Select( + BroadcastInDim(seen_error, a_shape.dimensions(), error_dim_indices), + FullLike(l, std::numeric_limits::quiet_NaN()), l); }); } -} // namespace - bool CholeskyExpander::InstructionMatchesPattern(HloInstruction* instruction) { return instruction->opcode() == HloOpcode::kCholesky; } diff --git a/tensorflow/compiler/xla/service/cholesky_expander.h b/tensorflow/compiler/xla/service/cholesky_expander.h index d2958db1b8c..ee8531d0f48 100644 --- a/tensorflow/compiler/xla/service/cholesky_expander.h +++ b/tensorflow/compiler/xla/service/cholesky_expander.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_CHOLESKY_EXPANDER_H_ #include "absl/container/flat_hash_map.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/service/op_expander_pass.h" namespace xla { @@ -31,7 +32,13 @@ class CholeskyExpander : public OpExpanderPass { StatusOr ExpandInstruction( HloInstruction* instruction) override; + virtual StatusOr> CholeskyUnblocked( + XlaOp a, PrecisionConfig::Precision precision); + private: + XlaOp BuildCholesky(XlaOp a, int64 block_size, + PrecisionConfig::Precision precision); + // Mapping from op signatures to existing computations. absl::flat_hash_map computation_cache_; }; diff --git a/tensorflow/compiler/xla/service/collective_ops_utils.h b/tensorflow/compiler/xla/service/collective_ops_utils.h index eb96d537fa8..4eaa9101cc4 100644 --- a/tensorflow/compiler/xla/service/collective_ops_utils.h +++ b/tensorflow/compiler/xla/service/collective_ops_utils.h @@ -82,22 +82,6 @@ struct RendezvousKey { collective_op_kind(collective_op_kind), op_id(op_id) {} - static RendezvousKey FromInstruction( - const RunId& run_id, std::vector global_devices, - int num_local_participants, const HloInstruction* instr) { - CollectiveOpKind collective_op_kind; - int64 op_id; - - std::tie(collective_op_kind, op_id) = - instr->channel_id().has_value() - ? std::make_pair(kCrossModule, instr->channel_id().value()) - : std::make_pair( - kCrossReplica, - static_cast(instr->GetModule()->unique_id())); - return RendezvousKey(run_id, std::move(global_devices), - num_local_participants, collective_op_kind, op_id); - } - template friend H AbslHashValue(H h, const RendezvousKey& k) { return H::combine(std::move(h), k.run_id, k.global_devices, diff --git a/tensorflow/compiler/xla/service/compiler.cc b/tensorflow/compiler/xla/service/compiler.cc index f03b27cdcc7..653f4555a77 100644 --- a/tensorflow/compiler/xla/service/compiler.cc +++ b/tensorflow/compiler/xla/service/compiler.cc @@ -28,14 +28,6 @@ namespace xla { /* static */ tensorflow::mutex Compiler::platform_compiler_mutex_( tensorflow::LINKER_INITIALIZED); -StatusOr< - std::tuple, std::unique_ptr>> -Compiler::RunHloPassesAndBufferAssignement( - std::unique_ptr module, se::StreamExecutor* executor, - se::DeviceMemoryAllocator* device_allocator) { - return Unimplemented("This compiler does not support this method"); -} - std::vector> Compiler::ComputeBackendConfigs(const HloInstruction& hlo, se::StreamExecutor* executor) const { diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h index 312a068ba65..253caac195c 100644 --- a/tensorflow/compiler/xla/service/compiler.h +++ b/tensorflow/compiler/xla/service/compiler.h @@ -188,7 +188,10 @@ class Compiler { std::tuple, std::unique_ptr>> RunHloPassesAndBufferAssignement(std::unique_ptr module, se::StreamExecutor* executor, - se::DeviceMemoryAllocator* device_allocator); + se::DeviceMemoryAllocator* device_allocator, + bool optimize) { + return Unimplemented("This compiler does not support this method"); + } // Compiles the HLO module for execution on a device given by the executor, // and returns an executable object or an error status. No HLO passes are diff --git a/tensorflow/compiler/xla/service/conditional_code_motion.cc b/tensorflow/compiler/xla/service/conditional_code_motion.cc index ce80b4cfc15..855e75a76e0 100644 --- a/tensorflow/compiler/xla/service/conditional_code_motion.cc +++ b/tensorflow/compiler/xla/service/conditional_code_motion.cc @@ -23,17 +23,20 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/strings/str_cat.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/call_graph.h" #include "tensorflow/compiler/xla/service/call_inliner.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_cse.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" +#include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -95,31 +98,63 @@ class BoundaryVisitor { absl::flat_hash_set visited_; }; +template +int64 CountNonLeafOps(const OpCollection& ops) { + absl::flat_hash_set op_set; + for (auto op : ops) { + if (!op_set.contains(op) && op->opcode() != HloOpcode::kConstant) { + op_set.insert(op); + } + } + return op_set.size(); +} + // Returns estimation of potential reuses carried by a given pair of // instructions. Use different integers to classify different levels // of reuses This is used as a placeholder only, assuming all // instructions can be fused to enable data reuses int64 ReusesCarriedBy(HloInstruction* op, HloInstruction* user) { + // Reuses in some way work like forces that pull instructions + // towards each other. We use a number 0-10 to classify how strong the force + // is between a pair of operations. Given a group of instructions that can be + // moved together, if the forces inside a conditional are stronger, the group + // will be moved incide or remain inside the conditional; otherwise, it will + // be moved outside to or remain outside of the conditional. VLOG(2) << "ConditionalCodeMotion: Add reuses carried by instr: " << op->ToString() << "=>" << user->ToString() << "\n"; switch (user->opcode()) { case HloOpcode::kGetTupleElement: - case HloOpcode::kTuple: return 0; + case HloOpcode::kConvert: + // Because convert is treated not moveable when following Dot or + // convolution, here if op is dot or convolution, they must be separated + // by a conditional boundary. Here we do not try to pull convert inside + // conditionals to be together with the dot or convolution. + switch (op->opcode()) { + case HloOpcode::kConvolution: + case HloOpcode::kDot: + return 0; + default: + break; + } + break; default: break; } switch (op->opcode()) { - // These instructions are lightweight and easy to fuse. + // These instructions do not carry weight of reuse themselves. + case HloOpcode::kParameter: case HloOpcode::kConstant: case HloOpcode::kGetTupleElement: return 0; - default: - // Assume fusion will not happen anyway if user count > 1) - if (op->user_count() > 1) { - return 0; - } + case HloOpcode::kConditional: return 10; + default: { + // Assume the reuse decreases with increasing user count. + int count1 = CountNonLeafOps(op->users()); + int count2 = CountNonLeafOps(user->operands()); + return 10 / count1 / count2; + } } } @@ -177,17 +212,35 @@ Status CopyInOrOutOfConditional( absl::InlinedVector new_operands; for (int i = 0; i < op->operands().size(); ++i) { auto op_i = op->operands()[i]; - VLOG(2) << "Looking for operand:" << op_i->ToString() << "\n"; + VLOG(2) << "Looking for " << op_i->ToString() << "\n"; if (ContainsKey(hoisted_instructions, op_i)) { auto new_op_i = FindOrDie(hoisted_instructions, op_i).operands()[dest_index]; - VLOG(2) << "new operand:" << new_op_i->ToString() << "\n"; + VLOG(2) << "new instruction:" << new_op_i->ToString() << "\n"; new_operands.push_back(new_op_i); } else { - CHECK(op_i->opcode() == HloOpcode::kConstant); - auto new_op_i = parent->AddInstruction(op_i->Clone()); - VLOG(2) << "new operand:" << new_op_i->ToString() << "\n"; - new_operands.push_back(new_op_i); + switch (op_i->opcode()) { + case HloOpcode::kConstant: { + auto new_op_i = parent->AddInstruction(op_i->Clone()); + VLOG(2) << "new instruction:" << new_op_i->ToString() << "\n"; + new_operands.push_back(new_op_i); + break; + } + case HloOpcode::kGetTupleElement: { + auto gte = Cast(op_i); + int64 index = gte->tuple_index(); + HloInstruction* root = parent->root_instruction(); + CHECK(root->opcode() == HloOpcode::kTuple && + index < root->operand_count()); + auto new_op_i = root->mutable_operand(index); + VLOG(2) << "new instruction:" << new_op_i->ToString() << "\n"; + new_operands.push_back(new_op_i); + break; + } + default: + LOG(FATAL) << "Unexpected out-of-boundary instruction:" + << op_i->ToString() << "\n"; + } } } HloInstruction* new_instruction = parent->AddInstruction( @@ -298,125 +351,130 @@ StatusOr ConvertSpecialMove(HloInstruction* conditional, return false; } - HloInstruction* old_root = - conditional->branch_computation(0)->root_instruction(); - if (old_root->opcode() != HloOpcode::kTuple) { - return false; - } else { - VLOG(2) << "BEFORE :" << conditional->parent()->parent()->ToString(); - // Identify the gte using `index'. - auto find_gte = [](const HloInstruction* conditional_result, - int64 index) -> HloInstruction* { - for (HloInstruction* instr : conditional_result->users()) { - if (instr->opcode() != HloOpcode::kGetTupleElement) { - return nullptr; - } - if (instr->tuple_index() == index) { - return instr; - } - } - return nullptr; - }; - - // Captures tuple indices refering to converts to be rematerialized/hoisted. - absl::flat_hash_set kspecial_convert = FindSpecialConverts( - old_root, branch_count, conditional, is_layout_sensitive); - - // Exit if we cannot find any converts to be hoisted. - if (kspecial_convert.empty()) { + // Determining whether all branch roots are tuples + for (int branch_num = 0; branch_num < branch_count; ++branch_num) { + HloInstruction* branch_root = + conditional->branch_computation(branch_num)->root_instruction(); + if (branch_root->opcode() != HloOpcode::kTuple) { return false; } + } - TF_RETURN_IF_ERROR( - RestructureConditionalInstruction(conditional->parent(), conditional)); - - for (int branch = 0; branch < branch_count; branch++) { - old_root = conditional->branch_computation(branch)->root_instruction(); - absl::flat_hash_map map_inst_to_tuple_index; - std::vector new_operands(old_root->operand_count()); - absl::flat_hash_set to_hoist_set; - - for (int64 operand_num = 0; operand_num < old_root->operand_count(); - ++operand_num) { - map_inst_to_tuple_index[old_root->mutable_operand(operand_num)] = - operand_num; + HloInstruction* old_root = + conditional->branch_computation(0)->root_instruction(); + VLOG(2) << "BEFORE :" << conditional->parent()->parent()->ToString(); + // Identify the gte using `index'. + auto find_gte = [](const HloInstruction* conditional_result, + int64 index) -> HloInstruction* { + for (HloInstruction* instr : conditional_result->users()) { + if (instr->opcode() != HloOpcode::kGetTupleElement) { + return nullptr; } - for (int64 operand_num = 0; operand_num < old_root->operand_count(); - ++operand_num) { - HloInstruction* hoist = old_root->mutable_operand(operand_num); - if (!kspecial_convert.contains(operand_num)) { - new_operands[operand_num] = old_root->mutable_operand(operand_num); - continue; - } - - to_hoist_set.insert(hoist); - int64 new_tuple_count = old_root->operand_count(); - - // Replace the hoisted instr in the tuple with the operand/operands. - // We will replace at least one of the operands of the hoist at the - // tuple place; the rest will be added at the end. - bool inplace = true; - CHECK(!hoist->operands().empty()); - for (HloInstruction* prod : hoist->operands()) { - if (inplace) { - map_inst_to_tuple_index[prod] = map_inst_to_tuple_index[hoist]; - new_operands[map_inst_to_tuple_index[hoist]] = prod; - inplace = false; - } else { - map_inst_to_tuple_index[prod] = new_tuple_count++; - new_operands.push_back(prod); - } - } + if (instr->tuple_index() == index) { + return instr; } + } + return nullptr; + }; - // Create the new root instruction. - HloComputation* cur_branch = conditional->branch_computation(branch); - HloInstruction* new_branch_root = - cur_branch->AddInstruction(HloInstruction::CreateTuple(new_operands)); - // The shape can vary since the operands to convert are now - // being returned through the branches' root. - cur_branch->set_root_instruction(new_branch_root, true /*new shape*/); - TF_CHECK_OK(cur_branch->RemoveInstruction(old_root)); + // Captures tuple indices refering to converts to be rematerialized/hoisted. + absl::flat_hash_set kspecial_convert = FindSpecialConverts( + old_root, branch_count, conditional, is_layout_sensitive); - // Only one of the branches needs to change the conditional->parent(). - if (branch != 0) { + // Exit if we cannot find any converts to be hoisted. + if (kspecial_convert.empty()) { + return false; + } + + TF_RETURN_IF_ERROR( + RestructureConditionalInstruction(conditional->parent(), conditional)); + + for (int branch = 0; branch < branch_count; branch++) { + old_root = conditional->branch_computation(branch)->root_instruction(); + absl::flat_hash_map map_inst_to_tuple_index; + std::vector new_operands(old_root->operand_count()); + absl::flat_hash_set to_hoist_set; + + for (int64 operand_num = 0; operand_num < old_root->operand_count(); + ++operand_num) { + map_inst_to_tuple_index[old_root->mutable_operand(operand_num)] = + operand_num; + } + for (int64 operand_num = 0; operand_num < old_root->operand_count(); + ++operand_num) { + HloInstruction* hoist = old_root->mutable_operand(operand_num); + if (!kspecial_convert.contains(operand_num)) { + new_operands[operand_num] = old_root->mutable_operand(operand_num); continue; } - HloComputation* conditional_parent = conditional->parent(); - HloInstruction* newconditional = - conditional_parent->AddInstruction(HloInstruction::CreateConditional( - cur_branch->root_instruction()->shape(), - conditional->mutable_operand(0), - absl::MakeSpan(conditional->branch_computations()), - absl::MakeSpan(conditional->operands()).subspan(1))); - // Ensure that all the users of conditional refer to the new one. - TF_RETURN_IF_ERROR( - conditional->ReplaceAllUsesWithDifferentShape(newconditional)); - TF_CHECK_OK(conditional_parent->RemoveInstruction(conditional)); - conditional = newconditional; - // Add the hoisted instructions in the parent. - for (HloInstruction* hoist : to_hoist_set) { - VLOG(2) << "Hoisting instruction:" << hoist->ToString(); - int64 hoist_index = map_inst_to_tuple_index[hoist]; - // Find out the gte that captured the hoisted instr result. - HloInstruction* gte_hoist = find_gte(conditional, hoist_index); - CHECK(gte_hoist != nullptr); - std::vector new_operands; - for (HloInstruction* op : hoist->operands()) { - HloInstruction* gte = conditional_parent->AddInstruction( - HloInstruction::CreateGetTupleElement( - op->shape(), conditional, map_inst_to_tuple_index[op])); - new_operands.push_back(gte); + + to_hoist_set.insert(hoist); + int64 new_tuple_count = old_root->operand_count(); + + // Replace the hoisted instr in the tuple with the operand/operands. + // We will replace at least one of the operands of the hoist at the + // tuple place; the rest will be added at the end. + bool inplace = true; + CHECK(!hoist->operands().empty()); + for (HloInstruction* prod : hoist->operands()) { + if (inplace) { + map_inst_to_tuple_index[prod] = map_inst_to_tuple_index[hoist]; + new_operands[map_inst_to_tuple_index[hoist]] = prod; + inplace = false; + } else { + map_inst_to_tuple_index[prod] = new_tuple_count++; + new_operands.push_back(prod); } - HloInstruction* hoisted = conditional_parent->AddInstruction( - hoist->CloneWithNewOperands(hoist->shape(), new_operands)); - VLOG(2) << "Hoisted instruction in parent:" << hoisted->ToString(); - TF_RETURN_IF_ERROR(gte_hoist->ReplaceAllUsesWith(hoisted)); - TF_CHECK_OK(conditional_parent->RemoveInstruction(gte_hoist)); } - // No need to explicitly delete a hoisted instruction since if its dead - // then the subsequent DCE will remove it. } + + // Create the new root instruction. + HloComputation* cur_branch = conditional->branch_computation(branch); + HloInstruction* new_branch_root = + cur_branch->AddInstruction(HloInstruction::CreateTuple(new_operands)); + // The shape can vary since the operands to convert are now + // being returned through the branches' root. + cur_branch->set_root_instruction(new_branch_root, true /*new shape*/); + TF_CHECK_OK(cur_branch->RemoveInstruction(old_root)); + + // Only one of the branches needs to change the conditional->parent(). + if (branch != 0) { + continue; + } + HloComputation* conditional_parent = conditional->parent(); + HloInstruction* newconditional = + conditional_parent->AddInstruction(HloInstruction::CreateConditional( + cur_branch->root_instruction()->shape(), + conditional->mutable_operand(0), + absl::MakeSpan(conditional->branch_computations()), + absl::MakeSpan(conditional->operands()).subspan(1))); + // Ensure that all the users of conditional refer to the new one. + TF_RETURN_IF_ERROR( + conditional->ReplaceAllUsesWithDifferentShape(newconditional)); + TF_CHECK_OK(conditional_parent->RemoveInstruction(conditional)); + conditional = newconditional; + // Add the hoisted instructions in the parent. + for (HloInstruction* hoist : to_hoist_set) { + VLOG(2) << "Hoisting instruction:" << hoist->ToString(); + int64 hoist_index = map_inst_to_tuple_index[hoist]; + // Find out the gte that captured the hoisted instr result. + HloInstruction* gte_hoist = find_gte(conditional, hoist_index); + CHECK(gte_hoist != nullptr); + std::vector new_operands; + for (HloInstruction* op : hoist->operands()) { + HloInstruction* gte = conditional_parent->AddInstruction( + HloInstruction::CreateGetTupleElement(op->shape(), conditional, + map_inst_to_tuple_index[op])); + new_operands.push_back(gte); + } + HloInstruction* hoisted = conditional_parent->AddInstruction( + hoist->CloneWithNewOperands(hoist->shape(), new_operands)); + VLOG(2) << "Hoisted instruction in parent:" << hoisted->ToString(); + TF_RETURN_IF_ERROR(gte_hoist->ReplaceAllUsesWith(hoisted)); + TF_CHECK_OK(conditional_parent->RemoveInstruction(gte_hoist)); + } + // No need to explicitly delete a hoisted instruction since if its dead + // then the subsequent DCE will remove it. } VLOG(2) << "AFTER :" << conditional->parent()->parent()->ToString(); return true; @@ -446,7 +504,7 @@ StatusOr ConditionalCodeMotion::MoveInstructionOut( << conditional_parent->ToString(HloPrintOptions::Fingerprint()) << "\n"; int64 op_index = 0; - for (Boundary b : new_boundaries) { + for (const Boundary& b : new_boundaries) { HloInstruction* op = b.operands()[0]; CHECK(op != nullptr); VLOG(2) << "Mapping new boundary instr: " << op->ToString() << "\n"; @@ -477,6 +535,7 @@ StatusOr ConditionalCodeMotion::MoveInstructionOut( int64 index = tuple_opd->tuple_index(); CHECK(old_root->operands().size() > index); HloInstruction* old_opd = old_root->operands()[index]; + VLOG(2) << "old opd = " << old_opd << "\n"; CHECK(ContainsKey(hoisted_instructions, old_opd)); HloInstruction* new_opd = hoisted_instructions[old_opd].operands()[0]; CHECK(old_opd != nullptr); @@ -492,7 +551,7 @@ StatusOr ConditionalCodeMotion::MoveInstructionOut( for (int i = 0; i < branch_count; i++) { auto computation = conditional->branch_computation(i); std::vector elements; - for (auto b1 : new_boundaries) { + for (const auto& b1 : new_boundaries) { HloInstruction* op = b1.operands()[i]; CHECK(op != nullptr); VLOG(2) << "Adding to root " << i << " with " << op->ToString() << "\n"; @@ -503,15 +562,24 @@ StatusOr ConditionalCodeMotion::MoveInstructionOut( computation->set_root_instruction(tuple, true); VLOG(2) << "computation is :" << computation->ToString() << "\n"; // Remove hoisted instructions from the branches. - for (auto b2 : to_move_out) { - VLOG(2) << "Removing boundary:" << b2.ToString() << "\n"; - TF_RETURN_IF_ERROR(computation->RemoveInstruction(b2.operands()[i])); + for (const auto& b2 : to_move_out) { + auto instr_to_remove = b2.operands()[i]; + // Double check to make sure it is safe to delete the instruction. + // Complications may arise due to some operations in the alternative + // branches (branches 1..n) being placed into the boundaries multiple + // times. + if (!computation->IsMarkedAsDead(instr_to_remove) && + instr_to_remove->user_count() == 0) { + VLOG(2) << "Removing boundary:" << b2.ToString() << "\n"; + TF_RETURN_IF_ERROR(computation->RemoveInstruction(instr_to_remove)); + } } } // Change conditional instruction shape to the shape of the new root. HloInstruction* new_root = conditional->branch_computation(0)->root_instruction(); *conditional->mutable_shape() = new_root->shape(); + // VLOG(1) << "done moving instructions out of branches\n" << conditional_parent->ToString(HloPrintOptions::Fingerprint()) @@ -535,16 +603,26 @@ StatusOr ConditionalCodeMotion::MoveInstructionIn( absl::flat_hash_map hoisted_instructions; int64 to_move_in_size = to_move_in.size(); int64 branch_count = conditional->branch_count(); + HloGetTupleElementInstruction* tuple_use = + DynCast(to_move_in[0].operands()[0]); + // If use_index is -1, the old conditional root entry used by to_move_in + // instructions still need to be included as an entry of the modified + // conditional root, and the new result of the to_move_in instructions + // need to be added as an extra entry of the modified root; otherwise, the + // old root entry will be replaced with the new result in the modified root. + // The entry replacement should be allowed only if tuple_use has <=1 users. + int64 use_index = (tuple_use != nullptr && tuple_use->user_count() == 1) + ? tuple_use->tuple_index() + : -1; + VLOG(2) << "Tuple use index = " << use_index << "\n"; // Number of old conditional entries still to be used outside. // If conditional shape is not tuple, will create a tuple and use subscript // 0 to save the old operand being used. - int64 op_index = conditional->shape().IsTuple() - ? conditional->shape().tuple_shapes_size() - 1 - : 0; - HloGetTupleElementInstruction* tuple_use = - dynamic_cast(to_move_in[0].operands()[0]); - int64 use_index = (tuple_use != nullptr) ? tuple_use->tuple_index() : -1; - VLOG(2) << "Tuple use index = " << use_index << "\n"; + int64 op_index = + conditional->shape().IsTuple() + ? ((use_index >= 0) ? conditional->shape().tuple_shapes_size() - 1 + : conditional->shape().tuple_shapes_size()) + : 0; // Use to map the tuple_use instruction to its operand; Boundary b_opd_use(Boundary::Position::kInsideBranch); Boundary b_old_root(Boundary::Position::kInsideBranch); @@ -582,6 +660,7 @@ StatusOr ConditionalCodeMotion::MoveInstructionIn( // to replace the conditional directly in the new computation. b_opd_use.mutable_operands().push_back(conditional); } + HloInstruction* new_root = computation->AddInstruction(HloInstruction::CreateTuple(operands)); VLOG(2) << "setting new root: " << new_root->ToString() << "\n"; @@ -592,29 +671,41 @@ StatusOr ConditionalCodeMotion::MoveInstructionIn( } VLOG(2) << "new branch computation: " << computation->ToString() << "\n"; } + // Update get tuple element index of the conditional. + if (use_index != -1) { + for (auto* user : conditional->users()) { + if (user->opcode() == HloOpcode::kGetTupleElement && + user->tuple_index() > use_index) { + user->set_tuple_index(user->tuple_index() - 1); + } + } + } hoisted_instructions[conditional] = b_old_root; int64 cp_start = 0; if (use_index >= 0) { + VLOG(2) << "Mapping GTE: " << tuple_use->ToString() << "\n"; hoisted_instructions[tuple_use] = b_opd_use; - cp_start = 1; } - for (int64 i = cp_start; i < to_move_in_size; i++) { - Boundary b_to_move = to_move_in[i]; + cp_start = (tuple_use != nullptr) ? 1 : 0; + for (int64 to_move_index = cp_start; to_move_index < to_move_in_size; + to_move_index++) { + Boundary b_to_move = to_move_in[to_move_index]; HloInstruction* op = b_to_move.operands()[0]; CHECK(op != nullptr); bool to_be_used_outside = true; VLOG(2) << "Mapping new boundary instr: " << op->ToString() << "\n"; - if (i < to_move_in_size - 1 && op->user_count() == 1 && - op->users()[0] == to_move_in[i + 1].operands()[0]) { + if (to_move_index < to_move_in_size - 1 && op->user_count() == 1 && + op->users()[0] == to_move_in[to_move_index + 1].operands()[0]) { to_be_used_outside = false; VLOG(2) << "Instruction is not to be used outside the branch\n"; } Boundary b(Boundary::Position::kInsideBranch); for (int i = 0; i < branch_count; i++) { auto computation = conditional->branch_computation(i); + VLOG(2) << "Copying to branch: " << i << "\n"; TF_RETURN_IF_ERROR(CopyInOrOutOfConditional(b_to_move, i, computation, hoisted_instructions)); - VLOG(2) << "After Copying to branch: " << computation->ToString() << "\n"; + VLOG(2) << "Done:" << computation->ToString() << "\n"; if (to_be_used_outside) { auto new_op = hoisted_instructions[op].operands()[i]; auto new_root = computation->root_instruction(); @@ -648,12 +739,23 @@ StatusOr ConditionalCodeMotion::MoveInstructionIn( // Remove hoisted instructions from the branches. for (int64 i = to_move_in_size - 1; i >= 0; i--) { Boundary boundary_to_move_in = to_move_in[i]; - VLOG(2) << "Removing boundary:" << boundary_to_move_in.ToString() << "\n"; HloInstruction* op = boundary_to_move_in.operands()[0]; - for (auto user : op->users()) { - VLOG(2) << "Has User: " << user->ToString() << "\n"; + if (op->user_count() == 0) { + VLOG(2) << "Removing boundary:" << boundary_to_move_in.ToString() << "\n"; + TF_RETURN_IF_ERROR(conditional->parent()->RemoveInstruction(op)); + VLOG(2) << "Done removing boundary.\n"; + } + } + + // Reset shapes of user gtes to the new shape. + if (use_index != -1) { + for (auto* user : conditional->users()) { + if (user->opcode() == HloOpcode::kGetTupleElement) { + VLOG(2) << "Resetting shape of user: " << user->ToString() << "\n"; + *user->mutable_shape() = + conditional->shape().tuple_shapes(user->tuple_index()); + } } - TF_RETURN_IF_ERROR(conditional->parent()->RemoveInstruction(op)); } VLOG(1) << "Done moving instructions inside branches\n" << conditional->parent()->ToString(HloPrintOptions::Fingerprint()) @@ -669,16 +771,23 @@ class GroupConnectedBoundaries { HloComputation* conditional_parent_; bool is_layout_sensitive_; // Instructions that have been visited but are not going to be moved. - absl::flat_hash_set visited_; + absl::flat_hash_map& visited_count_; public: - explicit GroupConnectedBoundaries(HloInstruction* conditional, - bool is_layout_sensitive) + explicit GroupConnectedBoundaries( + HloInstruction* conditional, bool is_layout_sensitive, + absl::flat_hash_map& visited_count) : conditional_(conditional), conditional_parent_(conditional->parent()), - is_layout_sensitive_(is_layout_sensitive) {} - // Returns true if `instruction` is worth hoisting out. - bool WorthHoisting(HloInstruction* instruction) { + is_layout_sensitive_(is_layout_sensitive), + visited_count_(visited_count) {} + void clear_recently_visited() { + for (const auto& boundary : new_boundaries_) { + visited_count_.erase(boundary.operands()[0]); + } + } + // Returns true if `instruction` is worth hoisting. + bool WorthHoisting(HloInstruction* instruction, bool is_inside_branch) { // This is needed for the "moving-in" transformation, to prevent the root // of the parent computation (which contains the conditional) to be moved // inside the conditional. @@ -686,6 +795,8 @@ class GroupConnectedBoundaries { instruction == conditional_parent_->root_instruction()) { return false; } + // TOOD[b/169182921] The following cost model is rather incomplete. Will + // need to extend to cover most of element-wise ops. switch (instruction->opcode()) { case HloOpcode::kConvert: // If Convert is after AllReduce, it is worth moving out AllReduce @@ -693,29 +804,44 @@ class GroupConnectedBoundaries { // ops such as Dot or Convolutional, it is better to keep convert // within conditional so that convert can be fused with Dot or // Convolutional. - // - // TODO(b/154283721): figure out the scenario when convert can be - // fused with AllReduce out of conditional. switch (instruction->operand(0)->opcode()) { case HloOpcode::kAllReduce: case HloOpcode::kReshape: + case HloOpcode::kGetTupleElement: return true; default: - VLOG(2) << "Instruction is convert and its operand is not know to " + VLOG(2) << "Instruction is convert and its operand is not known to " "be worth hoisting\n"; return false; } + case HloOpcode::kGetTupleElement: + switch (instruction->operand(0)->opcode()) { + // do not move GTE if its operand is a parameter + case HloOpcode::kParameter: + return false; + default: + return true; + } case HloOpcode::kAllReduce: + // It is not safe to move collective ops from outside to inside + // conditional branches, as it may cause synchronization problems, + // when different layouts are assigned to different branches. + return is_inside_branch; + case HloOpcode::kAbs: + case HloOpcode::kReduce: case HloOpcode::kAdd: case HloOpcode::kPower: + case HloOpcode::kCopy: case HloOpcode::kConstant: case HloOpcode::kSubtract: case HloOpcode::kMultiply: case HloOpcode::kDivide: case HloOpcode::kTuple: case HloOpcode::kSqrt: + case HloOpcode::kRsqrt: case HloOpcode::kReshape: - case HloOpcode::kGetTupleElement: + case HloOpcode::kMinimum: + case HloOpcode::kMaximum: return true; default: VLOG(2) << "Instruction is not known to be worth hoisting\n"; @@ -728,14 +854,20 @@ class GroupConnectedBoundaries { // The operand must be an instruction that is not going to be moved (if // user is inside the conditional); otherwise it must be the conditional // itself and its user must be outside of the conditional. - if (!ContainsKey(visited_, op) && op != conditional_) { + if (!ContainsKey(visited_count_, op) && op != conditional_) { continue; } - // Only consider single-user cases as reuseable. - if (user->opcode() == HloOpcode::kGetTupleElement && - user->user_count() == 1) { + if (auto tuple_gte = DynCast(user)) { + if (op->opcode() == HloOpcode::kConditional) { + auto tuple = op->branch_computation(0)->root_instruction(); + if (tuple->opcode() == HloOpcode::kTuple) { + auto index = tuple_gte->tuple_index(); + CHECK(index < tuple->operand_count()); + op = tuple->mutable_operand(index); + } + } reuses += ReusesCarriedBy(op, user->users()[0]); - } else if (op->user_count() == 1) { + } else { reuses += ReusesCarriedBy(op, user); } } @@ -753,6 +885,7 @@ class GroupConnectedBoundaries { // some aspects of the overall algorithm need to be redesigned to // accommandate the change. if (all_users.size() > 1) { + VLOG(2) << "Having multiple users from: " << user->ToString() << "\n"; return 0; } if (!all_users.empty()) { @@ -774,7 +907,7 @@ class GroupConnectedBoundaries { } } } - } else if (ContainsKey(visited_, op)) { + } else if (ContainsKey(visited_count_, op)) { reuses += ReusesCarriedBy(user, op); } VLOG(2) << "reuses after instruction " << user->ToString() << ":" @@ -822,16 +955,49 @@ class GroupConnectedBoundaries { } return b2; } - int64 CountNonLeafOps(const xla::HloInstruction::InstructionVector& ops) { - int64 count = 0; - absl::flat_hash_set op_set; - for (auto op : ops) { - if (!op_set.contains(op) && op->opcode() != HloOpcode::kConstant) { - count++; - op_set.insert(op); + + // Checking whether it is safe to move a boundary when visited through a + // dependent already considered for moving. + bool IsSafeToMoveBoundary(const Boundary& next_boundary) { + int64 next_boundary_count = + (next_boundary.IsInsideBranch()) + ? next_boundary.operands()[0]->user_count() + : CountNonLeafOps(next_boundary.operands()[0]->operands()); + if (next_boundary_count <= 1) { + // If boundary has only a single or no dependent, safe to move. + return true; + } else { + if (!ContainsKey(visited_count_, next_boundary.operands()[0])) { + VLOG(2) << "Skip next boundary " << next_boundary.ToString() << "\n" + << " because it has multiple dependents: " + << next_boundary_count << "\n"; + visited_count_[next_boundary.operands()[0]] = 1; + new_boundaries_.push_back(next_boundary); + } else { + auto pos = std::find(new_boundaries_.begin(), new_boundaries_.end(), + next_boundary); + if (pos != new_boundaries_.end() || + next_boundary.operands().size() == 1) { + int count = ++visited_count_[next_boundary.operands()[0]]; + if (count == next_boundary_count) { + VLOG(2) << "Recovering next boundary " << next_boundary.ToString() + << "\n" + << " because all of its dependents have been visited: " + << next_boundary_count << "\n"; + visited_count_.erase(next_boundary.operands()[0]); + if (pos != new_boundaries_.end()) { + new_boundaries_.erase(pos); + } + return true; + } + } else { + VLOG(2) << "Skip incompatible multi-dependent boundary: " + << next_boundary.ToString() << ":" << next_boundary_count + << "\n"; + } } } - return count; + return false; } // This function is reused both for moving the boundary outside or into a // conditional. As the result, the readability is somewhat compromised. @@ -846,7 +1012,8 @@ class GroupConnectedBoundaries { VLOG(2) << "visiting boundary " << b.ToString() << "\n"; if ((b.IsOutsideBranch() || InstructionWithinBranchIdentical( b.operands(), is_layout_sensitive_)) && - WorthHoisting(b.operands()[0])) { + IsSafeToMoveBoundary(b) && + WorthHoisting(b.operands()[0], b.IsInsideBranch())) { connected_boundaries_.push_back(b); VLOG(2) << "boundary can be moved\n"; int64 operand_count = (b.IsInsideBranch()) @@ -854,38 +1021,25 @@ class GroupConnectedBoundaries { : b.operands()[0]->users().size(); for (int i = 0; i < operand_count; i++) { Boundary next_boundary = GetNextBoundary(b, i); - int64 next_boundary_count = - (next_boundary.IsInsideBranch()) - ? next_boundary.operands()[0]->user_count() - : CountNonLeafOps(next_boundary.operands()[0]->operands()); - // only consider adding an exclusive producor into the same group. - if (next_boundary_count == 1) { - VLOG(2) << "Add operand " << i << " to visit later\n"; - visitor.AddToWorkList(next_boundary); - } else { - VLOG(2) << "Next boundary " << i - << " has multiple uses: " << next_boundary_count << "\n"; - if (!ContainsKey(visited_, next_boundary.operands()[0])) { - visited_.insert(next_boundary.operands()[0]); - new_boundaries_.push_back(next_boundary); - } - } + VLOG(2) << "Add operand/user " << i << " to visit later\n"; + visitor.AddToWorkList(next_boundary); } } else { VLOG(2) << "boundary cannot be moved\n"; - visited_.insert(b.operands()[0]); + visited_count_[b.operands()[0]] = 1; new_boundaries_.push_back(b); } } } - std::vector BoundariesToMoveInOrOut(const Boundary& b) { + std::vector BoundariesToMoveInOrOut(HloInstruction* conditional, + const Boundary& b) { // At the beginning of optimization, a conditional itself is added to a // worklist. Here the conditional is expanded into two sets of boundaries: // the first set contains the boundary that is inside branches and // contains the root of all branches; the second set of boundaries // contains all the users of the conditional. HloInstruction* inst = b.operands()[0]; - if (inst->opcode() == HloOpcode::kConditional) { + if (inst == conditional) { int branch_count = inst->branch_count(); // Add conditional roots as a new boundary to visit. Boundary boundary_in(Boundary::Position::kInsideBranch); @@ -914,9 +1068,12 @@ class GroupConnectedBoundaries { ConditionalCodeMotion::Decision ConditionalCodeMotion::ConsiderCodeMotion( HloInstruction* conditional, const Boundary& cur_boundary, - std::vector& to_move, std::vector& new_boundaries) { - GroupConnectedBoundaries connect(conditional, is_layout_sensitive_); - auto move_in_or_out = connect.BoundariesToMoveInOrOut(cur_boundary); + std::vector& to_move, std::vector& new_boundaries, + absl::flat_hash_map& visited_count) { + GroupConnectedBoundaries connect(conditional, is_layout_sensitive_, + visited_count); + auto move_in_or_out = + connect.BoundariesToMoveInOrOut(conditional, cur_boundary); if (!move_in_or_out.empty()) { auto benefit = connect.BenefitForMovingBoundaries(move_in_or_out); VLOG(2) << "benefit of moving in or out " @@ -929,16 +1086,37 @@ ConditionalCodeMotion::Decision ConditionalCodeMotion::ConsiderCodeMotion( // at the first entry of the sequence is sufficient to know which // direction the move is intended. to_move = move_in_or_out; - return to_move[0].IsInsideBranch() ? Decision::kMoveOutOfBranch - : Decision::kMoveIntoBranch; + return Decision(to_move[0].IsInsideBranch() + ? Decision::Direction::kMoveOutOfBranch + : Decision::Direction::kMoveIntoBranch, + benefit); + } else { + connect.clear_recently_visited(); } } else { connect.AddNewBoundaries(new_boundaries); } - return ConditionalCodeMotion::Decision::kNoChange; + return Decision(Decision::Direction::kNoChange, 0); } StatusOr ConditionalCodeMotion::Run(HloModule* module) { + VLOG(2) << "Begin a new pass of conditional code motion optimization.\n"; + // Use to support debugging of optimization, by disabling the opt after it has + // been applied a pre-determined times (to isolate impact of transformations). + if (!ConsumeFuel("conditional_code_motion", [&] { + return "Skipping conditional opt after allowed limit reaching 0.\n"; + })) { + return false; + } + bool changed = false; + bool cleanup_changed = false; + { + HloPassPipeline subpipeline("before_conditional_code_motion"); + subpipeline.AddPass(/*is_layout_sensitive=*/is_layout_sensitive_); + subpipeline.AddPass(); + TF_ASSIGN_OR_RETURN(auto cleanup_changed_now, subpipeline.Run(module)); + cleanup_changed |= cleanup_changed_now; + } // Gather all the conditional ops in the module ahead of time, to avoid // potential complications of modifying the code that affecting traversal. std::vector conditional_ops; @@ -956,12 +1134,26 @@ StatusOr ConditionalCodeMotion::Run(HloModule* module) { conditional_computations[branch_i] = 0; } } - conditional_ops.push_back(instr); + if (instr->shape().IsTuple()) { + bool can_change_tuple_shape = true; + for (auto user : instr->users()) { + VLOG(2) << "user is : " << user->ToString() << "\n"; + if (user->opcode() != HloOpcode::kGetTupleElement) { + can_change_tuple_shape = false; + } + } + if (can_change_tuple_shape) { + conditional_ops.push_back(instr); + } + } else { + conditional_ops.push_back(instr); + } } } } - bool changed = false; + // Use to collect mappings between cloned instructions. + HloCloneContext clone_context(module); for (HloInstruction* conditional : conditional_ops) { int branch_count = conditional->branch_count(); // check for shared conditional computations @@ -975,7 +1167,13 @@ StatusOr ConditionalCodeMotion::Run(HloModule* module) { } // Boundaries to move out or to move into the branches. - std::vector to_move_out, to_move_in, new_boundaries; + std::vector > to_move_out, to_move_in; + std::vector > new_boundaries_for_moveout; + std::vector > new_boundaries_for_movein; + // Number of times each instruction has been visited for moving. + absl::flat_hash_map visited_count; + int benefit_move_out = 0, benefit_move_in = 0; + Decision::Direction final_d = Decision::Direction::kNoChange; // The conditional is moved into a worklist as the seed (starting point). // The conditional will be expanded into multiple seeds (starting points), // its roots and its users, when it is visited by GroupConnectedBoundaries. @@ -983,76 +1181,130 @@ StatusOr ConditionalCodeMotion::Run(HloModule* module) { // so that the other seeding boundaries can be visited in turn. BoundaryVisitor visitor(conditional); VLOG(2) << "Analyzing conditional:" << conditional->ToString() << "\n"; - ConditionalCodeMotion::Decision d = Decision::kNoChange; - // The following loop breaks out as soon as a decision to modify the - // conditional is reached --- irrespective of whether visitor is empty. - while (d == Decision::kNoChange && visitor.HasNextBoundary()) { + // Try visit all the boundaries, collect the analysis results, and save + // all the benefitical non-conflicting decisions. If two decisions conflict + // with each other, save the more benefitical one. + while (visitor.HasNextBoundary()) { std::vector to_move, next_boundary; Boundary boundary = visitor.PopNextBoundary(); VLOG(2) << "Analyzing boundary:" << boundary.ToString() << "\n"; - d = ConsiderCodeMotion(conditional, boundary, to_move, next_boundary); - if (d != Decision::kNoChange && conditional_is_shared) { - for (int i = 0; i < branch_count; ++i) { - HloComputation* branch_i = conditional->branch_computation(i); - if (conditional_computations[branch_i] > 0) { - // Cloning is absolutely needed if the computation is shared by - // different branches, but the cloning can be potentially avoided - // if the sharing is only among branches of the same conditional. - // If cloning these branches causes a problem due to space issues, - // a fix can pass a vector of unique branches to the actual - // transformations, as an alternative representation of the - // conditional branches to be modified. Right now we assume the - // overhead of cloning is minimal since later stages of the compiler - // inline all the computations anyway. - HloComputation* clone_i = - conditional->parent()->parent()->AddEmbeddedComputation( - branch_i->Clone()); - conditional->set_branch_computation(i, clone_i); - conditional_computations[branch_i]--; + auto d = ConsiderCodeMotion(conditional, boundary, to_move, next_boundary, + visited_count); + switch (d.GetDirection()) { + case Decision::Direction::kMoveOutOfBranch: + VLOG(2) << "Local Decision is move out of branch\n"; + to_move_out.push_back(to_move); + new_boundaries_for_moveout.push_back(next_boundary); + benefit_move_out += d.GetBenefit(); + if (benefit_move_out >= benefit_move_in) { + final_d = Decision::Direction::kMoveOutOfBranch; + VLOG(2) << "Current Decision is move out of branch (" + << to_move_out.size() << ")\n"; + } else { + VLOG(2) << "Current Decision remains move into branch\n"; } - } - to_move.clear(); - next_boundary.clear(); - VLOG(2) << "Cloned branches as needed: " << conditional->ToString() - << "\n"; - // Need to reanalyze the cloned code to generate correct result. - d = ConsiderCodeMotion(conditional, boundary, to_move, next_boundary); - } - switch (d) { - case Decision::kMoveOutOfBranch: - VLOG(2) << "Decision is move out of branch\n"; - to_move_out.insert(to_move_out.end(), to_move.begin(), to_move.end()); - new_boundaries.insert(new_boundaries.end(), next_boundary.begin(), - next_boundary.end()); break; - case Decision::kMoveIntoBranch: + case Decision::Direction::kMoveIntoBranch: VLOG(2) << "Decision is move into branch\n"; - to_move_in.insert(to_move_in.end(), to_move.begin(), to_move.end()); - new_boundaries.insert(new_boundaries.end(), next_boundary.begin(), - next_boundary.end()); + to_move_in.push_back(to_move); + new_boundaries_for_movein.push_back(next_boundary); + benefit_move_in += d.GetBenefit(); + if (benefit_move_out >= benefit_move_in) { + VLOG(2) << "Current Decision remains move out of branch\n"; + } else { + final_d = Decision::Direction::kMoveIntoBranch; + VLOG(2) << "Current Decision is move into branch (" + << to_move_in.size() << ")\n"; + } break; - case Decision::kNoChange: + case Decision::Direction::kNoChange: VLOG(2) << "Decision is no change\n"; for (const Boundary& b : next_boundary) { visitor.AddToWorkList(b); + VLOG(2) << "Adding new boundary to worklist:" << b.ToString() + << "\n"; } break; } } + // If modification is to be made, need to clone the shared branches. + if (final_d != Decision::Direction::kNoChange && conditional_is_shared) { + for (int i = 0; i < branch_count; ++i) { + HloComputation* branch_i = conditional->branch_computation(i); + if (conditional_computations[branch_i] > 0) { + // Cloning is absolutely needed if the computation is shared by + // different branches, but the cloning can be potentially avoided + // if the sharing is only among branches of the same conditional. + // If cloning these branches causes a problem due to space issues, + // a fix can pass a vector of unique branches to the actual + // transformations, as an alternative representation of the + // conditional branches to be modified. Right now we assume the + // overhead of cloning is minimal since later stages of the compiler + // inline all the computations anyway. + HloComputation* clone_i = + conditional->parent()->parent()->AddEmbeddedComputation( + branch_i->Clone("clone", &clone_context)); + conditional->set_branch_computation(i, clone_i); + conditional_computations[branch_i]--; + // Need to translate the analysis result to generate correct result. + auto update_boundary = [&](Boundary& boundary) { + auto cloned_instr = + clone_context.FindInstruction(boundary.operands()[i]); + CHECK(cloned_instr != nullptr); + VLOG(2) << "boundary before cloning:" << boundary.operands()[i] + << "\n"; + boundary.mutable_operands()[i] = cloned_instr; + VLOG(2) << "boundary after cloning:" << boundary.operands()[i] + << "\n"; + }; + // Only boundaries to move out need to be updated. + if (final_d == Decision::Direction::kMoveOutOfBranch) { + for (int i = 0; i < to_move_out.size(); ++i) { + std::vector& m = to_move_out[i]; + std::for_each(m.begin(), m.end(), update_boundary); + } + for (int i = 0; i < new_boundaries_for_moveout.size(); ++i) { + std::vector& m = new_boundaries_for_moveout[i]; + std::for_each(m.begin(), m.end(), update_boundary); + } + } + } + } + VLOG(2) << "Cloned branches as needed: " << conditional->ToString() + << "\n"; + } // At most one of to_move_out or to_move_in can be non-empty, since there is // only one optimization decision. - if (!to_move_out.empty()) { - TF_ASSIGN_OR_RETURN( - bool result, - MoveInstructionOut(conditional, to_move_out, new_boundaries)); - VLOG(2) << "moving out result:" << result << "\n"; - changed |= result; - } else if (!to_move_in.empty()) { - TF_ASSIGN_OR_RETURN( - bool result, - MoveInstructionIn(conditional, to_move_in, new_boundaries)); - VLOG(2) << "moving in result:" << result << "\n"; - changed |= result; + if (final_d == Decision::Direction::kMoveOutOfBranch) { + CHECK(to_move_out.size() == new_boundaries_for_moveout.size()); + for (int i = 0; i < to_move_out.size(); ++i) { + TF_ASSIGN_OR_RETURN(bool result, + MoveInstructionOut(conditional, to_move_out[i], + new_boundaries_for_moveout[i])); + changed |= result; + } + VLOG(2) << "Done moving out of branches " << to_move_out.size() + << " times. \n"; + if (!ConsumeFuel("conditional_code_motion", [&] { + return "Skipping conditional opt after allowed limit reaching 0.\n"; + })) { + break; + } + } else if (final_d == Decision::Direction::kMoveIntoBranch) { + CHECK(to_move_in.size() == new_boundaries_for_movein.size()); + for (int i = 0; i < to_move_in.size(); ++i) { + TF_ASSIGN_OR_RETURN(bool result, + MoveInstructionIn(conditional, to_move_in[i], + new_boundaries_for_movein[i])); + changed |= result; + } + VLOG(2) << "Done moving into branches " << to_move_in.size() + << " times. \n"; + if (!ConsumeFuel("conditional_code_motion", [&] { + return "Skipping conditional opt after allowed limit reaching 0.\n"; + })) { + break; + } } else if (pursue_full_conditional_code_motion_ && !conditional_is_shared) { // Invoke special handling for convert rematerialization/hoisting // We need to make sure no sharing is present in the branches because no @@ -1061,17 +1313,30 @@ StatusOr ConditionalCodeMotion::Run(HloModule* module) { TF_ASSIGN_OR_RETURN( bool convert_result, ConvertSpecialMove(conditional, is_layout_sensitive_)); + if (convert_result) { + VLOG(2) << "Done special moving of convert\n"; + if (!ConsumeFuel("conditional_code_motion", [&] { + return "Skipping conditional opt after allowed limit reaching " + "0.\n"; + })) { + break; + } + } changed |= convert_result; } } if (changed) { HloPassPipeline subpipeline( "after_conditional_code_motion_after_convert_hoisting"); + VLOG(2) << "starting after motion passes: DCE\n"; subpipeline.AddPass(); subpipeline.AddPass(); subpipeline.AddPass(); - TF_ASSIGN_OR_RETURN(bool cleanup_changed, subpipeline.Run(module)); - changed |= cleanup_changed; + TF_ASSIGN_OR_RETURN(auto cleanup_changed_now, subpipeline.Run(module)); + cleanup_changed |= cleanup_changed_now; + } + if (cleanup_changed) { + VLOG(2) << "subpipeline cleanup have modified code\n"; } return changed; } diff --git a/tensorflow/compiler/xla/service/conditional_code_motion.h b/tensorflow/compiler/xla/service/conditional_code_motion.h index 68a2aa58235..eaec91cfb00 100644 --- a/tensorflow/compiler/xla/service/conditional_code_motion.h +++ b/tensorflow/compiler/xla/service/conditional_code_motion.h @@ -52,6 +52,9 @@ class Boundary { } return res; } + bool operator==(const Boundary& that) { + return ContainersEqual(operands_, that.operands_); + } private: // Boundary instructions in the conditional branches, one from each branch @@ -78,13 +81,30 @@ class ConditionalCodeMotion : public HloModulePass { StatusOr Run(HloModule* module) override; // Optimization decision for each boundary of the conditional instruction. - enum class Decision { kMoveOutOfBranch, kMoveIntoBranch, kNoChange }; + class Decision { + public: + enum class Direction : uint8 { + kMoveOutOfBranch, + kMoveIntoBranch, + kNoChange + }; + + public: + Decision(Direction direction, int benefit) + : direction_(direction), benefit_(benefit) {} + Direction GetDirection() const { return direction_; } + int GetBenefit() const { return benefit_; } + + private: + Direction direction_; + int benefit_; + }; // If the optimization decision is NO_CHANGE, new_boundary is set to nullptr; // otherwise, it is set to the new boundary after proposed optimization. - virtual Decision ConsiderCodeMotion(HloInstruction* conditional, - const Boundary& cur_boundary, - std::vector& to_move, - std::vector& new_boundaries); + virtual Decision ConsiderCodeMotion( + HloInstruction* conditional, const Boundary& cur_boundary, + std::vector& to_move, std::vector& new_boundaries, + absl::flat_hash_map& visited_count); private: const bool is_layout_sensitive_; diff --git a/tensorflow/compiler/xla/service/conditional_code_motion_test.cc b/tensorflow/compiler/xla/service/conditional_code_motion_test.cc index b91f3813980..3b40acf54e3 100644 --- a/tensorflow/compiler/xla/service/conditional_code_motion_test.cc +++ b/tensorflow/compiler/xla/service/conditional_code_motion_test.cc @@ -78,6 +78,52 @@ ENTRY main { EXPECT_THAT(root, AllOf(op::Tuple(op::Convert(), op::GetTupleElement()))); } +TEST_F(ConditionalCodeMotionTest, VerifyConditionalAnalysisWithWhileTuple) { + absl::string_view hlo_string = + R"( +HloModule RemoveDotOpOut + + body { + %p_body = (f32[2], bf16[2], s32[]) parameter(0) + %val = f32[2] get-tuple-element(p_body), index=0 + %val2 = bf16[2] get-tuple-element(p_body), index=1 + %const = s32[] constant(-1) + ROOT root = (f32[2], bf16[], s32[]) tuple(%val, %val2, %const) + } + + condition { + %p_cond = (f32[2], bf16[2], s32[]) parameter(0) + %gte = s32[] get-tuple-element(%p_cond), index=2 + %const = s32[] constant(42) + ROOT result = pred[] compare(%gte, %const), direction=EQ + } + + on_true { + %arg_tuple.1 = f32[2] parameter(0) + %const = s32[] constant(42) + %add.8493 = f32[2] add(f32[2] %arg_tuple.1, f32[2] %arg_tuple.1) + %convert.2894 = bf16[2] convert(f32[2] %add.8493) + ROOT %tuple.1 = (f32[2], bf16[2], s32[]) tuple(%add.8493, %convert.2894, %const) + } + on_false { + %arg_tuple.1 = f32[2] parameter(0) + %const = s32[] constant(42) + %add.8493 = f32[2] add(f32[2] %arg_tuple.1, f32[2] %arg_tuple.1) + %convert.2894 = bf16[2] convert(f32[2] %add.8493) + %while_init = (f32[2], bf16[2], s32[]) tuple(%add.8493, %convert.2894, %const) + ROOT while = (f32[2], bf16[2], s32[]) while(%while_init), condition=condition, body=body + } + ENTRY main { + pred.1 = pred[] parameter(0) + arg_tuple.11 = f32[2] parameter(1) + ROOT conditional = (f32[2], bf16[2], s32[]) conditional(pred.1, arg_tuple.11, arg_tuple.11), true_computation=on_true, false_computation=on_false + } +)"; + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + ConditionalCodeMotion pass(true, true); + ASSERT_FALSE(pass.Run(&*module).ValueOrDie()); +} + TEST_F(ConditionalCodeMotionTest, MoveConvertOutConditionalRoot) { absl::string_view hlo_string = R"( @@ -158,6 +204,44 @@ ENTRY main { EXPECT_THAT(root, AllOf(op::Tuple(op::Convert()))); } +TEST_F(ConditionalCodeMotionTest, ConditionalShapeNotMutable) { + absl::string_view hlo_string = + R"( +HloModule RemoveDotOpOut + +on_true { + %arg_tuple.1 = (f32[93184,4]{1,0}) parameter(0) + %get-tuple-element.1 = f32[93184,4]{1,0} get-tuple-element(%arg_tuple.1), index=0 + %reshape.8493 = f32[2,512,364]{2,1,0} reshape(f32[93184,4]{1,0} %get-tuple-element.1) + %add.8493 = f32[2,512,364]{2,1,0} add(f32[2,512,364]{2,1,0} %reshape.8493, f32[2,512,364]{2,1,0} %reshape.8493) + %convert.2894 = bf16[2,512,364]{2,1,0} convert(f32[2,512,364]{2,1,0} %add.8493) + ROOT %tuple.1 = ( bf16[2,512,364]{2,1,0}) tuple(%convert.2894) +} + +on_false { + %arg_tuple.2 = (f32[93184,4]{1,0}) parameter(0) + %get-tuple-element.3 = f32[93184,4]{1,0} get-tuple-element(%arg_tuple.2), index=0 + %reshape.9717 = f32[2,512,364]{2,1,0} reshape(f32[93184,4]{1,0} %get-tuple-element.3) + %add.8493 = f32[2,512,364]{2,1,0} add(f32[2,512,364]{2,1,0} %reshape.9717, f32[2,512,364]{2,1,0} %reshape.9717) + %sub.8493 = f32[2,512,364]{2,1,0} subtract(f32[2,512,364]{2,1,0} %add.8493, f32[2,512,364]{2,1,0} %reshape.9717) + %convert.3604 = bf16[2,512,364]{2,1,0} convert(f32[2,512,364]{2,1,0} %reshape.9717), metadata={op_type="Cast" op_name="gradients/Cast_125_grad/Cast"} + ROOT %tuple.2 = (bf16[2,512,364]{2,1,0}) tuple(%convert.3604) +} + +ENTRY main { + pred.1 = pred[] parameter(0) + arg_tuple.11 = (f32[93184,4]{1,0}) parameter(1) + arg_tuple.22 = (f32[93184,4]{1,0}) parameter(2) + conditional = (bf16[2,512,364]{2,1,0}) conditional(pred.1, arg_tuple.11, arg_tuple.22), true_computation=on_true, false_computation=on_false + get-first-index = bf16[2,512,364]{2,1,0} get-tuple-element(conditional), index=0 + ROOT result = (bf16[2,512,364]{2,1,0}, (bf16[2,512,364]{2,1,0})) tuple(get-first-index, conditional) +} +)"; + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + ConditionalCodeMotion pass(true, true); + ASSERT_FALSE(pass.Run(&*module).ValueOrDie()); +} + TEST_F(ConditionalCodeMotionTest, MoveConvertOut) { absl::string_view hlo_string = R"( @@ -196,17 +280,16 @@ ENTRY main { const HloInstruction* conditional = FindInstruction(module.get(), "conditional"); const HloComputation* on_true = conditional->branch_computation(0); - ASSERT_EQ(on_true->instruction_count(), 2); + ASSERT_EQ(on_true->instruction_count(), 1); const HloComputation* on_false = conditional->branch_computation(1); - ASSERT_EQ(on_false->instruction_count(), 2); + ASSERT_EQ(on_false->instruction_count(), 1); HloInstruction* root = module->entry_computation()->root_instruction(); EXPECT_THAT( root, - AllOf(op::Tuple(op::Add(op::Convert(op::Reshape(op::GetTupleElement( - op::GetTupleElement(op::Conditional())))), - op::Convert(op::Reshape(op::GetTupleElement( - op::GetTupleElement(op::Conditional())))))))); + AllOf(op::Tuple(op::Add( + op::Convert(op::Reshape(op::GetTupleElement(op::Conditional()))), + op::Convert(op::Reshape(op::GetTupleElement(op::Conditional()))))))); } TEST_F(ConditionalCodeMotionTest, UserShareOperandCannotBeMoved) { @@ -297,7 +380,7 @@ on_false { get-tuple-element.2 = f32[] get-tuple-element(arg_tuple.2), index=0 constant.3 = f32[] constant(1) constant.4 = f32[] constant(2) - add.4 = f32[] add(get-tuple-element.2, constant.3) + add.4 = f32[] add(constant.4, constant.3) add.5 = f32[] add(get-tuple-element.2, constant.4) add.6 = f32[] add(add.4, add.5) ROOT tuple.4 = (f32[]) tuple(add.6) @@ -322,7 +405,7 @@ ENTRY main { const HloComputation* on_true = conditional->branch_computation(0); ASSERT_EQ(on_true->instruction_count(), 1); const HloComputation* on_false = conditional->branch_computation(1); - ASSERT_EQ(on_false->instruction_count(), 1); + ASSERT_EQ(on_false->instruction_count(), 3); HloInstruction* root = module->entry_computation()->root_instruction(); EXPECT_THAT( @@ -505,6 +588,7 @@ ENTRY main { pred.1 = pred[] parameter(0) arg_tuple.3 = (bf16[2,54,168,128], bf16[2,52,168,128]) parameter(1) arg_tuple.4 = (bf16[2,86,104,128], bf16[2,84,104,128]) parameter(2) + arg_tuple.5 = f32[3,3,128,128] parameter(3) conditional = (f32[3,3,128,128]) conditional(pred.1, arg_tuple.3, arg_tuple.4), true_computation=on_true, false_computation=on_false @@ -519,6 +603,7 @@ ENTRY main { ASSERT_TRUE(pass.Run(&*module).ValueOrDie()); const HloInstruction* conditional = FindInstruction(module.get(), "conditional"); + CHECK(conditional != nullptr); const HloComputation* on_true = conditional->branch_computation(0); ASSERT_EQ(on_true->instruction_count(), 5); const HloComputation* on_false = conditional->branch_computation(1); @@ -537,6 +622,89 @@ ENTRY main { op::AllReduce(op::GetTupleElement(op::Conditional()))))))); } +TEST_F(ConditionalCodeMotionTest, DoNotMoveAllReduceIn) { + absl::string_view hlo_string = + R"( +HloModule RemoveIdenticalInstruction + +%add.64 (x.139: bf16[], y.139: bf16[]) -> bf16[] { + %x.139 = bf16[]{:T(512)} parameter(0) + %y.139 = bf16[]{:T(512)} parameter(1) + ROOT %add.44073 = bf16[]{:T(512)} add(bf16[]{:T(512)} %x.139, bf16[]{:T(512)} %y.139) +} + +%add.181 (x.256: bf16[], y.256: bf16[]) -> bf16[] { + %x.256 = bf16[]{:T(512)} parameter(0) + %y.256 = bf16[]{:T(512)} parameter(1) + ROOT %add.44842 = bf16[]{:T(512)} add(bf16[]{:T(512)} %x.256, bf16[]{:T(512)} %y.256) +} + +on_true { + arg_tuple.1 = (bf16[2,54,168,128], bf16[2,52,168,128]) parameter(0) + get-tuple-element.11 = bf16[2,54,168,128] get-tuple-element(arg_tuple.1), index=0 + get-tuple-element.12 = bf16[2,52,168,128] get-tuple-element(arg_tuple.1), index=1 + convolution.1 = bf16[3,3,128,128] convolution(bf16[2,54,168,128] + get-tuple-element.11, bf16[2,52,168,128] + get-tuple-element.12), window={size=52x168 pad=0_0x1_1}, + dim_labels=f01b_i01o->01bf + add.1 = bf16[3,3,128,128] add(bf16[3,3,128,128] convolution.1, bf16[3,3,128,128] convolution.1) + ROOT tuple.1 = (bf16[3,3,128,128]) tuple(add.1) +} + +on_false { + arg_tuple.2 = (bf16[2,86,104,128], bf16[2,84,104,128]) parameter(0) + get-tuple-element.21 = bf16[2,86,104,128] + get-tuple-element(arg_tuple.2), index=0 + get-tuple-element.22 = bf16[2,84,104,128] + get-tuple-element(arg_tuple.2), index=1 + convolution.2 = bf16[3,3,128,128] + convolution(bf16[2,86,104,128] get-tuple-element.21, bf16[2,84,104,128] + get-tuple-element.22), window={size=84x104 pad=0_0x1_1}, + dim_labels=f01b_i01o->01bf + add.2 = bf16[3,3,128,128] add(bf16[3,3,128,128] convolution.2, bf16[3,3,128,128] convolution.2) + ROOT tuple.2 = (bf16[3,3,128,128]) tuple(add.2) +} + +ENTRY main { + pred.1 = pred[] parameter(0) + arg_tuple.3 = (bf16[2,54,168,128], bf16[2,52,168,128]) parameter(1) + arg_tuple.4 = (bf16[2,86,104,128], bf16[2,84,104,128]) parameter(2) + arg_tuple.5 = f32[3,3,128,128] parameter(3) + conditional = (bf16[3,3,128,128]) + conditional(pred.1, arg_tuple.3, arg_tuple.4), true_computation=on_true, + false_computation=on_false + get-first-index = bf16[3,3,128,128] get-tuple-element(conditional), index=0 + all-reduce.2 = bf16[3,3,128,128] + all-reduce(bf16[3,3,128,128] %get-first-index), + channel_id=485, replica_groups={{0,1}}, use_global_device_ids=true, + to_apply=%add.181, metadata={op_type="Conv2DBackpropFilter" + op_name="gradients/resnet50/conv2d_22/Conv2D_grad/Conv2DBackpropFilter"} + convert.2 = f32[3,3,128,128] + convert(bf16[3,3,128,128] %all-reduce.2), + metadata={op_type="Cast" op_name="Cast_15"} + ROOT result = (f32[3,3,128,128]) tuple(convert.2) +} +)"; + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + ConditionalCodeMotion pass(true, true); + ASSERT_FALSE(pass.Run(&*module).ValueOrDie()); + const HloInstruction* conditional = + FindInstruction(module.get(), "conditional"); + CHECK(conditional != nullptr); + const HloComputation* on_true = conditional->branch_computation(0); + ASSERT_EQ(on_true->instruction_count(), 6); + const HloComputation* on_false = conditional->branch_computation(1); + ASSERT_EQ(on_false->instruction_count(), 6); + + // Checks if conditional shape has changed. + ASSERT_TRUE(ShapeUtil::Compatible( + conditional->shape(), ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape( + BF16, {3, 3, 128, 128})}))); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Tuple(op::Convert(op::AllReduce( + op::GetTupleElement(op::Conditional())))))); +} + TEST_F(ConditionalCodeMotionTest, MovePowOpIn) { absl::string_view hlo_string = R"( @@ -581,7 +749,47 @@ ENTRY main { EXPECT_THAT(root, AllOf(op::GetTupleElement(op::Conditional()))); } -TEST_F(ConditionalCodeMotionTest, MovePowInWithSharedBranch) { +TEST_F(ConditionalCodeMotionTest, MoveInWithMultipleGTE) { + absl::string_view hlo_string = + R"( +HloModule RemoveIdenticalInstruction + +on_true { + arg_tuple.1 = (f32[10]) parameter(0) + get-tuple-element.1 = f32[10] get-tuple-element(arg_tuple.1), index=0 + add.1 = f32[10] add(get-tuple-element.1, get-tuple-element.1) + ROOT tuple.3 = (f32[10]) tuple(add.1) +} + +on_false { + arg_tuple.2 = (f32[10]) parameter(0) + get-tuple-element.2 = f32[10] get-tuple-element(arg_tuple.2), index=0 + mul.1 = f32[10] multiply(get-tuple-element.2, get-tuple-element.2) + ROOT tuple.4 = (f32[10]) tuple(mul.1) +} + +ENTRY main { + pred.1 = pred[] parameter(0) + tuple.1 = (f32[10]) parameter(1) + tuple.2 = (f32[10]) parameter(2) + conditional = (f32[10]) + conditional(pred.1, tuple.1, tuple.2), true_computation=on_true, + false_computation=on_false + get-first-index = f32[10] get-tuple-element(conditional), index=0 + get-first-index.2 = f32[10] get-tuple-element(conditional), index=0 + pow.1 = f32[10] power(get-first-index, get-first-index.2) + ROOT tuple.3 = (f32[10], f32[10]) tuple(pow.1, get-first-index.2) +} +)"; + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + ConditionalCodeMotion pass(true, true); + ASSERT_TRUE(pass.Run(&*module).ValueOrDie()); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Tuple(op::GetTupleElement(op::Conditional()), + op::GetTupleElement(op::Conditional()))); +} + +TEST_F(ConditionalCodeMotionTest, MoveOutWithSharedBranch) { absl::string_view hlo_string = R"( HloModule RemoveIdenticalInstruction @@ -610,12 +818,16 @@ ENTRY main { const HloInstruction* conditional = FindInstruction(module.get(), "conditional"); const HloComputation* on_true = conditional->branch_computation(0); - ASSERT_EQ(on_true->instruction_count(), 5); + ASSERT_EQ(on_true->instruction_count(), 1); const HloComputation* on_false = conditional->branch_computation(1); - ASSERT_EQ(on_false->instruction_count(), 5); + ASSERT_EQ(on_false->instruction_count(), 1); HloInstruction* root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, AllOf(op::GetTupleElement(op::Conditional()))); + EXPECT_THAT( + root, AllOf(op::Power(op::Add(op::GetTupleElement(op::Conditional()), + op::GetTupleElement(op::Conditional())), + op::Add(op::GetTupleElement(op::Conditional()), + op::GetTupleElement(op::Conditional()))))); } TEST_F(ConditionalCodeMotionTest, MovePowInWithNonTupleRoot) { @@ -728,6 +940,257 @@ ENTRY main { EXPECT_THAT(root, AllOf(op::GetTupleElement(op::Conditional()))); } +TEST_F(ConditionalCodeMotionTest, MoveCopyInBranch) { + absl::string_view hlo_string = + R"( +HloModule RemoveIdenticalInstruction + +branch1 { + arg_tuple.1 = (s32[], f32[10,3]{0,1}) parameter(0) + constant.1 = s32[] constant(4) + get-tuple-element.1 = s32[] get-tuple-element(arg_tuple.1), index=0 + add.1 = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = f32[10,3]{0,1} get-tuple-element(arg_tuple.1), index=1 + slice.1 = f32[4,3]{0,1} slice(get-tuple-element.2), + slice={[0:4:1], [0:3:1]} + constant.2 = f32[] constant(0.0) + ROOT tuple.1 = (f32[4,3]{0,1}, s32[],f32[]) tuple(slice.1, add.1, constant.2) +} + +branch2 { + arg_tuple.2 = (s32[], f32[4,3]{1,0}) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(arg_tuple.2), index=0 + copy.1 = s32[] copy(get-tuple-element.3) + get-tuple-element.4 = f32[4,3]{1,0} get-tuple-element(arg_tuple.2), index=1 + copy.2 = f32[4,3]{0,1} copy(get-tuple-element.4) + constant.2 = f32[] constant(0.0) + ROOT tuple.2 = (f32[4,3]{0,1}, s32[], f32[]) tuple(copy.2, copy.1, constant.2) +} + +ENTRY main { + pred.1 = pred[] parameter(0) + tuple.3 = (s32[], f32[10,3]{0,1}) parameter(1) + tuple.4 = (s32[], f32[4,3]{1,0}) parameter(2) + conditional = (f32[4,3]{0,1}, s32[], f32[]) + conditional(pred.1, tuple.3, tuple.4), true_computation=branch1, + false_computation=branch2 + get-zero-index = f32[4,3]{0,1} get-tuple-element(conditional), index=0 + get-first-index = s32[] get-tuple-element(conditional), index=1 + get-second-index = f32[] get-tuple-element(conditional), index=2 + copy.3 = f32[4,3]{1,0} copy(get-zero-index) + ROOT tuple.5 = (f32[4,3]{0,1}, s32[], f32[]) tuple(copy.3, get-first-index, + get-second-index) +} +)"; + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + ConditionalCodeMotion pass(true, true); + ASSERT_TRUE(pass.Run(&*module).ValueOrDie()); + VLOG(1) << module->ToString(); + + const HloInstruction* conditional = + FindInstruction(module.get(), "conditional"); + const HloComputation* on_true = conditional->branch_computation(0); + ASSERT_EQ(on_true->instruction_count(), 9); + const HloComputation* on_false = conditional->branch_computation(1); + ASSERT_EQ(on_false->instruction_count(), 8); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, + AllOf(op::Tuple(op::GetTupleElement(op::Conditional(), 2), + op::GetTupleElement(op::Conditional(), 0), + op::GetTupleElement(op::Conditional(), 1)))); +} + +TEST_F(ConditionalCodeMotionTest, MoveReplicatedTupleEntryOut) { + absl::string_view hlo_string = + R"( +HloModule RemoveIdenticalInstruction + +%add.64 (x.139: bf16[], y.139: bf16[]) -> bf16[] { + %x.139 = bf16[]{:T(512)} parameter(0) + %y.139 = bf16[]{:T(512)} parameter(1) + ROOT %add.44073 = bf16[]{:T(512)} add(bf16[]{:T(512)} %x.139, bf16[]{:T(512)} %y.139) +} + +%add.181 (x.256: bf16[], y.256: bf16[]) -> bf16[] { + %x.256 = bf16[]{:T(512)} parameter(0) + %y.256 = bf16[]{:T(512)} parameter(1) + ROOT %add.44842 = bf16[]{:T(512)} add(bf16[]{:T(512)} %x.256, bf16[]{:T(512)} %y.256) +} + +on_true { + arg_tuple.1 = (bf16[2,54,168,128], bf16[2,52,168,128]) parameter(0) + get-tuple-element.11 = bf16[2,54,168,128] get-tuple-element(arg_tuple.1), index=0 + get-tuple-element.12 = bf16[2,52,168,128] get-tuple-element(arg_tuple.1), index=1 + convolution.1 = bf16[3,3,128,128] convolution(bf16[2,54,168,128] + get-tuple-element.11, bf16[2,52,168,128] + get-tuple-element.12), window={size=52x168 pad=0_0x1_1}, + dim_labels=f01b_i01o->01bf + all-reduce.1 = bf16[3,3,128,128] + all-reduce(bf16[3,3,128,128] %convolution.1), + channel_id=188, replica_groups={{0,1}}, use_global_device_ids=true, + to_apply=%add.64 + convert.1 = f32[3,3,128,128] convert(bf16[3,3,128,128] %all-reduce.1) + all-reduce.3 = bf16[3,3,128,128] + all-reduce(bf16[3,3,128,128] %convolution.1), + channel_id=188, replica_groups={{0,1}}, use_global_device_ids=true, + to_apply=%add.64 + convert.3 = f32[3,3,128,128] convert(bf16[3,3,128,128] %all-reduce.3) + ROOT tuple.1 = (f32[3,3,128,128], f32[3,3,128,128]) tuple(convert.1, convert.3) +} + +on_false { + arg_tuple.2 = (bf16[2,86,104,128], bf16[2,84,104,128]) parameter(0) + get-tuple-element.21 = bf16[2,86,104,128] + get-tuple-element(arg_tuple.2), index=0 + get-tuple-element.22 = bf16[2,84,104,128] + get-tuple-element(arg_tuple.2), index=1 + convolution.2 = bf16[3,3,128,128] + convolution(bf16[2,86,104,128] get-tuple-element.21, bf16[2,84,104,128] + get-tuple-element.22), window={size=84x104 pad=0_0x1_1}, + dim_labels=f01b_i01o->01bf + all-reduce.2 = bf16[3,3,128,128] + all-reduce(bf16[3,3,128,128] %convolution.2), + channel_id=485, replica_groups={{0,1}}, use_global_device_ids=true, + to_apply=%add.181 + convert.2 = f32[3,3,128,128] + convert(bf16[3,3,128,128] %all-reduce.2) + ROOT tuple.2 = (f32[3,3,128,128], f32[3,3,128,128]) tuple(convert.2, convert.2) +} + +ENTRY main { + pred.1 = pred[] parameter(0) + arg_tuple.3 = (bf16[2,54,168,128], bf16[2,52,168,128]) parameter(1) + arg_tuple.4 = (bf16[2,86,104,128], bf16[2,84,104,128]) parameter(2) + conditional = (f32[3,3,128,128], f32[3,3,128,128]) + conditional(pred.1, arg_tuple.3, arg_tuple.4), true_computation=on_true, + false_computation=on_false + get-first-index = f32[3,3,128,128] + get-tuple-element(conditional), index=0 + add.1 = f32[3,3,128,128] add(f32[3,3,128,128] get-first-index, f32[3,3,128,128] get-first-index) + ROOT result = (f32[3,3,128,128]) tuple(add.1) +} +)"; + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + ConditionalCodeMotion pass(true, true); + ASSERT_TRUE(pass.Run(&*module).ValueOrDie()); + const HloInstruction* conditional = + FindInstruction(module.get(), "conditional"); + const HloComputation* on_true = conditional->branch_computation(0); + ASSERT_EQ(on_true->instruction_count(), 5); + const HloComputation* on_false = conditional->branch_computation(1); + ASSERT_EQ(on_false->instruction_count(), 5); + + // Checks if conditional shape has changed. + ASSERT_TRUE(ShapeUtil::Compatible( + conditional->shape(), ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape( + BF16, {3, 3, 128, 128})}))); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + AllOf(op::Tuple(op::Add( + op::Convert(op::AllReduce(op::GetTupleElement(op::Conditional()))), + op::Convert( + op::AllReduce(op::GetTupleElement(op::Conditional()))))))); +} + +TEST_F(ConditionalCodeMotionTest, DoNotMoveWithExtraOperand) { + absl::string_view hlo_string = + R"( +HloModule RemoveIdenticalInstruction + +branch { + arg.1 = f32[10] parameter(0) + ROOT add.1 = f32[10] add(arg.1, arg.1) +} + +ENTRY main { + pred.1 = pred[] parameter(0) + tuple.1 = f32[10] parameter(1) + tuple.2 = f32[10] parameter(2) + conditional = f32[10] + conditional(pred.1, tuple.1, tuple.2), true_computation=branch, + false_computation=branch + ROOT pow.1 = f32[10] power(conditional, tuple.2) +} +)"; + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + ConditionalCodeMotion pass(true, true); + ASSERT_FALSE(pass.Run(&*module).ValueOrDie()); +} + +TEST_F(ConditionalCodeMotionTest, MultipleIndependentMoveIns) { + absl::string_view hlo_string = + R"( +HloModule FromNMT + +%add.31755 (x.139: f32[], y.139: bf16[]) -> bf16[] { + %x.139 = bf16[]{:T(512)} parameter(0) + %y.139 = bf16[]{:T(512)} parameter(1) + ROOT %add.44073 = bf16[]{:T(512)} add(bf16[]{:T(512)} %x.139, bf16[]{:T(512)} %y.139) +} + +%nmt.1 { + %wide_param.3 = (bf16[1024,4096]{1,0}, bf16[18,64,1024]{2,1,0}, s32[]) parameter(0) + %get-tuple-element.16525 = bf16[1024,4096]{1,0} get-tuple-element((bf16[1024,4096]{1,0}, bf16[18,64,1024]{2,1,0}, s32[]) %wide_param.3), index=0 + %get-tuple-element.16527 = bf16[18,64,1024]{2,1,0} get-tuple-element((bf16[1024,4096]{1,0}, bf16[18,64,1024]{2,1,0}, s32[]) %wide_param.3), index=1 + %get-tuple-element.16588 = s32[] get-tuple-element((bf16[1024,4096]{1,0}, bf16[18,64,1024]{2,1,0}, s32[]) %wide_param.3), index=2 + %add.3764 = s32[] add(s32[] %get-tuple-element.16588, s32[] %get-tuple-element.16588), metadata={op_type="Sub" op_name="sub"} + %reshape.9821 = s32[1]{0} reshape(s32[] %add.3764) + %reshape.9822 = s32[] reshape(s32[1]{0} %reshape.9821) + %constant.13127 = s32[] constant(0) + %dynamic-slice.1245 = bf16[1,64,1024]{2,1,0} dynamic-slice(bf16[18,64,1024]{2,1,0} %get-tuple-element.16527, s32[] %reshape.9822, s32[] %constant.13127, s32[] %constant.13127), dynamic_slice_sizes={1,64,1024} + %reshape.9825 = bf16[64,1024]{1,0} reshape(bf16[1,64,1024]{2,1,0} %dynamic-slice.1245), metadata={op_type="GatherV2" op_name="GatherV2"} + %logistic.814 = bf16[64,1024]{1,0} logistic(bf16[64,1024]{1,0} %reshape.9825), metadata={op_type="Sigmoid" op_name="Sigmoid"} + %multiply.4890 = bf16[64,1024]{1,0} multiply(bf16[64,1024]{1,0} %reshape.9825, bf16[64,1024]{1,0} %logistic.814), metadata={op_type="Mul" op_name="mul"} + %tanh.573 = bf16[64,1024]{1,0} tanh(bf16[64,1024]{1,0} %reshape.9825), metadata={op_type="Tanh" op_name="Tanh"} + %multiply.4891 = bf16[64,1024]{1,0} multiply(bf16[64,1024]{1,0} %logistic.814, bf16[64,1024]{1,0} %tanh.573), metadata={op_type="Mul" op_name="mul_1"} + %add.3766 = bf16[64,1024]{1,0} add(bf16[64,1024]{1,0} %multiply.4890, bf16[64,1024]{1,0} %multiply.4891), metadata={op_type="AddV2" op_name="add_1"} + %multiply.4894 = bf16[64,1024]{1,0} multiply(bf16[64,1024]{1,0} %add.3766, bf16[64,1024]{1,0} %logistic.814), metadata={op_type="Mul" op_name="gradients_1/mul_grad/Mul"} + %constant.10568 = bf16[] constant(1), metadata={op_type="TanhGrad" op_name="gradients/Tanh_1_grad/TanhGrad"} + %broadcast.7198 = bf16[64,1024]{1,0} broadcast(bf16[] %constant.10568), dimensions={}, metadata={op_type="TanhGrad" op_name="gradients/Tanh_1_grad/TanhGrad"} + %multiply.4896 = bf16[64,1024]{1,0} multiply(bf16[64,1024]{1,0} %tanh.573, bf16[64,1024]{1,0} %tanh.573), metadata={op_type="TanhGrad" op_name="gradients/Tanh_1_grad/TanhGrad"} + %constant.10571 = bf16[] constant(1), metadata={op_type="SigmoidGrad" op_name="gradients/Sigmoid_grad/SigmoidGrad"} + %broadcast.7201 = bf16[64,1024]{1,0} broadcast(bf16[] %constant.10571), dimensions={}, metadata={op_type="SigmoidGrad" op_name="gradients/Sigmoid_grad/SigmoidGrad"} + %subtract.1702 = bf16[64,1024]{1,0} subtract(bf16[64,1024]{1,0} %broadcast.7201, bf16[64,1024]{1,0} %logistic.814), metadata={op_type="SigmoidGrad" op_name="gradients/Sigmoid_grad/SigmoidGrad"} + %multiply.4907 = bf16[64,1024]{1,0} multiply(bf16[64,1024]{1,0} %tanh.573, bf16[64,1024]{1,0} %add.3766), metadata={op_type="Mul" op_name="gradients/mul_2_grad/Mul_1"} + %multiply.4908 = bf16[64,1024]{1,0} multiply(bf16[64,1024]{1,0} %multiply.4907, bf16[64,1024]{1,0} %logistic.814), metadata={op_type="SigmoidGrad" op_name="gradients/Sigmoid_2_grad/SigmoidGrad"} + %dot.781 = bf16[64,4096]{1,0} dot(bf16[64,1024]{1,0} %multiply.4908, bf16[1024,4096]{1,0} %get-tuple-element.16525), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_type="MatMul" op_name="MatMul"} + ROOT %tuple.3200 = (bf16[64,1024]{1,0}, bf16[64,4096]{1,0}, s32[]) tuple(bf16[64,1024]{1,0} %multiply.4894, bf16[64,4096]{1,0} %dot.781, s32[] %reshape.9822) + } +ENTRY main { + pred.1 = pred[] parameter(0) + arg_tuple.3 = (bf16[1024,4096]{1,0}, bf16[18,64,1024]{2,1,0}, s32[]) parameter(1) + arg_tuple.4 = (bf16[1024,4096]{1,0}, bf16[18,64,1024]{2,1,0}, s32[]) parameter(2) + %arg.2 = s32[] parameter(3) + %conditional.3 = (bf16[64,1024]{1,0}, bf16[64,4096]{1,0}, s32[]) conditional(pred.1, arg_tuple.3, arg_tuple.4), true_computation=nmt.1, false_computation=nmt.1 + %get-tuple-element.15889 = bf16[64,1024]{1,0} get-tuple-element((bf16[64,1024]{1,0}, bf16[64,4096]{1,0}, s32[]) %conditional.3), index=0, metadata={op_type="Case" op_name="switch_case/indexed_case"} + %multiply.4596 = bf16[64,1024]{1,0} multiply(bf16[64,1024]{1,0} %get-tuple-element.15889, bf16[64,1024]{1,0} %get-tuple-element.15889), metadata={op_type="L2Loss" op_name="global_norm/L2Loss"} + %constant.10279 = bf16[] constant(0), metadata={op_type="L2Loss" op_name="global_norm/L2Loss"} + %reduce.844 = bf16[] reduce(bf16[64,1024]{1,0} %multiply.4596, bf16[] %constant.10279), dimensions={0,1}, to_apply=%add.31755, metadata={op_type="L2Loss" op_name="global_norm/L2Loss"} + %get-tuple-element.15890 = bf16[64,4096]{1,0} get-tuple-element((bf16[64,1024]{1,0}, bf16[64,4096]{1,0}, s32[]) %conditional.3), index=1, metadata={op_type="Case" op_name="switch_case/indexed_case"} + %multiply.4597 = bf16[64,4096]{1,0} multiply(bf16[64,4096]{1,0} %get-tuple-element.15890, bf16[64,4096]{1,0} %get-tuple-element.15890), metadata={op_type="L2Loss" op_name="global_norm/L2Loss"} + %constant.10280 = bf16[] constant(0), metadata={op_type="L2Loss" op_name="global_norm/L2Loss"} + %reduce.845 = bf16[] reduce(bf16[64,4096]{1,0} %multiply.4597, bf16[] %constant.10280), dimensions={0,1}, to_apply=%add.31755, metadata={op_type="L2Loss" op_name="global_norm/L2Loss"} + %multiply.4667 = bf16[] multiply(bf16[] %reduce.845, bf16[]{:T(128)} %reduce.844), metadata={op_type="L2Loss" op_name="global_norm/L2Loss"} + ROOT %tuple.3200 = (bf16[], s32[]) tuple(%multiply.4667, s32[] %arg.2) + } +)"; + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + ConditionalCodeMotion pass(true, true); + ASSERT_TRUE(pass.Run(&*module).ValueOrDie()); + const HloInstruction* conditional = + FindInstruction(module.get(), "conditional.3"); + CHECK(conditional != nullptr); + const HloComputation* on_true = conditional->branch_computation(0); + ASSERT_EQ(on_true->instruction_count(), 27); + const HloComputation* on_false = conditional->branch_computation(1); + ASSERT_EQ(on_false->instruction_count(), 27); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Tuple(op::GetTupleElement(op::Conditional()), + op::Parameter()))); +} + } // namespace conditional_opt } // namespace xla diff --git a/tensorflow/compiler/xla/service/convolution_group_converter.cc b/tensorflow/compiler/xla/service/convolution_group_converter.cc index 323bf44dcd3..f5506b894fd 100644 --- a/tensorflow/compiler/xla/service/convolution_group_converter.cc +++ b/tensorflow/compiler/xla/service/convolution_group_converter.cc @@ -300,7 +300,8 @@ Status ConvolutionVisitor::HandleBatchGroupCount(HloInstruction* convolution) { window_dim->set_window_dilation(1); HloInstruction* new_convolution = MakeConvolveHlo(activation, filter, convolution->feature_group_count(), - window, dim_numbers, convolution->precision_config()) + /*batch_group_count=*/1, window, dim_numbers, + convolution->precision_config()) .ValueOrDie(); convolution->SetupDerivedInstruction(new_convolution); TF_CHECK_OK(computation_->ReplaceInstruction( @@ -649,7 +650,8 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { window_dim->set_window_reversal(false); window_dim->set_window_dilation(1); HloInstruction* new_convolution = - MakeConvolveHlo(activation, filter, 1, window, dim_numbers, + MakeConvolveHlo(activation, filter, /*feature_group_count=*/1, + /*batch_group_count=*/1, window, dim_numbers, convolution->precision_config()) .ValueOrDie(); convolution->SetupDerivedInstruction(new_convolution); diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index b88120d8128..e313dbe2415 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -217,10 +217,9 @@ bool IndicesToCopyForConditional(const HloDataflowAnalysis& dataflow, // Add kCopy instructions around the given kWhile instruction to eliminate any // possible live range interference of HLO values assuming a dependency-based -// ordering (HloDependencyOrdering). Copies are added conservatively. There -// likely are copies which are not strictly necessary, but they are removed -// later in the pass via RemoveUnnecessaryCopies. -// +// ordering. Copies are added conservatively. There likely are copies which are +// not strictly necessary, but they are removed later in the pass via +// RemoveUnnecessaryCopies. // // Elements (each ShapeIndex) in the loop state are considered independently. A // copy is added to each element of the loop state which is modified in the @@ -362,6 +361,19 @@ Status AddCopiesForConditional(const HloAliasAnalysis& alias_analysis, return Status::OK(); } +// Add copies for the operands of in-place operations. RemoveUnnecessaryCopies +// will remove the unnecessary copies. +Status AddCopiesForInPlaceOperation(const HloAliasAnalysis& alias_analysis, + HloInstruction* in_place_op, + int64 operand_number) { + VLOG(2) << "Adding copies for in-place operation " << in_place_op->name(); + HloInstruction* operand = in_place_op->mutable_operand(operand_number); + TF_ASSIGN_OR_RETURN(HloInstruction * deep_copy, + in_place_op->parent()->DeepCopyInstruction(operand)); + TF_RETURN_IF_ERROR(operand->ReplaceUseWith(in_place_op, deep_copy)); + return Status::OK(); +} + // Conservatively adds copies before root instruction of entry computation and // each aliased parameter to resolve interference of aliased input and output // buffer. We later rely on RemoveUnnecessaryCopies to drop the unnecessary @@ -509,6 +521,12 @@ class CopyRemover { // value. The map is used to construct the copy info map below. absl::flat_hash_map value_to_node; for (const HloBuffer& buffer : alias_analysis.buffers()) { + // No copies should have been inserted within fused computations, so no + // need to remove them. HloOrdering isn't compatible with HloValues inside + // fusions, so skip copy removal for them. + if (buffer.values().at(0)->defining_instruction()->IsFused()) { + continue; + } // Verify values contained in the buffer are strictly ordered. This // should always be the case after adding copies to eliminate // interference. Specifically, the addition of the control flow edges @@ -591,7 +609,7 @@ class CopyRemover { void CreateCopyMap( const HloModule& module, const absl::flat_hash_map& value_to_node) { - for (HloComputation* computation : module.computations()) { + for (HloComputation* computation : module.MakeNonfusionComputations()) { for (HloInstruction* instruction : computation->instructions()) { // Add copies with unambiguous source values to the map. Copies with // ambiguous sources are not removable. @@ -858,30 +876,13 @@ class CopyRemover { // We cannot use LiveRangeStrictlyBefore because HloValue::uses() is not // updated as copies are removed. bool LiveRangeBefore(const ValueNode& a, const ValueNode& b) { - VLOG(3) << "Checking live range of " << *a.value << " WRT " << *b.value; - bool is_live_range_before = [&] { - if (a.uses.empty()) { - VLOG(2) << "Empty uses for " << *a.value; - return ordering_.IsDefinedBefore(*a.value, *b.value); - } - for (const HloUse* use : a.uses) { - VLOG(3) << "Checking use " << *use << " against " << *b.value; - if (!ordering_.UseIsBeforeValueDefinition(*use, *b.value, dataflow_)) { - VLOG(2) << "Use " << *use << " is NOT before " << *b.value; - return false; - } - VLOG(3) << "Use " << *use << " is before " << *b.value; - } - return true; - }(); - if (is_live_range_before) { - VLOG(2) << " Live range of " << a.value->ToShortString() << " is before " - << b.value->ToShortString(); - } else { - VLOG(2) << " Live range of " << a.value->ToShortString() - << " is not before " << b.value->ToShortString(); + if (a.uses.empty()) { + VLOG(2) << "Empty uses for " << *a.value; + return ordering_.IsDefinedBefore(*a.value, *b.value); } - return is_live_range_before; + return absl::c_all_of(a.uses, [&](const HloUse* use) { + return ordering_.UseIsBeforeValueDefinition(*use, *b.value, dataflow_); + }); } // Returns whether 'node' is the last node in its list. @@ -1005,7 +1006,7 @@ Status CopyInsertion::AddCopiesToResolveInterference(HloModule* module) { TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, HloAliasAnalysis::Run(module, can_share_buffer_)); - for (HloComputation* computation : module->MakeComputationPostOrder()) { + for (HloComputation* computation : module->MakeNonfusionComputations()) { for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) { if (instruction->opcode() == HloOpcode::kWhile) { @@ -1013,6 +1014,15 @@ Status CopyInsertion::AddCopiesToResolveInterference(HloModule* module) { } else if (instruction->opcode() == HloOpcode::kConditional) { TF_RETURN_IF_ERROR( AddCopiesForConditional(*alias_analysis, instruction)); + } else { + for (const auto& operand_and_output_index : + HloDataflowAnalysis::GetInPlaceInputOutputPairs(instruction)) { + const HloUse& operand = operand_and_output_index.first; + CHECK_EQ(operand.operand_index, ShapeIndex{}) + << "Support for non-{} shape operand not currently implemented."; + TF_RETURN_IF_ERROR(AddCopiesForInPlaceOperation( + *alias_analysis, instruction, operand.operand_number)); + } } } } diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc index 3ee6b200da5..78730cbdcb8 100644 --- a/tensorflow/compiler/xla/service/copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc @@ -2530,5 +2530,250 @@ ENTRY Entry { EXPECT_EQ(CountCopies(*module), 1); } +TEST_F(CopyInsertionTest, DynamicUpdateSliceNoCopy) { + absl::string_view hlo_string = R"( +HloModule Module + +ENTRY main { + param = f32[1280,1,128] parameter(0) + negate = f32[1280,1,128] negate(param) + constant.1 = f32[] constant(0) + broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={} + constant.3 = s32[] constant(0) + ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(negate, broadcast.6, constant.3, constant.3, constant.3) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + InsertCopies(module.get()); + EXPECT_EQ(CountCopies(*module), 0); +} + +TEST_F(CopyInsertionTest, FusedDynamicUpdateSliceNoCopy) { + absl::string_view hlo_string = R"( +HloModule Module + +fused_computation { + param0 = f32[1280,1,128] parameter(0) + constant.1 = f32[] constant(0) + broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={} + constant.3 = s32[] constant(0) + ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param0, broadcast.6, constant.3, constant.3, constant.3) +} + +ENTRY main { + param = f32[1280,1,128] parameter(0) + negate = f32[1280,1,128] negate(param) + ROOT fusion = f32[1280,1,128] fusion(negate), kind=kLoop, calls=fused_computation +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + InsertCopies(module.get()); + EXPECT_EQ(CountCopies(*module), 0); +} + +TEST_F(CopyInsertionTest, DynamicUpdateSliceCopy) { + absl::string_view hlo_string = R"( +HloModule Module + +ENTRY main { + param = f32[1280,1,128] parameter(0) + negate = f32[1280,1,128] negate(param) + constant.1 = f32[] constant(0) + broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={} + constant.3 = s32[] constant(0) + add = f32[1280,1,128] add(negate, negate) + dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(negate, broadcast.6, constant.3, constant.3, constant.3) + ROOT tuple = (f32[1280,1,128], f32[1280,1,128]) tuple(add, dynamic-update-slice.5) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + InsertCopies(module.get()); + EXPECT_EQ(CountCopies(*module), 1); +} + +TEST_F(CopyInsertionTest, DynamicUpdateSliceParameterShareCopy) { + absl::string_view hlo_string = R"( +HloModule Module + +ENTRY main { + param = f32[1280,1,128] parameter(0) + constant.1 = f32[] constant(0) + broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={} + constant.3 = s32[] constant(0) + ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param, broadcast.6, constant.3, constant.3, constant.3) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + InsertCopies(module.get()); + EXPECT_EQ(CountCopies(*module), 1); +} + +TEST_F(CopyInsertionTest, FusedDynamicUpdateSliceCopy) { + absl::string_view hlo_string = R"( +HloModule Module + +fused_computation { + param0 = f32[1280,1,128] parameter(0) + constant.1 = f32[] constant(0) + broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={} + constant.3 = s32[] constant(0) + ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param0, broadcast.6, constant.3, constant.3, constant.3) +} + +ENTRY main { + param = f32[1280,1,128] parameter(0) + negate = f32[1280,1,128] negate(param) + add = f32[1280,1,128] add(negate, negate) + fusion = f32[1280,1,128] fusion(negate), kind=kLoop, calls=fused_computation + ROOT tuple = (f32[1280,1,128], f32[1280,1,128]) tuple(negate, fusion) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + InsertCopies(module.get()); + EXPECT_EQ(CountCopies(*module), 1); +} + +TEST_F(CopyInsertionTest, ChainDynamicUpdateSliceCopy) { + absl::string_view hlo_string = R"( +HloModule Module + +ENTRY main { + state = (s32[], f32[1280,1,128]{2,1,0}) parameter(0) + constant.1 = f32[] constant(0) + broadcast.6 = f32[128,1,128]{2,1,0} broadcast(constant.1), dimensions={} + get-tuple-element.4 = f32[1280,1,128]{2,1,0} get-tuple-element(state), index=1 + get-tuple-element.3 = s32[] get-tuple-element(state), index=0 + constant.2 = s32[] constant(128) + add.5 = s32[] add(get-tuple-element.3, constant.2) + constant.3 = s32[] constant(0) + dynamic-update-slice.5 = f32[1280,1,128]{2,1,0} dynamic-update-slice(get-tuple-element.4, broadcast.6, constant.3, constant.3, constant.3) + dynamic-update-slice.9 = f32[1280,1,128]{2,1,0} dynamic-update-slice(dynamic-update-slice.5, broadcast.6, constant.3, constant.3, constant.3) + ROOT tuple.85 = (s32[], s32[], s32[2]{0}, f32[1280,1,128]{2,1,0}) tuple(add.5, dynamic-update-slice.9) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + InsertCopies(module.get()); + EXPECT_EQ(CountCopies(*module), 1); +} + +TEST_F(CopyInsertionTest, FusedDynamicUpdateSliceCopy2) { + absl::string_view hlo_string = R"( +HloModule Module + +fused_computation.1 { + param0 = f32[1280,1,128] parameter(0) + constant.1 = f32[] constant(0) + broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={} + constant.3 = s32[] constant(0) + ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param0, broadcast.6, constant.3, constant.3, constant.3) +} + +fused_computation.2 { + param0 = f32[1280,1,128] parameter(0) + param1 = f32[1280,1,128] parameter(1) + slice = f32[128,1,128] slice(param1), slice={[0:128], [0:1], [0:128]} + constant.3 = s32[] constant(0) + ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param0, slice, constant.3, constant.3, constant.3) +} + +ENTRY main { + param = f32[1280,1,128] parameter(0) + negate = f32[1280,1,128] negate(param) + add = f32[1280,1,128] add(negate, negate) + fusion1 = f32[1280,1,128] fusion(negate), kind=kLoop, calls=fused_computation.1 + ROOT fusion2 = f32[1280,1,128] fusion(fusion1, negate), kind=kLoop, calls=fused_computation.2 +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + InsertCopies(module.get()); + EXPECT_EQ(CountCopies(*module), 1); +} + +TEST_F(CopyInsertionTest, MultiOutputFusedDynamicUpdateSliceCopy) { + // Tests multi-output fusion with two DUS outputs, requiring two copies. + absl::string_view hlo_string = R"( +HloModule Module + +fused_computation { + param0 = f32[1280,1,128] parameter(0) + param1 = f32[1280,1,128] parameter(1) + param2 = f32[1280,1,128] parameter(2) + constant.1 = f32[] constant(0) + broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={} + constant.3 = s32[] constant(0) + add.1 = f32[1280,1,128] add(param0, param0) + dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param1, broadcast.6, constant.3, constant.3, constant.3) + dynamic-update-slice.6 = f32[1280,1,128] dynamic-update-slice(param2, broadcast.6, constant.3, constant.3, constant.3) + ROOT tuple.1 = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) tuple(add.1, dynamic-update-slice.5, dynamic-update-slice.6) +} + +ENTRY main { + param = f32[1280,1,128] parameter(0) + negate0 = f32[1280,1,128] negate(param) + negate1 = f32[1280,1,128] negate(param) + negate2 = f32[1280,1,128] negate(param) + fusion = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) fusion(negate0, negate1, negate2), kind=kLoop, calls=fused_computation + gte0 = f32[1280,1,128] get-tuple-element(fusion), index=0 + gte1 = f32[1280,1,128] get-tuple-element(fusion), index=1 + gte2 = f32[1280,1,128] get-tuple-element(fusion), index=2 + add0 = f32[1280,1,128] add(negate0, gte0) + add1 = f32[1280,1,128] add(negate1, gte1) + add2 = f32[1280,1,128] add(negate2, gte2) + ROOT tuple = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) tuple(add0, add1, add2) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + InsertCopies(module.get()); + EXPECT_EQ(CountCopies(*module), 2); +} + +TEST_F(CopyInsertionTest, MultiOutputFusedDynamicUpdateSliceNoCopy) { + // Same as above, but negate1 is not used beyond fusion, so it only needs one + // copy for negate0. + absl::string_view hlo_string = R"( +HloModule Module + +fused_computation { + param0 = f32[1280,1,128] parameter(0) + param1 = f32[1280,1,128] parameter(1) + param2 = f32[1280,1,128] parameter(2) + constant.1 = f32[] constant(0) + broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={} + constant.3 = s32[] constant(0) + add.1 = f32[1280,1,128] add(param0, param0) + dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param1, broadcast.6, constant.3, constant.3, constant.3) + dynamic-update-slice.6 = f32[1280,1,128] dynamic-update-slice(param2, broadcast.6, constant.3, constant.3, constant.3) + ROOT tuple.1 = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) tuple(add.1, dynamic-update-slice.5, dynamic-update-slice.6) +} + +ENTRY main { + param = f32[1280,1,128] parameter(0) + negate0 = f32[1280,1,128] negate(param) + negate1 = f32[1280,1,128] negate(param) + negate2 = f32[1280,1,128] negate(param) + fusion = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) fusion(negate0, negate1, negate2), kind=kLoop, calls=fused_computation + gte0 = f32[1280,1,128] get-tuple-element(fusion), index=0 + gte1 = f32[1280,1,128] get-tuple-element(fusion), index=1 + gte2 = f32[1280,1,128] get-tuple-element(fusion), index=2 + add0 = f32[1280,1,128] add(negate0, gte0) + add1 = f32[1280,1,128] add(gte1, gte1) + add2 = f32[1280,1,128] add(negate2, gte2) + ROOT tuple = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) tuple(add0, add1, add2) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + InsertCopies(module.get()); + EXPECT_EQ(CountCopies(*module), 1); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index b622b712f82..0cc27e32749 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -1,11 +1,15 @@ # Description: # LLVM-based CPU backend for XLA. +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow/compiler/xla:xla.bzl", "ORC_JIT_MEMORY_MAPPER_TARGETS") load( "//third_party/mkl:build_defs.bzl", "mkl_deps", ) + +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "filegroup") load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test", "tf_openmp_copts") load(":build_defs.bzl", "runtime_copts") load("//tensorflow/core/platform:build_config.bzl", "if_llvm_system_z_available") @@ -87,7 +91,7 @@ cc_library( "//tensorflow/compiler/xla/service:generic_transfer_manager", "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "//tensorflow/stream_executor", "@com_google_absl//absl/base", "@com_google_absl//absl/memory", @@ -130,11 +134,14 @@ cc_library( ":target_machine_features", "@com_google_absl//absl/base", "@com_google_absl//absl/types:span", + "@llvm-project//mlir:Affine", "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", - "@llvm-project//mlir:ExecutionEngineUtils", "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:LinalgOps", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:VectorOps", "//tensorflow/compiler/xla/service:copy_insertion", - "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:dump", "//tensorflow/compiler/xla/service:topk_rewriter", "//tensorflow/compiler/xla/service:map_inliner", @@ -161,6 +168,7 @@ cc_library( "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:call_inliner", "//tensorflow/compiler/xla/service:cholesky_expander", + "//tensorflow/compiler/xla/service:qr_expander", "//tensorflow/compiler/xla/service:conditional_simplifier", "//tensorflow/compiler/xla/service:convolution_group_converter", "//tensorflow/compiler/xla/service:dot_decomposer", @@ -196,9 +204,8 @@ cc_library( "//tensorflow/compiler/xla/service:zero_sized_hlo_elimination", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "@llvm-project//llvm:Core", - "@llvm-project//llvm:MC", "@llvm-project//llvm:Object", "@llvm-project//llvm:Support", "@llvm-project//llvm:Target", @@ -314,11 +321,11 @@ cc_library( "//tensorflow/compiler/xla/service:maybe_owning_device_memory", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core/platform:logging", "//tensorflow/core/platform:macros", "//tensorflow/core/platform:mutex", "//tensorflow/core/platform:platform_port", + "//tensorflow/core/platform:stream_executor_no_cuda", "//tensorflow/core/platform:types", "//tensorflow/core/profiler/lib:traceme", "//tensorflow/stream_executor:device_memory_allocator", @@ -480,7 +487,6 @@ cc_library( ":cpu_runtime", ":ir_emission_utils", ":mlir_emitter", - ":mlir_matmul_codegen_strategy", ":target_machine_features", ":tiled_dot_emitter", ":vector_support_library", @@ -502,6 +508,7 @@ cc_library( "@llvm-project//mlir:EDSC", "@llvm-project//mlir:IR", "@llvm-project//mlir:LinalgOps", + "@llvm-project//mlir:LinalgTransforms", "@llvm-project//mlir:StandardOps", ], ) @@ -1136,24 +1143,3 @@ cc_library( "@llvm-project//mlir:VectorToLLVM", ], ) - -cc_library( - name = "mlir_matmul_codegen_strategy", - srcs = ["mlir_matmul_codegen_strategy.cc"], - hdrs = ["mlir_matmul_codegen_strategy.h"], - deps = [ - "@llvm-project//llvm:Support", - "@llvm-project//mlir:Affine", - "@llvm-project//mlir:Analysis", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:LinalgOps", - "@llvm-project//mlir:LinalgTransforms", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:SCFDialect", - "@llvm-project//mlir:StandardOps", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:Transforms", - "@llvm-project//mlir:VectorOps", - "@llvm-project//mlir:VectorToSCF", - ], -) diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index d8bf15ecdeb..1ffafd37a27 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -42,7 +42,12 @@ limitations under the License. #include "llvm/Support/TargetSelect.h" #include "llvm/Target/TargetMachine.h" #include "llvm/Target/TargetOptions.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project #include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project +#include "mlir/Dialect/Linalg/IR/LinalgTypes.h" // from @llvm-project +#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/Dialect/Vector/VectorOps.h" // from @llvm-project #include "mlir/InitAllDialects.h" // from @llvm-project #include "tensorflow/compiler/xla/cpu_function_runtime.h" #include "tensorflow/compiler/xla/literal.h" @@ -98,6 +103,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/logistic_expander.h" #include "tensorflow/compiler/xla/service/map_inliner.h" +#include "tensorflow/compiler/xla/service/qr_expander.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/service/rng_bit_generator_expander.h" #include "tensorflow/compiler/xla/service/rng_expander.h" @@ -121,6 +127,21 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/dynamic_annotations.h" +namespace { + +// We need to explicitly load all the dialects we will involved in emitting the +// IR. This is only needed because of how MLIR is bolted into XLA and does not +// make use of the MLIR infrastructure (like using a proper pass pipeline). +// Hopefully this will all go away at some point in favor of a better +// integration. +void LoadMLIRDialects(mlir::MLIRContext& context) { + context.loadDialect(); +} + +} // namespace + namespace xla { namespace cpu { using BufferInfo = cpu_function_runtime::BufferInfo; @@ -164,8 +185,6 @@ CpuCompiler::CpuCompiler() { // Initialize LLVM's MC layer for the native target. llvm::InitializeNativeTarget(); llvm::InitializeNativeTargetAsmPrinter(); - - mlir::registerAllDialects(); } namespace { @@ -263,6 +282,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( pipeline.AddPass(); pipeline.AddPass(); + pipeline.AddPass(); pipeline.AddPass(); // Inline computations with a single call site. @@ -542,9 +562,11 @@ StatusOr< std::tuple, std::unique_ptr>> CpuCompiler::RunHloPassesAndBufferAssignement( std::unique_ptr module, se::StreamExecutor* executor, - se::DeviceMemoryAllocator* device_allocator) { - TF_ASSIGN_OR_RETURN( - module, RunHloPasses(std::move(module), executor, device_allocator)); + se::DeviceMemoryAllocator* device_allocator, bool optimize) { + if (optimize) { + TF_ASSIGN_OR_RETURN( + module, RunHloPasses(std::move(module), executor, device_allocator)); + } // Select an order for emitting the HLO instructions for each computation. // Using this sequence enables tighter buffer liveness analysis and reduced @@ -622,7 +644,7 @@ StatusOr> CpuCompiler::RunBackend( // Compile must be thread-safe so create a new LLVM context for the module. mlir::MLIRContext mlir_context; - mlir_context.loadAllGloballyRegisteredDialects(); + LoadMLIRDialects(mlir_context); llvm::LLVMContext llvm_context; auto llvm_module = absl::make_unique("__compute_module", llvm_context); @@ -834,7 +856,7 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr module_group, // Compile must be thread-safe so create a new LLVM context for the module. mlir::MLIRContext mlir_context; - mlir_context.loadAllGloballyRegisteredDialects(); + LoadMLIRDialects(mlir_context); llvm::LLVMContext llvm_context; llvm::Module llvm_module("__compute_module", llvm_context); llvm_module.setDataLayout(target_machine->createDataLayout()); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h index d28ccd985a3..5c056fcacaa 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h @@ -138,9 +138,10 @@ class CpuCompiler : public LLVMCompiler { StatusOr< std::tuple, std::unique_ptr>> - RunHloPassesAndBufferAssignement( - std::unique_ptr module, se::StreamExecutor* executor, - se::DeviceMemoryAllocator* device_allocator) override; + RunHloPassesAndBufferAssignement(std::unique_ptr module, + se::StreamExecutor* executor, + se::DeviceMemoryAllocator* device_allocator, + bool optimize) override; StatusOr> RunBackend( std::unique_ptr module, se::StreamExecutor* stream_exec, diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index 7431e829b8e..02bc445ce9a 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -210,8 +210,7 @@ StatusOr CpuExecutable::CreateResultShapedBuffer( absl::Span buffers, absl::Span arguments) { se::Stream* stream = run_options->stream(); - ExecutionOutput result(/*on_host_shape=*/result_shape(), - /*on_device_shape=*/result_shape(), + ExecutionOutput result(/*on_device_shape=*/result_shape(), run_options->allocator(), stream->parent()->device_ordinal()); const HloInputOutputAliasConfig& input_output_alias = diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc index 9460cc55e10..42c6c9839bf 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc @@ -95,7 +95,7 @@ bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, // Cost condition: not fuse (simple, expensive producers) and (consumers who // reuse operand elements). if (producer->opcode() != HloOpcode::kFusion && is_expensive(*producer) && - consumer->ReusesOperandElements(operand_index)) { + ReusesOperandElements(consumer, operand_index)) { VLOG(2) << "Fusion is not profitable."; return false; } @@ -132,7 +132,7 @@ bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, fusion_node_evaluations_.emplace(consumer, FusionNodeIndexingEvaluation(consumer)); } - if (fusion_node_evaluations_.at(consumer).AverageCodeDuplicationTooHigh( + if (fusion_node_evaluations_.at(consumer).CodeDuplicationTooHigh( producer)) { return false; } diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc index e21ed7ad60e..bfd8e9e111a 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc @@ -250,9 +250,9 @@ StatusOr CpuTransferManager::TransferBuffersFromOutfeedInternal( size); } - if (size <= 0) { - return InvalidArgument("Outfeed shape must have positive size; got %d", - size); + if (size < 0) { + return InvalidArgument( + "Outfeed shape must have non-negative size; got %d", size); } int32 size_32 = static_cast(size); diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index 2b3865b4dba..ba8b74a64a5 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -24,6 +24,7 @@ limitations under the License. #include "llvm/IR/Module.h" #include "llvm/IR/Value.h" #include "mlir/Dialect/Linalg/EDSC/Intrinsics.h" // from @llvm-project +#include "mlir/Dialect/Linalg/Transforms/CodegenStrategy.h" // from @llvm-project #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" // from @llvm-project #include "mlir/EDSC/Builders.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project @@ -36,7 +37,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/cpu/mlir_emitter.h" -#include "tensorflow/compiler/xla/service/cpu/mlir_matmul_codegen_strategy.h" #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" #include "tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.h" #include "tensorflow/compiler/xla/service/cpu/vector_support_library.h" @@ -304,14 +304,17 @@ Status DotOpEmitter::EmitLinalgMatmul() { } } - llvm::SmallVector types( + llvm::SmallVector iteratorTypes( parallel_exprs.size(), mlir::IteratorType::Parallel); - types.push_back(mlir::IteratorType::Reduction); + iteratorTypes.push_back(mlir::IteratorType::Reduction); mlir::edsc::StructuredIndexed s_a(a), s_b(b), s_c(c); - mlir::edsc::makeGenericLinalgOp(types, {s_b(b_exprs), s_c(c_exprs)}, - {s_a(parallel_exprs)}, - mlir::edsc::ops::macRegionBuilder); + mlir::edsc::makeGenericLinalgOp( + /*iteratorTypes=*/iteratorTypes, + /*inputs=*/{s_b(b_exprs), s_c(c_exprs)}, + /*outputBuffers=*/{s_a(parallel_exprs)}, + /*initTensors=*/{}, + /*resultTensorTypes=*/{}, mlir::edsc::ops::macRegionBuilder); mlir::edsc::intrinsics::std_ret(); mlir::linalg::LinalgTilingOptions tilingOptions; @@ -319,7 +322,7 @@ Status DotOpEmitter::EmitLinalgMatmul() { int64 alignment = target_machine_features_.minimum_alignment_for_allocation( ShapeUtil::ByteSizeOf(dot_info_.result_shape)); - mlir_strategy::MatmulCodegenStrategy strategy; + mlir::linalg::CodegenStrategy strategy; strategy.tile(tilingOptions) .promote( mlir::linalg::LinalgPromotionOptions() diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 36566d6c25f..54822323137 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -449,7 +449,7 @@ Status IrEmitter::HandleInfeed(HloInstruction* instruction) { Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape, llvm::Value* program_buffer_address) { int64 length = ByteSizeOf(shape); - if (length <= 0 || length > std::numeric_limits::max()) { + if (length < 0 || length > std::numeric_limits::max()) { return InvalidArgument( "xfeed (infeed or outfeed) buffer length %d is outside the valid " "size range", diff --git a/tensorflow/compiler/xla/service/cpu/mlir_matmul_codegen_strategy.cc b/tensorflow/compiler/xla/service/cpu/mlir_matmul_codegen_strategy.cc deleted file mode 100644 index ea89071a967..00000000000 --- a/tensorflow/compiler/xla/service/cpu/mlir_matmul_codegen_strategy.cc +++ /dev/null @@ -1,269 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/cpu/mlir_matmul_codegen_strategy.h" - -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/Support/Debug.h" -#include "mlir/Analysis/SliceAnalysis.h" // from @llvm-project -#include "mlir/Conversion/VectorToSCF/VectorToSCF.h" // from @llvm-project -#include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project -#include "mlir/Dialect/Linalg/IR/LinalgOps.h" // from @llvm-project -#include "mlir/Dialect/Linalg/Transforms/Hoisting.h" // from @llvm-project -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" // from @llvm-project -#include "mlir/Dialect/Linalg/Utils/Utils.h" // from @llvm-project -#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project -#include "mlir/Dialect/SCF/Utils.h" // from @llvm-project -#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project -#include "mlir/Dialect/Vector/EDSC/Intrinsics.h" // from @llvm-project -#include "mlir/Dialect/Vector/VectorOps.h" // from @llvm-project -#include "mlir/Dialect/Vector/VectorTransforms.h" // from @llvm-project -#include "mlir/IR/AffineExpr.h" // from @llvm-project -#include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project -#include "mlir/IR/Dominance.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/OperationSupport.h" // from @llvm-project -#include "mlir/IR/PatternMatch.h" // from @llvm-project -#include "mlir/IR/StandardTypes.h" // from @llvm-project -#include "mlir/IR/Value.h" // from @llvm-project -#include "mlir/IR/Visitors.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Pass/PassManager.h" // from @llvm-project -#include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "mlir/Transforms/LoopUtils.h" // from @llvm-project -#include "mlir/Transforms/Passes.h" // from @llvm-project - -// TODO(kramerb): Remove this once strategy is in mlir core. - -using namespace mlir; // NOLINT -using namespace mlir::linalg; // NOLINT - -#define DEBUG_TYPE "matmul-codegen-strategy" - -namespace xla { -namespace cpu { -namespace mlir_strategy { - -//===----------------------------------------------------------------------===// -// TODO: Cleanup and upstream these to go into core. Please ignore for now ! -//===----------------------------------------------------------------------===// -static void hoistRedundantCopies(FuncOp func) { - bool changed = true; - while (changed) { - changed = false; - func.walk([&](linalg::FillOp op) { - auto loop = op.getParentOfType(); - if (!loop) return; - - for (auto operand : op.getOperands()) - if (!loop.isDefinedOutsideOfLoop(operand)) return; - - // Hoist fill before. - op.getOperation()->moveBefore(loop); - changed = true; - }); - - func.walk([&](linalg::CopyOp op) { - auto loop = op.getParentOfType(); - if (!loop) return; - - for (auto operand : op.getOperands()) - if (!loop.isDefinedOutsideOfLoop(operand)) return; - - Value sourceView = op.getInput(0); - while (auto subViewOp = sourceView.getDefiningOp()) - sourceView = subViewOp.getViewSource(); - - // Source traces back to a block argument. - if (sourceView.isa()) { - op.getOperation()->moveBefore(loop); - } else { - assert(sourceView.getDefiningOp() || - sourceView.getDefiningOp() || - sourceView.getDefiningOp()); - op.getOperation()->moveAfter(loop); - } - changed = true; - }); - } -} - -/// Substitute scf.for = %lb to %ub step %step by an AffineExpr expressing: -/// `%lb + %step * new_dim` where -/// 1. the AffineExpr for %lb is either an AffineConstantExpr or an -/// AffineDimExpr depending on whether the value is constant or not. -/// 2. the AffineExpr for %step is either an AffineConstantExpr or an -/// AffineSymbolExpr depending on whether the value is constant or not. -/// -static void substitute(scf::ForOp forOp, SmallVectorImpl &exprs, - SmallVectorImpl &dims, - SmallVectorImpl &symbols) { - MLIRContext *ctx = forOp.getContext(); - auto lbConstant = forOp.lowerBound().getDefiningOp(); - AffineExpr lb = lbConstant ? getAffineConstantExpr(lbConstant.getValue(), ctx) - : getAffineDimExpr(dims.size(), ctx); - - auto stepConstant = forOp.step().getDefiningOp(); - AffineExpr step = stepConstant - ? getAffineConstantExpr(stepConstant.getValue(), ctx) - : getAffineSymbolExpr(symbols.size(), ctx); - - if (!lbConstant) dims.push_back(forOp.lowerBound()); - if (!stepConstant) symbols.push_back(forOp.step()); - exprs.push_back(lb + step * getAffineDimExpr(dims.size(), ctx)); - - auto ubConstant = forOp.upperBound().getDefiningOp(); - AffineExpr ub = ubConstant ? getAffineConstantExpr(ubConstant.getValue(), ctx) - : getAffineDimExpr(dims.size(), ctx); - if (!ubConstant) dims.push_back(forOp.upperBound()); - exprs.push_back(ub); - - dims.push_back(forOp.getInductionVar()); -} - -/// Traverse the . -static void substitute(AffineMinOp minOp, SmallVectorImpl &exprs, - SmallVectorImpl &dims, - SmallVectorImpl &symbols) { - MLIRContext *ctx = minOp.getContext(); - for (Value v : minOp.getDimOperands()) { - if (auto forOp = scf::getForInductionVarOwner(v)) { - substitute(forOp, exprs, dims, symbols); - continue; - } - if (auto parentMinOp = v.getDefiningOp()) { - substitute(parentMinOp, exprs, dims, symbols); - continue; - } - exprs.push_back(getAffineDimExpr(dims.size(), ctx)); - dims.push_back(v); - } -} - -/// Perform folding of chains of AffineMinOp. -struct AffineMinCanonicalizationPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(AffineMinOp minOp, - PatternRewriter &rewriter) const override; -}; - -LogicalResult AffineMinCanonicalizationPattern::matchAndRewrite( - AffineMinOp minOp, PatternRewriter &rewriter) const { - LLVM_DEBUG(llvm::dbgs() << "\nCanonicalize AffineMin: " - << *minOp.getOperation() << "\n"); - - int64_t min = std::numeric_limits::max(); - for (auto e : minOp.map().getResults()) - if (auto cstExpr = e.dyn_cast()) - min = std::min(min, cstExpr.getValue()); - if (min == std::numeric_limits::max()) return failure(); - - SmallVector exprs; - SmallVector dims, symbols; - substitute(minOp, exprs, dims, symbols); - - SmallVector operands = dims; - operands.append(symbols.begin(), symbols.end()); - - MLIRContext *ctx = minOp.getContext(); - auto map = AffineMap::get(dims.size(), symbols.size(), exprs, ctx); - LLVM_DEBUG(llvm::dbgs() << "Substitution map: " << map << "\n"); - - SmallVector modExprs; - for (unsigned idx = 0, e = map.getNumResults(); idx < e; ++idx) - modExprs.push_back(getAffineDimExpr(idx, ctx) % min); - map = AffineMap::get(map.getNumResults(), 0, modExprs, ctx).compose(map); - canonicalizeMapAndOperands(&map, &operands); - map = simplifyAffineMap(map); - - LLVM_DEBUG(llvm::dbgs() << "Post mod: " << map << "\n"; - llvm::interleaveComma(operands, llvm::dbgs())); - - if (!llvm::all_of(map.getResults(), [](AffineExpr e) { - if (auto cst = e.dyn_cast()) - return cst.getValue() == 0; - return false; - })) - return failure(); - - rewriter.replaceOpWithNewOp(minOp, min); - return success(); -} -//===----------------------------------------------------------------------===// -// END TODO -//===----------------------------------------------------------------------===// - -void MatmulCodegenStrategy::transform(FuncOp func) const { - MLIRContext *context = func.getContext(); - // Emplace patterns one at a time while also maintaining a simple chained - // state transition. - unsigned stepCount = 0; - SmallVector stage1Patterns; - auto zeroState = Identifier::get(std::to_string(stepCount), context); - auto currentState = zeroState; - for (auto &t : transformation_sequence) { - auto nextState = Identifier::get(std::to_string(++stepCount), context); - auto marker = (currentState == zeroState) - ? linalg::LinalgMarker({}, nextState) - : linalg::LinalgMarker(currentState, nextState); - stage1Patterns.emplace_back(t->buildRewritePatterns(context, marker)); - currentState = nextState; - } - - OwningRewritePatternList stage2Patterns = - linalg::getLinalgTilingCanonicalizationPatterns(context); - stage2Patterns.insert(context); - - auto stage3Transforms = [](Operation *op) { - // Some of these may be too aggressive as a stage 3 that is applied on each - // stage 1 application and may have to be split out to post staged patterns - // application (in which case they could just be passes, TBD). - PassManager pm(op->getContext()); - pm.addPass(createLoopInvariantCodeMotionPass()); - if (failed(pm.run(op->getParentOfType()))) - llvm_unreachable("Unexpected failure in cleanup pass pipeline."); - promoteSingleIterationLoops(cast(op)); - hoistViewAllocOps(cast(op)); - hoistRedundantVectorTransfers(cast(op)); - hoistRedundantCopies(cast(op)); - return success(); - }; - linalg::applyStagedPatterns(func, stage1Patterns, stage2Patterns, - stage3Transforms); - - //===--------------------------------------------------------------------===// - // Post staged patterns transforms - //===--------------------------------------------------------------------===// - // Programmatic controlled lowering of vector.contract only. - OwningRewritePatternList vectorContractLoweringPatterns; - vectorContractLoweringPatterns - .insert( - vector_transforms_options, context); - applyPatternsAndFoldGreedily(func, vectorContractLoweringPatterns); - - // Programmatic controlled lowering of vector.transfer only. - OwningRewritePatternList vectorToLoopsPatterns; - populateVectorToSCFConversionPatterns(vectorToLoopsPatterns, context, - vector_to_scf_options); - applyPatternsAndFoldGreedily(func, vectorToLoopsPatterns); -} - -} // namespace mlir_strategy -} // namespace cpu -} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/mlir_matmul_codegen_strategy.h b/tensorflow/compiler/xla/service/cpu/mlir_matmul_codegen_strategy.h deleted file mode 100644 index 3b11b750c47..00000000000 --- a/tensorflow/compiler/xla/service/cpu/mlir_matmul_codegen_strategy.h +++ /dev/null @@ -1,188 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef MLIR_EDGE_BENCHMARKS_STRATEGIES_MATMULCODEGENSTRATEGIES_H_ -#define MLIR_EDGE_BENCHMARKS_STRATEGIES_MATMULCODEGENSTRATEGIES_H_ - -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringSwitch.h" -#include "mlir/Conversion/VectorToSCF/VectorToSCF.h" // from @llvm-project -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" // from @llvm-project -#include "mlir/Dialect/Vector/VectorOps.h" // from @llvm-project -#include "mlir/Dialect/Vector/VectorTransforms.h" // from @llvm-project -#include "mlir/Support/LLVM.h" // from @llvm-project - -// TODO(kramerb): Remove this once strategy is in mlir core. - -namespace xla { -namespace cpu { -namespace mlir_strategy { - -/// Abstract Transformation class applied in a sequence that also handles state -/// through markers. -struct Transformation { - virtual ~Transformation() = default; - virtual mlir::OwningRewritePatternList buildRewritePatterns( - mlir::MLIRContext *context, mlir::linalg::LinalgMarker m) = 0; - mlir::linalg::LinalgMarker marker; -}; - -/// Promotion transformation enqueues a particular stage-1 pattern for -/// `Tile`with the appropriate `options`. -// TODO: variadic LinalgOpTypes. -template -struct Tile : public Transformation { - explicit Tile(mlir::linalg::LinalgTilingOptions options) : options(options) {} - - mlir::OwningRewritePatternList buildRewritePatterns( - mlir::MLIRContext *context, mlir::linalg::LinalgMarker m) override { - mlir::OwningRewritePatternList tiling_patterns; - tiling_patterns.insert>( - context, options, m); - return tiling_patterns; - } - - private: - mlir::linalg::LinalgTilingOptions options; -}; - -/// Promotion transformation enqueues a particular stage-1 pattern for -/// `Promote`with the appropriate `options`. -// TODO: variadic LinalgOpTypes. -template -struct Promote : public Transformation { - explicit Promote(mlir::linalg::LinalgPromotionOptions options) - : options(options) {} - - mlir::OwningRewritePatternList buildRewritePatterns( - mlir::MLIRContext *context, mlir::linalg::LinalgMarker m) override { - mlir::OwningRewritePatternList promotion_patterns; - promotion_patterns - .insert>(context, - options, m); - return promotion_patterns; - } - - private: - mlir::linalg::LinalgPromotionOptions options; -}; - -/// Vectorization transformation enqueues a particular stage-1 pattern for -/// `LinalgVectorizationPattern` as well as copy to vector -/// transfer rewrite forwarding patterns. -// TODO: variadic LinalgOpTypes. -template -struct Vectorize : public Transformation { - mlir::OwningRewritePatternList buildRewritePatterns( - mlir::MLIRContext *context, mlir::linalg::LinalgMarker m) override { - mlir::OwningRewritePatternList vectorization_patterns; - // FillOp may interfere with forwarding patterns atm, so we bump up the - // priority of LinalgCopyVTRForwardingPattern / - // LinalgCopyVTWForwardingPattern. - vectorization_patterns - .insert>(context, - m); - vectorization_patterns.insert( - context, - /*benefit=*/2); - return vectorization_patterns; - } -}; - -/// Matmul-specific strategy object controls how a linalg.matmul is -/// progressively lowered. -/// The strategy uses a 3-level staged patterns strategy which allows ordering -/// transformations by using the Linalg `applyStagedPatterns` function, where: -/// 1. The first stage consists of the successive `tile`, `promote` and -/// `vectorize` patterns, applied sequentially. -/// 2. The second stage consists of common local canonicalization patterns -/// that are applied eagerly after each stage-1 pattern. -/// 3. the third stage consists of more global transformation, also applied -/// eagerly, after all stage-2 patterns. Such more global transformations -struct MatmulCodegenStrategy { - /// Append a pattern to add a level of tiling for `LinalgOpType` with tiling - /// `options`. - template - MatmulCodegenStrategy &tile(mlir::linalg::LinalgTilingOptions options) { - transformation_sequence.emplace_back(new Tile(options)); - return *this; - } - /// Conditionally append a pattern to add a level of tiling for `LinalgOpType` - /// with tiling `options`. - template - MatmulCodegenStrategy &tileIf(bool b, - mlir::linalg::LinalgTilingOptions options) { - return b ? tile(options) : *this; - } - /// Append a pattern to add a level of promotion for `LinalgOpType` with - /// promotion `options`. - template - MatmulCodegenStrategy &promote(mlir::linalg::LinalgPromotionOptions options) { - transformation_sequence.emplace_back(new Promote(options)); - return *this; - } - /// Conditionally append a pattern to add a level of promotion for - /// `LinalgOpType` with promotion `options`. - template - MatmulCodegenStrategy &promoteIf( - bool b, mlir::linalg::LinalgPromotionOptions options) { - return b ? promote(options) : *this; - return *this; - } - /// Append a pattern to rewrite `LinalgOpType` as a vector operation. - template - MatmulCodegenStrategy &vectorize() { - transformation_sequence.emplace_back(new Vectorize()); - return *this; - } - /// Conditionally append a pattern to rewrite `LinalgOpType` as a vector - /// operation. - template - MatmulCodegenStrategy &vectorizeIf(bool b) { - return b ? vectorize() : *this; - return *this; - } - /// Configure the post staged-patterns late vector transformations. - MatmulCodegenStrategy &setVectorTransformsOptions( - mlir::vector::VectorTransformsOptions options) { - vector_transforms_options = options; - return *this; - } - /// Configure the post staged-patterns late vector.transfer to scf conversion. - MatmulCodegenStrategy &setVectorTransferToSCFOptions( - mlir::VectorTransferToSCFOptions options) { - vector_to_scf_options = options; - return *this; - } - - /// Apply the transformation patterns in sequence with cleanup transformations - /// interleaved. - void transform(mlir::FuncOp func) const; - - private: - mlir::LogicalResult postPatternTransforms(mlir::Operation *func) const; - - mlir::vector::VectorTransformsOptions vector_transforms_options; - mlir::VectorTransferToSCFOptions vector_to_scf_options; - llvm::SmallVector, 4> transformation_sequence; -}; - -} // namespace mlir_strategy -} // namespace cpu -} // namespace xla - -#endif // MLIR_EDGE_BENCHMARKS_STRATEGIES_MATMULCODEGENSTRATEGIES_H_ diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc index ffbd0d68ce9..23f5a5c434f 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc @@ -31,9 +31,15 @@ ParallelLoopEmitter::ParallelLoopEmitter( std::vector ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, - llvm::Type* index_type) { + llvm::Type* index_type, + llvm::Value* base_index) { CHECK_NE(index_type, nullptr); + CHECK_EQ(base_index, nullptr) + << "XLA CPU implementation of" + << " ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock doesn't support" + << " base_index, but it was requested."; + CHECK(!shape_.IsTuple()); CHECK(!ShapeUtil::IsScalar(shape_)); diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h index a604e1db222..a11fd44f1ce 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h @@ -61,7 +61,8 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter { ~ParallelLoopEmitter() override = default; std::vector EmitIndexAndSetExitBasicBlock( - absl::string_view loop_name, llvm::Type* index_type) override; + absl::string_view loop_name, llvm::Type* index_type, + llvm::Value* base_index) override; private: const DynamicLoopBounds* dynamic_loop_bounds_; diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc index 225102e6ae6..48f2248d2d7 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc @@ -143,7 +143,8 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount( // TODO(b/27458679) Parallelize instructions which are skipped here. auto opcode = instruction->opcode(); if (llvm_ir::MayBeImplementedAsInPlaceDynamicUpdateSlice(instruction) || - instruction->shape().IsTuple() || opcode == HloOpcode::kRng) { + instruction->shape().IsTuple() || opcode == HloOpcode::kRng || + opcode == HloOpcode::kConstant) { return 1; } diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc index e22210a61f2..5b454379876 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc @@ -191,5 +191,19 @@ TEST_F(ParallelTaskAssignmentTest, AllReduceNotParallelized) { EXPECT_FALSE(changed); } +TEST_F(ParallelTaskAssignmentTest, ConstantNotParallelized) { + constexpr char hlo_string[] = R"( + HloModule TestTaskParallel_constant + ENTRY const { + ROOT constant = f32[1234567] constant({...}) + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(m.get())); + EXPECT_FALSE(changed); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD index 527071d5f31..aab9556d135 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD @@ -1,6 +1,8 @@ # Description: # Tests for LLVM-based CPU backend for XLA. +load("//tensorflow:tensorflow.bzl", "filegroup") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test") package( diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc index b2ed9bd5f31..f6925ce5c80 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc @@ -56,6 +56,34 @@ CHECK: private unnamed_addr constant [48 x i8] /*match_optimized_ir=*/false); } +TEST_F(CpuOutfeedTest, OutfeedEmpty) { + const string hlo_text = R"( +HloModule Outfeed + +ENTRY main { + const_a = f32[2,0] constant({{}, {}}) + token0 = token[] after-all() + outfeed = token[] outfeed(f32[2,0] const_a, token0) + ROOT root = () tuple() +} +)"; + + string filecheck_pattern = R"( +CHECK: private unnamed_addr constant [0 x i8] +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_text)); + + CpuAotCompilationOptions options{ + /*triple=*/kTargetTripleForHost, /*cpu_name=*/kTargetCpuForHost, + /*features=*/"", + /*entry_point_name=*/"entry", + /*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static}; + + CompileAheadOfTimeAndVerifyIr(std::move(module), options, filecheck_pattern, + /*match_optimized_ir=*/false); +} + TEST_F(CpuOutfeedTest, OutfeedTokenInTuple) { const string hlo_text = R"( HloModule OutfeedTokenInTuple diff --git a/tensorflow/compiler/xla/service/cpu/xfeed_manager_test.cc b/tensorflow/compiler/xla/service/cpu/xfeed_manager_test.cc index e36eff09009..e364c0f1b42 100644 --- a/tensorflow/compiler/xla/service/cpu/xfeed_manager_test.cc +++ b/tensorflow/compiler/xla/service/cpu/xfeed_manager_test.cc @@ -128,6 +128,22 @@ TEST_F(InfeedManagerTest, MultiThreaded) { ProcessNextBuffer(length); } +TEST_F(InfeedManagerTest, OutfeedBasic) { + TestInfeedBuffer* b = new TestInfeedBuffer(32, /*expect_shape_match=*/true); + cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager(0); + xfeed->outfeed()->EnqueueBuffersAtomically({b}); + + ProcessNextOutfeedBuffer(32, ShapeUtil::MakeShape(U8, {32})); +} + +TEST_F(InfeedManagerTest, OutfeedEmpty) { + TestInfeedBuffer* b = new TestInfeedBuffer(0, /*expect_shape_match=*/true); + cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager(0); + xfeed->outfeed()->EnqueueBuffersAtomically({b}); + + ProcessNextOutfeedBuffer(0, ShapeUtil::MakeShape(U8, {0})); +} + TEST_F(InfeedManagerTest, OutfeedWrongShape) { TestInfeedBuffer* b = new TestInfeedBuffer(32, /*expect_shape_match=*/false); cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager(0); diff --git a/tensorflow/compiler/xla/service/dot_as_convolution_util.cc b/tensorflow/compiler/xla/service/dot_as_convolution_util.cc index b95636c7039..3adde5f7d48 100644 --- a/tensorflow/compiler/xla/service/dot_as_convolution_util.cc +++ b/tensorflow/compiler/xla/service/dot_as_convolution_util.cc @@ -49,14 +49,11 @@ bool ConvSpatialDimensionIsParallel(const WindowDimension& wd, int64 lhs_size) { return false; } -/* static */ absl::optional -ParseDotGeneralFromConvolution(const HloInstruction* conv) { +/* static */ DotConvolutionDimsInfo ParseConvolutionDimsInfo( + const HloInstruction* conv) { CHECK_EQ(conv->opcode(), HloOpcode::kConvolution); - if (conv->feature_group_count() != 1 || conv->batch_group_count() != 1) { - return absl::nullopt; - } const auto& conv_dims = conv->convolution_dimension_numbers(); - DotGeneralAsConvolutionDimsInfo dims; + DotConvolutionDimsInfo dims; dims.lhs_non_contracting_dims.push_back( {conv_dims.input_batch_dimension(), -1, conv_dims.output_batch_dimension(), -1}); @@ -98,10 +95,10 @@ ParseDotGeneralFromConvolution(const HloInstruction* conv) { // padding N - 1, high padding N - 1 and window reversal. dims.rhs_non_contracting_dims.push_back({lhs, rhs, output, i}); } else { - return absl::nullopt; + dims.conv_spatial_dims.push_back({lhs, rhs, output, i}); } } else { - return absl::nullopt; + dims.conv_spatial_dims.push_back({lhs, rhs, output, i}); } } @@ -110,8 +107,7 @@ ParseDotGeneralFromConvolution(const HloInstruction* conv) { StatusOr> CreateShardedConvForDotGeneralConvolution( - const HloInstruction& conv, - const DotGeneralAsConvolutionDimsInfo& dot_dnums, + const HloInstruction& conv, const DotConvolutionDimsInfo& dot_dnums, HloInstruction* sharded_lhs_hlo, HloInstruction* sharded_rhs_hlo) { CHECK_EQ(conv.opcode(), HloOpcode::kConvolution); const auto& conv_dnums = conv.convolution_dimension_numbers(); @@ -141,22 +137,23 @@ CreateShardedConvForDotGeneralConvolution( wd->set_padding_high(wd->size() - 1); wd->set_padding_low(wd->size() - 1); } - TF_ASSIGN_OR_RETURN(Shape sharded_conv_shape, - ShapeInference::InferConvolveShape( - sharded_lhs_hlo->shape(), sharded_rhs_hlo->shape(), - /*feature_group_count=*/1, - /*batch_group_count=*/1, window, conv_dnums)); + TF_ASSIGN_OR_RETURN( + Shape sharded_conv_shape, + ShapeInference::InferConvolveShape( + sharded_lhs_hlo->shape(), sharded_rhs_hlo->shape(), + /*feature_group_count=*/conv.feature_group_count(), + /*batch_group_count=*/conv.batch_group_count(), window, conv_dnums)); *sharded_conv_shape.mutable_layout() = conv.shape().layout(); return HloInstruction::CreateConvolve( sharded_conv_shape, sharded_lhs_hlo, sharded_rhs_hlo, - /*feature_group_count=*/1, - /*batch_group_count=*/1, window, conv_dnums, conv.precision_config()); + /*feature_group_count=*/conv.feature_group_count(), + /*batch_group_count=*/conv.batch_group_count(), window, conv_dnums, + conv.precision_config()); } -DotGeneralAsConvolutionDimsInfo ParseDotGeneralFromDot( - const HloInstruction* dot) { +DotConvolutionDimsInfo ParseDotGeneralFromDot(const HloInstruction* dot) { const auto& dot_dim_numbs = dot->dot_dimension_numbers(); - dot_as_convolution_util::DotGeneralAsConvolutionDimsInfo dnums; + dot_as_convolution_util::DotConvolutionDimsInfo dnums; for (int64 i = 0; i < dot_dim_numbs.lhs_batch_dimensions().size(); ++i) { dnums.batch_dims.emplace_back(); dnums.batch_dims.back().lhs = dot_dim_numbs.lhs_batch_dimensions(i); diff --git a/tensorflow/compiler/xla/service/dot_as_convolution_util.h b/tensorflow/compiler/xla/service/dot_as_convolution_util.h index 81914b193a3..16a542208d2 100644 --- a/tensorflow/compiler/xla/service/dot_as_convolution_util.h +++ b/tensorflow/compiler/xla/service/dot_as_convolution_util.h @@ -25,8 +25,9 @@ limitations under the License. namespace xla { namespace dot_as_convolution_util { -// Describes the dimensions of a convolution that can be interpreted as a dot. -struct DotGeneralAsConvolutionDimsInfo { +// Describes the dimensions of a convolution that can be interpreted as a dot +// or a normal convolution. +struct DotConvolutionDimsInfo { // The dimension numbers for the operands and output corresponding to a // logical dimension (e.g., batch, contracting, non-contracting). If an // operand or the output doesn't have the logical dimension, it is set to @@ -43,23 +44,22 @@ struct DotGeneralAsConvolutionDimsInfo { std::vector contracting_dims; std::vector lhs_non_contracting_dims; std::vector rhs_non_contracting_dims; + std::vector conv_spatial_dims; }; -// Parses a convolution and returns a DotGeneralAsConvolutionDimsInfo if it can -// be interpreted as a dot, or absl::nullopt otherwise. -absl::optional ParseDotGeneralFromConvolution( - const HloInstruction* conv); +// Parses a convolution and returns a DotGeneralAsConvolutionDimsInfo. If it can +// be interpreted as a dot, there is no conv_spatial_dims. +DotConvolutionDimsInfo ParseConvolutionDimsInfo(const HloInstruction* conv); // Creates sharded convolution instruction that can be interpreted as a dot. // This is a utility for per-op partitioners. // - 'conv' is the original convolution instruction. -// - 'dot_dnums' is the result of ParseDotGeneralFromConvolution() for 'conv'. +// - 'dot_dnums' is the result of ParseDotConvolutionDimsInfo() for 'conv'. // - 'sharded_lhs_hlo' and 'sharded_rhs_hlo' are sharded inputs for the result // convolution instruction. StatusOr> CreateShardedConvForDotGeneralConvolution( - const HloInstruction& conv, - const DotGeneralAsConvolutionDimsInfo& dot_dnums, + const HloInstruction& conv, const DotConvolutionDimsInfo& dot_dnums, HloInstruction* sharded_lhs_hlo, HloInstruction* sharded_rhs_hlo); // Check if a spatial dim is parallel batch dimension. @@ -68,10 +68,9 @@ CreateShardedConvForDotGeneralConvolution( // dilation B. bool ConvSpatialDimensionIsParallel(const WindowDimension& wd, int64 lhs_size); -// Returns a DotGeneralAsConvolutionDimsInfo from a kDot instruction, where all +// Returns a DotConvolutionDimsInfo from a kDot instruction, where all // the spatial_dim values are set to -1. -DotGeneralAsConvolutionDimsInfo ParseDotGeneralFromDot( - const HloInstruction* dot); +DotConvolutionDimsInfo ParseDotGeneralFromDot(const HloInstruction* dot); } // namespace dot_as_convolution_util } // namespace xla diff --git a/tensorflow/compiler/xla/service/dynamic_padder.cc b/tensorflow/compiler/xla/service/dynamic_padder.cc index 9b4d24bbbe9..e728cd75caf 100644 --- a/tensorflow/compiler/xla/service/dynamic_padder.cc +++ b/tensorflow/compiler/xla/service/dynamic_padder.cc @@ -39,12 +39,17 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/monitoring/gauge.h" #include "tensorflow/core/platform/errors.h" namespace xla { namespace { +auto* dynamic_padding_gauge = tensorflow::monitoring::Gauge::New( + "/tensorflow/core/use_dynamic_padding_gauge", + "Tracks if dynamic padder is used."); + // ChooseIdentityValue looks at the instruction's operand, returns a // identity value which, when padded, doesn't change the result of the // instruction. @@ -179,6 +184,22 @@ StatusOr ReplaceSetSize(HloInstruction* instr) { return true; } +StatusOr ReplaceSetBound(HloInstruction* instr) { + if (instr->opcode() != HloOpcode::kCustomCall || + instr->custom_call_target() != "SetBound") { + return false; + } + + TF_RET_CHECK(Shape::Equal().IgnoreDynamicDimension()( + instr->shape(), instr->operand(0)->shape())) + << "instr->shape() " << instr->shape().ToString() << " , " + << "instruction operand shape " << instr->operand(0)->shape(); + HloInstruction* operand = instr->mutable_operand(0); + + TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(operand)); + return true; +} + bool ShouldSkipPadOnOperand(const HloInstruction* inst, int64 operand_num, int64 dimension) { if ((inst->opcode() == HloOpcode::kReduceWindow || @@ -1335,6 +1356,7 @@ StatusOr DynamicPadder::Run(HloModule* module) { operand, input_dim, operand_dynamic_size, identity_value); TF_RETURN_IF_ERROR(inst->ReplaceOperandWith(operand_num, padded)); operand = inst->mutable_operand(operand_num); + dynamic_padding_gauge->GetCell()->Set(true); changed = true; } } @@ -1370,7 +1392,10 @@ StatusOr DynamicPadder::Run(HloModule* module) { for (auto* computation : module->computations()) { for (auto instruction : computation->MakeInstructionPostOrder()) { TF_ASSIGN_OR_RETURN(bool replaced_set_size, ReplaceSetSize(instruction)); + TF_ASSIGN_OR_RETURN(bool replaced_set_bound, + ReplaceSetBound(instruction)); changed = changed || replaced_set_size; + changed = changed || replaced_set_bound; } } @@ -1378,6 +1403,7 @@ StatusOr DynamicPadder::Run(HloModule* module) { TF_ASSIGN_OR_RETURN(changed, dce.Run(module)); VLOG(2) << "Post DynamicPadder HLO:"; XLA_VLOG_LINES(2, module->ToString()); + dynamic_padding_gauge->GetCell()->Set(changed); return changed; } diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc index d5cf2ee9ac0..e9a3c6b3018 100644 --- a/tensorflow/compiler/xla/service/executable.cc +++ b/tensorflow/compiler/xla/service/executable.cc @@ -59,11 +59,11 @@ void ExecutionInput::SetUnownedBuffer(const ShapeIndex& index, unowned_indices_.insert(index); } -xla::StatusOr ExecutionInput::ToShapedBuffer( +StatusOr ExecutionInput::ToShapedBuffer( se::DeviceMemoryAllocator* allocator, int device_ordinal) const { const Shape& input_shape = shape(); - xla::ShapedBuffer shaped_buffer(input_shape, input_shape, - allocator->platform(), device_ordinal); + ShapedBuffer shaped_buffer(input_shape, allocator->platform(), + device_ordinal); for (const auto& index_buffer : Buffers()) { const tensorflow::se::OwningDeviceMemory* mem = index_buffer.second.AsOwningDeviceMemory(); @@ -93,8 +93,7 @@ StatusOr Executable::ExecuteOnStream( static ExecutionInput MakeMaybeOwningDeviceMemoryTree( const ShapedBuffer& shaped_buffer) { - ExecutionInput result(shaped_buffer.on_device_shape(), - shaped_buffer.on_host_shape()); + ExecutionInput result(shaped_buffer.on_device_shape()); shaped_buffer.buffers().ForEachElement( [&](const ShapeIndex& index, const se::DeviceMemoryBase& mem) { result.SetBuffer(index, MaybeOwningDeviceMemory(mem)); diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h index 2e3ddedfb8c..1e1b3436a3c 100644 --- a/tensorflow/compiler/xla/service/executable.h +++ b/tensorflow/compiler/xla/service/executable.h @@ -60,15 +60,24 @@ namespace xla { // with their indices absent from unowned_indices_. class ExecutionInput { public: - explicit ExecutionInput(xla::Shape shape, xla::Shape host_shape) + explicit ExecutionInput(xla::Shape shape) : buffers_(std::move(shape)) { + SetHostShape(ShapeUtil::DeviceShapeToHostShape(buffers_.shape())); + } + // TODO(b/170310047): remove this overload. + ExecutionInput(xla::Shape shape, xla::Shape host_shape) : buffers_(std::move(shape)) { - SetHostShape(std::move(host_shape)); + SetHostShape(ShapeUtil::DeviceShapeToHostShape(buffers_.shape())); } - explicit ExecutionInput(ShapeTree buffers, - xla::Shape host_shape) + explicit ExecutionInput(ShapeTree buffers) : buffers_(std::move(buffers)) { - SetHostShape(std::move(host_shape)); + SetHostShape(ShapeUtil::DeviceShapeToHostShape(buffers_.shape())); + } + // TODO(b/170310047): remove this overload. + ExecutionInput(ShapeTree buffers, + xla::Shape host_shape) + : buffers_(std::move(buffers)) { + SetHostShape(ShapeUtil::DeviceShapeToHostShape(buffers_.shape())); } ExecutionInput(ExecutionInput&&) = default; @@ -144,10 +153,13 @@ class ExecutionOutput { std::vector to_be_released) : result_(std::move(result)), to_be_released_(std::move(to_be_released)) {} + // TODO(b/170310047): remove this overload. ExecutionOutput(Shape on_host_shape, Shape on_device_shape, se::DeviceMemoryAllocator* allocator, int device_ordinal) - : result_(std::move(on_host_shape), std::move(on_device_shape), allocator, - device_ordinal) {} + : result_(std::move(on_device_shape), allocator, device_ordinal) {} + ExecutionOutput(Shape on_device_shape, se::DeviceMemoryAllocator* allocator, + int device_ordinal) + : result_(std::move(on_device_shape), allocator, device_ordinal) {} ExecutionOutput(ExecutionOutput&&) = default; ExecutionOutput& operator=(ExecutionOutput&&) = default; diff --git a/tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.cc b/tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.cc index 75d39298aa3..17d3fb2b3d6 100644 --- a/tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.cc +++ b/tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.cc @@ -25,35 +25,36 @@ limitations under the License. namespace xla { FusionNodeIndexingEvaluation::FusionNodeIndexingEvaluation( - const HloInstruction* fusion) + const HloInstruction* fusion, int64 root_usage_count) : fusion_(fusion) { - total_emitted_instructions_ = 0; HloInstruction* root = fusion->fused_expression_root(); indexing_users_[root].insert(fusion); - index_usage_count_[fusion] = 1; + index_usage_count_[fusion] = root_usage_count; RecomputeCache(); } -bool FusionNodeIndexingEvaluation::AverageCodeDuplicationTooHigh( - const HloInstruction* producer) const { - // This constant is arbitrarily chosen. Essentially we don't want to have too - // much code duplication, because it slows down the compilation time. There is - // a tradeoff between compilation time and runtime here. - const int64 kAllowedCodeDuplication = 15; +// This constant is arbitrarily chosen. Essentially we don't want to have too +// much code duplication, because it slows down the compilation time. There is +// a tradeoff between compilation time and runtime here. +const int64 FusionNodeIndexingEvaluation::kAllowedCodeDuplication = 15; - // index_usage_count_ contains an entry for each instruction in the fusion - // computation (except parameter instructions), plus an entry for the 'fusion' - // instruction. So the size of this map is already one bigger than the number - // of instructions in the fusion node that are emitted, thus accounting for - // the number of instructions after 'producer' is fused. - return EvaluateTotalEmittedInstructions(producer) / - index_usage_count_.size() > - kAllowedCodeDuplication; +bool FusionNodeIndexingEvaluation::CodeDuplicationTooHigh( + const HloInstruction* producer) const { + return EvaluateEmittedInstructions(producer) > kAllowedCodeDuplication; } -int64 FusionNodeIndexingEvaluation::EvaluateTotalEmittedInstructions( +bool FusionNodeIndexingEvaluation::MaxCodeDuplicationTooHigh() const { + for (const auto& entry : index_usage_count_) { + if (entry.second > kAllowedCodeDuplication) { + return true; + } + } + return false; +} + +int64 FusionNodeIndexingEvaluation::EvaluateEmittedInstructions( const HloInstruction* producer) const { - int64 total = total_emitted_instructions_; + int64 total = 0; for (const auto* user : indexing_users_.at(producer)) { total += index_usage_count_.at(user); } @@ -96,19 +97,9 @@ void FusionNodeIndexingEvaluation::UpdateIndexUsageCount( const HloInstruction* instruction) { int64 total = 0; for (const auto* user : indexing_users_[instruction]) { - int64 weight = 1; - // Concatenate is special: the index differs for each operand, so - // in the worst case we have to deal with as many index values as - // the number of operands of Concatenate. By considering the worst - // case, we are more conservative than necessary regarding - // counting the index usage. - if (user->opcode() == HloOpcode::kConcatenate) { - weight = user->operand_count(); - } - total += index_usage_count_.at(user) * weight; + total += index_usage_count_.at(user); } CHECK(index_usage_count_.emplace(instruction, total).second); - total_emitted_instructions_ += total; } void FusionNodeIndexingEvaluation::UpdateIndexingUsersOfOperands( diff --git a/tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.h b/tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.h index 9630986d188..abe154a5149 100644 --- a/tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.h +++ b/tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.h @@ -24,19 +24,22 @@ limitations under the License. namespace xla { class FusionNodeIndexingEvaluation { public: - explicit FusionNodeIndexingEvaluation(const HloInstruction* fusion); + explicit FusionNodeIndexingEvaluation(const HloInstruction* fusion, + int64 root_usage_count = 1); - // Evaluate the average number of times an instruction is emitted inside the - // fusion node, if 'producer' is fused into 'fusion_'. If this average - // duplication is "too high" (some arbitrary chosen constant), returns - // true. - bool AverageCodeDuplicationTooHigh(const HloInstruction* producer) const; + // Evaluate the number of times 'producer' would be emitted if it is fused + // into 'fusion_'. If the duplication is "too high" (some arbitrary chosen + // constant), returns true. + bool CodeDuplicationTooHigh(const HloInstruction* producer) const; - // Evaluate the total number of times an instruction is emitted inside the - // fusion node, if 'producer' is fused into 'fusion_'. An instruction may be - // emitted several times, once for each different index value with which it is - // indexed. - int64 EvaluateTotalEmittedInstructions(const HloInstruction* producer) const; + // Evaluate the maximum code duplication inside the fusion node. If the + // maximum code duplication is "too high" (some arbitrary chosen constant), + // returns true. + bool MaxCodeDuplicationTooHigh() const; + + // Evaluate the number of times 'producer' would be emitted if it is fused + // into 'fusion_'. + int64 EvaluateEmittedInstructions(const HloInstruction* producer) const; // Update the evaluation cache after having fused 'producer' into 'fusion_'. // 'producer' is the cloned instruction which is now part of the fusion @@ -56,6 +59,8 @@ class FusionNodeIndexingEvaluation { HloInstruction* fusion_operand); private: + static const int64 kAllowedCodeDuplication; + // Computes the 'indexing_users_' and 'index_usage_count_' maps based on the // current instructions inside the fusion node. Also updates // 'total_emitted_instructions_' accordingly. @@ -84,9 +89,6 @@ class FusionNodeIndexingEvaluation { // The fusion instruction. const HloInstruction* fusion_; - - // The total number of emitted instructions. - int64 total_emitted_instructions_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/fusion_node_indexing_evaluation_test.cc b/tensorflow/compiler/xla/service/fusion_node_indexing_evaluation_test.cc index b20f52d2d62..b00abdc9abf 100644 --- a/tensorflow/compiler/xla/service/fusion_node_indexing_evaluation_test.cc +++ b/tensorflow/compiler/xla/service/fusion_node_indexing_evaluation_test.cc @@ -29,7 +29,7 @@ using FusionNodeIndexingEvaluationTest = HloTestBase; // Subclass of InstructionFusion exposing the protected methods Fuse and // FuseInstruction for testing. Also adds the FusionNodeIndexingEvaluation to -// track the average code duplication due to indexing HloInstructions with +// track the code duplication due to indexing HloInstructions with // different index values. class InstructionFusionForTesting : public InstructionFusion { public: @@ -61,8 +61,8 @@ class InstructionFusionForTesting : public InstructionFusion { return InstructionFusion::Fuse(producer, consumer); } - int64 EvaluateTotalEmittedInstructions(const HloInstruction* producer, - const HloInstruction* consumer) { + int64 EvaluateEmittedInstructions(const HloInstruction* producer, + const HloInstruction* consumer) { if (consumer->opcode() != HloOpcode::kFusion) { return 0; } @@ -71,8 +71,8 @@ class InstructionFusionForTesting : public InstructionFusion { fusion_node_evaluations_.emplace(consumer, FusionNodeIndexingEvaluation(consumer)); } - return fusion_node_evaluations_.at(consumer) - .EvaluateTotalEmittedInstructions(producer); + return fusion_node_evaluations_.at(consumer).EvaluateEmittedInstructions( + producer); } private: @@ -109,8 +109,7 @@ TEST_F(FusionNodeIndexingEvaluationTest, FuseThreeInstructions) { HloInstruction* slice1 = sub->mutable_operand(0); HloInstruction* slice2 = sub->mutable_operand(1); auto fusion = instruction_fusion.Fuse(slice1, sub); - EXPECT_EQ(instruction_fusion.EvaluateTotalEmittedInstructions(slice2, fusion), - 3); + EXPECT_EQ(instruction_fusion.EvaluateEmittedInstructions(slice2, fusion), 1); instruction_fusion.Fuse(slice2, fusion); } @@ -151,37 +150,31 @@ TEST_F(FusionNodeIndexingEvaluationTest, ExponentialDuplicationPattern) { HloInstruction* slice2_1 = add2->mutable_operand(1); auto fusion = instruction_fusion.Fuse(slice2_0, add2); // So far we have fused add2 and slice2.0. So when we also fuse slice2.1, we - // expect to emit 3 instructions. - EXPECT_EQ( - instruction_fusion.EvaluateTotalEmittedInstructions(slice2_1, fusion), 3); + // expect to emit it 1 time. + EXPECT_EQ(instruction_fusion.EvaluateEmittedInstructions(slice2_1, fusion), + 1); instruction_fusion.Fuse(slice2_1, fusion); HloInstruction* add1 = fusion->mutable_operand(0); EXPECT_EQ(add1->opcode(), HloOpcode::kAdd); - // If we fuse add1 into 'fusion', it needs to be emitted twice, adding 2 to - // the sum. - EXPECT_EQ(instruction_fusion.EvaluateTotalEmittedInstructions(add1, fusion), - 5); + // If we fuse add1 into 'fusion', it needs to be emitted twice. + EXPECT_EQ(instruction_fusion.EvaluateEmittedInstructions(add1, fusion), 2); instruction_fusion.Fuse(add1, fusion); HloInstruction* slice1_0 = fusion->mutable_operand(0); EXPECT_EQ(slice1_0->opcode(), HloOpcode::kSlice); - // If we fuse slice1.0 into 'fusion', it needs to be emitted twice, adding 2 - // to the sum. - EXPECT_EQ( - instruction_fusion.EvaluateTotalEmittedInstructions(slice1_0, fusion), 7); + // If we fuse slice1.0 into 'fusion', it needs to be emitted twice. + EXPECT_EQ(instruction_fusion.EvaluateEmittedInstructions(slice1_0, fusion), + 2); instruction_fusion.Fuse(slice1_0, fusion); HloInstruction* slice1_1 = fusion->mutable_operand(0); EXPECT_EQ(slice1_1->opcode(), HloOpcode::kSlice); - // If we fuse slice1.1 into 'fusion', it needs to be emitted twice, adding 2 - // to the sum. - EXPECT_EQ( - instruction_fusion.EvaluateTotalEmittedInstructions(slice1_1, fusion), 9); + // If we fuse slice1.1 into 'fusion', it needs to be emitted twice. + EXPECT_EQ(instruction_fusion.EvaluateEmittedInstructions(slice1_1, fusion), + 2); instruction_fusion.Fuse(slice1_1, fusion); HloInstruction* add0 = fusion->mutable_operand(0); EXPECT_EQ(add0->opcode(), HloOpcode::kAdd); - // If we fuse add0 into 'fusion', it needs to be emitted twice, adding 4 to - // the sum. - EXPECT_EQ(instruction_fusion.EvaluateTotalEmittedInstructions(add0, fusion), - 13); + // If we fuse add0 into 'fusion', it needs to be emitted four times. + EXPECT_EQ(instruction_fusion.EvaluateEmittedInstructions(add0, fusion), 4); instruction_fusion.Fuse(add0, fusion); } @@ -212,10 +205,9 @@ ENTRY entry_computation { HloInstruction* add0 = fusion->mutable_operand(0); EXPECT_EQ(add0->opcode(), HloOpcode::kAdd); // Here, the cache for the fusion node needs to be recomputed. Make sure we - // still get the same evaluation as before when we incrementally built the + // still get the same evaluation as before when we incrementally build the // cache. - EXPECT_EQ(instruction_fusion.EvaluateTotalEmittedInstructions(add0, fusion), - 13); + EXPECT_EQ(instruction_fusion.EvaluateEmittedInstructions(add0, fusion), 4); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc index c09757fe1af..d2febb5fb73 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc @@ -69,13 +69,8 @@ void GenericTransferManager::TransferLiteralFromDevice( TF_RET_CHECK(stream->parent()->device_ordinal() == device_buffer.device_ordinal()); - // The on-host and on-device shape should always be the same for the generic - // transfer manager. - TF_RET_CHECK(ShapeUtil::Equal(device_buffer.on_device_shape(), - device_buffer.on_host_shape())); - TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( - device_buffer.on_host_shape(), + device_buffer.on_device_shape(), [&](const Shape& subshape, const ShapeIndex& index) -> Status { if (subshape.IsArray()) { stream->ThenMemcpy( @@ -103,20 +98,15 @@ Status GenericTransferManager::TransferLiteralToDeviceAsync( << ShapeUtil::HumanString(shape) << "; device buffer: " << device_buffer; - // The on-host and on-device shape should always be the same for the generic - // transfer manager. - TF_RET_CHECK(ShapeUtil::Equal(device_buffer.on_device_shape(), - device_buffer.on_host_shape())); - TF_RET_CHECK( - ShapeUtil::Compatible(literal.shape(), device_buffer.on_host_shape())); + ShapeUtil::Compatible(literal.shape(), device_buffer.on_device_shape())); TF_RET_CHECK(stream->parent()->device_ordinal() == device_buffer.device_ordinal()); TF_RETURN_IF_ERROR(WriteTupleIndexTablesAsync(stream, device_buffer)); return ShapeUtil::ForEachSubshapeWithStatus( - device_buffer.on_host_shape(), + device_buffer.on_device_shape(), [&](const Shape& device_subshape, const ShapeIndex& index) -> Status { se::DeviceMemoryBase device_memory = device_buffer.buffer(index); if (device_subshape.IsArray()) { diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index d1d0827981e..9463454ae0b 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -1,9 +1,10 @@ # Description: # GPU-specific components in XLA service implementation. +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load( "//tensorflow/core/platform:build_config.bzl", - "tf_proto_library_cc", + "tf_proto_library", ) load( "//tensorflow/core/platform:build_config_root.bzl", @@ -26,6 +27,14 @@ load( "//tensorflow/core/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured", ) + +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "filegroup") + +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "get_compatible_with_cloud") + +# buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "if_nccl") load("//third_party/mlir:tblgen.bzl", "gentbl") @@ -50,7 +59,7 @@ filegroup( ]), ) -tf_proto_library_cc( +tf_proto_library( name = "backend_configs", srcs = ["backend_configs.proto"], cc_api_version = 2, @@ -207,7 +216,9 @@ cc_library( deps = [ ":backend_configs_cc", ":buffer_allocations", + ":cudnn_batchnorm_runner", ":gpu_constants", + ":gpu_conv_runner", ":gpu_executable", ":ir_emission_utils", ":nccl_all_reduce_thunk", @@ -254,6 +265,8 @@ cc_library( ":target_util", ":thunk", ":thunk_emitter", + "//tensorflow/compiler/mlir:name_utils", + "//tensorflow/compiler/mlir/hlo", "//tensorflow/compiler/mlir/hlo:lhlo", "//tensorflow/compiler/mlir/xla:hlo_utils", "//tensorflow/compiler/mlir/xla:mhlo_to_lhlo_with_xla", @@ -264,6 +277,7 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:union_find", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto_cc", @@ -311,6 +325,7 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/service/llvm_ir:ir_array", + "//tensorflow/compiler/xla/service/llvm_ir:kernel_support_library", "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter", @@ -362,7 +377,7 @@ cc_library( "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "//tensorflow/stream_executor:device_memory_allocator", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", @@ -381,7 +396,7 @@ cc_library( "//tensorflow/compiler/xla/service:stream_pool", "//tensorflow/core:lib", "//tensorflow/core:ptr_util", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "@com_google_absl//absl/memory", ], ) @@ -397,7 +412,7 @@ cc_library( "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", ], ) @@ -447,7 +462,8 @@ tf_cuda_library( "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", + "//tensorflow/compiler/xla:xla_data_proto_cc", ] + if_cuda([ "//tensorflow/stream_executor/cuda:cuda_activation", "//tensorflow/stream_executor/cuda:cuda_gpu_executor", @@ -586,7 +602,7 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "//tensorflow/core/profiler/lib:traceme", "//tensorflow/core/profiler/lib:scoped_annotation", "//tensorflow/stream_executor", @@ -632,7 +648,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "//tensorflow/stream_executor:device_description", "@com_google_absl//absl/algorithm:container", "@llvm-project//llvm:Core", @@ -676,7 +692,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/core/protobuf:autotuning_proto_cc", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "//tensorflow/core/util/proto:proto_utils", "//tensorflow/stream_executor:blas", "//tensorflow/stream_executor:device_memory", @@ -713,7 +729,7 @@ cc_library( "//tensorflow/core/protobuf:autotuning_proto_cc", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "//tensorflow/core/util/proto:proto_utils", "//tensorflow/stream_executor:device_memory_allocator", ] + if_cuda_is_configured([ @@ -738,7 +754,8 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", + "//tensorflow/stream_executor:dnn", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", ], @@ -760,7 +777,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", ], @@ -815,7 +832,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "//tensorflow/stream_executor:blas", ] + if_cuda_is_configured([ "@local_config_cuda//cuda:cuda_headers", @@ -838,7 +855,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "//tensorflow/stream_executor:blas", "//tensorflow/stream_executor:device_memory_allocator", ]), @@ -896,6 +913,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/compiler/xla/service:hlo_reachability", + "//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -1033,6 +1051,24 @@ cc_library( ], ) +tf_cc_test( + name = "gpu_conv_padding_legalization_test", + srcs = ["gpu_conv_padding_legalization_test.cc"], + tags = tf_cuda_tests_tags(), + deps = [ + ":gpu_conv_padding_legalization", + ":ir_emission_utils", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep + "//tensorflow/core:test", + ], +) + cc_library( name = "cudnn_pad_for_convolutions", srcs = ["cudnn_pad_for_convolutions.cc"], @@ -1122,7 +1158,7 @@ cc_library( "//tensorflow/compiler/xla/service:generic_transfer_manager", "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "@com_google_absl//absl/memory", "@llvm-project//llvm:Core", ], @@ -1151,7 +1187,8 @@ cc_library( ":gpu_layout_assignment", ":gpu_sanitize_constant_names", ":gpu_scatter_expander", - ":horizontal_fusion", + ":horizontal_input_fusion", + ":horizontal_loop_fusion", ":instruction_fusion", ":ir_emission_utils", ":ir_emitter", @@ -1201,6 +1238,8 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/service:llvm_compiler", "//tensorflow/compiler/xla/service:logistic_expander", + "//tensorflow/compiler/xla/service:qr_expander", + "//tensorflow/compiler/xla/service:reshape_mover", "//tensorflow/compiler/xla/service:rng_bit_generator_expander", "//tensorflow/compiler/xla/service:rng_expander", "//tensorflow/compiler/xla/service:slice_sinker", @@ -1217,8 +1256,8 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "//tensorflow/core:regexp_internal", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:regexp", + "//tensorflow/core/platform:stream_executor_no_cuda", "//tensorflow/core/profiler/lib:traceme", "//tensorflow/stream_executor:stream_executor_headers", "@com_google_absl//absl/memory", @@ -1281,7 +1320,7 @@ cc_library( "//tensorflow/compiler/xla/service:tuple_simplifier", "//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", - "//tensorflow/core:cuda_libdevice_path", + "//tensorflow/core/platform:cuda_libdevice_path", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core/profiler/lib:traceme", @@ -1367,7 +1406,7 @@ cc_library( ":xfeed_queue", "//tensorflow/compiler/xla:shape_tree", "//tensorflow/compiler/xla:types", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", ], @@ -1405,7 +1444,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:layout_assignment", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", ], ) @@ -1492,6 +1531,8 @@ cc_library( hdrs = ["stream_executor_util.h"], copts = tf_copts(), deps = [ + ":ir_emission_utils", + ":launch_dimensions", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", @@ -1499,11 +1540,11 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_module_config", - "//tensorflow/core:cuda_libdevice_path", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "//tensorflow/core:regexp_internal", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:cuda_libdevice_path", + "//tensorflow/core/platform:regexp", + "//tensorflow/core/platform:stream_executor_no_cuda", "//tensorflow/core/profiler/lib:traceme", "//tensorflow/stream_executor:kernel_spec", "//tensorflow/stream_executor/gpu:gpu_asm_opts", @@ -1526,7 +1567,7 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:hlo_module_config", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "//tensorflow/stream_executor:stream_executor_headers", "//tensorflow/stream_executor/gpu:asm_compiler", ]), @@ -1585,7 +1626,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/compiler/xla/service:pattern_matcher", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", ], ) @@ -1661,7 +1702,7 @@ tf_cc_test( ], ) -tf_proto_library_cc( +tf_proto_library( name = "gpu_autotuning_proto", srcs = ["gpu_autotuning.proto"], cc_api_version = 2, @@ -1679,7 +1720,7 @@ cc_library( deps = [ ":gpu_autotuning_proto_cc", "//tensorflow/compiler/xla:debug_options_flags", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "//tensorflow/core/protobuf:autotuning_proto_cc", "@com_google_absl//absl/container:flat_hash_map", ], @@ -1726,10 +1767,11 @@ tf_cc_test( ) cc_library( - name = "horizontal_fusion", - srcs = ["horizontal_fusion.cc"], - hdrs = ["horizontal_fusion.h"], + name = "horizontal_loop_fusion", + srcs = ["horizontal_loop_fusion.cc"], + hdrs = ["horizontal_loop_fusion.h"], deps = [ + ":gpu_fusible", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_creation_utils", @@ -1742,11 +1784,11 @@ cc_library( ) tf_cc_test( - name = "horizontal_fusion_test", - srcs = ["horizontal_fusion_test.cc"], + name = "horizontal_loop_fusion_test", + srcs = ["horizontal_loop_fusion_test.cc"], deps = [ ":fusion_merger", - ":horizontal_fusion", + ":horizontal_loop_fusion", ":instruction_fusion", ":multi_output_fusion", "//tensorflow/compiler/jit:xla_gpu_jit", @@ -1766,6 +1808,45 @@ tf_cc_test( ], ) +cc_library( + name = "horizontal_input_fusion", + srcs = ["horizontal_input_fusion.cc"], + hdrs = ["horizontal_input_fusion.h"], + deps = [ + ":gpu_fusible", + ":ir_emission_utils", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_creation_utils", + "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + +tf_cc_test( + name = "horizontal_input_fusion_test", + srcs = ["horizontal_input_fusion_test.cc"], + tags = tf_cuda_tests_tags(), + deps = [ + ":horizontal_input_fusion", + ":multi_output_fusion", + "//tensorflow/compiler/jit:xla_gpu_jit", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/service:hlo_pass_pipeline", + "//tensorflow/compiler/xla/service/gpu/tests:gpu_codegen_test", + "//tensorflow/compiler/xla/tests:filecheck", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + cc_library( name = "reduction_degenerate_dim_remover", srcs = ["reduction_degenerate_dim_remover.cc"], @@ -1891,6 +1972,7 @@ cc_library( gentbl( name = "xla_thunks_ops_inc_gen", + compatible_with = get_compatible_with_cloud(), tbl_outs = [ ("-gen-op-decls", "ir/xla_thunks_ops.h.inc"), ("-gen-op-defs", "ir/xla_thunks_ops.cc.inc"), @@ -1921,16 +2003,3 @@ cc_library( "@llvm-project//mlir:LLVMDialect", ], ) - -# Library with XLA thunks dialect static initialization. -cc_library( - name = "xla_thunks_dialect_registration", - srcs = [ - "ir/dialect_registration.cc", - ], - deps = [ - ":xla_thunks_ops", - "@llvm-project//mlir:IR", - ], - alwayslink = 1, -) diff --git a/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc index 9b192aaa8e1..21b4ef40d97 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc +++ b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc @@ -610,13 +610,21 @@ static StatusOr DeviceCompare(se::Stream* stream, executor->GetDeviceDescription().threads_per_block_limit(); gpu_device_info.threads_per_warp = executor->GetDeviceDescription().threads_per_warp(); + gpu_device_info.shared_memory_per_block = + executor->GetDeviceDescription().shared_memory_per_block(); + gpu_device_info.threads_per_core_limit = + executor->GetDeviceDescription().threads_per_core_limit(); + gpu_device_info.core_count = executor->GetDeviceDescription().core_count(); LaunchDimensions dim = CalculateLaunchDimensions(buffer_shape, gpu_device_info); - stream->ThenLaunch(se::ThreadDim(dim.threads_per_block()), - se::BlockDim(dim.block_count()), *comparison_kernel, - lhs_typed, rhs_typed, static_cast(kTolerance), - buffer_size, out_param.cref()); + LaunchDimensions::Dim3D thread_counts = dim.thread_counts_per_block(); + LaunchDimensions::Dim3D block_counts = dim.block_counts(); + stream->ThenLaunch( + se::ThreadDim(thread_counts.x, thread_counts.y, thread_counts.z), + se::BlockDim(block_counts.x, block_counts.y, block_counts.z), + *comparison_kernel, lhs_typed, rhs_typed, static_cast(kTolerance), + buffer_size, out_param.cref()); uint64 result = -1; CHECK_EQ(out_param->size(), sizeof(result)); diff --git a/tensorflow/compiler/xla/service/gpu/cholesky_thunk.cc b/tensorflow/compiler/xla/service/gpu/cholesky_thunk.cc index c34c299fea8..4ac5784e51a 100644 --- a/tensorflow/compiler/xla/service/gpu/cholesky_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/cholesky_thunk.cc @@ -45,10 +45,7 @@ CholeskyThunk::CholeskyThunk(ThunkInfo thunk_info, info_buffer_(info_buffer), type_(type), batch_size_(batch_size), - a_batch_stride_( - n * n * - ShapeUtil::ByteSizeOfPrimitiveType( - thunk_info.hlo_instruction->operand(0)->shape().element_type())), + a_batch_stride_(n * n * ShapeUtil::ByteSizeOfPrimitiveType(type)), n_(n) {} Status CholeskyThunk::ExecuteOnStream(const ExecuteParams& params) { diff --git a/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.cc b/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.cc index b3b5cf7e048..88982d3c034 100644 --- a/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/refcounting_hash_map.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/core/blocking_counter.h" @@ -217,16 +218,23 @@ RefcountingHashMap& GlobalRendezvousMap() { } // anonymous namespace +CollectivePermuteConfig GetCollectivePermuteConfig( + const HloInstruction* instr) { + CollectivePermuteConfig config; + auto* collective_permute = Cast(instr); + config.source_target_pairs = collective_permute->source_target_pairs(); + return config; +} + CollectivePermuteThunk::CollectivePermuteThunk( - ThunkInfo thunk_info, const BufferAllocation::Slice& src, - const BufferAllocation::Slice& dest) + ThunkInfo thunk_info, CollectivePermuteConfig&& config, + const BufferAllocation::Slice& src, const BufferAllocation::Slice& dest) : Thunk(kCollectivePermute, thunk_info), - hlo_instruction_(thunk_info.hlo_instruction), + config_(std::move(config)), src_(src), dest_(dest) {} Status CollectivePermuteThunk::ExecuteOnStream(const ExecuteParams& params) { - auto* instr = Cast(hlo_instruction_); auto op_profiler = params.profiler->MakeScopedInstructionProfiler(profile_index()); @@ -245,7 +253,7 @@ Status CollectivePermuteThunk::ExecuteOnStream(const ExecuteParams& params) { // Figure out which replicas our data is copied to. std::vector dest_replicas; - for (const auto& src_dest : instr->source_target_pairs()) { + for (const auto& src_dest : config_.source_target_pairs) { if (src_dest.first == replica_id) { dest_replicas.push_back(src_dest.second); } @@ -260,7 +268,7 @@ Status CollectivePermuteThunk::ExecuteOnStream(const ExecuteParams& params) { // If no replica writes into us (i.e. we aren't the target of any copies), our // contract is that we zero our output. - if (absl::c_none_of(instr->source_target_pairs(), + if (absl::c_none_of(config_.source_target_pairs, [&](std::pair src_dest) { return src_dest.second == replica_id; })) { diff --git a/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.h b/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.h index 44cc6a1c64e..bef86eec9af 100644 --- a/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.h @@ -19,23 +19,30 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" namespace xla { namespace gpu { +struct CollectivePermuteConfig { + std::vector> source_target_pairs; +}; + +CollectivePermuteConfig GetCollectivePermuteConfig(const HloInstruction* instr); + // Thunk that implements the collective-permute HLO. class CollectivePermuteThunk : public Thunk { public: - CollectivePermuteThunk(ThunkInfo thunk_info, + CollectivePermuteThunk(ThunkInfo thunk_info, CollectivePermuteConfig&& config, const BufferAllocation::Slice& src, const BufferAllocation::Slice& dest); Status ExecuteOnStream(const ExecuteParams& params) override; private: - const HloInstruction* hlo_instruction_; - BufferAllocation::Slice src_; - BufferAllocation::Slice dest_; + const CollectivePermuteConfig config_; + const BufferAllocation::Slice src_; + const BufferAllocation::Slice dest_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc index 4cff48a89da..6560c1a819c 100644 --- a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc @@ -17,43 +17,51 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" namespace xla { namespace gpu { -ConditionalThunk::ConditionalThunk( - ThunkInfo thunk_info, - const BufferAllocation::Slice& branch_index_buffer_index, - absl::Span branch_operand_buffer_indexes, - std::vector branch_thunk_sequences) - : Thunk(Kind::kConditional, thunk_info), - hlo_instruction_(thunk_info.hlo_instruction), - branch_index_is_bool_( - thunk_info.hlo_instruction->operand(0)->shape().element_type() == - PRED), - branch_index_buffer_index_(branch_index_buffer_index), - branch_operand_buffer_indexes_(branch_operand_buffer_indexes.begin(), - branch_operand_buffer_indexes.end()) { - // Pass nullptr as the HloInstruction* to the branch_thunks_ +ConditionalThunkConfig GetConditionalThunkConfig( + const HloInstruction* instr, + std::vector&& branch_thunk_sequences, + std::vector>&& branch_profile_indices) { + ConditionalThunkConfig config; + config.branch_index_is_bool = + instr->operand(0)->shape().element_type() == PRED; + config.branch_count = instr->branch_count(); + // Pass nullptr as the HloInstruction* to the branch_thunks // constructors because these SequentialThunks are logically "part of" // this ConditionalThunk, and shouldn't be profiled separately from it. - branch_thunks_.reserve(branch_thunk_sequences.size()); + config.branch_thunks.reserve(branch_thunk_sequences.size()); for (auto& branch_thunk_sequence : branch_thunk_sequences) { - branch_thunks_.emplace_back( - new SequentialThunk(ThunkInfo(), std::move(branch_thunk_sequence))); + config.branch_thunks.emplace_back(new SequentialThunk( + Thunk::ThunkInfo(), std::move(branch_thunk_sequence))); } + config.branch_profile_indices = std::move(branch_profile_indices); + return config; } +ConditionalThunk::ConditionalThunk( + ThunkInfo thunk_info, ConditionalThunkConfig&& config, + const BufferAllocation::Slice& branch_index_buffer_index, + absl::Span branch_operand_buffer_indexes) + : Thunk(Kind::kConditional, thunk_info), + config_(std::move(config)), + branch_index_buffer_index_(branch_index_buffer_index), + branch_operand_buffer_indexes_(branch_operand_buffer_indexes.begin(), + branch_operand_buffer_indexes.end()) {} + Status ConditionalThunk::Initialize(const GpuExecutable& executable, se::StreamExecutor* executor) { - if (branch_index_is_bool_) { - TF_RET_CHECK(branch_thunks_.size() == 2); + if (config_.branch_index_is_bool) { + TF_RET_CHECK(config_.branch_thunks.size() == 2); } else { - TF_RET_CHECK(!branch_thunks_.empty()); + TF_RET_CHECK(!config_.branch_thunks.empty()); } - for (auto& branch_thunk : branch_thunks_) { + for (auto& branch_thunk : config_.branch_thunks) { TF_RETURN_IF_ERROR(branch_thunk->Initialize(executable, executor)); } return Status::OK(); @@ -69,7 +77,7 @@ Status ConditionalThunk::ExecuteOnStream(const ExecuteParams& params) { bool pred = false; se::DeviceMemoryBase branch_index_address = params.buffer_allocations->GetDeviceAddress(branch_index_buffer_index_); - if (branch_index_is_bool_) { + if (config_.branch_index_is_bool) { stream.ThenMemcpy(&pred, branch_index_address, sizeof(bool)); } else { stream.ThenMemcpy(&branch_index, branch_index_address, sizeof(int32)); @@ -81,20 +89,20 @@ Status ConditionalThunk::ExecuteOnStream(const ExecuteParams& params) { "Failed to retrieve branch_index value on stream %p: %s.", &stream, block_status.error_message()); } - if (branch_index_is_bool_) { + if (config_.branch_index_is_bool) { branch_index = pred ? 0 : 1; } else { // Handle default scenario for branch_index not in [0, num_branches). - if (branch_index < 0 || branch_index >= hlo_instruction_->branch_count()) { - branch_index = hlo_instruction_->branch_count() - 1; + if (branch_index < 0 || branch_index >= config_.branch_count) { + branch_index = config_.branch_count - 1; } } // Execute the branch computation corresponding to the value of branch_index. profiler.StartHloComputation(); - TF_RETURN_IF_ERROR(branch_thunks_[branch_index]->ExecuteOnStream(params)); - profiler.FinishHloComputation( - hlo_instruction_->branch_computation(branch_index)); + TF_RETURN_IF_ERROR( + config_.branch_thunks[branch_index]->ExecuteOnStream(params)); + profiler.FinishHloComputation(config_.branch_profile_indices[branch_index]); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/conditional_thunk.h b/tensorflow/compiler/xla/service/gpu/conditional_thunk.h index f91f1c52146..bf4280cdb12 100644 --- a/tensorflow/compiler/xla/service/gpu/conditional_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/conditional_thunk.h @@ -30,6 +30,18 @@ limitations under the License. namespace xla { namespace gpu { +struct ConditionalThunkConfig { + bool branch_index_is_bool; + int64 branch_count; + std::vector> branch_thunks; + std::vector> branch_profile_indices; +}; + +ConditionalThunkConfig GetConditionalThunkConfig( + const HloInstruction* instr, + std::vector&& branch_thunk_sequences, + std::vector>&& branch_profile_indices); + // ConditionalThunk implements the conditional instruction on GPU by reading the // predicate of the conditional and executing the true or the false computation // depending on the value of the predicate. @@ -43,10 +55,9 @@ namespace gpu { class ConditionalThunk : public Thunk { public: ConditionalThunk( - ThunkInfo thunk_info, + ThunkInfo thunk_info, ConditionalThunkConfig&& config, const BufferAllocation::Slice& branch_index_buffer_index, - absl::Span branch_operand_buffer_indexes, - std::vector branch_thunk_sequences); + absl::Span branch_operand_buffer_indexes); ConditionalThunk(const ConditionalThunk&) = delete; ConditionalThunk& operator=(const ConditionalThunk&) = delete; @@ -56,11 +67,9 @@ class ConditionalThunk : public Thunk { Status ExecuteOnStream(const ExecuteParams& params) override; private: - const HloInstruction* hlo_instruction_; - const bool branch_index_is_bool_; + const ConditionalThunkConfig config_; BufferAllocation::Slice branch_index_buffer_index_; std::vector branch_operand_buffer_indexes_; - std::vector> branch_thunks_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc index 3048db95c39..efa3a5802d6 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc @@ -31,15 +31,16 @@ namespace xla { namespace gpu { ConvolutionThunk::ConvolutionThunk( - ThunkInfo thunk_info, std::vector operand_slices, + ThunkInfo thunk_info, GpuConvConfig&& config, + std::vector operand_slices, BufferAllocation::Slice result_slice, BufferAllocation::Slice scratch_slice, BufferAllocation::Slice tuple_result_slice) : Thunk(Kind::kConvolution, thunk_info), - cudnn_call_(Cast(thunk_info.hlo_instruction)), operand_buffers_(std::move(operand_slices)), result_buffer_(result_slice), scratch_buffer_(scratch_slice), - tuple_result_buffer_(tuple_result_slice) {} + tuple_result_buffer_(tuple_result_slice), + config_(std::move(config)) {} Status ConvolutionThunk::ExecuteOnStream(const ExecuteParams& params) { const auto& buffer_allocations = *params.buffer_allocations; @@ -57,7 +58,7 @@ Status ConvolutionThunk::ExecuteOnStream(const ExecuteParams& params) { auto op_profiler = params.profiler->MakeScopedInstructionProfiler(profile_index()); - TF_RETURN_IF_ERROR(RunGpuConv(cudnn_call_, absl::MakeSpan(operand_se_buffers), + TF_RETURN_IF_ERROR(RunGpuConv(config_, absl::MakeSpan(operand_se_buffers), result_buffer, scratch, params.stream)); // Write the output tuple. diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h index 03fae88c6dc..7f8377ebe4c 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h @@ -43,7 +43,7 @@ class ConvolutionThunk : public Thunk { // write a tuple (result, scratch_memory) into `tuple_result_buffer`. // // operand_slices should be in the same order as cudnn_call->operands(). - ConvolutionThunk(ThunkInfo thunk_info, + ConvolutionThunk(ThunkInfo thunk_info, GpuConvConfig&& config, std::vector operand_slices, BufferAllocation::Slice result_slice, BufferAllocation::Slice scratch_slice, @@ -55,11 +55,13 @@ class ConvolutionThunk : public Thunk { Status ExecuteOnStream(const ExecuteParams& params) override; private: - const HloCustomCallInstruction* cudnn_call_; std::vector operand_buffers_; BufferAllocation::Slice result_buffer_; BufferAllocation::Slice scratch_buffer_; BufferAllocation::Slice tuple_result_buffer_; + + // Convolution config + const GpuConvConfig config_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_runner.cc index adf6b68096d..6b01151b48a 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_runner.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_runner.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { namespace gpu { @@ -109,26 +110,23 @@ DnnBatchDescriptors MakeBatchNormDescriptors(const Shape& shape, return batch_descs; } -void AssignCommonParams(const HloInstruction* batchnorm, +void AssignCommonParams(const CudnnBatchNormConfig& config, CudnnBatchNormParamsCommon* params, const se::DeviceMemoryBase& operand, - const se::DeviceMemory& scale, float epsilon, - int64 feature_index) { + const se::DeviceMemory& scale) { // The BatchNormTraining HLO outputs a tuple of three elements: output data, // batch mean, and batch variance. We want to make our descriptors based on // the shape of the output data. Batchnorm backward call outputs a tuple of // three elements: grad data, grad offset, and grad scale. We want to make // our descriptors based on the shape of the grad data. - const Shape& shape = batchnorm->shape().IsTuple() - ? batchnorm->shape().tuple_shapes(0) - : batchnorm->shape(); + const Shape& shape = config.output_shape; DnnBatchDescriptors batch_descs = - MakeBatchNormDescriptors(shape, feature_index); + MakeBatchNormDescriptors(shape, config.feature_index); params->operand_desc = batch_descs.input_desc; params->scale_offset_desc = batch_descs.scale_offset_desc; params->operand = operand; params->scale = scale; - params->epsilon = epsilon; + params->epsilon = config.epsilon; } template @@ -211,22 +209,33 @@ void RunCudnnBatchNormBackwardImpl(CudnnBatchNormBackwardParams* params, } // namespace +CudnnBatchNormConfig GetCudnnBatchNormConfig(const HloInstruction* instr, + float epsilon, + int64 feature_index) { + CudnnBatchNormConfig config; + + config.output_shape = instr->shape().IsTuple() + ? instr->shape().tuple_shapes(0) + : instr->shape(); + config.output_type = config.output_shape.element_type(); + config.epsilon = epsilon; + config.feature_index = feature_index; + return config; +} + Status RunCudnnBatchNormForwardInference( - const HloInstruction* batchnorm, se::DeviceMemoryBase operand, + const CudnnBatchNormConfig& config, se::DeviceMemoryBase operand, se::DeviceMemoryBase output, se::DeviceMemory scale, se::DeviceMemory offset, se::DeviceMemory mean, - se::DeviceMemory variance, float epsilon, int64 feature_index, - se::Stream* stream) { + se::DeviceMemory variance, se::Stream* stream) { CudnnBatchNormForwardInferenceParams inference_params; - AssignCommonParams(batchnorm, &inference_params.common, operand, scale, - epsilon, feature_index); + AssignCommonParams(config, &inference_params.common, operand, scale); inference_params.offset = offset; inference_params.mean = mean; inference_params.variance = variance; inference_params.output = output; - PrimitiveType output_primitive_type = batchnorm->shape().element_type(); - switch (output_primitive_type) { + switch (config.output_type) { case F16: RunCudnnBatchNormForwardInferenceImpl(&inference_params, stream); @@ -235,29 +244,27 @@ Status RunCudnnBatchNormForwardInference( RunCudnnBatchNormForwardInferenceImpl(&inference_params, stream); break; default: - return Unimplemented("Primitive type not implemented for \"%s\" ", - batchnorm->ToString()); + return Unimplemented( + "Primitive type %s not implemented for batchnorm forward inference", + primitive_util::LowercasePrimitiveTypeName(config.output_type) + .c_str()); } return Status::OK(); } Status RunCudnnBatchNormForwardTraining( - const HloInstruction* batchnorm, se::DeviceMemoryBase operand, + const CudnnBatchNormConfig& config, se::DeviceMemoryBase operand, se::DeviceMemoryBase output_data, se::DeviceMemory output_mean, se::DeviceMemory output_inv_stddev, se::DeviceMemory scale, - se::DeviceMemory offset, float epsilon, int64 feature_index, - se::Stream* stream) { + se::DeviceMemory offset, se::Stream* stream) { CudnnBatchNormForwardTrainingParams forward_params; - AssignCommonParams(batchnorm, &forward_params.common, operand, scale, epsilon, - feature_index); + AssignCommonParams(config, &forward_params.common, operand, scale); forward_params.offset = offset; forward_params.output_data = output_data; forward_params.output_mean = output_mean; forward_params.output_inv_stddev = output_inv_stddev; - PrimitiveType output_primitive_type = - batchnorm->shape().tuple_shapes(0).element_type(); - switch (output_primitive_type) { + switch (config.output_type) { case F16: RunCudnnBatchNormForwardTrainingImpl(&forward_params, stream); @@ -266,22 +273,23 @@ Status RunCudnnBatchNormForwardTraining( RunCudnnBatchNormForwardTrainingImpl(&forward_params, stream); break; default: - return Unimplemented("Primitive type not implemented for \"%s\" ", - batchnorm->ToString()); + return Unimplemented( + "Primitive type %s not implemented for batchnorm forward training", + primitive_util::LowercasePrimitiveTypeName(config.output_type) + .c_str()); } return Status::OK(); } Status RunCudnnBatchNormBackward( - const HloInstruction* batchnorm, se::DeviceMemoryBase operand, + const CudnnBatchNormConfig& config, se::DeviceMemoryBase operand, se::DeviceMemoryBase output_grad_data, se::DeviceMemoryBase grad_output, se::DeviceMemory output_grad_scale, se::DeviceMemory output_grad_offset, se::DeviceMemory scale, se::DeviceMemory mean, se::DeviceMemory inv_stddev, - float epsilon, int64 feature_index, se::Stream* stream) { + se::Stream* stream) { CudnnBatchNormBackwardParams backward_params; - AssignCommonParams(batchnorm, &backward_params.common, operand, scale, - epsilon, feature_index); + AssignCommonParams(config, &backward_params.common, operand, scale); backward_params.output_grad_data = output_grad_data; backward_params.grad_output = grad_output; backward_params.output_grad_scale = output_grad_scale; @@ -289,9 +297,7 @@ Status RunCudnnBatchNormBackward( backward_params.mean = mean; backward_params.inv_stddev = inv_stddev; - PrimitiveType output_primitive_type = - batchnorm->shape().tuple_shapes(0).element_type(); - switch (output_primitive_type) { + switch (config.output_type) { case F16: RunCudnnBatchNormBackwardImpl(&backward_params, stream); break; @@ -299,8 +305,10 @@ Status RunCudnnBatchNormBackward( RunCudnnBatchNormBackwardImpl(&backward_params, stream); break; default: - return Unimplemented("Primitive type not implemented for \"%s\" ", - batchnorm->ToString()); + return Unimplemented( + "Primitive type %s not implemented for batchnorm backward", + primitive_util::LowercasePrimitiveTypeName(config.output_type) + .c_str()); } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_runner.h b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_runner.h index 9a630d013f7..b0791b01868 100755 --- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_runner.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_runner.h @@ -28,27 +28,36 @@ limitations under the License. namespace xla { namespace gpu { +struct CudnnBatchNormConfig { + Shape output_shape; + PrimitiveType output_type; + float epsilon; + int64 feature_index; +}; + +CudnnBatchNormConfig GetCudnnBatchNormConfig(const HloInstruction *instr, + float epsilon, + int64 feature_index); + Status RunCudnnBatchNormForwardInference( - const HloInstruction* batchnorm, se::DeviceMemoryBase operand, + const CudnnBatchNormConfig &config, se::DeviceMemoryBase operand, se::DeviceMemoryBase output, se::DeviceMemory scale, se::DeviceMemory offset, se::DeviceMemory mean, - se::DeviceMemory variance, float epsilon, int64 feature_index, - se::Stream* stream); + se::DeviceMemory variance, se::Stream *stream); Status RunCudnnBatchNormForwardTraining( - const HloInstruction* batchnorm, se::DeviceMemoryBase operand, + const CudnnBatchNormConfig &config, se::DeviceMemoryBase operand, se::DeviceMemoryBase output_data, se::DeviceMemory output_mean, se::DeviceMemory output_inv_stddev, se::DeviceMemory scale, - se::DeviceMemory offset, float epsilon, int64 feature_index, - se::Stream* stream); + se::DeviceMemory offset, se::Stream *stream); Status RunCudnnBatchNormBackward( - const HloInstruction* batchnorm, se::DeviceMemoryBase operand, + const CudnnBatchNormConfig &config, se::DeviceMemoryBase operand, se::DeviceMemoryBase output_grad_data, se::DeviceMemoryBase grad_output, se::DeviceMemory output_grad_scale, se::DeviceMemory output_grad_offset, se::DeviceMemory scale, se::DeviceMemory mean, se::DeviceMemory inv_stddev, - float epsilon, int64 feature_index, se::Stream* stream); + se::Stream *stream); } // namespace gpu } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_BATCHNORM_RUNNER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc index e91b2c4d0d2..dae490e0d18 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc @@ -31,90 +31,21 @@ namespace gpu { namespace dnn = se::dnn; -namespace { -void CheckInputOutputPrimitivetypeAreValid(const HloInstruction* hlo) { - // All input and output statistics variables must be F32. Also, the last - // operand for CudnnBatchNormForwardInference, CudnnBatchNormForwardTraining, - // and CudnnBatchNormBackward is the feature_index which must be S64. - // The allowed types for non-statistics variables are as follows: - // CudnnBatchNormForwardInference: - // operand[0]: {half, float} - // out[0]: {half, float} - // CudnnBatchNormForwardTraining: - // operand[0]: {half, float} - // out[0]: {half, float} - // CudnnBatchNormBackward: - // operand[0]: {half, float} - // operand[4]: {half, float} - // out[0]: {half, float} - // Note non-statistics inputs and outputs mentioned above should be of the - // same type. - - // Check Inputs. - int64 num_operands = hlo->operand_count(); - PrimitiveType operand_primitive_type = - hlo->operand(0)->shape().element_type(); - CHECK(operand_primitive_type == F16 || operand_primitive_type == F32) - << "Not yet implemented"; - - for (int i = 1; i < num_operands - 2; i++) { - if (hlo->custom_call_target() == kCudnnBatchNormBackwardCallTarget && - i == 4) { - // The first operand to batchnorm grad is the input and the 4th operand is - // the grad_output, both of which can be Eigen::half. - CHECK_EQ(hlo->operand(i)->shape().element_type(), operand_primitive_type) - << "Invalid datatype"; - continue; - } - CHECK_EQ(hlo->operand(i)->shape().element_type(), F32) - << "Not yet implemented"; - } - - // The last operand is the feature index which must be int64. - CHECK_EQ(hlo->operand(num_operands - 1)->shape().element_type(), S64) - << "Not yet implemented"; - - // Check Outputs. - if (hlo->shape().IsTuple()) { - CHECK_EQ(hlo->shape().tuple_shapes(0).element_type(), - operand_primitive_type) - << "Invalid datatype"; - - for (int j = 1; j < hlo->shape().tuple_shapes_size(); j++) { - CHECK_EQ(hlo->shape().tuple_shapes(j).element_type(), F32) - << "Not yet implemented"; - } - } else { - CHECK_EQ(hlo->shape().element_type(), operand_primitive_type) - << "Invalid datatype"; - } -} -} // namespace - CudnnBatchNormForwardInferenceThunk::CudnnBatchNormForwardInferenceThunk( - ThunkInfo thunk_info, const BufferAllocation::Slice& operand, + ThunkInfo thunk_info, CudnnBatchNormConfig&& config, + const BufferAllocation::Slice& operand, const BufferAllocation::Slice& scale, const BufferAllocation::Slice& offset, const BufferAllocation::Slice& mean, - const BufferAllocation::Slice& variance, float epsilon, int64 feature_index, + const BufferAllocation::Slice& variance, const BufferAllocation::Slice& output) : Thunk(Thunk::Kind::kCudnnBatchNormForwardInference, thunk_info), - hlo_instruction_(thunk_info.hlo_instruction), + config_(std::move(config)), operand_(operand), scale_(scale), offset_(offset), mean_(mean), variance_(variance), - epsilon_(epsilon), - feature_index_(feature_index), - output_(output) { - const auto* hlo = hlo_instruction_; - CHECK_EQ(hlo->opcode(), HloOpcode::kCustomCall); - CHECK_EQ(hlo->custom_call_target(), - kCudnnBatchNormForwardInferenceCallTarget); - CHECK( - LayoutUtil::LayoutsInShapesEqual(hlo->shape(), hlo->operand(0)->shape())); - CheckInputOutputPrimitivetypeAreValid(hlo); -} + output_(output) {} Status CudnnBatchNormForwardInferenceThunk::ExecuteOnStream( const ExecuteParams& params) { @@ -131,8 +62,7 @@ Status CudnnBatchNormForwardInferenceThunk::ExecuteOnStream( buffer_allocations.GetDeviceAddress(variance_)); auto& stream = *params.stream; TF_RETURN_IF_ERROR(RunCudnnBatchNormForwardInference( - hlo_instruction_, operand, output_base, scale, offset, mean, variance, - epsilon_, feature_index_, &stream)); + config_, operand, output_base, scale, offset, mean, variance, &stream)); if (!stream.ok()) { return InternalError("BatchNormalizationForward call failed."); @@ -141,32 +71,22 @@ Status CudnnBatchNormForwardInferenceThunk::ExecuteOnStream( } CudnnBatchNormForwardTrainingThunk::CudnnBatchNormForwardTrainingThunk( - ThunkInfo thunk_info, const BufferAllocation::Slice& operand, + ThunkInfo thunk_info, CudnnBatchNormConfig&& config, + const BufferAllocation::Slice& operand, const BufferAllocation::Slice& scale, const BufferAllocation::Slice& offset, - float epsilon, int64 feature_index, const BufferAllocation::Slice& output_data, const BufferAllocation::Slice& output_mean, const BufferAllocation::Slice& output_inv_stddev, const BufferAllocation::Slice& output_tuple) : Thunk(Thunk::Kind::kCudnnBatchNormForwardTraining, thunk_info), - hlo_instruction_(thunk_info.hlo_instruction), + config_(std::move(config)), operand_(operand), scale_(scale), offset_(offset), - epsilon_(epsilon), - feature_index_(feature_index), output_data_(output_data), output_mean_(output_mean), output_inv_stddev_(output_inv_stddev), - output_tuple_(output_tuple) { - const auto* hlo = hlo_instruction_; - CHECK_EQ(hlo->opcode(), HloOpcode::kCustomCall); - CHECK_EQ(hlo->custom_call_target(), kCudnnBatchNormForwardTrainingCallTarget); - CHECK_EQ(hlo->shape().tuple_shapes_size(), 3); - CHECK(LayoutUtil::LayoutsInShapesEqual(hlo->shape().tuple_shapes(0), - hlo->operand(0)->shape())); - CheckInputOutputPrimitivetypeAreValid(hlo); -} + output_tuple_(output_tuple) {} Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream( const ExecuteParams& params) { @@ -185,10 +105,10 @@ Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream( params.profiler->MakeScopedInstructionProfiler(profile_index()); auto& stream = *params.stream; TF_RETURN_IF_ERROR(RunCudnnBatchNormForwardTraining( - hlo_instruction_, operand, output_data, output_mean, output_inv_stddev, + config_, operand, output_data, output_mean, output_inv_stddev, se::DeviceMemory(buffer_allocations.GetDeviceAddress(scale_)), se::DeviceMemory(buffer_allocations.GetDeviceAddress(offset_)), - epsilon_, feature_index_, &stream)); + &stream)); // Write the output tuple. const int kNumOutputs = 3; @@ -207,37 +127,26 @@ Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream( } CudnnBatchNormBackwardThunk::CudnnBatchNormBackwardThunk( - ThunkInfo thunk_info, const BufferAllocation::Slice& operand, + ThunkInfo thunk_info, CudnnBatchNormConfig&& config, + const BufferAllocation::Slice& operand, const BufferAllocation::Slice& scale, const BufferAllocation::Slice& mean, const BufferAllocation::Slice& inv_stddev, - const BufferAllocation::Slice& grad_output, float epsilon, - int64 feature_index, const BufferAllocation::Slice& output_grad_data, + const BufferAllocation::Slice& grad_output, + const BufferAllocation::Slice& output_grad_data, const BufferAllocation::Slice& output_grad_scale, const BufferAllocation::Slice& output_grad_offset, const BufferAllocation::Slice& output_tuple) : Thunk(Thunk::Kind::kCudnnBatchNormBackward, thunk_info), - hlo_instruction_(thunk_info.hlo_instruction), + config_(std::move(config)), operand_(operand), scale_(scale), mean_(mean), inv_stddev_(inv_stddev), grad_output_(grad_output), - epsilon_(epsilon), - feature_index_(feature_index), output_grad_data_(output_grad_data), output_grad_scale_(output_grad_scale), output_grad_offset_(output_grad_offset), - output_tuple_(output_tuple) { - const auto* hlo = hlo_instruction_; - CHECK_EQ(hlo->opcode(), HloOpcode::kCustomCall); - CHECK_EQ(hlo->custom_call_target(), kCudnnBatchNormBackwardCallTarget); - CHECK_EQ(hlo->shape().tuple_shapes_size(), 3); - CHECK(LayoutUtil::LayoutsInShapesEqual(hlo->shape().tuple_shapes(0), - hlo->operand(0)->shape())); - CHECK(LayoutUtil::LayoutsInShapesEqual(hlo->shape().tuple_shapes(0), - hlo->operand(4)->shape())); - CheckInputOutputPrimitivetypeAreValid(hlo); -} + output_tuple_(output_tuple) {} Status CudnnBatchNormBackwardThunk::ExecuteOnStream( const ExecuteParams& params) { @@ -256,12 +165,12 @@ Status CudnnBatchNormBackwardThunk::ExecuteOnStream( params.profiler->MakeScopedInstructionProfiler(profile_index()); se::Stream* stream = params.stream; TF_RETURN_IF_ERROR(RunCudnnBatchNormBackward( - hlo_instruction_, operand, output_grad_data, grad_output, - output_grad_scale, output_grad_offset, + config_, operand, output_grad_data, grad_output, output_grad_scale, + output_grad_offset, se::DeviceMemory(buffer_allocations.GetDeviceAddress(scale_)), se::DeviceMemory(buffer_allocations.GetDeviceAddress(mean_)), se::DeviceMemory(buffer_allocations.GetDeviceAddress(inv_stddev_)), - epsilon_, feature_index_, stream)); + stream)); // Write the output tuple. const int kNumOutputs = 3; diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h index bb46017b8fb..d45e284ea2c 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_runner.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -47,12 +48,12 @@ namespace gpu { class CudnnBatchNormForwardInferenceThunk : public Thunk { public: CudnnBatchNormForwardInferenceThunk(ThunkInfo thunk_info, + CudnnBatchNormConfig&& config, const BufferAllocation::Slice& operand, const BufferAllocation::Slice& scale, const BufferAllocation::Slice& offset, const BufferAllocation::Slice& mean, const BufferAllocation::Slice& variance, - float epsilon, int64 feature_index, const BufferAllocation::Slice& output); CudnnBatchNormForwardInferenceThunk( @@ -63,23 +64,22 @@ class CudnnBatchNormForwardInferenceThunk : public Thunk { Status ExecuteOnStream(const ExecuteParams& params) override; private: - const HloInstruction* hlo_instruction_; + CudnnBatchNormConfig config_; BufferAllocation::Slice operand_; BufferAllocation::Slice scale_; BufferAllocation::Slice offset_; BufferAllocation::Slice mean_; BufferAllocation::Slice variance_; - float epsilon_; - int64 feature_index_; BufferAllocation::Slice output_; }; class CudnnBatchNormForwardTrainingThunk : public Thunk { public: CudnnBatchNormForwardTrainingThunk( - ThunkInfo thunk_info, const BufferAllocation::Slice& operand, + ThunkInfo thunk_info, CudnnBatchNormConfig&& config, + const BufferAllocation::Slice& operand, const BufferAllocation::Slice& scale, - const BufferAllocation::Slice& offset, float epsilon, int64 feature_index, + const BufferAllocation::Slice& offset, const BufferAllocation::Slice& output_data, const BufferAllocation::Slice& output_mean, const BufferAllocation::Slice& output_inv_stddev, @@ -93,12 +93,10 @@ class CudnnBatchNormForwardTrainingThunk : public Thunk { Status ExecuteOnStream(const ExecuteParams& params) override; private: - const HloInstruction* hlo_instruction_; + CudnnBatchNormConfig config_; BufferAllocation::Slice operand_; BufferAllocation::Slice scale_; BufferAllocation::Slice offset_; - float epsilon_; - int64 feature_index_; BufferAllocation::Slice output_data_; BufferAllocation::Slice output_mean_; BufferAllocation::Slice output_inv_stddev_; @@ -108,12 +106,12 @@ class CudnnBatchNormForwardTrainingThunk : public Thunk { class CudnnBatchNormBackwardThunk : public Thunk { public: CudnnBatchNormBackwardThunk(ThunkInfo thunk_info, + CudnnBatchNormConfig&& config, const BufferAllocation::Slice& operand, const BufferAllocation::Slice& scale, const BufferAllocation::Slice& mean, const BufferAllocation::Slice& inv_stddev, const BufferAllocation::Slice& grad_output, - float epsilon, int64 feature_index, const BufferAllocation::Slice& output_grad_data, const BufferAllocation::Slice& output_grad_scale, const BufferAllocation::Slice& output_grad_offset, @@ -126,14 +124,12 @@ class CudnnBatchNormBackwardThunk : public Thunk { Status ExecuteOnStream(const ExecuteParams& params) override; private: - const HloInstruction* hlo_instruction_; + const CudnnBatchNormConfig config_; BufferAllocation::Slice operand_; BufferAllocation::Slice scale_; BufferAllocation::Slice mean_; BufferAllocation::Slice inv_stddev_; BufferAllocation::Slice grad_output_; - float epsilon_; - int64 feature_index_; BufferAllocation::Slice output_grad_data_; BufferAllocation::Slice output_grad_scale_; BufferAllocation::Slice output_grad_offset_; diff --git a/tensorflow/compiler/xla/service/gpu/custom_call_thunk.cc b/tensorflow/compiler/xla/service/gpu/custom_call_thunk.cc index dae15659402..c9b2318af79 100644 --- a/tensorflow/compiler/xla/service/gpu/custom_call_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/custom_call_thunk.cc @@ -24,29 +24,12 @@ namespace gpu { CustomCallThunk::CustomCallThunk( ThunkInfo thunk_info, void* call_target, std::vector> operand_slices, - ShapeTree result_slices, std::string opaque) + ShapeTree result_slices, const std::string& opaque) : Thunk(Thunk::kCustomCall, thunk_info), - hlo_instruction_(thunk_info.hlo_instruction), call_target_(call_target), operand_slices_(std::move(operand_slices)), result_slices_(std::move(result_slices)), - opaque_(std::move(opaque)) { - const HloInstruction* instr = hlo_instruction_; - CHECK_EQ(instr->operand_count(), operand_slices_.size()); - for (int64 i = 0; i < instr->operand_count(); ++i) { - const auto& s1 = operand_slices_[i].shape(); - const auto& s2 = instr->operand(i)->shape(); - CHECK(ShapeUtil::Equal(s1, s2)) << absl::StreamFormat( - "Shape mismatch between instr->operand(%d) and " - "operand_slices[%d].shape(): %s vs %s", - i, i, s1.ToString(), s2.ToString()); - } - CHECK(ShapeUtil::Equal(instr->shape(), result_slices.shape())) - << absl::StreamFormat( - "Shape mismatch between instr->shape() and result_slices.shape(): " - "%s vs %s.", - instr->shape().ToString(), result_slices.shape().ToString()); -} + opaque_(opaque) {} // For each leaf in a preorder traversal of `slices`, appends its device address // to `buffers`. diff --git a/tensorflow/compiler/xla/service/gpu/custom_call_thunk.h b/tensorflow/compiler/xla/service/gpu/custom_call_thunk.h index 31c03f5252f..f36eaa9cef2 100644 --- a/tensorflow/compiler/xla/service/gpu/custom_call_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/custom_call_thunk.h @@ -41,16 +41,16 @@ class CustomCallThunk : public Thunk { CustomCallThunk( ThunkInfo thunk_info, void* call_target, std::vector> operand_slices, - ShapeTree result_slices, std::string opaque); + ShapeTree result_slices, + const std::string& opaque); Status ExecuteOnStream(const ExecuteParams& params) override; private: - const HloInstruction* hlo_instruction_; void* call_target_; std::vector> operand_slices_; ShapeTree result_slices_; - std::string opaque_; + const std::string opaque_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/dummy_all_reduce_thunk.cc b/tensorflow/compiler/xla/service/gpu/dummy_all_reduce_thunk.cc index 318b8aff176..4cc19a23201 100644 --- a/tensorflow/compiler/xla/service/gpu/dummy_all_reduce_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/dummy_all_reduce_thunk.cc @@ -14,10 +14,22 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" namespace xla { namespace gpu { +struct NcclAllReduceConfig::AuxData {}; + +NcclAllReduceConfig::NcclAllReduceConfig(NcclAllReduceConfig &&) = default; +NcclAllReduceConfig::~NcclAllReduceConfig() = default; + +NcclAllReduceConfig GetNcclAllReduceConfig(const HloInstruction *instr, + int64 replica_count) { + NcclAllReduceConfig config = {}; + return config; +} + /* static */ bool NcclAllReduceThunk::NcclIsEnabled() { return false; // Skylark selects this source file if NCCL is disabled. } @@ -32,20 +44,16 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) { "compiler, which is necessary to build the NCCL source library."); } -NcclAllReduceThunk::~NcclAllReduceThunk() = default; - /*static*/ absl::flat_hash_set NcclAllReduceThunk::DevicesWithOpenNcclChannels() { return {}; } -struct NcclAllReduceThunk::AuxData {}; - NcclAllReduceThunk::NcclAllReduceThunk( - ThunkInfo thunk_info, int64 replica_count, + ThunkInfo thunk_info, NcclAllReduceConfig &&config, std::vector buffers) : Thunk(Thunk::kNcclAllReduce, thunk_info), - replica_count_(replica_count), + config_(std::move(config)), buffers_(std::move(buffers)) {} } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/for_thunk.cc b/tensorflow/compiler/xla/service/gpu/for_thunk.cc index ccd661d8ade..a9e6cd05c31 100644 --- a/tensorflow/compiler/xla/service/gpu/for_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/for_thunk.cc @@ -24,15 +24,16 @@ namespace xla { namespace gpu { ForThunk::ForThunk(ThunkInfo thunk_info, const int64 loop_limit, - std::unique_ptr body_thunk_sequence) + std::unique_ptr body_thunk_sequence, + absl::optional body_profile_index) : Thunk(Kind::kWhile, thunk_info), - hlo_instruction_(thunk_info.hlo_instruction), loop_limit_(loop_limit), body_thunk_sequence_(absl::make_unique( // Pass nullptr as the HloInstruction* to the body_thunk_sequence_ // constructor because this SequentialThunk is logically "part of" // this ForThunk, and shouldn't be profiled separately from it. - ThunkInfo(), std::move(*body_thunk_sequence))) {} + ThunkInfo(), std::move(*body_thunk_sequence))), + body_profile_index_(body_profile_index) {} Status ForThunk::Initialize(const GpuExecutable& executable, se::StreamExecutor* executor) { @@ -41,15 +42,14 @@ Status ForThunk::Initialize(const GpuExecutable& executable, } Status ForThunk::ExecuteOnStream(const ExecuteParams& params) { - VLOG(2) << "Executing ForThunk with " << loop_limit_ << " iters for " - << (hlo_instruction_ ? hlo_instruction_->ToString() : ""); + VLOG(2) << "Executing ForThunk with " << loop_limit_ << " iters"; auto op_profiler = params.profiler->MakeScopedInstructionProfiler(profile_index()); for (int64 i = 0; i < loop_limit_; ++i) { params.profiler->StartHloComputation(); // Invoke loop body thunk sequence. TF_RETURN_IF_ERROR(body_thunk_sequence_->ExecuteOnStream(params)); - params.profiler->FinishHloComputation(hlo_instruction_->while_body()); + params.profiler->FinishHloComputation(body_profile_index_); } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/for_thunk.h b/tensorflow/compiler/xla/service/gpu/for_thunk.h index b6ee950737e..9a8bd069290 100644 --- a/tensorflow/compiler/xla/service/gpu/for_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/for_thunk.h @@ -32,7 +32,8 @@ namespace gpu { class ForThunk : public Thunk { public: ForThunk(ThunkInfo thunk_info, const int64 loop_limit, - std::unique_ptr body_thunk_sequence); + std::unique_ptr body_thunk_sequence, + absl::optional body_profile_index_); ForThunk(const ForThunk&) = delete; ForThunk& operator=(const ForThunk&) = delete; @@ -41,9 +42,9 @@ class ForThunk : public Thunk { Status ExecuteOnStream(const ExecuteParams& params) override; private: - const HloInstruction* hlo_instruction_; const int64 loop_limit_; std::unique_ptr body_thunk_sequence_; + const absl::optional body_profile_index_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc index a499dc70e23..23706cb9973 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc @@ -201,8 +201,10 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { // Merging into all users enables the removal of 'fusion' from the // computation. if (!absl::c_all_of(fusion->users(), [&](const HloInstruction* user) { - return user->opcode() == HloOpcode::kFusion && - IsProducerConsumerFusible(*fusion, *user); + return IsProducerConsumerFusible(*fusion, *user) && + // Do not fuse into bitcast ops, which are no-ops and do not + // generate any GPU code. + user->opcode() != HloOpcode::kBitcast; })) { VLOG(3) << "Not merging " << fusion->name() << ": Some of its users are not loop/input fusion kernels."; @@ -283,7 +285,15 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { // Merge fused instructions from 'fusion' into each user. std::vector users = fusion->users(); for (HloInstruction* user : users) { - user->MergeFusionInstruction(fusion); + if (user->opcode() == HloOpcode::kFusion) { + user->MergeFusionInstruction(fusion); + } else { + HloInstruction* fused_user = + computation_->AddInstruction(HloInstruction::CreateFusion( + user->shape(), ChooseFusionKind(*fusion, *user), user)); + TF_CHECK_OK(computation_->ReplaceInstruction(user, fused_user)); + fused_user->MergeFusionInstruction(fusion); + } changed_ = true; } ++total_merged_; @@ -296,7 +306,7 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { }) << " }"; // Remove 'fusion' instruction. - CHECK_EQ(0, fusion->user_count()); + CHECK_EQ(0, fusion->user_count()) << fusion->ToString(); return computation_->RemoveInstruction(fusion); } diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc index cc4894f4c00..7468114516d 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc @@ -234,6 +234,54 @@ TEST_F(FusionMergerTest, WillMergeIntoInputFusion) { op::Fusion(op::Parameter())); } +TEST_F(FusionMergerTest, WillMergeIntoUnfusedConsumer) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule jit_matmul.36 + + max (parameter.13: f32[], parameter.14: f32[]) -> f32[] { + parameter.13 = f32[] parameter(0) + parameter.14 = f32[] parameter(1) + ROOT maximum.15 = f32[] maximum(f32[] parameter.13, f32[] parameter.14) + } + + add (parameter.29: f32[], parameter.30: f32[]) -> f32[] { + parameter.29 = f32[] parameter(0) + parameter.30 = f32[] parameter(1) + ROOT add.31 = f32[] add(f32[] parameter.29, f32[] parameter.30) + } + + fused_computation.1 (param_1.4: f32[200,200,200], param_2.1: f32[200,200]) -> f32[200,200] { + param_1.4 = f32[200,200,200]{2,1,0} parameter(0) + param_2.1 = f32[200,200]{1,0} parameter(1) + broadcast.3 = f32[200,200,200]{2,1,0} broadcast(f32[200,200]{1,0} param_2.1), dimensions={0,2} + subtract.0 = f32[200,200,200]{2,1,0} subtract(f32[200,200,200]{2,1,0} param_1.4, f32[200,200,200]{2,1,0} broadcast.3) + exponential.0 = f32[200,200,200]{2,1,0} exponential(f32[200,200,200]{2,1,0} subtract.0) + constant.27 = f32[] constant(0) + ROOT reduce.0 = f32[200,200]{1,0} reduce(f32[200,200,200]{2,1,0} exponential.0, f32[] constant.27), dimensions={1}, to_apply=add + } + + fused_computation.3 (param_0.7: f32[200,200], param_1.9: f32[200,200]) -> f32[200,200,200] { + param_1.9 = f32[200,200]{1,0} parameter(1) + broadcast.10 = f32[200,200,200]{2,1,0} broadcast(f32[200,200]{1,0} param_1.9), dimensions={0,1} + param_0.7 = f32[200,200]{1,0} parameter(0) + broadcast.8 = f32[200,200,200]{2,1,0} broadcast(f32[200,200]{1,0} param_0.7), dimensions={1,2} + ROOT add.1 = f32[200,200,200]{2,1,0} add(f32[200,200,200]{2,1,0} broadcast.10, f32[200,200,200]{2,1,0} broadcast.8) + } + + ENTRY entry (parameter.1: f32[200,200], parameter.2: f32[200,200]) -> f32[200,200] { + parameter.2 = f32[200,200]{1,0} parameter(1) + parameter.1 = f32[200,200]{1,0} parameter(0) + fusion.3 = f32[200,200,200]{2,1,0} fusion(f32[200,200]{1,0} parameter.2, f32[200,200]{1,0} parameter.1), kind=kLoop, calls=fused_computation.3 + constant.11 = f32[] constant(-inf) + reduce.16 = f32[200,200]{1,0} reduce(f32[200,200,200]{2,1,0} fusion.3, f32[] constant.11), dimensions={1}, to_apply=max + ROOT fusion.1 = f32[200,200]{1,0} fusion(f32[200,200,200]{2,1,0} fusion.3, f32[200,200]{1,0} reduce.16), kind=kInput, calls=fused_computation.1 + })") + .ValueOrDie(); + EXPECT_TRUE(FusionMerger().Run(module.get()).ValueOrDie()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Fusion(op::Fusion(), op::Parameter(), op::Parameter())); +} + TEST_F(FusionMergerTest, WillNotMergeReduceUnfriendlyLayouts) { auto module = ParseAndReturnVerifiedModule(R"( HloModule m @@ -421,6 +469,165 @@ TEST_F(FusionMergerTest, WillMergeExpensiveFusionsWithSingleConsumer) { EXPECT_TRUE(FusionMerger().Run(module.get()).ValueOrDie()); } +TEST_F(FusionMergerTest, NoMergeBecauseCodeDuplication) { + auto module = ParseAndReturnVerifiedModule(R"( +HloModule module + +and.reduce_sub_computation { + x = pred[] parameter(0) + y = pred[] parameter(1) + ROOT and = pred[] and(x, y) +} + +fused_computation.1 { + param_4.658 = f32[2,20,256]{2,0,1} parameter(4) + slice.1385 = f32[2,1,256]{2,0,1} slice(param_4.658), slice={[0:2], [11:12], [0:256]} + constant.6847 = s32[] constant(0) + broadcast.4823 = s32[3]{0} broadcast(constant.6847), dimensions={} + param_9.415 = s32[3]{0} parameter(9) + compare.700 = pred[3]{0} compare(broadcast.4823, param_9.415), direction=LE + constant.6846 = pred[] constant(true) + reduce.221 = pred[] reduce(compare.700, constant.6846), dimensions={0}, to_apply=and.reduce_sub_computation + broadcast.2933 = pred[2,1,256]{2,0,1} broadcast(reduce.221), dimensions={} + param_5.528 = f32[2,512]{1,0} parameter(5) + slice.1384 = f32[2,256]{1,0} slice(param_5.528), slice={[0:2], [0:256]} + bitcast.341 = f32[2,1,256]{2,0,1} bitcast(slice.1384) + constant.5418 = f32[] constant(0) + broadcast.3227 = f32[2,1,256]{2,0,1} broadcast(constant.5418), dimensions={} + select.173 = f32[2,1,256]{2,0,1} select(broadcast.2933, bitcast.341, broadcast.3227) + add.573 = f32[2,1,256]{2,0,1} add(slice.1385, select.173) + param_0.299 = s32[] parameter(0) + constant.5157 = s32[] constant(11) + dynamic-update-slice.189 = f32[2,20,256]{2,0,1} dynamic-update-slice(param_4.658, add.573, param_0.299, constant.5157, param_0.299) + slice.1383 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.189), slice={[0:2], [10:11], [0:256]} + constant.6800 = s32[] constant(0) + broadcast.4803 = s32[3]{0} broadcast(constant.6800), dimensions={} + param_8.484 = s32[3]{0} parameter(8) + compare.681 = pred[3]{0} compare(broadcast.4803, param_8.484), direction=LE + constant.6798 = pred[] constant(true) + reduce.203 = pred[] reduce(compare.681, constant.6798), dimensions={0}, to_apply=and.reduce_sub_computation + broadcast.2932 = pred[2,1,256]{2,0,1} broadcast(reduce.203), dimensions={} + param_3.1169 = f32[2,512]{1,0} parameter(3) + slice.1382 = f32[2,256]{1,0} slice(param_3.1169), slice={[0:2], [0:256]} + bitcast.340 = f32[2,1,256]{2,0,1} bitcast(slice.1382) + select.172 = f32[2,1,256]{2,0,1} select(broadcast.2932, bitcast.340, broadcast.3227) + add.572 = f32[2,1,256]{2,0,1} add(slice.1383, select.172) + constant.5154 = s32[] constant(10) + dynamic-update-slice.188 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.189, add.572, param_0.299, constant.5154, param_0.299) + slice.1381 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.188), slice={[0:2], [9:10], [0:256]} + constant.6794 = s32[] constant(0) + broadcast.4801 = s32[3]{0} broadcast(constant.6794), dimensions={} + param_7.478 = s32[3]{0} parameter(7) + compare.679 = pred[3]{0} compare(broadcast.4801, param_7.478), direction=LE + constant.6793 = pred[] constant(true) + reduce.201 = pred[] reduce(compare.679, constant.6793), dimensions={0}, to_apply=and.reduce_sub_computation + broadcast.2930 = pred[2,1,256]{2,0,1} broadcast(reduce.201), dimensions={} + param_2.1685 = f32[2,512]{1,0} parameter(2) + slice.1380 = f32[2,256]{1,0} slice(param_2.1685), slice={[0:2], [0:256]} + bitcast.339 = f32[2,1,256]{2,0,1} bitcast(slice.1380) + select.171 = f32[2,1,256]{2,0,1} select(broadcast.2930, bitcast.339, broadcast.3227) + add.571 = f32[2,1,256]{2,0,1} add(slice.1381, select.171) + constant.5153 = s32[] constant(9) + dynamic-update-slice.187 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.188, add.571, param_0.299, constant.5153, param_0.299) + slice.1379 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.187), slice={[0:2], [8:9], [0:256]} + constant.6788 = s32[] constant(0) + broadcast.4799 = s32[3]{0} broadcast(constant.6788), dimensions={} + param_6.495 = s32[3]{0} parameter(6) + compare.677 = pred[3]{0} compare(broadcast.4799, param_6.495), direction=LE + constant.6786 = pred[] constant(true) + reduce.199 = pred[] reduce(compare.677, constant.6786), dimensions={0}, to_apply=and.reduce_sub_computation + broadcast.2929 = pred[2,1,256]{2,0,1} broadcast(reduce.199), dimensions={} + param_1.1408 = f32[2,512]{1,0} parameter(1) + slice.1378 = f32[2,256]{1,0} slice(param_1.1408), slice={[0:2], [0:256]} + bitcast.338 = f32[2,1,256]{2,0,1} bitcast(slice.1378) + select.170 = f32[2,1,256]{2,0,1} select(broadcast.2929, bitcast.338, broadcast.3227) + add.570 = f32[2,1,256]{2,0,1} add(slice.1379, select.170) + constant.5152 = s32[] constant(8) + ROOT dynamic-update-slice.186 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.187, add.570, param_0.299, constant.5152, param_0.299) +} + +fused_computation.2 { + param_4.655 = f32[2,20,256]{2,0,1} parameter(4) + slice.1369 = f32[2,1,256]{2,0,1} slice(param_4.655), slice={[0:2], [7:8], [0:256]} + param_6.483 = pred[] parameter(6) + broadcast.2927 = pred[2,1,256]{2,0,1} broadcast(param_6.483), dimensions={} + param_5.525 = f32[2,512]{1,0} parameter(5) + slice.1368 = f32[2,256]{1,0} slice(param_5.525), slice={[0:2], [0:256]} + bitcast.333 = f32[2,1,256]{2,0,1} bitcast(slice.1368) + constant.5415 = f32[] constant(0) + broadcast.3225 = f32[2,1,256]{2,0,1} broadcast(constant.5415), dimensions={} + select.161 = f32[2,1,256]{2,0,1} select(broadcast.2927, bitcast.333, broadcast.3225) + add.549 = f32[2,1,256]{2,0,1} add(slice.1369, select.161) + param_0.265 = s32[] parameter(0) + constant.5151 = s32[] constant(7) + dynamic-update-slice.185 = f32[2,20,256]{2,0,1} dynamic-update-slice(param_4.655, add.549, param_0.265, constant.5151, param_0.265) + slice.1367 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.185), slice={[0:2], [6:7], [0:256]} + constant.6782 = s32[] constant(0) + broadcast.4797 = s32[3]{0} broadcast(constant.6782), dimensions={} + param_9.391 = s32[3]{0} parameter(9) + compare.675 = pred[3]{0} compare(broadcast.4797, param_9.391), direction=LE + constant.6781 = pred[] constant(true) + reduce.197 = pred[] reduce(compare.675, constant.6781), dimensions={0}, to_apply=and.reduce_sub_computation + broadcast.2926 = pred[2,1,256]{2,0,1} broadcast(reduce.197), dimensions={} + param_3.1167 = f32[2,512]{1,0} parameter(3) + slice.1366 = f32[2,256]{1,0} slice(param_3.1167), slice={[0:2], [0:256]} + bitcast.332 = f32[2,1,256]{2,0,1} bitcast(slice.1366) + select.160 = f32[2,1,256]{2,0,1} select(broadcast.2926, bitcast.332, broadcast.3225) + add.548 = f32[2,1,256]{2,0,1} add(slice.1367, select.160) + constant.5150 = s32[] constant(6) + dynamic-update-slice.184 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.185, add.548, param_0.265, constant.5150, param_0.265) + slice.1365 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.184), slice={[0:2], [5:6], [0:256]} + constant.6776 = s32[] constant(0) + broadcast.4794 = s32[3]{0} broadcast(constant.6776), dimensions={} + param_8.464 = s32[3]{0} parameter(8) + compare.673 = pred[3]{0} compare(broadcast.4794, param_8.464), direction=LE + constant.6775 = pred[] constant(true) + reduce.195 = pred[] reduce(compare.673, constant.6775), dimensions={0}, to_apply=and.reduce_sub_computation + broadcast.2925 = pred[2,1,256]{2,0,1} broadcast(reduce.195), dimensions={} + param_2.1684 = f32[2,512]{1,0} parameter(2) + slice.1364 = f32[2,256]{1,0} slice(param_2.1684), slice={[0:2], [0:256]} + bitcast.331 = f32[2,1,256]{2,0,1} bitcast(slice.1364) + select.159 = f32[2,1,256]{2,0,1} select(broadcast.2925, bitcast.331, broadcast.3225) + add.547 = f32[2,1,256]{2,0,1} add(slice.1365, select.159) + constant.5149 = s32[] constant(5) + dynamic-update-slice.183 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.184, add.547, param_0.265, constant.5149, param_0.265) + slice.1363 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.183), slice={[0:2], [4:5], [0:256]} + constant.6770 = s32[] constant(0) + broadcast.4792 = s32[3]{0} broadcast(constant.6770), dimensions={} + param_7.458 = s32[3]{0} parameter(7) + compare.671 = pred[3]{0} compare(broadcast.4792, param_7.458), direction=LE + constant.6769 = pred[] constant(true) + reduce.193 = pred[] reduce(compare.671, constant.6769), dimensions={0}, to_apply=and.reduce_sub_computation + broadcast.2924 = pred[2,1,256]{2,0,1} broadcast(reduce.193), dimensions={} + param_1.1405 = f32[2,512]{1,0} parameter(1) + slice.1362 = f32[2,256]{1,0} slice(param_1.1405), slice={[0:2], [0:256]} + bitcast.330 = f32[2,1,256]{2,0,1} bitcast(slice.1362) + select.158 = f32[2,1,256]{2,0,1} select(broadcast.2924, bitcast.330, broadcast.3225) + add.546 = f32[2,1,256]{2,0,1} add(slice.1363, select.158) + constant.5148 = s32[] constant(4) + ROOT dynamic-update-slice.182 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.183, add.546, param_0.265, constant.5148, param_0.265) +} + +ENTRY main { + param_0.0 = s32[] parameter(0) + param_1.0 = f32[2,512]{1,0} parameter(1) + param_2.0 = f32[2,512]{1,0} parameter(2) + param_3.0 = f32[2,512]{1,0} parameter(3) + param_4.0 = f32[2,20,256]{2,1,0} parameter(4) + param_5.0 = f32[2,512]{1,0} parameter(5) + param_6.0 = s32[3]{0} parameter(6) + param_7.0 = s32[3]{0} parameter(7) + param_8.0 = s32[3]{0} parameter(8) + param_9.0 = s32[3]{0} parameter(9) + fusion.1 = f32[2,20,256]{2,0,1} fusion(param_0.0, param_1.0, param_2.0, param_3.0, param_4.0, param_5.0, param_6.0, param_7.0, param_8.0, param_9.0), kind=kLoop, calls=fused_computation.1 + param_10 = pred[] parameter(10) + ROOT fusion.2 = f32[2,20,256]{2,0,1} fusion(param_0.0, param_1.0, param_2.0, param_3.0, fusion.1, param_5.0, param_10, param_7.0, param_8.0, param_9.0), kind=kLoop, calls=fused_computation.2 +} + )") + .ValueOrDie(); + EXPECT_FALSE(FusionMerger().Run(module.get()).ValueOrDie()); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc index 0320496ea98..5a8265a53a6 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc @@ -115,6 +115,8 @@ static StatusOr> DoUncachedGemmAutotune( absl::optional first_algorithm; std::vector profile_results; + GpuGemmConfig config = GetGpuGemmConfig(gemm); + for (se::blas::AlgorithmType algorithm : algorithms) { // Make sure the output buffer always has the same value if we use // the bias parameter. @@ -129,8 +131,7 @@ static StatusOr> DoUncachedGemmAutotune( // for all algorithms if we're targeting < sm_50. But because we pass a // non-null ProfileResult, DoGemmWithAlgorithm should always return true, // and the actual success-ness is returned in ProfileResult::is_valid. - CHECK(RunGemm(gemm, backend_config, lhs_buffer, rhs_buffer, output_buffer, - stream, + CHECK(RunGemm(config, lhs_buffer, rhs_buffer, output_buffer, stream, /*implements_whole_instruction=*/true, /*profile_index=*/-1, /*profiler=*/nullptr, diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc index e55df0bb230..ea4f3951a3d 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc @@ -33,32 +33,40 @@ limitations under the License. namespace xla { namespace gpu { -GemmThunk::GemmThunk(ThunkInfo thunk_info, +GpuGemmConfig GetGpuGemmConfig(const HloInstruction *gemm) { + GpuGemmConfig config; + config.output_shape = gemm->shape(); + config.lhs_shape = gemm->operand(0)->shape(); + config.rhs_shape = gemm->operand(1)->shape(); + auto backend_config_or = gemm->backend_config(); + config.backend_config = std::move(backend_config_or.ValueOrDie()); + return config; +} + +GemmThunk::GemmThunk(ThunkInfo thunk_info, GpuGemmConfig &&config, const BufferAllocation::Slice &lhs_buffer, const BufferAllocation::Slice &rhs_buffer, const BufferAllocation::Slice &output_buffer, - bool implements_whole_instruction, - const GemmBackendConfig &backend_config) + bool implements_whole_instruction) : Thunk(Kind::kGemm, thunk_info), - hlo_instruction_(thunk_info.hlo_instruction), + config_(std::move(config)), lhs_buffer_(lhs_buffer), rhs_buffer_(rhs_buffer), output_buffer_(output_buffer), - implements_whole_instruction_(implements_whole_instruction), - backend_config_(backend_config) {} + implements_whole_instruction_(implements_whole_instruction) {} Status GemmThunk::ExecuteOnStream(const ExecuteParams ¶ms) { auto get_device_address = [&](const BufferAllocation::Slice &slice) { return params.buffer_allocations->GetDeviceAddress(slice); }; - VLOG(3) << "Running GEMM thunk on instruction: " << hlo_instruction_; + VLOG(3) << "Running GEMM thunk"; se::DeviceMemoryBase lhs_data = get_device_address(lhs_buffer_); se::DeviceMemoryBase rhs_data = get_device_address(rhs_buffer_); se::DeviceMemoryBase output_data = get_device_address(output_buffer_); - return RunGemm(hlo_instruction_, backend_config_, lhs_data, rhs_data, - output_data, params.stream, implements_whole_instruction_, - profile_index(), params.profiler); + return RunGemm(config_, lhs_data, rhs_data, output_data, params.stream, + implements_whole_instruction_, profile_index(), + params.profiler); } // This struct contains the metadata of a matrix, e.g., its base address and @@ -160,8 +168,7 @@ static bool DoGemmWithAlgorithm( .ok(); } -Status RunGemm(const HloInstruction *gemm, - const GemmBackendConfig &backend_config, +Status RunGemm(const GpuGemmConfig &gemm_config, se::DeviceMemoryBase lhs_buffer, se::DeviceMemoryBase rhs_buffer, se::DeviceMemoryBase output_buffer, se::Stream *stream, bool implements_whole_instruction, @@ -170,14 +177,11 @@ Status RunGemm(const HloInstruction *gemm, se::blas::ProfileResult *profile_result, absl::optional algorithm) { VLOG(2) << "Executing a GemmThunk"; - CHECK(IsCublasGemm(*gemm)); - const Shape &output_shape = gemm->shape(); - const HloInstruction *lhs = gemm->operand(0); - const HloInstruction *rhs = gemm->operand(1); - - const Shape &lhs_shape = lhs->shape(); - const Shape &rhs_shape = rhs->shape(); + const Shape &output_shape = gemm_config.output_shape; + const Shape &lhs_shape = gemm_config.lhs_shape; + const Shape &rhs_shape = gemm_config.rhs_shape; + const GemmBackendConfig &backend_config = gemm_config.backend_config; const DotDimensionNumbers &dim_nums = backend_config.dot_dimension_numbers(); CHECK_EQ(dim_nums.lhs_batch_dimensions_size(), diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h index 1a51a7d4e0c..9d6613dbe77 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h @@ -33,17 +33,26 @@ namespace gpu { // This class stores everything that StreamExecutor needs to launch a BLAS gemm. // It is generated by IrEmitter. -// + +struct GpuGemmConfig { + Shape lhs_shape; + Shape rhs_shape; + Shape output_shape; + GemmBackendConfig backend_config; +}; + +GpuGemmConfig GetGpuGemmConfig(const HloInstruction* gemm); + // This is thread-compatible. class GemmThunk : public Thunk { public: // Constructs a thunk that computes "output = (lhs rhs) * alpha" using // BLAS gemm (alpha is stored in the instruction GemmBackendConfig). - GemmThunk(ThunkInfo thunk_info, const BufferAllocation::Slice& lhs_buffer, + GemmThunk(ThunkInfo thunk_info, GpuGemmConfig&& config, + const BufferAllocation::Slice& lhs_buffer, const BufferAllocation::Slice& rhs_buffer, const BufferAllocation::Slice& output_buffer, - bool implements_whole_instruction, - const GemmBackendConfig& backend_config); + bool implements_whole_instruction); GemmThunk(const GemmThunk&) = delete; GemmThunk& operator=(const GemmThunk&) = delete; @@ -51,28 +60,27 @@ class GemmThunk : public Thunk { Status ExecuteOnStream(const ExecuteParams& params) override; private: - const HloInstruction* hlo_instruction_; + const GpuGemmConfig config_; const BufferAllocation::Slice lhs_buffer_; const BufferAllocation::Slice rhs_buffer_; const BufferAllocation::Slice output_buffer_; - bool implements_whole_instruction_; - GemmBackendConfig backend_config_; + const bool implements_whole_instruction_; }; // Run the given GEMM instruction `gemm` subject to the configuration -// in `backend_config` and the passed buffers. +// in `gemm_config` and the passed buffers. // // `implements_whole_instruction` is used for the default profiler creation // if the `profiler` is not supplied. False value indicates that the created // profiler will not specifically profile the `gemm` instruction. // // If `algorithm` is provided, it overrides the one specified in -// `backend_config`. +// `gemm_config.backend_config`. Status RunGemm( - const HloInstruction* gemm, const GemmBackendConfig& backend_config, - se::DeviceMemoryBase lhs_buffer, se::DeviceMemoryBase rhs_buffer, - se::DeviceMemoryBase output_buffer, se::Stream* stream, - bool implements_whole_instruction, absl::optional profile_index, + const GpuGemmConfig& gemm_config, se::DeviceMemoryBase lhs_buffer, + se::DeviceMemoryBase rhs_buffer, se::DeviceMemoryBase output_buffer, + se::Stream* stream, bool implements_whole_instruction, + absl::optional profile_index, HloExecutionProfiler* profiler = nullptr, se::blas::ProfileResult* profile_result = nullptr, absl::optional algorithm = absl::nullopt); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 77fcf2c59f7..feedff0e0b3 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -59,7 +59,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h" #include "tensorflow/compiler/xla/service/gpu/gpu_sanitize_constant_names.h" #include "tensorflow/compiler/xla/service/gpu/gpu_scatter_expander.h" -#include "tensorflow/compiler/xla/service/gpu/horizontal_fusion.h" +#include "tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.h" +#include "tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" @@ -91,6 +92,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/logistic_expander.h" +#include "tensorflow/compiler/xla/service/qr_expander.h" +#include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/service/rng_bit_generator_expander.h" #include "tensorflow/compiler/xla/service/rng_expander.h" #include "tensorflow/compiler/xla/service/slice_sinker.h" @@ -150,6 +153,8 @@ Status GpuCompiler::OptimizeHloModule( pipeline.AddPass(); pipeline.AddPass(); + // TODO(phawkins): replace QR decompositions with calls to cuSOLVER. + pipeline.AddPass(); pipeline.AddPass(); @@ -226,6 +231,7 @@ Status GpuCompiler::OptimizeHloModule( // pass.AddPass(); pass.AddPass(); + pass.AddPass(); pass.AddPass(); pass.AddPass(); } @@ -236,8 +242,7 @@ Status GpuCompiler::OptimizeHloModule( return IsMatrixMultiplication(dot) ? candidate_operands : TransposeFolding::OperandIndices{}; - }, - TransposeFolding::NeverFoldTranspose); + }); pipeline.AddPass(/*is_layout_sensitive=*/false); pipeline.AddPass(); @@ -301,12 +306,14 @@ Status GpuCompiler::OptimizeHloModule( TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status()); HloPassPipeline horizontal_fusion("horizontal_fusion"); - horizontal_fusion.AddPass(); + horizontal_fusion.AddPass(); + horizontal_fusion.AddPass(); horizontal_fusion.AddPass(/*is_layout_sensitive=*/true, /*only_fusion_computations=*/true); horizontal_fusion.AddPass(); TF_RETURN_IF_ERROR(horizontal_fusion.Run(hlo_module).status()); } + { HloPassPipeline pipeline("all_reduce_combiner"); pipeline.AddPass( @@ -483,7 +490,8 @@ static Status CompileModuleToLlvmIrImpl( int pointer_size, const HloProfileIndexMap* profile_index_map, std::unique_ptr* llvm_module, std::unique_ptr* buffer_assignment, - std::unique_ptr* thunk_schedule) { + std::unique_ptr* thunk_schedule, + std::vector* constants) { *llvm_module = absl::make_unique("", *llvm_context); (*llvm_module)->setTargetTriple(target_triple); @@ -516,7 +524,6 @@ static Status CompileModuleToLlvmIrImpl( DumpHloModuleIfEnabled(*hlo_module, **buffer_assignment, "after_optimizations"); - mlir::registerAllDialects(); mlir::MLIRContext mlir_context; IrEmitterContext ir_emitter_context( @@ -531,8 +538,6 @@ static Status CompileModuleToLlvmIrImpl( IrEmitterUnnested::Create(hlo_module->config(), entry_computation, &ir_emitter_context)); - TF_RETURN_IF_ERROR(ir_emitter->EmitConstantGlobals()); - { XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend - IR emission"); @@ -581,6 +586,10 @@ static Status CompileModuleToLlvmIrImpl( *thunk_schedule = absl::make_unique( std::make_unique(std::move(thunk_sequence)), std::move(stream_assignment), std::move(thunk_to_hlo)); + + if (constants) { + *constants = std::move(ir_emitter_context.constants()); + } } return Status::OK(); @@ -612,6 +621,9 @@ StatusOr> GpuCompiler::RunBackend( stream_exec->GetDeviceDescription().threads_per_warp(); gpu_device_info.shared_memory_per_block = stream_exec->GetDeviceDescription().shared_memory_per_block(); + gpu_device_info.threads_per_core_limit = + stream_exec->GetDeviceDescription().threads_per_core_limit(); + gpu_device_info.core_count = stream_exec->GetDeviceDescription().core_count(); absl::optional cuda_compute_capability = [&]() -> absl::optional { @@ -646,12 +658,13 @@ StatusOr> GpuCompiler::RunBackend( std::unique_ptr llvm_module; std::unique_ptr buffer_assignment; std::unique_ptr thunk_schedule; + std::vector constants; TF_RETURN_IF_ERROR(CompileModuleToLlvmIrImpl( module.get(), &llvm_context, target_triple_, data_layout_, stream_exec->platform()->Name(), gpu_device_info, cuda_compute_capability, GetCanShareBuffer(), pointer_size_, profile_index_map.get(), &llvm_module, - &buffer_assignment, &thunk_schedule)); + &buffer_assignment, &thunk_schedule, &constants)); if (user_pre_optimization_hook_) { user_pre_optimization_hook_(*llvm_module); @@ -697,7 +710,7 @@ StatusOr> GpuCompiler::RunBackend( backend_result.first, backend_result.second, gpu_version, std::move(thunk_schedule), std::move(module), std::move(buffer_assignment), std::move(profile_printer), - std::move(profile_index_map)); + std::move(profile_index_map), std::move(constants)); if (embed_ir_in_executable) { DCHECK_NE("", ir_module_string_before_opt); gpu_executable->set_ir_module_string(ir_module_string_before_opt); @@ -731,7 +744,7 @@ StatusOr> CompileModuleToLlvmIr( hlo_module, llvm_context, target_triple, data_layout, platform_name, gpu_device_info, cuda_compute_capability, DummyCanShareBufferFunction, pointer_size, /*profile_index_map=*/nullptr, &llvm_module, - &buffer_assignment, &thunk_schedule)); + &buffer_assignment, &thunk_schedule, nullptr)); return llvm_module; } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc index 8fb741323f3..925caadbb97 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc @@ -122,24 +122,28 @@ std::vector GetAlgorithms(CudnnConvKind kind, } StatusOr> GetMIOpenAlgorithms( - const HloCustomCallInstruction* conv, + const HloCustomCallInstruction* instr, absl::Span operand_buffers, se::DeviceMemoryBase result_buffer, se::StreamExecutor* stream_exec, ScratchAllocator* scratch_allocator, se::Stream* stream) { std::vector algorithms; - TF_ASSIGN_OR_RETURN(se::dnn::ConvolutionKind kind, - GetDnnConvolutionKind(conv)); + TF_ASSIGN_OR_RETURN(GpuConvConfig config, GetGpuConvConfig(instr)); - TF_ASSIGN_OR_RETURN(se::dnn::DataType dtype, GetDnnDataType(conv)); + TF_ASSIGN_OR_RETURN(se::dnn::ConvolutionKind kind, + GetDNNConvKindFromCudnnConvKind(config.kind)); + + TF_ASSIGN_OR_RETURN(se::dnn::DataType dtype, + GetDNNDataTypeFromPrimitiveType(config.output_type)); TF_ASSIGN_OR_RETURN(GpuConvParams params, - GetGpuConvParams(conv, operand_buffers, result_buffer)); + GetGpuConvParams(config, operand_buffers, result_buffer)); bool succ = stream_exec->GetMIOpenConvolveAlgorithms( - kind, dtype, stream, params.input_descriptor, params.input_buf, - params.filter_descriptor, params.filter_buf, params.output_descriptor, - params.output_buf, params.conv_desc, scratch_allocator, &algorithms); + kind, dtype, stream, params.config.input_descriptor, params.input_buf, + params.config.filter_descriptor, params.filter_buf, + params.config.output_descriptor, params.output_buf, + params.config.conv_desc, scratch_allocator, &algorithms); DCHECK(succ); return algorithms; @@ -442,6 +446,8 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda( GetComputeCapability(stream_exec_), GetCudnnVersion(stream_exec_), blas_version, canonical_hlo); + TF_ASSIGN_OR_RETURN(GpuConvConfig config, GetGpuConvConfig(instr)); + for (const AlgorithmDesc& alg : GetAlgorithms(kind, stream_exec_)) { XLA_SCOPED_LOGGING_TIMER_LEVEL( absl::StrCat("CudnnConvAlgorithmPicker::PickBestAlgorithm algo ", @@ -465,7 +471,7 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda( options.profile_result = &profile_result; options.algo_override = alg; Status launch_status = - RunGpuConv(instr, absl::MakeSpan(operand_buffers), result_buffer, + RunGpuConv(config, absl::MakeSpan(operand_buffers), result_buffer, &scratch_allocator, stream, options); if (!launch_status.ok()) { @@ -700,6 +706,7 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheRocm( *result.mutable_run_time() = tensorflow::proto_utils::ToDurationProto( absl::Milliseconds(profile_result.elapsed_time_in_ms())); } else { + TF_ASSIGN_OR_RETURN(GpuConvConfig config, GetGpuConvConfig(instr)); for (const auto& miopen_alg : algorithms) { const auto& alg = miopen_alg.algorithm(); XLA_SCOPED_LOGGING_TIMER_LEVEL( @@ -717,7 +724,7 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheRocm( options.algo_override = alg; options.scratch_size_override = miopen_alg.scratch_size(); Status launch_status = - RunGpuConv(instr, absl::MakeSpan(operand_buffers), result_buffer, + RunGpuConv(config, absl::MakeSpan(operand_buffers), result_buffer, &scratch_allocator, stream, options); if (!launch_status.ok()) { diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_padding_legalization.cc b/tensorflow/compiler/xla/service/gpu/gpu_conv_padding_legalization.cc index 5fa102ac785..94f9a96c0fe 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_conv_padding_legalization.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_conv_padding_legalization.cc @@ -313,7 +313,11 @@ bool GpuConvPaddingLegalization::CanonicalizeBackwardInputConvolution( new_backward_conv_window.mutable_dimensions(i)); } // Decreasing the padding by X *increases* the size of our output by X. - int64 dim = backward_conv_dnums.output_spatial_dimensions(i); + // Note that we have swapped input spatial dimensions with output spatial + // dimensions to be compatible with the cuDNN API, so + // input_spatial_dimensions(i) gives the i-th spatial dimension of the + // output. + int64 dim = backward_conv_dnums.input_spatial_dimensions(i); new_backward_conv_shape.set_dimensions( dim, new_backward_conv_shape.dimensions(dim) + std::abs(padding_low - padding_high)); @@ -353,7 +357,11 @@ bool GpuConvPaddingLegalization::CanonicalizeBackwardInputConvolution( for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) { int64 padding_low = backward_conv->window().dimensions(i).padding_low(); int64 padding_high = backward_conv->window().dimensions(i).padding_high(); - int64 dim = backward_conv_dnums.output_spatial_dimensions(i); + // Note that we have swapped input spatial dimensions with output spatial + // dimensions to be compatible with the cuDNN API, so + // input_spatial_dimensions(i) gives the i-th spatial dimension of the + // output. + int64 dim = backward_conv_dnums.input_spatial_dimensions(i); if (padding_low > padding_high) { // If the amount of low padding (of the old backward convolution) is // larger, we internally pad the low end of the activations and slice diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_padding_legalization_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_conv_padding_legalization_test.cc new file mode 100644 index 00000000000..c214486e18f --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_conv_padding_legalization_test.cc @@ -0,0 +1,97 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/gpu_conv_padding_legalization.h" + +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace gpu { +namespace { + +namespace op = xla::testing::opcode_matchers; +using ::testing::_; + +using GpuConvPaddingLegalizationTest = HloTestBase; + +TEST_F(GpuConvPaddingLegalizationTest, BackwardInputConvolve) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule convolution_module +ENTRY %convolution (operand f64[2,2,2,3]{3,2,1,0}) -> (f64[2,2,4,4]{3,2,1,0}, u8[0]) { + %operand = f64[2,2,2,3]{3,2,1,0} parameter(0) + %kernel = f64[2,3,2,3]{3,2,1,0} constant( + { + { /*i0=0*/ + { /*i1=0*/ + { 0.29629629629629628, 0.30246913580246915, 0.30864197530864196 }, + { 0.31481481481481483, 0.32098765432098764, 0.3271604938271605 } + }, + { /*i1=1*/ + { 0.25925925925925924, 0.26543209876543211, 0.27160493827160492 }, + { 0.27777777777777779, 0.2839506172839506, 0.29012345679012347 } + }, + { /*i1=2*/ + { 0.22222222222222221, 0.22839506172839505, 0.23456790123456789 }, + { 0.24074074074074073, 0.24691358024691357, 0.25308641975308643 } + } + }, + { /*i0=1*/ + { /*i1=0*/ + { 0.18518518518518517, 0.19135802469135801, 0.19753086419753085 }, + { 0.20370370370370369, 0.20987654320987653, 0.21604938271604937 } + }, + { /*i1=1*/ + { 0.14814814814814814, 0.15432098765432098, 0.16049382716049382 }, + { 0.16666666666666666, 0.1728395061728395, 0.17901234567901234 } + }, + { /*i2=2*/ + { 0.1111111111111111, 0.11728395061728394, 0.12345679012345678 }, + { 0.12962962962962962, 0.13580246913580246, 0.1419753086419753 } + } + } + }) + %reverse = f64[2,3,2,3]{3,2,1,0} reverse(%kernel), dimensions={0,1} + ROOT %custom-call = (f64[2,2,4,4]{3,2,1,0}, u8[0]{0}) custom-call(f64[2,2,2,3]{3,2,1,0} %operand, f64[2,3,2,3]{3,2,1,0} %reverse), window={size=2x3 stride=2x2 pad=0_0x0_1}, dim_labels=bf01_01io->b01f, custom_call_target="__cudnn$convBackwardInput", backend_config="{\"algorithm\":\"0\",\"tensor_ops_enabled\":false,\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}" +} + )") + .ValueOrDie(); + ASSERT_TRUE(GpuConvPaddingLegalization().Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, + op::Tuple(op::Slice(op::GetTupleElement( + op::CustomCall(kCudnnConvBackwardInputCallTarget, _, + op::Reverse(op::Constant())), + 0)), + op::GetTupleElement())); + auto slice = root->operand(0); + Shape expected_slice_shape = ShapeUtil::MakeShape(F64, {2, 2, 4, 4}); + EXPECT_TRUE(ShapeUtil::Equal(slice->shape(), expected_slice_shape)); + auto conv = slice->operand(0); + Shape expected_conv_shape = ShapeUtil::MakeShape(F64, {2, 2, 4, 5}); + EXPECT_TRUE(ShapeUtil::Equal(conv->shape(), expected_conv_shape)); +} + +} // anonymous namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.cc b/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.cc index 5cc5fa7d16d..e0ccbad3a01 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.cc @@ -79,16 +79,16 @@ Status RunGpuConvForward(GpuConvParams params, DeviceMemory filter_buf, DeviceMemory output_buf, AlgorithmConfig algorithm) { - if (params.conv_result_scale != 1) { + if (params.config.conv_result_scale != 1) { return InternalError( "StreamExecutor doesn't support scaled convolution: %lf.", - params.conv_result_scale); + params.config.conv_result_scale); } - stream->ThenConvolveWithAlgorithm( - params.input_descriptor, input_buf, params.filter_descriptor, filter_buf, - params.conv_desc, params.output_descriptor, &output_buf, - scratch_allocator, algorithm, options.profile_result); - return Status::OK(); + return stream->ConvolveWithAlgorithm( + params.config.input_descriptor, input_buf, + params.config.filter_descriptor, filter_buf, params.config.conv_desc, + params.config.output_descriptor, &output_buf, scratch_allocator, + algorithm, options.profile_result); } template @@ -103,13 +103,14 @@ Status RunGpuConvForwardActivation(GpuConvParams params, bias_desc.set_count(1) .set_height(1) .set_width(1) - .set_feature_map_count(params.output_descriptor.feature_map_count()) - .set_layout(params.output_descriptor.layout()); + .set_feature_map_count( + params.config.output_descriptor.feature_map_count()) + .set_layout(params.config.output_descriptor.layout()); se::DeviceMemory side_input(params.fusion->side_input_buf); // If there is no side input, use output as the side input. if (side_input.is_null()) { - if (params.fusion->side_input_scale != 0) { + if (params.config.fusion->side_input_scale != 0) { return InternalError( "Side input scale is not 0, yet no side input buffer is " "provided"); @@ -123,15 +124,14 @@ Status RunGpuConvForwardActivation(GpuConvParams params, side_input = output_buf; } - stream->ThenFusedConvolveWithAlgorithm( - params.input_descriptor, input_buf, params.conv_result_scale, - params.filter_descriptor, filter_buf, params.conv_desc, side_input, - params.fusion->side_input_scale, bias_desc, - DeviceMemory(params.fusion->bias_buf), params.fusion->mode, - params.output_descriptor, &output_buf, scratch_allocator, algorithm, - options.profile_result); - - return Status::OK(); + return stream->FusedConvolveWithAlgorithm( + params.config.input_descriptor, input_buf, + params.config.conv_result_scale, params.config.filter_descriptor, + filter_buf, params.config.conv_desc, side_input, + params.config.fusion->side_input_scale, bias_desc, + DeviceMemory(params.fusion->bias_buf), + params.config.fusion->mode, params.config.output_descriptor, &output_buf, + scratch_allocator, algorithm, options.profile_result); } // StreamExecutor supports various data types via overloading, and the support @@ -152,31 +152,33 @@ Status RunGpuConvInternalImpl(GpuConvParams params, DeviceMemory filter_buf, DeviceMemory output_buf, AlgorithmConfig algorithm) { - switch (params.kind) { + switch (params.config.kind) { case CudnnConvKind::kForward: return RunGpuConvForward(params, scratch_allocator, stream, options, input_buf, filter_buf, output_buf, algorithm); case CudnnConvKind::kBackwardInput: - if (params.conv_result_scale != 1) { + if (params.config.conv_result_scale != 1) { return InternalError( "StreamExecutor doesn't support scaled convolution: %lf.", - params.conv_result_scale); + params.config.conv_result_scale); } - stream->ThenConvolveBackwardDataWithAlgorithm( - params.filter_descriptor, filter_buf, params.output_descriptor, - output_buf, params.conv_desc, params.input_descriptor, &input_buf, - scratch_allocator, algorithm, options.profile_result); + return stream->ConvolveBackwardDataWithAlgorithm( + params.config.filter_descriptor, filter_buf, + params.config.output_descriptor, output_buf, params.config.conv_desc, + params.config.input_descriptor, &input_buf, scratch_allocator, + algorithm, options.profile_result); break; case CudnnConvKind::kBackwardFilter: - if (params.conv_result_scale != 1) { + if (params.config.conv_result_scale != 1) { return InternalError( "StreamExecutor doesn't support scaled convolution: %lf.", - params.conv_result_scale); + params.config.conv_result_scale); } - stream->ThenConvolveBackwardFilterWithAlgorithm( - params.input_descriptor, input_buf, params.output_descriptor, - output_buf, params.conv_desc, params.filter_descriptor, &filter_buf, - scratch_allocator, algorithm, options.profile_result); + return stream->ConvolveBackwardFilterWithAlgorithm( + params.config.input_descriptor, input_buf, + params.config.output_descriptor, output_buf, params.config.conv_desc, + params.config.filter_descriptor, &filter_buf, scratch_allocator, + algorithm, options.profile_result); break; case CudnnConvKind::kForwardActivation: { return RunGpuConvForwardActivation( @@ -198,7 +200,7 @@ Status RunGpuConvInternalImpl(GpuConvParams params, DeviceMemory filter_buf, DeviceMemory output_buf, AlgorithmConfig algorithm) { - switch (params.kind) { + switch (params.config.kind) { case CudnnConvKind::kForward: return RunGpuConvForward(params, scratch_allocator, stream, options, input_buf, filter_buf, output_buf, algorithm); @@ -221,7 +223,7 @@ Status RunGpuConvImpl(const GpuConvParams& params, auto input_buf = se::DeviceMemory(params.input_buf); auto filter_buf = se::DeviceMemory(params.filter_buf); auto output_buf = se::DeviceMemory(params.output_buf); - AlgorithmConfig algorithm = params.algorithm; + AlgorithmConfig algorithm = params.config.algorithm; if (options.algo_override.has_value()) { algorithm = AlgorithmConfig(*options.algo_override); @@ -241,7 +243,8 @@ Status RunGpuConvImpl(const GpuConvParams& params, if (!stream->ok()) { return InternalError( "Unable to launch convolution with type %s and algorithm (%d, %s)", - CudnnConvKindToString(params.kind), algorithm.algorithm()->algo_id(), + CudnnConvKindToString(params.config.kind), + algorithm.algorithm()->algo_id(), algorithm.algorithm_no_scratch().has_value() ? absl::StrCat(algorithm.algorithm_no_scratch()->algo_id()) : "none"); @@ -251,95 +254,83 @@ Status RunGpuConvImpl(const GpuConvParams& params, } // anonymous namespace -StatusOr GetGpuConvParams( - const HloCustomCallInstruction* conv, - absl::Span operand_buffers, - se::DeviceMemoryBase result_buffer) { - GpuConvParams params; +StatusOr GetGpuConvConfig( + const HloCustomCallInstruction* cudnn_call) { + GpuConvConfig config; + + config.input_type = cudnn_call->operand(0)->shape().element_type(); + config.output_type = cudnn_call->shape().tuple_shapes(0).element_type(); TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config, - conv->backend_config()); - TF_ASSIGN_OR_RETURN(params.kind, GetCudnnConvKind(conv)); - const Shape* input_shape; - const Shape* filter_shape; - const Shape* output_shape; + cudnn_call->backend_config()); + TF_ASSIGN_OR_RETURN(config.kind, GetCudnnConvKind(cudnn_call)); // The third field is scratch size stored from conv_algorithm_picker // The operand is added to the shape field of the conv instruction // in GpuConvAlgorithmPicker::RunOnInstruction() call. - params.algorithm = se::dnn::AlgorithmConfig( + config.algorithm = se::dnn::AlgorithmConfig( se::dnn::AlgorithmDesc(backend_config.algorithm(), backend_config.tensor_ops_enabled()), - conv->shape().tuple_shapes(1).dimensions(0)); - params.conv_result_scale = backend_config.conv_result_scale(); + cudnn_call->shape().tuple_shapes(1).dimensions(0)); + config.conv_result_scale = backend_config.conv_result_scale(); - switch (params.kind) { + Shape operand0_shape = cudnn_call->operand(0)->shape(); + Shape operand1_shape = cudnn_call->operand(1)->shape(); + Shape result_shape = cudnn_call->shape().tuple_shapes(0); + + switch (config.kind) { case CudnnConvKind::kForward: - input_shape = &conv->operand(0)->shape(); - filter_shape = &conv->operand(1)->shape(); - output_shape = &conv->shape().tuple_shapes(0); - params.input_buf = operand_buffers[0]; - params.filter_buf = operand_buffers[1]; - params.output_buf = result_buffer; + case CudnnConvKind::kForwardActivation: + config.input_shape = operand0_shape; + config.filter_shape = operand1_shape; + config.output_shape = result_shape; break; case CudnnConvKind::kBackwardInput: - input_shape = &conv->shape().tuple_shapes(0); - filter_shape = &conv->operand(1)->shape(); - output_shape = &conv->operand(0)->shape(); - params.input_buf = result_buffer; - params.filter_buf = operand_buffers[1]; - params.output_buf = operand_buffers[0]; + config.input_shape = result_shape; + config.filter_shape = operand1_shape; + config.output_shape = operand0_shape; break; case CudnnConvKind::kBackwardFilter: - input_shape = &conv->operand(0)->shape(); - filter_shape = &conv->shape().tuple_shapes(0); - output_shape = &conv->operand(1)->shape(); - params.input_buf = operand_buffers[0]; - params.filter_buf = result_buffer; - params.output_buf = operand_buffers[1]; + config.input_shape = operand0_shape; + config.filter_shape = result_shape; + config.output_shape = operand1_shape; break; - case CudnnConvKind::kForwardActivation: { - input_shape = &conv->operand(0)->shape(); - filter_shape = &conv->operand(1)->shape(); - output_shape = &conv->shape().tuple_shapes(0); - params.fusion.emplace(); - GpuConvParams::FusionParams& fusion = *params.fusion; - if (!se::dnn::ActivationMode_IsValid(backend_config.activation_mode())) { - return InternalError("Bad activation mode: %s", - backend_config.ShortDebugString()); - } - fusion.mode = static_cast( - backend_config.activation_mode()); - fusion.side_input_scale = backend_config.side_input_scale(); - params.input_buf = operand_buffers[0]; - params.filter_buf = operand_buffers[1]; - params.output_buf = result_buffer; - params.fusion->bias_buf = operand_buffers[2]; - if (operand_buffers.size() >= 4) { - params.fusion->side_input_buf = operand_buffers[3]; - } - } + default: + return InternalError("Unknown convolution kind"); } - const Window& window = conv->window(); + if (config.kind == CudnnConvKind::kForwardActivation) { + config.fusion.emplace(); + GpuConvConfig::FusionConfig& fusion = *config.fusion; + if (!se::dnn::ActivationMode_IsValid(backend_config.activation_mode())) { + return InternalError("Bad activation mode: %s", + backend_config.ShortDebugString()); + } + fusion.mode = + static_cast(backend_config.activation_mode()); + fusion.side_input_scale = backend_config.side_input_scale(); + } + + const Window& window = cudnn_call->window(); const ConvolutionDimensionNumbers& dnums = - conv->convolution_dimension_numbers(); + cudnn_call->convolution_dimension_numbers(); VLOG(3) << "Convolution Algorithm: " - << params.algorithm.algorithm()->algo_id(); + << config.algorithm.algorithm()->algo_id(); VLOG(3) << "tensor_ops_enabled: " - << params.algorithm.algorithm()->tensor_ops_enabled(); - VLOG(3) << "Convolution kind: " << CudnnConvKindToString(params.kind); - VLOG(3) << "input shape: " << ShapeUtil::HumanStringWithLayout(*input_shape); + << config.algorithm.algorithm()->tensor_ops_enabled(); + VLOG(3) << "Convolution kind: " << CudnnConvKindToString(config.kind); + VLOG(3) << "input shape: " + << ShapeUtil::HumanStringWithLayout(config.input_shape); VLOG(3) << "filter shape: " - << ShapeUtil::HumanStringWithLayout(*filter_shape); + << ShapeUtil::HumanStringWithLayout(config.filter_shape); VLOG(3) << "Output shape: " - << ShapeUtil::HumanStringWithLayout(*output_shape); + << ShapeUtil::HumanStringWithLayout(config.output_shape); VLOG(3) << "Window: { " << window.ShortDebugString() << " }"; VLOG(3) << "Dim nums: { " << dnums.ShortDebugString() << " }"; const int num_dimensions = window.dimensions_size(); - CHECK_LE(num_dimensions, 3) << conv->ToString(); + CHECK_LE(num_dimensions, 3) << cudnn_call->ToString(); // cuDNN does not support 1D convolutions. We therefore express 1D // convolutions as 2D convolutions where the first spatial dimension is 1. @@ -353,18 +344,18 @@ StatusOr GetGpuConvParams( window.dimensions_size() > 0 && window.dimensions()[0].window_reversal(); CHECK_EQ(num_dimensions, dnums.input_spatial_dimensions_size()) - << conv->ToString(); + << cudnn_call->ToString(); CHECK_EQ(num_dimensions, dnums.kernel_spatial_dimensions_size()) - << conv->ToString(); + << cudnn_call->ToString(); CHECK_EQ(num_dimensions, dnums.output_spatial_dimensions_size()) - << conv->ToString(); + << cudnn_call->ToString(); for (const WindowDimension& dim : window.dimensions()) { - CHECK_EQ(dims_reversed, dim.window_reversal()) << conv->ToString(); - CHECK_EQ(dim.padding_low(), dim.padding_high()) << conv->ToString(); + CHECK_EQ(dims_reversed, dim.window_reversal()) << cudnn_call->ToString(); + CHECK_EQ(dim.padding_low(), dim.padding_high()) << cudnn_call->ToString(); CHECK_EQ(dim.base_dilation(), 1) << "cudnn does not support base dilation; it " "must be made explicit with a kPad: " - << conv->ToString(); + << cudnn_call->ToString(); } // cuDNN's convolution APIs support the BDYX layout for activations/output and @@ -373,12 +364,16 @@ StatusOr GetGpuConvParams( FilterLayout filter_dl; DataLayout output_dl; + const Shape* input_shape = &config.input_shape; + const Shape* filter_shape = &config.filter_shape; + const Shape* output_shape = &config.output_shape; + TF_ASSIGN_OR_RETURN(std::tie(input_dl, filter_dl, output_dl), XlaConvLayoutsToStreamExecutorLayouts( dnums, input_shape->layout(), filter_shape->layout(), output_shape->layout())); - BatchDescriptor& input_descriptor = params.input_descriptor; + BatchDescriptor& input_descriptor = config.input_descriptor; input_descriptor = BatchDescriptor(effective_num_dimensions); input_descriptor.set_layout(input_dl) .set_feature_map_count( @@ -391,7 +386,7 @@ StatusOr GetGpuConvParams( input_shape->dimensions(dnums.input_spatial_dimensions(dim))); } - FilterDescriptor& filter_descriptor = params.filter_descriptor; + FilterDescriptor& filter_descriptor = config.filter_descriptor; filter_descriptor = FilterDescriptor(effective_num_dimensions); filter_descriptor.set_layout(filter_dl) .set_input_feature_map_count( @@ -404,11 +399,11 @@ StatusOr GetGpuConvParams( filter_shape->dimensions(dnums.kernel_spatial_dimensions(dim))); } - params.conv_desc = ConvolutionDescriptor(effective_num_dimensions); - params.conv_desc.set_group_count(conv->feature_group_count()); - params.conv_desc.set_convolution_not_crosscorr(dims_reversed); + config.conv_desc = ConvolutionDescriptor(effective_num_dimensions); + config.conv_desc.set_group_count(cudnn_call->feature_group_count()); + config.conv_desc.set_convolution_not_crosscorr(dims_reversed); for (int dim = 0; dim < num_dimensions; ++dim) { - params.conv_desc + config.conv_desc .set_zero_padding( static_cast(effective_num_dimensions - dim - 1), window.dimensions(dim).padding_low()) @@ -420,7 +415,7 @@ StatusOr GetGpuConvParams( window.dimensions(dim).window_dilation()); } - BatchDescriptor& output_descriptor = params.output_descriptor; + BatchDescriptor& output_descriptor = config.output_descriptor; output_descriptor = BatchDescriptor(effective_num_dimensions); output_descriptor.set_layout(output_dl) .set_feature_map_count( @@ -437,32 +432,70 @@ StatusOr GetGpuConvParams( input_descriptor.set_spatial_dim(static_cast(dim), 1); output_descriptor.set_spatial_dim(static_cast(dim), 1); filter_descriptor.set_spatial_dim(static_cast(dim), 1); - params.conv_desc.set_zero_padding(static_cast(dim), 0) + config.conv_desc.set_zero_padding(static_cast(dim), 0) .set_filter_stride(static_cast(dim), 1); } + return config; +} + +StatusOr GetGpuConvParams( + const GpuConvConfig& config, + absl::Span operand_buffers, + se::DeviceMemoryBase result_buffer) { + GpuConvParams params; + params.config = config; + + switch (config.kind) { + case CudnnConvKind::kForward: + case CudnnConvKind::kForwardActivation: + params.input_buf = operand_buffers[0]; + params.filter_buf = operand_buffers[1]; + params.output_buf = result_buffer; + break; + case CudnnConvKind::kBackwardInput: + params.input_buf = result_buffer; + params.filter_buf = operand_buffers[1]; + params.output_buf = operand_buffers[0]; + break; + case CudnnConvKind::kBackwardFilter: + params.input_buf = operand_buffers[0]; + params.filter_buf = result_buffer; + params.output_buf = operand_buffers[1]; + break; + } + + if (config.kind == CudnnConvKind::kForwardActivation) { + params.fusion.emplace(); + GpuConvParams::FusionParams& fusion = *params.fusion; + fusion.bias_buf = operand_buffers[2]; + if (operand_buffers.size() >= 4) { + fusion.side_input_buf = operand_buffers[3]; + } + } + return params; } -Status RunGpuConv(const HloCustomCallInstruction* conv, +Status RunGpuConv(const gpu::GpuConvConfig& config, absl::Span operand_buffers, se::DeviceMemoryBase result_buffer, se::DeviceMemoryBase scratch_buf, se::Stream* stream, RunConvOptions options) { ScratchBufAllocator scratch_allocator(scratch_buf); - return RunGpuConv(conv, operand_buffers, result_buffer, &scratch_allocator, + return RunGpuConv(config, operand_buffers, result_buffer, &scratch_allocator, stream, options); } -Status RunGpuConv(const HloCustomCallInstruction* conv, +Status RunGpuConv(const gpu::GpuConvConfig& config, absl::Span operand_buffers, se::DeviceMemoryBase result_buffer, se::ScratchAllocator* scratch_allocator, se::Stream* stream, RunConvOptions options) { TF_ASSIGN_OR_RETURN(GpuConvParams params, - GetGpuConvParams(conv, operand_buffers, result_buffer)); + GetGpuConvParams(config, operand_buffers, result_buffer)); - PrimitiveType input_primitive_type = conv->operand(0)->shape().element_type(); + PrimitiveType input_primitive_type = config.input_type; switch (input_primitive_type) { case F16: return RunGpuConvImpl( @@ -474,8 +507,7 @@ Status RunGpuConv(const HloCustomCallInstruction* conv, return RunGpuConvImpl(params, scratch_allocator, stream, options); case S8: { - PrimitiveType output_primitive_type = - conv->shape().tuple_shapes(0).element_type(); + PrimitiveType output_primitive_type = config.output_type; switch (output_primitive_type) { case F32: return RunGpuConvImpl(params, scratch_allocator, @@ -484,12 +516,11 @@ Status RunGpuConv(const HloCustomCallInstruction* conv, return RunGpuConvImpl(params, scratch_allocator, stream, options); default: - return Unimplemented("Unimplemented convolution %s", - conv->ToString()); + return Unimplemented("Unimplemented convolution"); } } default: - return Unimplemented("Unimplemented convolution %s", conv->ToString()); + return Unimplemented("Unimplemented convolution"); } } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h b/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h index 3b8ce0f0f1c..5d27e6d6da7 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_CONV_RUNNER_H_ #include "absl/types/optional.h" +#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" @@ -25,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/stream_executor/dnn.h" namespace xla { namespace gpu { @@ -40,10 +42,10 @@ struct RunConvOptions { absl::optional scratch_size_override; }; -// Implementation struct exposed for debugging and log analysis. -struct GpuConvParams { - // Here are the fields related to cuDNN's fused convolution. The result thus - // is defined as: +// Structure to describe static properties of a GPU convolution. +struct GpuConvConfig { + // Field related to cuDNN's fused convolution are in FusionConfig & + // FusionParams structures. The result thus is defined as: // activation(conv_result_scale * conv(x, w) + // side_input_scale * side_input + broadcast(bias)) // @@ -54,23 +56,39 @@ struct GpuConvParams { // added to the final results. // // side_input_buf, if valid, must have the same shape as the output buffer. - struct FusionParams { + struct FusionConfig { se::dnn::ActivationMode mode; double side_input_scale; + }; + + PrimitiveType input_type; + PrimitiveType output_type; + CudnnConvKind kind; + se::dnn::AlgorithmConfig algorithm; + double conv_result_scale; + + se::dnn::BatchDescriptor input_descriptor; + se::dnn::FilterDescriptor filter_descriptor; + se::dnn::BatchDescriptor output_descriptor; + se::dnn::ConvolutionDescriptor conv_desc; + + Shape input_shape; + Shape filter_shape; + Shape output_shape; + absl::optional fusion; +}; + +// Implementation struct exposed for debugging and log analysis. +struct GpuConvParams { + GpuConvConfig config; + struct FusionParams { se::DeviceMemoryBase bias_buf; se::DeviceMemoryBase side_input_buf; // nullable }; - CudnnConvKind kind; - se::dnn::BatchDescriptor input_descriptor; - se::dnn::FilterDescriptor filter_descriptor; - se::dnn::BatchDescriptor output_descriptor; se::DeviceMemoryBase input_buf; se::DeviceMemoryBase filter_buf; se::DeviceMemoryBase output_buf; - se::dnn::ConvolutionDescriptor conv_desc; - se::dnn::AlgorithmConfig algorithm; - double conv_result_scale; absl::optional fusion; }; @@ -89,21 +107,24 @@ struct GpuConvParams { // allocator and take note of how much memory is used. The next time you call // the same conv, you can provide an explicitly preallocated scratch buffer of // that size, if you like. -Status RunGpuConv(const HloCustomCallInstruction* conv, +Status RunGpuConv(const GpuConvConfig& conv_config, absl::Span operand_buffers, se::DeviceMemoryBase result_buffer, se::DeviceMemoryBase scratch_buf, se::Stream* stream, RunConvOptions = {}); -Status RunGpuConv(const HloCustomCallInstruction* conv, +Status RunGpuConv(const GpuConvConfig& conv_config, absl::Span operand_buffers, se::DeviceMemoryBase result_buffer, se::ScratchAllocator* scratch_allocator, se::Stream* stream, RunConvOptions = {}); +StatusOr GetGpuConvConfig( + const HloCustomCallInstruction* cudnn_call); + // Implementation details exposed for debugging and log analysis. StatusOr GetGpuConvParams( - const HloCustomCallInstruction* conv, + const GpuConvConfig& conv_config, absl::Span operand_buffers, se::DeviceMemoryBase result_buffer); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_device_info.h b/tensorflow/compiler/xla/service/gpu/gpu_device_info.h index 7352bad1a66..afb773c4527 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_device_info.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_device_info.h @@ -32,6 +32,8 @@ struct GpuDeviceInfo { int threads_per_block_limit; int threads_per_warp; int shared_memory_per_block; + int threads_per_core_limit; + int core_count; }; } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index 726f1963545..1a0d1e0beb6 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -60,14 +60,16 @@ GpuExecutable::GpuExecutable( std::shared_ptr hlo_module, std::shared_ptr assignment, std::unique_ptr hlo_profile_printer_data, - std::unique_ptr hlo_profile_index_map) + std::unique_ptr hlo_profile_index_map, + std::vector globals) : Executable(std::move(hlo_module), std::move(hlo_profile_printer_data), std::move(hlo_profile_index_map)), text_(text), binary_(binary), gpu_version_(gpu_version), thunk_schedule_(std::move(thunk_schedule)), - assignment_(std::move(assignment)) { + assignment_(std::move(assignment)), + constants_(std::move(globals)) { CHECK(has_module() && assignment_); GpuDebugInfoManager::Get()->RegisterModule(module().name(), shared_module(), assignment_); @@ -280,28 +282,23 @@ GpuExecutable::ResolveConstantGlobals(se::Stream* stream) { se::ModuleHandle module_handle; TF_RETURN_IF_ERROR(executor->LoadModule(module_spec, &module_handle)); - for (BufferAllocation::Index i = 0; i < assignment_->Allocations().size(); - ++i) { - const BufferAllocation& allocation = assignment_->GetAllocation(i); - if (allocation.is_constant()) { - TF_ASSIGN_OR_RETURN( - se::DeviceMemoryBase global, - executor->GetUntypedSymbol( - llvm_ir::ConstantBufferAllocationToGlobalName(allocation), - module_handle)); - VLOG(3) << "Resolved global " - << llvm_ir::ConstantBufferAllocationToGlobalName(allocation) - << " to " << global.opaque(); - InsertOrDie(&globals, i, global); + for (const auto& info : constants_) { + const Literal& literal = info.content; - const Literal& literal = - llvm_ir::LiteralForConstantAllocation(allocation); - CHECK(literal.shape().IsArray()); - if (!ShouldEmitLiteralInLlvmIr(literal)) { - VLOG(3) << "H2D memcpy for constant with shape " - << ShapeUtil::HumanString(literal.shape()); - stream->ThenMemcpy(&global, literal.untyped_data(), allocation.size()); - } + TF_ASSIGN_OR_RETURN(auto global, executor->GetUntypedSymbol( + info.symbol_name, module_handle)); + VLOG(3) << "Resolved global " << info.symbol_name << " to " + << global.opaque(); + + CHECK(literal.shape().IsArray()); + if (!ShouldEmitLiteralInLlvmIr(literal)) { + VLOG(3) << "H2D memcpy for constant with shape " + << ShapeUtil::HumanString(literal.shape()); + stream->ThenMemcpy(&global, literal.untyped_data(), literal.size_bytes()); + } + + if (info.allocation_index != -1) { + InsertOrDie(&globals, info.allocation_index, global); } } @@ -334,7 +331,11 @@ StatusOr GpuExecutable::BufferForAllocation( } return registered_buffer; } else if (allocation.is_constant()) { - return FindOrDie(*globals, arg_idx); + auto it = globals->find(arg_idx); + if (it == globals->end()) { + return se::DeviceMemoryBase(); + } + return it->second; } else { // Allocate each allocation that might escape, or is the temp buffer. CHECK(allocation.maybe_live_out() || allocation.IsPreallocatedTempBuffer()); @@ -449,8 +450,7 @@ StatusOr GpuExecutable::ExecuteAsyncOnStream( HloInstruction* root = hlo_module_->entry_computation()->root_instruction(); const Shape& root_shape = root->shape(); auto device_ordinal = executor->device_ordinal(); - ExecutionOutput result(/*on_host_shape=*/root->shape(), - /*on_device_shape=*/root->shape(), memory_allocator, + ExecutionOutput result(/*on_device_shape=*/root->shape(), memory_allocator, device_ordinal); TF_ASSIGN_OR_RETURN(BufferAllocations buffer_allocations, diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h index 516fa9b269a..613880fd44b 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h @@ -49,6 +49,12 @@ namespace gpu { // This is an immutable data type after initialization, and thus thread safe. class GpuExecutable : public Executable { public: + struct ConstantInfo { + std::string symbol_name; + xla::Literal content; + int allocation_index = -1; + }; + // We need to share ownership of hlo_module and assignment with profiler to // safely keep a reference to these objects during tracing period, thus they // are passed as shared pointers. @@ -58,7 +64,8 @@ class GpuExecutable : public Executable { std::shared_ptr hlo_module, std::shared_ptr assignment, std::unique_ptr hlo_profile_printer_data, - std::unique_ptr hlo_profile_index_map); + std::unique_ptr hlo_profile_index_map, + std::vector constants); ~GpuExecutable() override; int64 SizeOfGeneratedCodeInBytes() const override; @@ -169,6 +176,8 @@ class GpuExecutable : public Executable { std::map module_globals_ TF_GUARDED_BY(module_handle_mutex_); + std::vector constants_; + TF_DISALLOW_COPY_AND_ASSIGN(GpuExecutable); }; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc index bb4184ff76f..b69b32c17c5 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc @@ -143,29 +143,27 @@ bool IsInputFusibleReduction(const HloInstruction& instr) { IsReductionFromOrToContiguousDimensions(instr); } +const HloInstruction* GetRealHeroForMultiOutputFusion( + const HloInstruction& instr) { + if (instr.opcode() != HloOpcode::kFusion) { + return &instr; + } + auto fused_expression_root = instr.fused_expression_root(); + if (!instr.IsMultiOutputFusion()) { + return fused_expression_root; + } + // If possible, we want to pick a reduction-from-or-to-contiguous-dims + // operand of the fusion root, because it has the most constraints. + for (const auto* inst : fused_expression_root->operands()) { + if (IsReductionFromOrToContiguousDimensions(*inst)) { + return inst; + } + } + return fused_expression_root->operands()[0]; +} + bool ShapesCompatibleForMultiOutputFusion(const HloInstruction& instr1, const HloInstruction& instr2) { - // Returns the instructions that determines the emitter used for lowering, - // sometimes referred to as "the real hero". - auto get_real_hero = - [&](const HloInstruction* instr) -> const HloInstruction* { - if (instr->opcode() != HloOpcode::kFusion) { - return instr; - } - auto fused_expression_root = instr->fused_expression_root(); - if (!instr->IsMultiOutputFusion()) { - return fused_expression_root; - } - // If possible, we want to pick a reduction-to-vector operand of the - // fusion root, because it has the most constraints. - for (const auto* inst : fused_expression_root->operands()) { - if (IsReductionFromOrToContiguousDimensions(*inst)) { - return inst; - } - } - return fused_expression_root->operands()[0]; - }; - // Multi-output fusion kernels share a common parallel loop. The loop // dimensions are determined by instruction shapes. auto get_loop_shape = [&](const HloInstruction* element_instr) { @@ -181,8 +179,8 @@ bool ShapesCompatibleForMultiOutputFusion(const HloInstruction& instr1, // root ops should have equal output shapes. An exception are // reduction-to-vector ops. Here the input shapes of the reduction (first // operand shape) and the reduction dimensions need to match. - auto* instr_1 = get_real_hero(&instr1); - auto* instr_2 = get_real_hero(&instr2); + auto* instr_1 = GetRealHeroForMultiOutputFusion(instr1); + auto* instr_2 = GetRealHeroForMultiOutputFusion(instr2); if (IsReductionFromOrToContiguousDimensions(*instr_1) && IsReductionFromOrToContiguousDimensions(*instr_2) && !AreFusedReductionOutputsConsistent({instr_1, instr_2}, instr_1)) { @@ -240,28 +238,42 @@ bool IsLoopFusible(const HloInstruction& instr) { instr.opcode() == HloOpcode::kTranspose); } -bool IsFusible(const HloInstruction& instr) { - return IsInputFusible(instr) || IsLoopFusible(instr); -} - bool IsProducerConsumerFusible(const HloInstruction& producer, const HloInstruction& consumer) { - if (!IsLoopFusible(producer) || !IsFusible(consumer)) { + if (!IsLoopFusible(producer)) { + VLOG(5) << "Producer " << producer.name() << " is not loop-fusible"; return false; } + + if (!IsInputFusible(consumer) && !IsLoopFusible(consumer)) { + VLOG(5) << "Consumer " << consumer.name() + << "is not input-fusible and not loop-fusible"; + return false; + } + // Skip multiple output fusion. It's not yet supported. if (producer.IsMultiOutputFusion()) { + VLOG(5) << "Producer " << producer.name() + << " is not fusible as it is a multi-output fusion"; return false; } + if (CreatesNestedLoop(producer, consumer)) { + VLOG(5) << "Fusing " << producer.name() << " into " << consumer.name() + << " creates nested loop"; return false; } + // Do not fuse into reduce input fusions if the resulting kernel would suffer // from poor data locality (due to unfriendly input layouts). if (IsInputFusibleReduction(consumer) && !LayoutsAreReduceInputFusionFriendly(producer, consumer)) { + VLOG(5) << "Layout of " << producer.name() + << " is not fusion-friendly for consumer reduction " + << consumer.name(); return false; } + // Fuse scalar constants into loop fusion nodes. This reduces the number of // parameters and makes matching scalar broadcasts easier. // @@ -270,10 +282,14 @@ bool IsProducerConsumerFusible(const HloInstruction& producer, // but fused constants are handled by shrared CPU/GPU code and always emitted // in the IR/PTX. The external constant representation makes for faster // compiles and significantly smaller assembly code. - if (producer.opcode() == HloOpcode::kConstant) { - return ShapeUtil::IsEffectiveScalar(producer.shape()) && - consumer.opcode() == HloOpcode::kFusion; + if (producer.opcode() == HloOpcode::kConstant && + (!ShapeUtil::IsEffectiveScalar(producer.shape()) || + consumer.opcode() != HloOpcode::kFusion)) { + VLOG(5) << "Not fusing constant " << producer.name() << " into " + << consumer.name(); + return false; } + return true; } @@ -347,8 +363,13 @@ static int64 SharedMemoryUsage(const HloInstruction& instr) { // This limit is also often good for performance. In a fusion with many // operands, each GPU thread likely has to do a lot of work, and so possibly // uses a lot of registers, thus limiting occupancy. +// +// If the fusion is a producer/consumer fusion and instr1 is the +// consumer and instr2 is the producer, set is_consumer_producer_fusion +// to true to enable more fusion. bool FusionWouldBeTooLarge(const HloInstruction& instr1, - const HloInstruction& instr2) { + const HloInstruction& instr2, + bool is_consumer_producer_fusion) { if (SharedMemoryUsage(instr1) + SharedMemoryUsage(instr2) > kSharedMemoryBudgetInBytes) { VLOG(5) << "Shared memory usage of fusion of " << instr1.ToString() @@ -404,6 +425,17 @@ bool FusionWouldBeTooLarge(const HloInstruction& instr1, // producer -> consumer relationship. operands.erase(&instr1); operands.erase(&instr2); + + // If we generate the same numbers of inputs and outputs as + // before, it won't be bigger after fusion. So accept the fusion. + // As this is a consumer_producer fusion, this does not change the + // consumer numbers of output. So no need to check it. + if (is_consumer_producer_fusion && + operands.size() <= instr1.operands().size()) { + return false; + } + + // Does the new fusion have more operands and outputs than the max? return operands.size() + num_output_buffers > kMaxOperandsAndOutputsPerFusion; } @@ -490,5 +522,24 @@ HloInstruction::FusionKind ChooseFusionKind(const HloInstruction& /*producer*/, : HloInstruction::FusionKind::kLoop; } +bool IsConsumerTheOnlyNonRootUser(const HloInstruction& instr, + const HloInstruction& consumer) { + return absl::c_all_of(instr.users(), [&](const HloInstruction* user) { + if (user->opcode() == HloOpcode::kGetTupleElement) { + // Skip GTE. + return IsConsumerTheOnlyNonRootUser(*user, consumer); + } + if (user == &consumer) { + // `user` is `consumer`. + return true; + } + if (user == user->parent()->root_instruction()) { + // Consumed by ROOT. + return true; + } + return false; + }); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible.h b/tensorflow/compiler/xla/service/gpu/gpu_fusible.h index e2a42ecb0a3..9fa098a3394 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_fusible.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible.h @@ -26,11 +26,6 @@ namespace gpu { constexpr int64 kMaxOperandsAndOutputsPerFusion = 64; -// Whether 'instr' can occur inside fusions, i.e. whether it is a candidate -// for being fused. Note that further restrictions apply, e.g. Scatter must -// be the root of an input fusion. -bool IsFusible(const HloInstruction& instr); - bool IsInputFusible(const HloInstruction& instr); bool IsLoopFusible(const HloInstruction& instr); @@ -64,14 +59,23 @@ bool IsInputFusibleScatter(const HloInstruction& instr); // Determines whether the combination of `instr1` and `instr2` into a (possibly // multi-output) fusion would be "too large" -- i.e., have more operands and // outputs than is allowed or occupy too much shared memory. +// If the fusion is a producer/consumer fusion and instr1 is the +// consumer and instr2 is the producer, set consumer_producer_fusion +// to true to enable more fusion. bool FusionWouldBeTooLarge(const HloInstruction& instr1, - const HloInstruction& instr2); + const HloInstruction& instr2, + bool is_consumer_producer_fusion = false); // Check if fusing producer and consumer will generate a nested loop, e.g. both // producer and consumer are `reduce-window` HLO instructions. bool CreatesNestedLoop(const HloInstruction& producer, const HloInstruction& consumer); +// Returns the instruction that determines the emitter used for lowering, +// sometimes referred to as "the real hero". +const HloInstruction* GetRealHeroForMultiOutputFusion( + const HloInstruction& instr); + // Whether instruction shapes are compatible for multi-output fusion, i.e. // whether the emitters support lowering the resulting fusion. // This function works for both, sibling and producer-consumer multi-output @@ -101,6 +105,10 @@ bool IsFusibleAsMultiOutputFusionRoot(const HloInstruction& instr); HloInstruction::FusionKind ChooseFusionKind(const HloInstruction& producer, const HloInstruction& consumer); +// Returns whether `consumer` is the only non-root user of `instr`. +bool IsConsumerTheOnlyNonRootUser(const HloInstruction& instr, + const HloInstruction& consumer); + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc index 1f83ec71984..e73c4885e9e 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc @@ -90,6 +90,15 @@ void HloExecutionProfiler::FinishHloComputation( } } +void HloExecutionProfiler::FinishHloComputation( + absl::optional profile_index) { + if (do_profile_) { + profile_->SetCyclesTakenBy( + profile_index.value(), + GetCyclesTaken(&timers_, sub_streams_, stream_, clock_rate_ghz_)); + } +} + void HloExecutionProfiler::StartHloInstruction() { if (do_profile_) { InitAndStartTimer(&timers_, stream_); diff --git a/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h index 1189143e3f9..860fa167790 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h +++ b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h @@ -53,6 +53,11 @@ class HloExecutionProfiler { // the time that the computation took to execute in the profile. void FinishHloComputation(const HloComputation* computation); + // If profiling is enabled stops the timer for a (sub)computation with the + // given profile index and records the time that the computation took to + // execute in the profile. + void FinishHloComputation(absl::optional profile_index); + // If profiling is enabled, starts a per-operation timer. void StartHloInstruction(); diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc index 332db83b6ad..26a22005dae 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc @@ -83,6 +83,8 @@ void HloToIrBindings::EmitBasePointersForHlos( if (non_io_hlo->opcode() == HloOpcode::kConstant) { llvm::Value* global_for_constant = module_->getGlobalVariable( llvm_ir::ConstantHloToGlobalName(*non_io_hlo)); + CHECK(global_for_constant) + << llvm_ir::ConstantHloToGlobalName(*non_io_hlo); BindHloToIrValue(*non_io_hlo, global_for_constant); } else { llvm::Type* pointee_type = diff --git a/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.cc b/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.cc new file mode 100644 index 00000000000..9287f9a92b7 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.cc @@ -0,0 +1,167 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.h" + +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_creation_utils.h" +#include "tensorflow/core/platform/errors.h" + +namespace xla { +namespace gpu { + +namespace { + +// Gets the representative input shape of the multi-output fusion. +Shape GetInputShapeForMultiOutputFusion(const HloInstruction& instr) { + // Get the HLO that determines the emitter used for lowering. + const HloInstruction* real_hero = GetRealHeroForMultiOutputFusion(instr); + if (real_hero->operands().empty()) { + // Simply return an empty shape if the representative node has no input + // operands. + return Shape(); + } else { + return real_hero->operand(0)->shape(); + } +} + +class HorizontalInputFusionImpl { + public: + explicit HorizontalInputFusionImpl(HloComputation* computation) + : computation_(computation) {} + + ~HorizontalInputFusionImpl() {} + + StatusOr Run(); + + private: + HloComputation* computation_; +}; // HorizontalInputFusionImpl + +// Compares one-by-one the dimensions of `shape_a` and `shape_b` from left to +// right. +bool CompareShapeDimsFromLeftToRight(const Shape& shape_a, + const Shape& shape_b) { + if (shape_a.rank() != shape_b.rank()) { + return shape_a.rank() < shape_b.rank(); + } + auto dims_a = shape_a.dimensions(); + auto dims_b = shape_b.dimensions(); + for (size_t i = 0; i < dims_a.size(); ++i) { + if (dims_a[i] != dims_b[i]) { + return dims_a[i] < dims_b[i]; + } + } + return true; +} + +std::vector FindAndSortFusionCandidates( + HloInstruction* consumer) { + absl::flat_hash_set fusion_instr_set; + for (auto opnd : consumer->operands()) { + HloInstruction* predecessor = opnd->LatestNonGteAncestor(); + // Find out the input fusion instructions whose only consumer is `consumer`. + // This guarantees that fusing these candidates will never create cycles, as + // there is no back edge. + if (IsReduceInputFusion(*predecessor) && + IsConsumerTheOnlyNonRootUser(*predecessor, *consumer)) { + fusion_instr_set.insert(predecessor); + } + } + + std::vector fusion_instrs; + fusion_instrs.insert(fusion_instrs.end(), fusion_instr_set.begin(), + fusion_instr_set.end()); + + std::sort(fusion_instrs.begin(), fusion_instrs.end(), + [&](const HloInstruction* a, const HloInstruction* b) { + Shape shape_a = GetInputShapeForMultiOutputFusion(*a); + Shape shape_b = GetInputShapeForMultiOutputFusion(*b); + if (!ShapeUtil::EqualIgnoringElementType(shape_a, shape_b)) { + // Sort shapes according to dimensions, so that the same input + // shapes will be placed adjacent each other. + return CompareShapeDimsFromLeftToRight(shape_a, shape_b); + } + // Sort `fusion_instrs` according to instruction counts, because + // we'd like to fuse together computations of similar sizes. + return a->fused_instruction_count() < + b->fused_instruction_count(); + }); + + return fusion_instrs; +} + +StatusOr HorizontalInputFusionImpl::Run() { + bool changed = false; + XLA_VLOG_LINES(3, computation_->ToString()); + + // Using def-to-use order is sound since we do not modify users. + std::vector def_to_use_order = + computation_->MakeInstructionPostOrder(); + for (auto consumer : def_to_use_order) { + auto candidates = FindAndSortFusionCandidates(consumer); + if (candidates.empty()) { + continue; + } + + size_t fusion_anchor_id = 0; + for (size_t j = 1; j < candidates.size(); ++j) { + HloInstruction* fusion_anchor = candidates[fusion_anchor_id]; + HloInstruction* fused = candidates[j]; + if (ShapesCompatibleForMultiOutputFusion(*fusion_anchor, *fused) && + !FusionWouldBeTooLarge(*fusion_anchor, *fused)) { + VLOG(3) << "Fuse " << fused->ToString() << " into " + << fusion_anchor->ToString(); + fusion_anchor->MergeFusionInstructionIntoMultiOutput(fused); + changed = true; + } else { + // Update the `fusion_anchor_id` since `fused` is either not + // compatible or not beneficial to be fused with current fusion anchor. + VLOG(3) << j - fusion_anchor_id - 1 << " instructions are fused."; + fusion_anchor_id = j; + } + } + } + + return changed; +} + +} // namespace + +StatusOr GpuHorizontalInputFusion::RunOnComputation( + HloComputation* computation) { + HorizontalInputFusionImpl horizontal_fusion_impl(computation); + return horizontal_fusion_impl.Run(); +} + +StatusOr GpuHorizontalInputFusion::Run(HloModule* module) { + bool changed = false; + VLOG(2) << "Run horizontal input fusion."; + for (auto* comp : module->MakeNonfusionComputations()) { + TF_ASSIGN_OR_RETURN(changed, RunOnComputation(comp)); + } + + return changed; +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.h b/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.h new file mode 100644 index 00000000000..85313d03412 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.h @@ -0,0 +1,57 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HORIZONTAL_INPUT_FUSION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HORIZONTAL_INPUT_FUSION_H_ + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/core/platform/macros.h" + +namespace xla { +namespace gpu { + +// This optimization pass horizontally fuses kInput fusions to both reduce the +// kernel launch overhead and increase parallelism degree. See +// GpuHorizontalFusion for general description and motivation about horizontal +// fusion. GpuHorizontalFusion deals with kLoop fusions while this pass deals +// with kInput fusions. +// +// Following GpuHorizontalFusion, a simple yet effective heuristic is used +// to search the fusion candidates while avoiding creating cycles. That is, +// we simply search for fusion candidates by looking for instructions whose +// outputs are all consumed by the same instruction. This catches the typical +// target cases; often, the candidate instructions are just consumed by the +// ROOT tuple of the entry computation. +class GpuHorizontalInputFusion : public HloModulePass { + public: + GpuHorizontalInputFusion() {} + + absl::string_view name() const override { + return "gpu_horizontal_input_fusion"; + } + + StatusOr Run(HloModule* module) override; + + private: + StatusOr RunOnComputation(HloComputation*); +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HORIZONTAL_INPUT_FUSION_H_ diff --git a/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion_test.cc new file mode 100644 index 00000000000..8ecfbb5a8d2 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion_test.cc @@ -0,0 +1,216 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.h" + +#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/filecheck.h" + +namespace xla { +namespace gpu { +namespace { + +namespace op = xla::testing::opcode_matchers; + +class HorizontalInputFusionTest : public GpuCodegenTest {}; + +TEST_F(HorizontalInputFusionTest, BasicTest) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule BasicTest + + %add_f16 { + %x = f16[] parameter(0) + %y = f16[] parameter(1) + ROOT %add = f16[] add(%x, %y) + } + + fused_computation.1 { + arg.1 = f16[1024]{0} parameter(0) + constant0 = f16[] constant(0) + ROOT reduce1 = f16[] reduce(arg.1, constant0), dimensions={0}, to_apply=%add_f16 + } + + fused_computation.2 { + arg.1 = f16[1024]{0} parameter(0) + constant0 = f16[] constant(0) + ROOT reduce1 = f16[] reduce(arg.1, constant0), dimensions={0}, to_apply=%add_f16 + } + + ENTRY entry_computation { + arg.1 = f16[1024]{0} parameter(0) + arg.2 = f16[1024]{0} parameter(1) + fusion.1 = f16[] fusion(arg.1), kind=kInput, calls=fused_computation.1 + fusion.2 = f16[] fusion(arg.2), kind=kInput, calls=fused_computation.2 + ROOT tuple.1 = (f16[], f16[]) tuple(fusion.1, fusion.2) + } +)") + .ValueOrDie(); + + EXPECT_TRUE(GpuHorizontalInputFusion().Run(module.get()).ValueOrDie()); + + const HloInstruction* entry_root = + module->entry_computation()->root_instruction(); + EXPECT_THAT(entry_root, op::Tuple((op::GetTupleElement(op::Fusion())), + (op::GetTupleElement(op::Fusion())))); + + const HloInstruction* fusion = entry_root->operand(0)->operand(0); + ASSERT_TRUE(fusion->IsMultiOutputFusion()); + EXPECT_THAT(fusion->fused_expression_root(), + op::Tuple(op::Reduce(), op::Reduce())); +} + +TEST_F(HorizontalInputFusionTest, ManyInputFusions) { + auto module = CreateNewVerifiedModule(); + + HloComputation* reduce_computation; + { + auto embedded_builder = HloComputation::Builder("add"); + auto lhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {}), "lhs")); + auto rhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {}), "rhs")); + embedded_builder.AddInstruction( + HloInstruction::CreateBinary(lhs->shape(), HloOpcode::kAdd, lhs, rhs)); + reduce_computation = + module->AddEmbeddedComputation(embedded_builder.Build()); + } + + HloComputation::Builder builder(TestName()); + std::vector var_outs; + auto input_shape = ShapeUtil::MakeShape(F32, {1024, 1024}); + auto output_shape = ShapeUtil::MakeShape(F32, {1024}); + for (int64 i = 0; i < 130; ++i) { + // %fused_computation.3 (param_0: f32[1024,1024], param_1: f32[]) -> + // f32[1024] { + // %param_0 = f32[1024,1024]{1,0} parameter(0) + // %param_1 = f32[] parameter(1) + // %broadcast = f32[1024,1024]{1,0} broadcast(f32[] %param_1), + // dimensions={} + // %multiply = f32[1024,1024]{1,0} + // multiply(f32[1024,1024]{1,0} %param_0, f32[1024,1024]{1,0} + // %broadcast) + // %constant0 = f32[] constant(0) + // ROOT %reduce = f32[1024]{0} + // reduce(f32[1024,1024]{1,0} %multiply, f32[] %constant0), + // dimensions={1}, to_apply=%add + // } + HloInstruction* param_var_in = builder.AddInstruction( + HloInstruction::CreateParameter(i * 2 + 0, input_shape, "var.in")); + HloInstruction* param_alpha = + builder.AddInstruction(HloInstruction::CreateParameter( + i * 2 + 1, ShapeUtil::MakeShape(F32, {}), "alpha")); + auto alpha_broadcasted = builder.AddInstruction( + HloInstruction::CreateBroadcast(input_shape, param_alpha, {})); + auto mul = builder.AddInstruction(HloInstruction::CreateBinary( + input_shape, HloOpcode::kMultiply, param_var_in, alpha_broadcasted)); + HloInstruction* const0 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + auto reduce = builder.AddInstruction(HloInstruction::CreateReduce( + output_shape, mul, const0, {1}, reduce_computation)); + var_outs.push_back(reduce); + } + builder.AddInstruction(HloInstruction::CreateTuple(var_outs)); + module->AddEntryComputation(builder.Build()); + + // Verify that horizontal fusion is kicked in. Check that there are multiple + // `reduce` instructions fused into the same fusion. 6 is just a randomly + // picked number as we don't exactly know how large the fusion will be + // created due to the `FusionWouldBeTooLarge` constraint. + CompileAndVerifyIr(module->Clone(), R"(CHECK: reduce-group-6)", + /*match_optimized_ir=*/false); + + // Testing with the entire gpu optimization pipeline. + EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{1e-5, 1e-5})); +} + +TEST_F(HorizontalInputFusionTest, MultiOutputFusionTest) { + // This tests the below pattern. One known issue is that gtes (to fusions) can + // be removed after their producer fusions are merged. In the below case, gte2 + // and gte6 will be gone if Fusion2 is fused into Fusion1. + // + // Fusion1 Fusion2 + // | | | | + // | gte1 gte2 | + // | | | | + // | Fusion3 | + // | | | | + // gte3 gte4 gte5 gte6 + // \ | | / + // =====ROOT===== + // + auto module = ParseAndReturnVerifiedModule(R"( + HloModule MultiOutputFusionTest + + %add_f16 { + %x = f16[] parameter(0) + %y = f16[] parameter(1) + ROOT %add = f16[] add(%x, %y) + } + + fused_computation.1 { + arg.1 = f16[1024]{0} parameter(0) + constant0 = f16[] constant(0) + reduce.1 = f16[] reduce(arg.1, constant0), dimensions={0}, to_apply=%add_f16 + add.0 = f16[1024] add(arg.1, arg.1) + ROOT tuple.1 = (f16[], f16[1024]) tuple(reduce.1, add.0) + } + + fused_computation.2 { + arg.1 = f16[1024]{0} parameter(0) + constant0 = f16[] constant(0) + reduce.1 = f16[] reduce(arg.1, constant0), dimensions={0}, to_apply=%add_f16 + add.0 = f16[1024] add(arg.1, arg.1) + ROOT tuple.1 = (f16[], f16[1024]) tuple(reduce.1, add.0) + } + + fused_computation.3 { + arg.0 = f16[1024]{0} parameter(0) + arg.1 = f16[1024]{0} parameter(1) + add.0 = f16[1024] add(arg.0, arg.1) + mul.0 = f16[1024] multiply(arg.0, arg.1) + ROOT tuple.1 = (f16[1024], f16[1024]) tuple(add.0, mul.0) + } + + ENTRY entry_computation { + arg.1 = f16[1024]{0} parameter(0) + arg.2 = f16[1024]{0} parameter(1) + fusion.1 = (f16[],f16[1024]) fusion(arg.1), kind=kInput, calls=fused_computation.1 + fusion.2 = (f16[],f16[1024]) fusion(arg.2), kind=kInput, calls=fused_computation.2 + gte.3 = f16[] get-tuple-element(fusion.1), index=0 + gte.1 = f16[1024]{0} get-tuple-element(fusion.1), index=1 + gte.2 = f16[1024]{0} get-tuple-element(fusion.2), index=1 + gte.6 = f16[] get-tuple-element(fusion.2), index=0 + fusion.3 = (f16[1024],f16[1024]) fusion(gte.1, gte.2), + kind=kLoop, calls=fused_computation.3 + gte.4 = f16[1024] get-tuple-element(fusion.3), index=0 + gte.5 = f16[1024]{0} get-tuple-element(fusion.3), index=1 + ROOT tuple.1 = (f16[], f16[1024]{0}, f16[], f16[1024]{0}) + tuple(gte.3, gte.4, gte.5, gte.6) + } +)") + .ValueOrDie(); + + EXPECT_TRUE(GpuHorizontalInputFusion().Run(module.get()).ValueOrDie()); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/horizontal_fusion.cc b/tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.cc similarity index 93% rename from tensorflow/compiler/xla/service/gpu/horizontal_fusion.cc rename to tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.cc index 6d663c66b50..9d1e0533a91 100644 --- a/tensorflow/compiler/xla/service/gpu/horizontal_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.cc @@ -13,13 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/horizontal_fusion.h" +#include "tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.h" #include #include "absl/container/flat_hash_set.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h" #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/util/env_var.h" @@ -66,12 +67,12 @@ PrimitiveType GetUniqueOutputTypeOfFusion(const HloInstruction& instr) { return first_output_type; } -class HorizontalFusionImpl { +class HorizontalLoopFusionImpl { public: - explicit HorizontalFusionImpl(HloComputation* computation) + explicit HorizontalLoopFusionImpl(HloComputation* computation) : computation_(computation) {} - ~HorizontalFusionImpl() {} + ~HorizontalLoopFusionImpl() {} StatusOr Run(); @@ -113,7 +114,7 @@ class HorizontalFusionImpl { }; HloComputation* computation_; -}; // HorizontalFusionImpl +}; // HorizontalLoopFusionImpl bool IsFusionSupported(const HloInstruction& instr) { // Support only kLoop fusion now. @@ -137,25 +138,6 @@ bool IsFusionSupported(const HloInstruction& instr) { return true; } -bool IsConsumerTheOnlyNonRootUser(const HloInstruction& instr, - const HloInstruction& consumer) { - return absl::c_all_of(instr.users(), [&](const HloInstruction* user) { - if (user->opcode() == HloOpcode::kGetTupleElement) { - // Skip GTE. - return IsConsumerTheOnlyNonRootUser(*user, consumer); - } else if (user == &consumer) { - // `user` is `consumer`. - return true; - } else if (user == user->parent()->root_instruction()) { - // Consumed by ROOT is always fine, since it is impossible to create - // cycles through ROOT. - return true; - } else { - return false; - } - }); -} - // Returns whether `instr` is a profitable candidate to be horizontally fused. // Since the primary benefit of horizontal fusion comes from reducing the // kernel launch overhead, we want to exclude the instructions with @@ -221,7 +203,7 @@ bool HasOnlyRowMajorLayout(const HloInstruction& fusion_instr) { return true; } -void HorizontalFusionImpl::FusionCandidates::Initialize( +void HorizontalLoopFusionImpl::FusionCandidates::Initialize( HloInstruction* consumer) { // First, find out all fusion instructions. We will filter out // unsupported/non-profitable cases below. @@ -275,7 +257,7 @@ void HorizontalFusionImpl::FusionCandidates::Initialize( // Gets a next span of fusion instructions to be fused. absl::Span -HorizontalFusionImpl::FusionCandidates::GetNextSpanOfFusions() { +HorizontalLoopFusionImpl::FusionCandidates::GetNextSpanOfFusions() { if (pos_ >= fusion_instrs_.size()) { return absl::Span(); } @@ -333,7 +315,7 @@ HorizontalFusionImpl::FusionCandidates::GetNextSpanOfFusions() { return absl::MakeSpan(fusion_instrs_).subspan(left, right - left); } -Status HorizontalFusionImpl::CreateFusedComputation( +Status HorizontalLoopFusionImpl::CreateFusedComputation( absl::Span fused_fusion_instrs, std::unique_ptr* uniq_computation, std::vector* bound_operands) { @@ -441,7 +423,7 @@ Status HorizontalFusionImpl::CreateFusedComputation( return Status::OK(); } -Status HorizontalFusionImpl::Fuse( +Status HorizontalLoopFusionImpl::Fuse( absl::Span fused_fusion_instrs) { // Fuse fused_fusion_instrs and replace them with the new fused computation. std::unique_ptr uniq_computation; @@ -483,7 +465,7 @@ Status HorizontalFusionImpl::Fuse( return Status::OK(); } -StatusOr HorizontalFusionImpl::Run() { +StatusOr HorizontalLoopFusionImpl::Run() { bool changed = false; XLA_VLOG_LINES(3, computation_->ToString()); @@ -492,7 +474,7 @@ StatusOr HorizontalFusionImpl::Run() { computation_->MakeInstructionPostOrder(); for (size_t i = 0; i < def_to_use_order.size(); ++i) { auto consumer = def_to_use_order[i]; - HorizontalFusionImpl::FusionCandidates fusion_candidates(consumer); + HorizontalLoopFusionImpl::FusionCandidates fusion_candidates(consumer); while (true) { auto fusions = fusion_candidates.GetNextSpanOfFusions(); if (fusions.empty()) { @@ -512,13 +494,13 @@ StatusOr HorizontalFusionImpl::Run() { } // namespace -StatusOr GpuHorizontalFusion::RunOnComputation( +StatusOr GpuHorizontalLoopFusion::RunOnComputation( HloComputation* computation) { - HorizontalFusionImpl horizontal_fusion_impl(computation); + HorizontalLoopFusionImpl horizontal_fusion_impl(computation); return horizontal_fusion_impl.Run(); } -StatusOr GpuHorizontalFusion::Run(HloModule* module) { +StatusOr GpuHorizontalLoopFusion::Run(HloModule* module) { bool changed = false; VLOG(2) << "Run horizontal fusion."; for (auto* comp : module->MakeNonfusionComputations()) { diff --git a/tensorflow/compiler/xla/service/gpu/horizontal_fusion.h b/tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.h similarity index 91% rename from tensorflow/compiler/xla/service/gpu/horizontal_fusion.h rename to tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.h index 9a804949b1c..3824c5df352 100644 --- a/tensorflow/compiler/xla/service/gpu/horizontal_fusion.h +++ b/tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HORIZONTAL_FUSION_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HORIZONTAL_FUSION_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HORIZONTAL_LOOP_FUSION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HORIZONTAL_LOOP_FUSION_H_ #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -94,11 +94,13 @@ namespace gpu { // output dims of the concatenate will be used as the kernel launch dims. // Instruction bitcasts can be used for Reshape2 and Reshape3 as long as the // outputs of Mul and Add are row-major. -class GpuHorizontalFusion : public HloModulePass { +class GpuHorizontalLoopFusion : public HloModulePass { public: - GpuHorizontalFusion() {} + GpuHorizontalLoopFusion() {} - absl::string_view name() const override { return "gpu_horizontal_fusion"; } + absl::string_view name() const override { + return "gpu_horizontal_loop_fusion"; + } StatusOr Run(HloModule* module) override; @@ -109,4 +111,4 @@ class GpuHorizontalFusion : public HloModulePass { } // namespace gpu } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HORIZONTAL_FUSION_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HORIZONTAL_LOOP_FUSION_H_ diff --git a/tensorflow/compiler/xla/service/gpu/horizontal_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion_test.cc similarity index 93% rename from tensorflow/compiler/xla/service/gpu/horizontal_fusion_test.cc rename to tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion_test.cc index bad589964ff..8091330cd47 100644 --- a/tensorflow/compiler/xla/service/gpu/horizontal_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/horizontal_fusion.h" +#include "tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/gpu/fusion_merger.h" @@ -37,9 +37,9 @@ namespace { namespace op = xla::testing::opcode_matchers; -class HorizontalFusionTest : public HloTestBase {}; +class HorizontalLoopFusionTest : public HloTestBase {}; -TEST_F(HorizontalFusionTest, BasicTest) { +TEST_F(HorizontalLoopFusionTest, BasicTest) { auto module = ParseAndReturnVerifiedModule(R"( HloModule BasicTest @@ -70,7 +70,7 @@ TEST_F(HorizontalFusionTest, BasicTest) { )") .ValueOrDie(); - EXPECT_TRUE(GpuHorizontalFusion().Run(module.get()).ValueOrDie()); + EXPECT_TRUE(GpuHorizontalLoopFusion().Run(module.get()).ValueOrDie()); EXPECT_TRUE(HloDCE().Run(module.get()).ValueOrDie()); const HloInstruction* entry_root = @@ -88,7 +88,7 @@ TEST_F(HorizontalFusionTest, BasicTest) { } // Horizontal fusion should not be triggered as fusion will create cycles. -TEST_F(HorizontalFusionTest, NegativeTestForCycle) { +TEST_F(HorizontalLoopFusionTest, NegativeTestForCycle) { auto module = ParseAndReturnVerifiedModule(R"( HloModule NegativeTestForCycle @@ -122,10 +122,10 @@ TEST_F(HorizontalFusionTest, NegativeTestForCycle) { )") .ValueOrDie(); - EXPECT_FALSE(GpuHorizontalFusion().Run(module.get()).ValueOrDie()); + EXPECT_FALSE(GpuHorizontalLoopFusion().Run(module.get()).ValueOrDie()); } -TEST_F(HorizontalFusionTest, NegativeTestForIncompatibleTypes) { +TEST_F(HorizontalLoopFusionTest, NegativeTestForIncompatibleTypes) { auto module = ParseAndReturnVerifiedModule(R"( HloModule NegativeTestForIncompatibleTypes @@ -158,10 +158,10 @@ TEST_F(HorizontalFusionTest, NegativeTestForIncompatibleTypes) { )") .ValueOrDie(); - EXPECT_FALSE(GpuHorizontalFusion().Run(module.get()).ValueOrDie()); + EXPECT_FALSE(GpuHorizontalLoopFusion().Run(module.get()).ValueOrDie()); } -TEST_F(HorizontalFusionTest, HorizontalFusionAfterVerticalFusion) { +TEST_F(HorizontalLoopFusionTest, HorizontalLoopFusionAfterVerticalFusion) { auto module = ParseAndReturnVerifiedModule(R"( HloModule MergeSharedFusionInstruction @@ -190,7 +190,7 @@ TEST_F(HorizontalFusionTest, HorizontalFusionAfterVerticalFusion) { fusion.AddPass(/*may_duplicate=*/false); fusion.AddPass(/*may_duplicate=*/true); EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); - EXPECT_TRUE(GpuHorizontalFusion().Run(module.get()).ValueOrDie()); + EXPECT_TRUE(GpuHorizontalLoopFusion().Run(module.get()).ValueOrDie()); VLOG(2) << "Dump after horizontal fusion:"; VLOG(2) << module->ToString(); @@ -198,7 +198,7 @@ TEST_F(HorizontalFusionTest, HorizontalFusionAfterVerticalFusion) { EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec{0, 0})); } -TEST_F(HorizontalFusionTest, GradientDescentOptimizerLike) { +TEST_F(HorizontalLoopFusionTest, GradientDescentOptimizerLike) { HloComputation::Builder builder(TestName()); std::vector var_outs; @@ -229,7 +229,7 @@ TEST_F(HorizontalFusionTest, GradientDescentOptimizerLike) { EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{0, 0})); } -TEST_F(HorizontalFusionTest, FusingDifferentOutputs) { +TEST_F(HorizontalLoopFusionTest, FusingDifferentOutputs) { auto module = ParseAndReturnVerifiedModule(R"( HloModule HeterogeneousMultiOutputFusions @@ -280,7 +280,7 @@ TEST_F(HorizontalFusionTest, FusingDifferentOutputs) { )") .ValueOrDie(); - EXPECT_TRUE(GpuHorizontalFusion().Run(module.get()).ValueOrDie()); + EXPECT_TRUE(GpuHorizontalLoopFusion().Run(module.get()).ValueOrDie()); EXPECT_TRUE(HloDCE().Run(module.get()).ValueOrDie()); VLOG(2) << "Dump after horizontal fusion:"; @@ -289,7 +289,7 @@ TEST_F(HorizontalFusionTest, FusingDifferentOutputs) { EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec{0, 0})); } -TEST_F(HorizontalFusionTest, RMSPropLike) { +TEST_F(HorizontalLoopFusionTest, RMSPropLike) { HloComputation::Builder builder(TestName()); std::vector all_outputs; @@ -364,7 +364,7 @@ TEST_F(HorizontalFusionTest, RMSPropLike) { EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{1.0e-5, 1.0e-5})); } -TEST_F(HorizontalFusionTest, NegativeTestForDynamicUpdateSlice) { +TEST_F(HorizontalLoopFusionTest, NegativeTestForDynamicUpdateSlice) { auto module = ParseAndReturnVerifiedModule(R"( HloModule NegativeTestForDynamicUpdateSlice @@ -400,7 +400,7 @@ TEST_F(HorizontalFusionTest, NegativeTestForDynamicUpdateSlice) { })") .ValueOrDie(); - EXPECT_FALSE(GpuHorizontalFusion().Run(module.get()).ValueOrDie()); + EXPECT_FALSE(GpuHorizontalLoopFusion().Run(module.get()).ValueOrDie()); } } // namespace diff --git a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc index 5fe459a70bc..dc3a0c788ac 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc @@ -26,14 +26,13 @@ InfeedThunk::InfeedThunk( ThunkInfo thunk_info, const ShapeTree& infeed_slices) : Thunk(Kind::kInfeed, thunk_info), - hlo_instruction_(thunk_info.hlo_instruction), infeed_slices_(infeed_slices) {} Status InfeedThunk::ExecuteOnStream(const ExecuteParams& params) { auto& stream = *params.stream; auto& buffer_allocations = *params.buffer_allocations; - VLOG(2) << "Infeeding to GPU: " << hlo_instruction_->ToString(); + VLOG(2) << "Infeeding to GPU"; auto op_profiler = params.profiler->MakeScopedInstructionProfiler(profile_index()); diff --git a/tensorflow/compiler/xla/service/gpu/infeed_thunk.h b/tensorflow/compiler/xla/service/gpu/infeed_thunk.h index ab410661ba1..ec33235c466 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/infeed_thunk.h @@ -43,7 +43,6 @@ class InfeedThunk : public Thunk { Status ExecuteOnStream(const ExecuteParams& params) override; private: - const HloInstruction* hlo_instruction_; const ShapeTree infeed_slices_; }; diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc index b994ead17ca..b90e4d85f80 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc @@ -60,18 +60,22 @@ bool GpuInstructionFusion::ShouldFuseInexpensiveChecks(HloInstruction* consumer, // Output fusions are not currently supported on GPUs. if (producer->opcode() == HloOpcode::kFusion) { + VLOG(4) << "Producer " << producer->name() << " is a fusion op"; return false; } // Cost condition: not fuse (simple, expensive producers) and (consumers who // reuse operand elements). - if (producer->opcode() != HloOpcode::kFusion && - consumer->ReusesOperandElements(operand_index) && - is_expensive(*producer)) { + if (producer->opcode() != HloOpcode::kFusion && is_expensive(*producer) && + ReusesOperandElements(consumer, operand_index)) { + VLOG(4) << "Do not fuse simple, expensive producer " << producer->name() + << " and consumer which reuses operand elements."; return false; } if (!IsProducerConsumerFusible(*producer, *consumer) || !InstructionFusion::ShouldFuse(consumer, operand_index)) { + VLOG(4) << "Producer " << producer->name() + << " is not fusible or should not be fused."; return false; } return true; @@ -87,7 +91,8 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, auto producer = consumer->operand(operand_index); // The following checks are potentially expensive. - if (FusionWouldBeTooLarge(*consumer, *producer)) { + if (FusionWouldBeTooLarge(*consumer, *producer, + /*is_consumer_producer_fusion=*/true)) { VLOG(5) << "Fusion of (" << producer->ToString() << ") into (" << consumer->ToString() << ") would be too large"; return false; @@ -107,8 +112,12 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, fusion_node_evaluations_.emplace(consumer, FusionNodeIndexingEvaluation(consumer)); } - return !fusion_node_evaluations_.at(consumer).AverageCodeDuplicationTooHigh( - producer); + if (fusion_node_evaluations_.at(consumer).CodeDuplicationTooHigh(producer)) { + VLOG(5) << "Fusion of " << producer->name() << " into " << consumer->name() + << " would result in overly large code duplication."; + return false; + } + return true; } bool GpuInstructionFusion::ShouldFuseIntoMultiOutput(HloInstruction* consumer, diff --git a/tensorflow/compiler/xla/service/gpu/ir/xla_thunks_ops.cc b/tensorflow/compiler/xla/service/gpu/ir/xla_thunks_ops.cc index 154612824ef..4f4409ab896 100644 --- a/tensorflow/compiler/xla/service/gpu/ir/xla_thunks_ops.cc +++ b/tensorflow/compiler/xla/service/gpu/ir/xla_thunks_ops.cc @@ -35,8 +35,9 @@ XLAThunksDialect::XLAThunksDialect(MLIRContext *context) >(); } +} // namespace xla_thunks + #define GET_OP_CLASSES #include "tensorflow/compiler/xla/service/gpu/ir/xla_thunks_ops.cc.inc" -} // namespace xla_thunks } // namespace mlir diff --git a/tensorflow/compiler/xla/service/gpu/ir/xla_thunks_ops.h b/tensorflow/compiler/xla/service/gpu/ir/xla_thunks_ops.h index ede9adb9ab1..bc0da6a8fc8 100644 --- a/tensorflow/compiler/xla/service/gpu/ir/xla_thunks_ops.h +++ b/tensorflow/compiler/xla/service/gpu/ir/xla_thunks_ops.h @@ -33,10 +33,11 @@ class XLAThunksDialect : public Dialect { static StringRef getDialectNamespace() { return "xla_thunks"; } }; +} // namespace xla_thunks + #define GET_OP_CLASSES #include "tensorflow/compiler/xla/service/gpu/ir/xla_thunks_ops.h.inc" -} // namespace xla_thunks } // namespace mlir #endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_XLA_THUNKS_OPS_H_ diff --git a/tensorflow/compiler/xla/service/gpu/ir/xla_thunks_ops.td b/tensorflow/compiler/xla/service/gpu/ir/xla_thunks_ops.td index 38602550864..eb203e6917d 100644 --- a/tensorflow/compiler/xla/service/gpu/ir/xla_thunks_ops.td +++ b/tensorflow/compiler/xla/service/gpu/ir/xla_thunks_ops.td @@ -21,12 +21,6 @@ limitations under the License. include "mlir/Dialect/LLVMIR/LLVMOpBase.td" include "mlir/IR/OpBase.td" -class LLVMPointerTo - : ContainerType().isPointerTy()">, - "$_self.cast<::mlir::LLVM::LLVMType>().getPointerElementTy()", - "LLVM pointer">; - def XLAThunks_Dialect : Dialect { let name = "xla_thunks"; let cppNamespace = "xla_thunks"; @@ -45,12 +39,12 @@ def AllocationSlice : StructAttr<"AllocationSlice", XLAThunks_Dialect, [ def MemzeroThunkOp : ThunkOp<"execute_memzero_thunk"> { let arguments = (ins - LLVMPointerTo>:$execute_params, + LLVM_PointerTo:$execute_params, AllocationSlice:$allocation_slice ); let results = (outs I<1>:$ok, - LLVMPointerTo>:$error_message + LLVM_PointerTo:$error_message ); } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 9d4ec358bd3..7743d19497d 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -472,39 +472,6 @@ StatusOr GetCudnnConvKind( return InternalError("Unexpected call target: %s", target); } -StatusOr GetDnnConvolutionKind( - const HloCustomCallInstruction* instr) { - absl::string_view target = instr->custom_call_target(); - if (target == kCudnnConvForwardCallTarget) { - return se::dnn::ConvolutionKind::FORWARD; - } - if (target == kCudnnConvBackwardInputCallTarget) { - return se::dnn::ConvolutionKind::BACKWARD_DATA; - } - if (target == kCudnnConvBackwardFilterCallTarget) { - return se::dnn::ConvolutionKind::BACKWARD_FILTER; - } - return InternalError("Unexpected call target: %s", target); -} - -StatusOr GetDnnDataType( - const HloCustomCallInstruction* conv) { - PrimitiveType output_primitive_type = - conv->shape().tuple_shapes(0).element_type(); - switch (output_primitive_type) { - case F16: - return se::dnn::ToDataType::value; - case F32: - return se::dnn::ToDataType::value; - case F64: - return se::dnn::ToDataType::value; - default: - break; - } - return InternalError("Unsupported convolution datatype : %s", - conv->ToString()); -} - string CudnnConvKindToString(CudnnConvKind kind) { switch (kind) { case CudnnConvKind::kForward: diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h index 6f731b2936f..a782eb3f507 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h @@ -55,12 +55,6 @@ enum class CudnnConvKind { StatusOr GetCudnnConvKind(const HloCustomCallInstruction* instr); -StatusOr GetDnnConvolutionKind( - const HloCustomCallInstruction* instr); - -StatusOr GetDnnDataType( - const HloCustomCallInstruction* conv); - // Converts a CudnnConvKind value to a string. string CudnnConvKindToString(CudnnConvKind kind); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index 31203b9c5f0..2215881271c 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -30,12 +30,14 @@ limitations under the License. #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" #include "tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_constants.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h" #include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" @@ -98,6 +100,64 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) { .MakeElementGenerator(hlo, operand_to_generator)); } +Status IrEmitter::EmitConstants(const HloComputation& computation, + bool lookup_indices) { + for (HloInstruction* instr : computation.instructions()) { + if (instr->opcode() != HloOpcode::kConstant) { + continue; + } + Literal& literal = *Cast(instr)->mutable_literal(); + const bool should_emit_initializer = ShouldEmitLiteralInLlvmIr(literal); + llvm::ArrayType* global_type = + llvm::ArrayType::get(b_.getInt8Ty(), literal.size_bytes()); + llvm::Constant* initializer = + should_emit_initializer + ? llvm_ir::ConvertLiteralToIrConstant(literal, module_) + : llvm::ConstantAggregateZero::get(global_type); + if (should_emit_initializer) { + VLOG(3) << "Emitted initializer for constant with shape " + << ShapeUtil::HumanString(literal.shape()); + } + + // These globals will be looked up by name by GpuExecutable so we need to + // give them an external linkage. Not all of their uses are visible in + // the LLVM IR (e.g. TupleThunk) so we can't give then a linkage that + // merely preserves their names (like available_externally), we also need + // to ensure that they stick around even if they're "unused". + // + // We may have to be more more clever here in the future if we notice that + // we're keeping around too many globals because of their linkage. + unsigned global_address_space = llvm_ir::GetGlobalMemoryAddressSpace( + *ir_emitter_context_->llvm_module()); + + std::string global_name = llvm_ir::ConstantHloToGlobalName(*instr); + + llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable( + global_type, /*isConstant=*/should_emit_initializer, + llvm::GlobalValue::ExternalLinkage, + /*Initializer=*/initializer, global_name, + /*TLMode=*/llvm::GlobalValue::NotThreadLocal, + /*AddressSpace=*/global_address_space, + /*isExternallyInitialized=*/false); + global_for_const->setAlignment(llvm::Align(kConstantBufferAlignBytes)); + ir_emitter_context_->llvm_module()->getGlobalList().push_back( + global_for_const); + + GpuExecutable::ConstantInfo info; + info.symbol_name = global_name; + info.content = literal.Clone(); + if (lookup_indices) { + auto maybe_slice = + ir_emitter_context_->buffer_assignment().GetUniqueSlice(instr, {}); + if (maybe_slice.ok()) { + info.allocation_index = maybe_slice.ValueOrDie().index(); + } + } + ir_emitter_context_->constants().push_back(std::move(info)); + } + return Status::OK(); +} + Status IrEmitter::HandleConstant(HloInstruction* constant) { return Status::OK(); } @@ -175,10 +235,12 @@ Status IrEmitter::EmitCallToNestedComputation( llvm::Function*& emitted_function = computation_to_ir_function_[&nested_computation]; if (emitted_function == nullptr) { - IrEmitterNested ir_emitter_nested(hlo_module_config_, nested_computation, - ir_emitter_context_); - TF_RETURN_IF_ERROR(ir_emitter_nested.CodegenNestedComputation()); - emitted_function = ir_emitter_nested.GetEmittedFunction(); + TF_ASSIGN_OR_RETURN( + auto ir_emitter_nested, + IrEmitterNested::Create(hlo_module_config_, nested_computation, + ir_emitter_context_)); + TF_RETURN_IF_ERROR(ir_emitter_nested->CodegenNestedComputation()); + emitted_function = ir_emitter_nested->GetEmittedFunction(); } // Operands are in default address space for non-AMDGPU target. diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h index 50e9f06ef08..1a387528220 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h @@ -105,6 +105,12 @@ class IrEmitter : public DfsHloVisitorWithDefault, llvm::IRBuilder<>* builder() { return &b_; } + // Emits constants to generated LLVM IR, and also populate related + // inforamtion to ir_emitter_context for large-constant initializations. If + // `lookup_indices` is true, the allocation index associated with the constant + // is also populated. + Status EmitConstants(const HloComputation& computation, bool lookup_indices); + protected: // Constructs an IrEmitter with the given IrEmitter context. // ir_emitter_context is owned by the caller and should outlive the IrEmitter diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_context.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_context.h index 7d5a8d032e6..34b93ca5b3f 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_context.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_context.h @@ -17,14 +17,19 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_CONTEXT_H_ #include "llvm/IR/Module.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" #include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h" #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" namespace xla { namespace gpu { + // IrEmitterContext encapsulates common (mutable and immutable) data structures // used by both IrEmitterNested and IrEmitterUnnested, such as the buffer // assignment and the name uniquer. @@ -44,7 +49,11 @@ class IrEmitterContext { cuda_compute_capability_(cuda_compute_capability), profile_index_map_(profile_index_map), mlir_context_(mlir_context), - llvm_module_(llvm_module) {} + llvm_module_(llvm_module) { + mlir_context_ + ->loadDialect(); + } // Disallow copy and assign. IrEmitterContext(const IrEmitterContext&) = delete; IrEmitterContext& operator=(const IrEmitterContext&) = delete; @@ -64,6 +73,8 @@ class IrEmitterContext { llvm::Module* llvm_module() { return llvm_module_; } NameUniquer* name_uniquer() { return &name_uniquer_; } + std::vector& constants() { return constants_; } + private: const HloModule* hlo_module_; const BufferAssignment* buffer_assignment_; @@ -74,6 +85,7 @@ class IrEmitterContext { mlir::MLIRContext* mlir_context_; llvm::Module* llvm_module_; NameUniquer name_uniquer_; + std::vector constants_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc index e96c5f05e60..5fc091ed8e7 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc @@ -41,6 +41,16 @@ IrEmitterNested::IrEmitterNested(const HloModuleConfig& hlo_module_config, : IrEmitter(hlo_module_config, ir_emitter_context, /*is_nested=*/true), nested_computation_(nested_computation) {} +StatusOr> IrEmitterNested::Create( + const HloModuleConfig& hlo_module_config, + const HloComputation& nested_computation, + IrEmitterContext* ir_emitter_context) { + std::unique_ptr emitter(new IrEmitterNested( + hlo_module_config, nested_computation, ir_emitter_context)); + TF_RETURN_IF_ERROR(emitter->EmitConstants(nested_computation, false)); + return emitter; +} + // Nested function serves the same purpose on GPU as a thread-local function on // a CPU. Status IrEmitterNested::CodegenNestedComputation() { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h index ce825851bcc..8ed76cabcda 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h @@ -39,12 +39,11 @@ namespace gpu { // class IrEmitterNested : public IrEmitter { public: - // Constructs an LLVM IR emitter for a nested HLO computation. `function` is - // the containing IR function this emitter produces IR to. See - // IrEmitter::IrEmitter for the meanings of other arguments. - IrEmitterNested(const HloModuleConfig& hlo_module_config, - const HloComputation& nested_computation, - IrEmitterContext* ir_emitter_context); + static StatusOr> Create( + const HloModuleConfig& hlo_module_config, + const HloComputation& nested_computation, + IrEmitterContext* ir_emitter_context); + IrEmitterNested(const IrEmitterNested&) = delete; IrEmitterNested& operator=(const IrEmitterNested&) = delete; @@ -62,6 +61,13 @@ class IrEmitterNested : public IrEmitter { Status CodegenNestedComputation(); private: + // Constructs an LLVM IR emitter for a nested HLO computation. `function` is + // the containing IR function this emitter produces IR to. See + // IrEmitter::IrEmitter for the meanings of other arguments. + IrEmitterNested(const HloModuleConfig& hlo_module_config, + const HloComputation& nested_computation, + IrEmitterContext* ir_emitter_context); + const HloComputation& nested_computation_; llvm::Function* emitted_function_; }; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index f88c70b1a33..b94a7458df2 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -41,6 +41,7 @@ limitations under the License. #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/Function.h" // from @llvm-project #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" +#include "tensorflow/compiler/mlir/utils/name_utils.h" #include "tensorflow/compiler/mlir/xla/hlo_utils.h" #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h" #include "tensorflow/compiler/mlir/xla/type_to_shape.h" @@ -90,6 +91,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/union_find.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -141,7 +143,7 @@ void UpdateLaunchDimensions(const LaunchDimensions& launch_dims, Thunk* thunk, llvm::LLVMContext& llvm_context = llvm_module->getContext(); llvm::ConstantInt* threads_per_block_ir_value = llvm::ConstantInt::get( llvm::IntegerType::get(llvm_context, /*NumBits=*/32), - launch_dims.threads_per_block()); + launch_dims.thread_counts_per_block().x); // Our launch bounds are exact, so we can specify them as reqntidx rather than // maxntidx. nvvm_annotations_node->addOperand(llvm::MDNode::get( @@ -151,24 +153,22 @@ void UpdateLaunchDimensions(const LaunchDimensions& launch_dims, Thunk* thunk, llvm::ConstantAsMetadata::get(threads_per_block_ir_value)})); } -const BufferAllocation* GetAllocation( - mlir::BlockArgument func_arg, const BufferAssignment& buffer_assignment) { +int64_t GetAllocationIndex(mlir::BlockArgument func_arg) { auto func_op = mlir::cast(func_arg.getParentRegion()->getParentOp()); - int64 allocation_index = func_op - .getArgAttrOfType( - func_arg.getArgNumber(), "lmhlo.alloc") - .getValue() - .getSExtValue(); - return &buffer_assignment.GetAllocation(allocation_index); + return func_op + .getArgAttrOfType(func_arg.getArgNumber(), + "lmhlo.alloc") + .getValue() + .getSExtValue(); } StatusOr GetAllocationSliceForMlir( - mlir::Value v, const BufferAssignment& buffer_assignment) { + mlir::Value v, absl::Span allocations) { int64 size = v.getType().cast().getSizeInBits() / 8; if (auto arg = v.dyn_cast()) { - return BufferAllocation::Slice(GetAllocation(arg, buffer_assignment), 0, + return BufferAllocation::Slice(&allocations[GetAllocationIndex(arg)], 0, size); } @@ -185,8 +185,8 @@ StatusOr GetAllocationSliceForMlir( } if (auto view = mlir::dyn_cast(op)) { return BufferAllocation::Slice( - GetAllocation(view.source().cast(), - buffer_assignment), + &allocations[GetAllocationIndex( + view.source().cast())], mlir::cast(view.byte_shift().getDefiningOp()) .value() .cast() @@ -202,12 +202,29 @@ StatusOr GetAllocationSliceForMlir( "StaticMemRefCastOp(ViewOp(arg))"); } -absl::string_view GetHloName(mlir::Operation* op) { - if (auto attr = op->getAttrOfType("name")) { - auto ref = attr.getValue(); - return absl::string_view(ref.data(), ref.size()); +StatusOr> GetMlirBufferSlices( + mlir::Operation* op, mlir::OperandRange operands, + absl::Span allocations) { + const auto buffer_is_written = [op](mlir::Value operand) { + llvm::SmallVector effects; + mlir::cast(op).getEffectsOnValue(operand, + effects); + return absl::c_any_of( + effects, [](const mlir::MemoryEffects::EffectInstance& instance) { + return mlir::isa(instance.getEffect()); + }); + }; + + std::vector slices; + for (mlir::Value operand : operands) { + slices.emplace_back(); + auto& slice = slices.back(); + TF_ASSIGN_OR_RETURN(slice.buffer_slice, + GetAllocationSliceForMlir(operand, allocations)); + slice.written = buffer_is_written(operand); + slice.shape = TypeToShape(operand.getType()); } - return ""; + return slices; } } // namespace @@ -229,6 +246,7 @@ StatusOr> IrEmitterUnnested::Create( auto emitter = std::unique_ptr(new IrEmitterUnnested( hlo_module_config, hlo_computation, ir_emitter_context)); TF_RETURN_IF_ERROR(emitter->lhlo_scratch_emitter_.Initialize()); + TF_RETURN_IF_ERROR(emitter->EmitConstants(*hlo_computation, true)); return std::move(emitter); } @@ -387,6 +405,62 @@ llvm::Type* GetIndexTypeForKernel(const HloInstruction* hlo, int64 launch_size, return b->getInt32Ty(); } +// The same as GetIndexTypeForKernel, but works with MLIR ops. +llvm::Type* GetIndexTypeForKernelFromMlir(mlir::Operation* op, + int64 launch_size, + llvm::IRBuilder<>* b) { + auto shape_in_range = [&](const Shape& s) { + bool in_range = true; + ShapeUtil::ForEachSubshape(s, [&](const Shape& sub_shape, + const ShapeIndex& /*index*/) { + if (sub_shape.IsArray() && !IsInt32(ShapeUtil::ElementsIn(sub_shape))) { + in_range = false; + } + }); + + return in_range; + }; + + llvm::Type* i64_ty = b->getInt64Ty(); + // Check launch dimension + if (!IsInt32(launch_size)) { + return i64_ty; + } + + // Check the size of result tensors + for (auto result : op->getResults()) { + if (!shape_in_range(TypeToShape(result.getType()))) { + return i64_ty; + } + } + + auto hlo_shape_in_range = [&](mlir::Value operand) -> bool { + return shape_in_range(TypeToShape(operand.getType())); + }; + + // Check the size of input tensors + if (!absl::c_all_of(op->getOperands(), hlo_shape_in_range)) { + return i64_ty; + } + + // Check the size of the internal result tensors + if (auto fusion = mlir::cast(op)) { + auto result = fusion.region().walk([&](mlir::Operation* op) { + for (mlir::Value result : op->getResults()) { + if (!hlo_shape_in_range(result)) { + return mlir::WalkResult::interrupt(); + } + } + return mlir::WalkResult::advance(); + }); + if (result.wasInterrupted()) { + return i64_ty; + } + } + + return b->getInt32Ty(); +} + // Gets the input shape of the ROOT slices, which will be used as the kernel // launch dims. The slice input fusion requires the input shapes of the ROOT // slices to be the same although the (slice) output shapes can be different. @@ -1366,13 +1440,6 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { TF_ASSIGN_OR_RETURN(auto sort_op, lhlo_scratch_emitter_.EmitSortOp(sort)); result.op = sort_op; - result.name = GetHloName(sort_op); - // The name in sort op has no semantics, and it's for debug only. If the name - // doesn't exist, we should use a namer (e.g. count-based). - // TODO(timshen): use a namer instead of relying on the HloInstruction names. - if (result.name.empty()) { - result.name = sort->name(); - } const auto& buffer_assignment = ir_emitter_context_->buffer_assignment(); auto& slice = result.extra_slice; TF_ASSIGN_OR_RETURN(slice.buffer_slice, @@ -1382,74 +1449,57 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { result.thunk_info = GetThunkInfo(sort); - return EmitMlirSort(result); + return EmitSortFromMlir(result); } -Status IrEmitterUnnested::EmitMlirSort(MlirEmitterInput input) { - const auto& buffer_assignment = ir_emitter_context_->buffer_assignment(); - auto sort_op = mlir::cast(input.op); - - int operand_count = sort_op.operands().size(); - std::vector operand_shapes(operand_count); - std::vector slices; - std::vector output_shapes(sort_op.output().size()); - - for (int i = 0; i < operand_count; i++) { - operand_shapes[i] = - TypeToShape(sort_op.operands()[i].getType().cast()); - } - - // Craft n + 1 slices, where the first n are output parameters, and the last - // is the on-device tuple storage. We don't need n operands because sorting - // kernels are always in-place. - for (int i = 0; i < operand_count; i++) { - output_shapes[i] = - TypeToShape(sort_op.output()[i].getType().cast()); - MlirBufferSlice slice; - TF_ASSIGN_OR_RETURN( - slice.buffer_slice, - GetAllocationSliceForMlir(sort_op.output()[i], buffer_assignment)); - slice.written = true; - slice.shape = operand_shapes[i]; - slices.push_back(slice); - } - slices.push_back(input.extra_slice); +Status IrEmitterUnnested::EmitSortFromMlir(MlirEmitterInput mlir_input) { + absl::Span allocations( + ir_emitter_context_->buffer_assignment().Allocations()); + auto sort_op = mlir::cast(mlir_input.op); + std::string name = mlir::GetNameFromLoc(sort_op.getLoc()); + TF_ASSIGN_OR_RETURN( + std::vector operands, + GetMlirBufferSlices(sort_op, sort_op.operands(), allocations)); + TF_ASSIGN_OR_RETURN( + std::vector outputs, + GetMlirBufferSlices(sort_op, sort_op.output(), allocations)); + outputs.push_back(mlir_input.extra_slice); std::vector> thunks; - Shape keys_shape = operand_shapes[0]; - int64 dimension_to_sort = sort_op.dimension().getSExtValue(); - for (int64 i = 0; i < operand_count; ++i) { + Shape keys_shape = operands[0].shape; + int64 dimension_to_sort = sort_op.dimension(); + for (int64 i = 0; i < operands.size(); ++i) { // We assume that the layout of all involved operands and outputs is the // same. TF_RET_CHECK( - LayoutUtil::LayoutsInShapesEqual(keys_shape, operand_shapes[i])); + LayoutUtil::LayoutsInShapesEqual(keys_shape, operands[i].shape)); TF_RET_CHECK( - LayoutUtil::LayoutsInShapesEqual(keys_shape, output_shapes[i])); + LayoutUtil::LayoutsInShapesEqual(keys_shape, outputs[i].shape)); // If possible, we share buffers. If that is not possible, we need to copy // the values, because the emitter does the sorting in-place. TF_ASSIGN_OR_RETURN( auto destination_buffer, - GetAllocationSliceForMlir(sort_op.output()[i], buffer_assignment)); + GetAllocationSliceForMlir(sort_op.output()[i], allocations)); TF_ASSIGN_OR_RETURN( auto source_address, - GetAllocationSliceForMlir(sort_op.operands()[i], buffer_assignment)); + GetAllocationSliceForMlir(sort_op.operands()[i], allocations)); if (destination_buffer != source_address) { // TODO(b/26783907): Figure out why we never seem to share buffers for // key/value sort. - VLOG(2) << input.name << " requires initial D2D copy for operand " << i; + VLOG(2) << name << " requires initial D2D copy for operand " << i; thunks.push_back(absl::make_unique( Thunk::ThunkInfo(), /*source_address=*/source_address, /*destination_buffer=*/destination_buffer, - /*mem_size=*/ShapeUtil::ByteSizeOf(operand_shapes[i]))); + /*mem_size=*/ShapeUtil::ByteSizeOf(operands[i].shape))); } } uint64 dimension_to_sort_bound = keys_shape.dimensions(dimension_to_sort); int64 num_stages = tensorflow::Log2Ceiling(dimension_to_sort_bound); - VLOG(2) << input.name << " requires " << num_stages << " stages."; + VLOG(2) << name << " requires " << num_stages << " stages."; CHECK_GE(1ULL << num_stages, dimension_to_sort_bound); CHECK_LT(1ULL << (num_stages - 1), dimension_to_sort_bound); @@ -1513,10 +1563,10 @@ Status IrEmitterUnnested::EmitMlirSort(MlirEmitterInput input) { // we have not enough threads, or not enough shared memory. Also it does not // give a speedup if the tile size is < 128. int64 total_shared_memory_needed = 0; - for (int64 i = 0; i < operand_count; ++i) { + for (int64 i = 0; i < operands.size(); ++i) { total_shared_memory_needed += kTileSize * - ShapeUtil::ByteSizeOfPrimitiveType(operand_shapes[i].element_type()); + ShapeUtil::ByteSizeOfPrimitiveType(operands[i].shape.element_type()); } bool no_tiling = kTileSize < 128 || @@ -1529,7 +1579,7 @@ Status IrEmitterUnnested::EmitMlirSort(MlirEmitterInput input) { "kTileSize=%d < 128, " "kThreadsPerBlock=%d > threads_per_block_limit=%d, " "total_shared_memory_needed=%d > shared_memory_per_block=%d", - input.name, (no_tiling ? "won't" : "will"), kTileSize, kThreadsPerBlock, + name, (no_tiling ? "won't" : "will"), kTileSize, kThreadsPerBlock, ir_emitter_context_->gpu_device_info().threads_per_block_limit, total_shared_memory_needed, ir_emitter_context_->gpu_device_info().shared_memory_per_block); @@ -1537,32 +1587,32 @@ Status IrEmitterUnnested::EmitMlirSort(MlirEmitterInput input) { uint64 num_blocks = CeilOfRatio(num_iterations, kThreadsPerBlock); LaunchDimensions tiled_launch_dimensions(num_blocks, kThreadsPerBlock); VLOG(2) << absl::StreamFormat("%s launch dims: %d blocks, %d threads/block", - input.name, num_blocks, kThreadsPerBlock); + name, num_blocks, kThreadsPerBlock); std::vector ir_arrays; auto emit_kernel = [&](absl::Span xor_masks) { VLOG(2) << absl::StreamFormat( - "%s uses kernel for xor masks [%s]", input.name, + "%s uses kernel for xor masks [%s]", name, absl::StrJoin(xor_masks, ", ", [](std::string* out, int64 xor_mask) { absl::StrAppendFormat(out, "0x%x", xor_mask); })); - thunks.push_back(BuildKernelThunkForMlir(input.name, Thunk::ThunkInfo(), - slices, &ir_arrays)); + thunks.push_back( + BuildKernelThunkForMlir(name, Thunk::ThunkInfo(), outputs, &ir_arrays)); LaunchDimensions launch_dimensions = xor_masks.size() > 1 ? tiled_launch_dimensions : standard_launch_dimensions; UpdateLaunchDimensions(launch_dimensions, thunks.back().get(), ir_emitter_context_->llvm_module()); std::vector values_arrays; - values_arrays.reserve(operand_count); - for (int64 i = 0; i < operand_count; ++i) { + values_arrays.reserve(operands.size()); + for (int64 i = 0; i < operands.size(); ++i) { values_arrays.push_back(ir_arrays[i]); } TF_ASSIGN_OR_RETURN( const HloComputation* comparator, GetOrCreateSubComputationFromRegion(&sort_op.comparator())); return llvm_ir::EmitSortInPlace( - dimension_to_sort, values_arrays, IrName(input.name), xor_masks, &b_, + dimension_to_sort, values_arrays, IrName(name), xor_masks, &b_, launch_dimensions, xor_masks.size() > 1 ? num_iterations_in_sort_dim : standard_num_iterations_in_sort_dim, @@ -1595,17 +1645,16 @@ Status IrEmitterUnnested::EmitMlirSort(MlirEmitterInput input) { TF_RETURN_IF_ERROR(emit_kernel(xor_masks)); } VLOG(2) << absl::StreamFormat( - "%s requires %d thunks (including any D2D copies)", input.name, - thunks.size()); + "%s requires %d thunks (including any D2D copies)", name, thunks.size()); - AddThunkToThunkSequence( - absl::make_unique(input.thunk_info, std::move(thunks))); - if (operand_count > 1) { + AddThunkToThunkSequence(absl::make_unique( + mlir_input.thunk_info, std::move(thunks))); + if (operands.size() > 1) { // Emit the tuple as part of the last stage of sorting. // We are currently in the block sorted.in_bounds.after. b_.SetInsertPoint(b_.GetInsertBlock()->getTerminator()); llvm_ir::EmitTuple( - ir_arrays[operand_count], + ir_arrays.back(), absl::MakeSpan(ir_arrays).subspan(0, ir_arrays.size() - 1), &b_); } return Status::OK(); @@ -1624,9 +1673,10 @@ Status IrEmitterUnnested::HandleReplicaId(HloInstruction* hlo) { } Status IrEmitterUnnested::HandleCollectivePermute(HloInstruction* hlo) { + CollectivePermuteConfig config = GetCollectivePermuteConfig(hlo); AddThunkToThunkSequence(absl::make_unique( - GetThunkInfo(hlo), GetAllocationSlice(*hlo->operand(0)), - GetAllocationSlice(*hlo))); + GetThunkInfo(hlo), std::move(config), + GetAllocationSlice(*hlo->operand(0)), GetAllocationSlice(*hlo))); return Status::OK(); } @@ -1658,9 +1708,10 @@ Status IrEmitterUnnested::HandleAllReduce(HloInstruction* crs) { *crs, crs->shape().IsTuple() ? ShapeIndex({i}) : ShapeIndex({})); tuple_element_buffers.push_back(buffers[i].destination_buffer); } + NcclAllReduceConfig config = + GetNcclAllReduceConfig(crs, hlo_module_config_.replica_count()); auto all_reduce_thunk = absl::make_unique( - GetThunkInfo(crs), - /*replica_count=*/hlo_module_config_.replica_count(), + GetThunkInfo(crs), std::move(config), /*buffers=*/std::move(buffers)); if (crs->shape().IsTuple()) { std::vector> thunks; @@ -2252,11 +2303,19 @@ StatusOr> IrEmitterUnnested::BuildWhileThunk( IrEmitterUnnested::Create(hlo_module_config_, body, ir_emitter_context_)); TF_RETURN_IF_ERROR(body->Accept(ir_emitter_body.get())); + const auto* index_map = ir_emitter_context_->profile_index_map(); + absl::optional condition_profile_index, body_profile_index; + if (index_map) { + condition_profile_index = index_map->GetProfileIndexFor(*condition); + body_profile_index = index_map->GetProfileIndexFor(*body); + } + return std::unique_ptr(new WhileThunk( GetThunkInfo(hlo), GetAllocationSlice(*condition->root_instruction()), // cond result ir_emitter_condition->ConsumeThunkSequence(), - ir_emitter_body->ConsumeThunkSequence())); + ir_emitter_body->ConsumeThunkSequence(), condition_profile_index, + body_profile_index)); } StatusOr> IrEmitterUnnested::BuildForThunk( @@ -2272,8 +2331,15 @@ StatusOr> IrEmitterUnnested::BuildForThunk( IrEmitterUnnested::Create(hlo_module_config_, body, ir_emitter_context_)); TF_RETURN_IF_ERROR(body->Accept(ir_emitter_body.get())); + const auto* index_map = ir_emitter_context_->profile_index_map(); + absl::optional body_profile_index; + if (index_map) { + body_profile_index = index_map->GetProfileIndexFor(*body); + } + return std::unique_ptr(new ForThunk( - GetThunkInfo(hlo), loop_limit, ir_emitter_body->ConsumeThunkSequence())); + GetThunkInfo(hlo), loop_limit, ir_emitter_body->ConsumeThunkSequence(), + body_profile_index)); } StatusOr> IrEmitterUnnested::BuildConditionalThunk( @@ -2285,7 +2351,15 @@ StatusOr> IrEmitterUnnested::BuildConditionalThunk( std::vector branch_operands; std::vector branch_thunks; - for (int j = 0; j < hlo->branch_count(); ++j) { + std::vector> branch_profile_indices; + + int branch_count = hlo->branch_count(); + branch_thunks.reserve(branch_count); + branch_profile_indices.reserve(branch_count); + + const auto* index_map = ir_emitter_context_->profile_index_map(); + + for (int j = 0; j < branch_count; ++j) { branch_operands.emplace_back(GetAllocationSlice(*hlo->operand(j + 1))); HloComputation* branch_computation = hlo->branch_computation(j); TF_ASSIGN_OR_RETURN( @@ -2294,17 +2368,25 @@ StatusOr> IrEmitterUnnested::BuildConditionalThunk( ir_emitter_context_)); TF_CHECK_OK(branch_computation->Accept(ir_emitter.get())); branch_thunks.push_back(std::move(*ir_emitter->ConsumeThunkSequence())); + + absl::optional profile_index; + if (index_map) { + profile_index = index_map->GetProfileIndexFor(*branch_computation); + } + branch_profile_indices.push_back(profile_index); } + ConditionalThunkConfig config = GetConditionalThunkConfig( + hlo, std::move(branch_thunks), std::move(branch_profile_indices)); return std::unique_ptr(new ConditionalThunk( - GetThunkInfo(hlo), GetAllocationSlice(*hlo->operand(0)), branch_operands, - std::move(branch_thunks))); + GetThunkInfo(hlo), std::move(config), + GetAllocationSlice(*hlo->operand(0)), branch_operands)); } Status IrEmitterUnnested::EmitTargetElementLoopInThunk( const HloInstruction& hlo, const llvm_ir::ElementGenerator& element_generator, KernelThunk* thunk, - int unroll_factor) { + int unroll_factor, bool few_waves) { VLOG(3) << bindings_.ToString(); bool multi_output = hlo.shape().IsTuple(); @@ -2315,7 +2397,8 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( << ShapeUtil::HumanStringWithLayout(hlo.shape()) << " for unroll_factor " << unroll_factor; LaunchDimensions launch_dimensions = CalculateLaunchDimensions( - element_shape, ir_emitter_context_->gpu_device_info(), unroll_factor); + element_shape, ir_emitter_context_->gpu_device_info(), unroll_factor, + few_waves); UpdateLaunchDimensions(launch_dimensions, thunk, ir_emitter_context_->llvm_module()); if (!multi_output) { @@ -2401,8 +2484,27 @@ Status IrEmitterUnnested::EmitTargetElementLoop( std::unique_ptr kernel_thunk = BuildKernelThunk(&hlo, /*implements_whole_instruction=*/true); + + // Check if we want to schedule grid size that has fewer SM waves. + // This speed up computations in some cases. + bool few_waves = false; + auto few_waves_allow_instr = [](const HloInstruction* instr) { + return instr->IsElementwise() || instr->opcode() == HloOpcode::kParameter || + // We need to make the codegen broadcast aware before enabling + // more broadcast pattern. + (instr->opcode() == HloOpcode::kBroadcast && + instr->dimensions().empty()); + }; + if (hlo.opcode() == HloOpcode::kFusion) { + few_waves = + absl::c_all_of(hlo.fused_instructions_computation()->instructions(), + few_waves_allow_instr); + } else { + few_waves = few_waves_allow_instr(&hlo); + } + Status emit_status = EmitTargetElementLoopInThunk( - hlo, body_emitter, kernel_thunk.get(), unroll_factor); + hlo, body_emitter, kernel_thunk.get(), unroll_factor, few_waves); thunk_sequence_.emplace_back(std::move(kernel_thunk)); return emit_status; @@ -2886,7 +2988,7 @@ void IrEmitterUnnested::EmitEpilogueForReduction( current_output); llvm::Value* warp_id = b_.CreateUDiv(thread_id_info.thread_id_x, constant(kWarpSize)); - ksl.If(is_zero(thread_id_info.lane_id), [&] { + ksl.If("intra_warp_reduce_write", is_zero(thread_id_info.lane_id), [&] { llvm::Value* shmem_output_addr = shared_to_global(b_.CreateInBoundsGEP( shared_cache, {b_.getInt32(0), constant(j), warp_id})); @@ -2894,7 +2996,7 @@ void IrEmitterUnnested::EmitEpilogueForReduction( }); EmitSyncThreads(); - ksl.If(is_zero(warp_id), [&] { + ksl.If("inter_warp_reduce", is_zero(warp_id), [&] { llvm::Value* block_accum_addr = shared_to_global(b_.CreateInBoundsGEP( shared_cache, {b_.getInt32(0), constant(j), thread_id_info.lane_id})); @@ -2914,10 +3016,11 @@ void IrEmitterUnnested::EmitEpilogueForReduction( EmitFullWarpShuffleDownLoopForReduce( reducers[i], element_type, /*block_accum_addr*/ selected_value); - ksl.If(is_zero(thread_id_info.thread_id_x), [&] { - TF_CHECK_OK(EmitAtomicOperationForNestedComputation( - *reducers[i], output_address, block_accum_addr)); - }); + ksl.If("reduction_atomic_update", is_zero(thread_id_info.thread_id_x), + [&] { + TF_CHECK_OK(EmitAtomicOperationForNestedComputation( + *reducers[i], output_address, block_accum_addr)); + }); }); } else { @@ -2952,10 +3055,11 @@ void IrEmitterUnnested::EmitEpilogueForReduction( b_.CreateICmpULT(thread_id_info.thread_id_x, tiling_kernel_info.output_tile_bounds[kDimY])); - ksl.If(b_.CreateAnd(has_output, is_zero(thread_id_info.lane_id)), [&] { - TF_CHECK_OK(EmitAtomicOperationForNestedComputation( - *reducers[i], output_address, shmem_transposed_addr)); - }); + ksl.If("reduction_atomic_update", + b_.CreateAnd(has_output, is_zero(thread_id_info.lane_id)), [&] { + TF_CHECK_OK(EmitAtomicOperationForNestedComputation( + *reducers[i], output_address, shmem_transposed_addr)); + }); } } } @@ -2991,6 +3095,28 @@ void IrEmitterUnnested::EmitPrintfWithThreadId( }); } +namespace { + +// Obtains the corresponding index of the out_instr in the outputs of the +// `unnested_hlo`. +ShapeIndex CreateShapeIndexForOutputInstruction( + const HloInstruction& unnested_hlo, const HloInstruction& out_instr) { + if (!unnested_hlo.IsMultiOutputFusion()) { + return ShapeIndex({}); + } + const auto& all_outputs = unnested_hlo.fused_expression_root()->operands(); + for (size_t i = 0; i < all_outputs.size(); ++i) { + if (all_outputs[i] == &out_instr) { + return ShapeIndex({static_cast(i)}); + } + } + LOG(FATAL) << " Fusion root does not contain output instruction; " + << " fusion: " << unnested_hlo.ToString() + << ", output instruction: " << out_instr.ToString(); +} + +} // namespace + void IrEmitterUnnested::EmitTileElementForReduction( HloInstruction* unnested_hlo, const Shape& reduction_operand_shape, absl::Span output_instructions, @@ -2998,7 +3124,6 @@ void IrEmitterUnnested::EmitTileElementForReduction( const ReductionCodegenInfo& reduction_info, absl::Span reducers, int64 x_iter_num) { VLOG(10) << "Emit tile element for reduce " << unnested_hlo->ToString(); - bool returns_tuple = output_instructions.size() > 1; int partial_result_index = reduction_info.IsRowReduction() ? 0 : x_iter_num; InlinedVector input_gens; @@ -3015,7 +3140,8 @@ void IrEmitterUnnested::EmitTileElementForReduction( for (int i = 0, e = output_instructions.size(); i != e; ++i) { const HloInstruction* inst = output_instructions[i]; - ShapeIndex idx = returns_tuple ? ShapeIndex({i}) : ShapeIndex({}); + ShapeIndex idx = + CreateShapeIndexForOutputInstruction(*unnested_hlo, *inst); if (IsReductionFromOrToContiguousDimensions(*inst)) { input_gens.push_back(fused_emitter.GetGenerator(inst->operand(0))); } else { @@ -3748,71 +3874,41 @@ ReductionCodegenInfo IrEmitterUnnested::ComputeReductionCodegenInfo( reduction_dimensions.is_row_reduction); } -Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions( +void IrEmitterUnnested::EmitIRForReduction( HloInstruction* unnested_hlo, - absl::Span output_instructions) { - bool returns_tuple = output_instructions.size() > 1; - VLOG(10) << "Emitting reduction to vector " << unnested_hlo->ToString(); - + absl::Span output_instructions, + ReductionCodegenInfo* reduction_info, const Shape& input_shape) { std::vector reduce_instructions; InlinedVector reduction_output_shape_indices; InlinedVector reducers; - - // Build an initializer thunk to initialize each reduction output. - std::vector> thunks; - for (int i = 0; i < output_instructions.size(); ++i) { + for (size_t i = 0; i < output_instructions.size(); ++i) { if (!IsReductionFromOrToContiguousDimensions(*output_instructions[i])) { continue; } HloInstruction* output_instruction = output_instructions[i]; reduce_instructions.push_back(output_instruction); - ShapeIndex idx = returns_tuple ? ShapeIndex({i}) : ShapeIndex({}); - reduction_output_shape_indices.push_back(idx); + reduction_output_shape_indices.push_back( + CreateShapeIndexForOutputInstruction(*unnested_hlo, + *output_instruction)); reducers.push_back(output_instruction->to_apply()); - - TF_ASSIGN_OR_RETURN(std::unique_ptr initializer_thunk, - BuildInitializerThunk(unnested_hlo, idx)); - thunks.push_back(std::move(initializer_thunk)); } + CHECK(reduce_instructions.size() != 0) + << " expect at least one reduce instructions."; - const HloInstruction* first_reduce = reduce_instructions.at(0); - if (output_instructions.size() > 1) { - if (!AreFusedReductionOutputsConsistent(output_instructions, - first_reduce)) { - return InternalError("Inconsistent reduction fusion outputs"); - } - } - - // Build a kernel thunk to compute all the outputs. - std::unique_ptr kernel_thunk = - BuildKernelThunk(unnested_hlo, /*implements_whole_instruction=*/false); - - const Shape& input_shape = first_reduce->operand(0)->shape(); - // The layout of a reduction input is either set by LayoutAssignment for - // unnested kReduce or by InstructionFusion for fused kReduce. - CHECK(input_shape.has_layout()) << "LayoutAssignment or InstructionFusion " - "doesn't set the input layout of " - << first_reduce->ToString(); - - ReductionCodegenInfo reduction_info = - ComputeReductionCodegenInfo(unnested_hlo, first_reduce); const KernelMappingScheme& mapping_scheme = - reduction_info.GetKernelMappingScheme(); + reduction_info->GetKernelMappingScheme(); LaunchDimensions launch_dimensions(mapping_scheme.GetNumberOfBlocks(), mapping_scheme.GetThreadsPerBlock()); - VLOG(3) << "Launch dimensions of " << unnested_hlo->name() - << ": number of blocks: " << mapping_scheme.GetNumberOfBlocks() - << " - threads per block: " << mapping_scheme.GetThreadsPerBlock(); llvm::Type* index_ty = GetIndexTypeForKernel( unnested_hlo, launch_dimensions.launch_bound(), &b_); - EmitPrologueForReduction(unnested_hlo, &reduction_info, reduce_instructions, + EmitPrologueForReduction(unnested_hlo, reduction_info, reduce_instructions, index_ty); EmitElementFunction emit_reduction_tile = [&](const llvm_ir::IrArray::Index& index, llvm::Value* y_loc, llvm::Value* x_loc, int64 x_iter_num) { EmitTileElementForReduction(unnested_hlo, input_shape, - output_instructions, index, reduction_info, + output_instructions, index, *reduction_info, reducers, x_iter_num); }; @@ -3821,70 +3917,185 @@ Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions( [&](const ThreadIdInfo& thread_id_info, const IrArray::Index& index, const string& loop_name, llvm::Value* tile_height, llvm::Value* tile_width, KernelSupportLibrary* ksl) { - EmitTile(reduction_info.GetKernelMappingScheme(), index, loop_name, ksl, - thread_id_info, tile_height, tile_width, emit_reduction_tile); + EmitTile(reduction_info->GetKernelMappingScheme(), index, loop_name, + ksl, thread_id_info, tile_height, tile_width, + emit_reduction_tile); }); - EmitEpilogueForReduction(index_ty, unnested_hlo, reduction_info, + EmitEpilogueForReduction(index_ty, unnested_hlo, *reduction_info, reduce_instructions, reduction_output_shape_indices, reducers, tiling_kernel_info); +} +namespace { + +// Returns whether the `instr` is either a constant, a scalar, or a +// broadcasted constant/scalar. +bool IsBroadcastedConstantOrScalar(const HloInstruction& instr) { + return instr.IsConstant() || ShapeUtil::IsScalar(instr.shape()) || + (HloOpcode::kBroadcast == instr.opcode() && + (instr.operand(0)->IsConstant() || + ShapeUtil::IsScalar(instr.operand(0)->shape()))); +} + +// Divides output_instructions into groups. Different groups will be executed +// in parallel. Generally speaking, we'd like to run the reduce instructions +// in parallel without incurring too much recomputation overhead. The current +// heuristic is to place reduce instructions who share nothing or only +// (broadcasted) scalars/constants into different groups; otherwise, they are +// placed in the same group. Non-reduce instructions always go with the reduce +// instructions into the same group so long as they share any predecessors. +std::vector> DivideOutputInstructionsIntoGroups( + HloInstruction* unnested_hlo, + absl::Span output_instructions) { + CHECK(!output_instructions.empty()); + if (output_instructions.size() == 1) { + return {{output_instructions[0]}}; + } + + std::vector> disjoint_sets( + output_instructions.size()); + for (size_t i = 0; i < output_instructions.size(); ++i) { + disjoint_sets[i].Get() = output_instructions[i]; + } + + std::unique_ptr reachability_map = + HloReachabilityMap::Build(unnested_hlo->fused_instructions_computation()); + for (auto* instr : unnested_hlo->fused_instructions()) { + std::vector reached_output_ids; + for (size_t oid = 0; oid < output_instructions.size(); ++oid) { + if (HloOpcode::kReduce == output_instructions[oid]->opcode() && + (IsBroadcastedConstantOrScalar(*instr))) { + // Do not group output reduce instructions through broadcasted + // constants or scalars, as the recomputation should be acceptable. + VLOG(3) << "Skip broadcasted constant or scalar " << instr->ToString(); + continue; + } + // Now group output instructions if they have common predecessors. + if (reachability_map->IsReachable(instr, output_instructions[oid])) { + VLOG(3) << "Reaching " << output_instructions[oid]->ToString() + << " from " << instr->ToString(); + reached_output_ids.push_back(oid); + } + } + for (size_t j = 1; j < reached_output_ids.size(); ++j) { + disjoint_sets[reached_output_ids[0]].Merge( + &disjoint_sets[reached_output_ids[j]]); + } + } + // Place output instructions in the same set into the same group. + absl::flat_hash_map> groups; + for (size_t oid = 0; oid < output_instructions.size(); ++oid) { + groups[disjoint_sets[oid].Get()].push_back(output_instructions.at(oid)); + } + + std::vector> ret; + absl::c_for_each( + groups, [&](auto& iter) { ret.emplace_back(std::move(iter.second)); }); + return ret; +} + +} // namespace + +Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions( + HloInstruction* unnested_hlo, + absl::Span output_instructions) { + bool returns_tuple = output_instructions.size() > 1; + VLOG(10) << "Emitting reduction to vector " << unnested_hlo->ToString(); + + // Build an initializer thunk to initialize each reduction output. + std::vector> thunks; + for (int i = 0; i < output_instructions.size(); ++i) { + if (!IsReductionFromOrToContiguousDimensions(*output_instructions[i])) { + continue; + } + + ShapeIndex idx = returns_tuple ? ShapeIndex({i}) : ShapeIndex({}); + TF_ASSIGN_OR_RETURN(std::unique_ptr initializer_thunk, + BuildInitializerThunk(unnested_hlo, idx)); + thunks.push_back(std::move(initializer_thunk)); + } + + // Build a kernel thunk to compute all the outputs. + const HloInstruction* first_reduce = nullptr; + for (int i = 0; i < output_instructions.size(); ++i) { + if (IsReductionFromOrToContiguousDimensions(*output_instructions[i])) { + first_reduce = output_instructions[i]; + break; + } + } + CHECK(first_reduce); + if (output_instructions.size() > 1) { + if (!AreFusedReductionOutputsConsistent(output_instructions, + first_reduce)) { + return InternalError("Inconsistent reduction fusion outputs"); + } + } + const Shape& input_shape = first_reduce->operand(0)->shape(); + // The layout of a reduction input is either set by LayoutAssignment for + // unnested kReduce or by InstructionFusion for fused kReduce. + CHECK(input_shape.has_layout()) << "LayoutAssignment or InstructionFusion " + "doesn't set the input layout of " + << first_reduce->ToString(); + + // Group output instructions. Each group will be executed in parallel. + std::vector> instr_groups = + DivideOutputInstructionsIntoGroups(unnested_hlo, output_instructions); + VLOG(2) << StrCat("Generate in ", instr_groups.size(), " groups for ", + unnested_hlo->ToString()); + std::unique_ptr kernel_thunk = + BuildKernelThunk(unnested_hlo, /*implements_whole_instruction=*/false); + KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll); + for (size_t i = 0; i < instr_groups.size(); ++i) { + // Create a new ReductionCodegenInfo instance as it contains states for + // code generation per reduction group. For now, let's always use the very + // first reduce as representative to construct ReductionCodegenInfo, since + // all the reductions are required to have the same shape and layout as + // verified by `AreFusedReductionOutputsConsistent()`. We can loosen the + // constraint later when the needs arise. + ReductionCodegenInfo reduction_info = + ComputeReductionCodegenInfo(unnested_hlo, first_reduce); + auto emit_reduction_func = [&] { + EmitIRForReduction(unnested_hlo, instr_groups[i], &reduction_info, + input_shape); + }; + // Use raw block_id_y to select the i-th parallel reduction to run. Using + // block_id_y instead of block_id_x simplifies the index calculation + // for reduction code generation as the block_id_y is orthogonal to + // the indices used within the reductions. + llvm::CallInst* raw_block_id_y = gpu::EmitCallToTargetIntrinsic( + gpu::TargetIntrinsicID::kBlockIdy, {}, {}, &b_); + llvm_ir::AddRangeMetadata(0, instr_groups.size(), + llvm::cast(raw_block_id_y)); + llvm::Value* guarding_cond = + b_.CreateICmpEQ(raw_block_id_y, b_.getInt32(i)); + ksl.If(StrCat("reduce-group-", i), guarding_cond, emit_reduction_func); + } + ReductionCodegenInfo reduction_info = + ComputeReductionCodegenInfo(unnested_hlo, first_reduce); + const KernelMappingScheme& mapping_scheme = + reduction_info.GetKernelMappingScheme(); + // block_y_count is set to instr_groups.size(), so that each reduction group + // can be run in parallel by a different BlockIdy. + LaunchDimensions launch_dimensions( + {/*x=*/mapping_scheme.GetNumberOfBlocks(), + /*y=*/static_cast(instr_groups.size()), + /*z=*/1}, + {/*x=*/mapping_scheme.GetThreadsPerBlock(), /*y=*/1, /*z=*/1}); + VLOG(3) << "Launch dimensions of " << unnested_hlo->name() + << ": number of blocks: " << mapping_scheme.GetNumberOfBlocks() + << " - threads per block: " << mapping_scheme.GetThreadsPerBlock(); UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(), ir_emitter_context_->llvm_module()); thunks.push_back(std::move(kernel_thunk)); - auto sequential_thunk = absl::make_unique( - GetThunkInfo(unnested_hlo), std::move(thunks)); + std::unique_ptr sequential_thunk = + absl::make_unique(GetThunkInfo(unnested_hlo), + std::move(thunks)); AddThunkToThunkSequence(std::move(sequential_thunk)); return Status::OK(); } -Status IrEmitterUnnested::EmitConstantGlobals() { - for (const BufferAllocation& allocation : - ir_emitter_context_->buffer_assignment().Allocations()) { - if (!allocation.is_constant()) { - continue; - } - - const Literal& literal = llvm_ir::LiteralForConstantAllocation(allocation); - const bool should_emit_initializer = ShouldEmitLiteralInLlvmIr(literal); - llvm::ArrayType* global_type = - llvm::ArrayType::get(b_.getInt8Ty(), allocation.size()); - llvm::Constant* initializer = - should_emit_initializer - ? llvm_ir::ConvertLiteralToIrConstant(literal, module_) - : llvm::ConstantAggregateZero::get(global_type); - if (should_emit_initializer) { - VLOG(3) << "Emitted initializer for constant with shape " - << ShapeUtil::HumanString(literal.shape()); - } - - // These globals will be looked up by name by GpuExecutable so we need to - // give them an external linkage. Not all of their uses are visible in - // the LLVM IR (e.g. TupleThunk) so we can't give then a linkage that - // merely preserves their names (like available_externally), we also need - // to ensure that they stick around even if they're "unused". - // - // We may have to be more more clever here in the future if we notice that - // we're keeping around too many globals because of their linkage. - unsigned global_address_space = llvm_ir::GetGlobalMemoryAddressSpace( - *ir_emitter_context_->llvm_module()); - llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable( - global_type, /*isConstant=*/should_emit_initializer, - llvm::GlobalValue::ExternalLinkage, - /*Initializer=*/initializer, - llvm_ir::ConstantBufferAllocationToGlobalName(allocation), - /*TLMode=*/llvm::GlobalValue::NotThreadLocal, - /*AddressSpace=*/global_address_space, - /*isExternallyInitialized=*/false); - global_for_const->setAlignment(llvm::Align(kConstantBufferAlignBytes)); - ir_emitter_context_->llvm_module()->getGlobalList().push_back( - global_for_const); - } - - return Status::OK(); -} - // Emits code for slices based on the below structure. An if statement with // a guarding condition is generated for each ROOT slice. // diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index b9146dd8fae..5cc5e206167 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -58,7 +58,6 @@ struct MlirBufferSlice : public BufferSlice { struct MlirEmitterInput { mlir::Operation* op; - absl::string_view name; Thunk::ThunkInfo thunk_info; MlirBufferSlice extra_slice; }; @@ -161,7 +160,7 @@ class IrEmitterUnnested : public IrEmitter, Status HandleScatter(HloInstruction* scatter) override; Status HandleSelect(HloInstruction* select) override; Status HandleSort(HloInstruction* sort) override; - Status EmitMlirSort(MlirEmitterInput input); + Status EmitSortFromMlir(MlirEmitterInput mlir_input); Status HandleTriangularSolve(HloInstruction* hlo) override; Status HandleTupleSelect(HloInstruction* tuple_select) override; Status HandleAllReduce(HloInstruction* crs) override; @@ -178,10 +177,7 @@ class IrEmitterUnnested : public IrEmitter, // `unroll_factor` is greater than one. Status EmitTargetElementLoopInThunk( const HloInstruction& hlo, const llvm_ir::ElementGenerator& body_emitter, - KernelThunk* thunk, int unroll_factor); - - // Emits LLVM global variables corresponding to constant instructions. - Status EmitConstantGlobals(); + KernelThunk* thunk, int unroll_factor, bool few_waves = false); Status Postprocess(HloInstruction* hlo) override; @@ -372,6 +368,16 @@ class IrEmitterUnnested : public IrEmitter, // } // ``` // + // Moreover, a heuristic is implemented to divide the reduce instructions + // into groups for parallelization (see `DivideOutputInstructionsIntoGroups` + // for details about the heuristic.) Reduce instructions in the same group + // will run sequentially while different groups will run in parallel. + // + // we use raw block_id_y to select the reduce groups for execution without + // complicating the index calculation in the code generation of the reduce + // instructions. In other words, a block_id_y is assigned to a group and so + // different groups can be run in parallel. + // // output_instructions: Output instructions in the computation: instruction // itself if it's not a fusion, fusion root if fusion is not multi-output, and // elements of the fusion multi-output tuple otherwise. @@ -404,11 +410,10 @@ class IrEmitterUnnested : public IrEmitter, // the process. `scatter` may be fused, scatter indices are taken from // `scatter_indices_gen`, updates from`updates_gen`. The output buffer is // expected to have the operand values in it already. If unique_indices - // is false, we will use an atomic update. Using false for unique_indices - // is safe only when it is guaranteed that there are no duplicate - // indices. - // When using unique_indices=true, it is the caller's responsibility to - // ensure there is no overlap. + // is false, we will use an atomic update. Using true for unique_indices + // behaves properly only when it is guaranteed that the indices to be + // updated do not overlap. The caller is responsible for ensuring this is + // the case. Status EmitScatter(Thunk* thunk, HloInstruction* scatter, const llvm_ir::ElementGenerator& scatter_indices_gen, const llvm_ir::ElementGenerator& updates_gen); @@ -519,6 +524,12 @@ class IrEmitterUnnested : public IrEmitter, absl::Span reducers, const TilingKernelInfo& tiling_kernel_info); + // Emits code for reductions in the output_instructions. + void EmitIRForReduction(HloInstruction* unnested_hlo, + absl::Span output_instructions, + ReductionCodegenInfo* reduction_info, + const Shape& input_shape); + // For each reducer, emits the shuffle-down loop to accumulate the partial // result to the global result. void EmitFullWarpShuffleDownLoopForAllReduces( diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc index 19fef37db7e..6c138258aa0 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc @@ -115,9 +115,8 @@ Status KernelThunk::ExecuteOnStream(const ExecuteParams& params) { auto op_profiler = params.profiler->MakeScopedInstructionProfiler(profile_index()); - return ExecuteKernelOnStream(*kernel, buffer_args, - launch_dimensions.threads_per_block(), - launch_dimensions.block_count(), params.stream); + return ExecuteKernelOnStream(*kernel, buffer_args, launch_dimensions, + params.stream); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/launch_dimensions.cc b/tensorflow/compiler/xla/service/gpu/launch_dimensions.cc index 3668a521ec7..5dbbb2d65da 100644 --- a/tensorflow/compiler/xla/service/gpu/launch_dimensions.cc +++ b/tensorflow/compiler/xla/service/gpu/launch_dimensions.cc @@ -26,8 +26,11 @@ namespace gpu { std::ostream& operator<<(std::ostream& out, const LaunchDimensions& launch_dims) { - out << absl::StrFormat("[block: %d, thread: %d]", launch_dims.block_count(), - launch_dims.threads_per_block()); + LaunchDimensions::Dim3D block_counts = launch_dims.block_counts(); + LaunchDimensions::Dim3D thread_counts = launch_dims.thread_counts_per_block(); + out << absl::StrFormat("[block: {%d, %d, %d}, thread: {%d, %d, %d}]", + block_counts.x, block_counts.y, block_counts.z, + thread_counts.x, thread_counts.y, thread_counts.z); return out; } @@ -53,7 +56,7 @@ static int64 ThreadsPerBlockLimit(GpuDeviceInfo gpu_device_info) { // Calculates the launch dimensions used to invoke `hlo`. LaunchDimensions CalculateLaunchDimensions(const Shape& shape, GpuDeviceInfo gpu_device_info, - int unroll_factor) { + int unroll_factor, bool few_waves) { int64 num_elements = ShapeUtil::ElementsIn(shape); if (num_elements <= 1) { return LaunchDimensions(); @@ -87,6 +90,11 @@ LaunchDimensions CalculateLaunchDimensions(const Shape& shape, } int64 block_count = CeilOfRatio(num_elements, threads_per_block); + if (few_waves) { + threads_per_block = std::min(threads_per_block, int64{128}); + block_count = gpu_device_info.core_count * + (gpu_device_info.threads_per_core_limit / threads_per_block); + } VLOG(2) << absl::StrFormat( "Initialized the block count to ceil(# of elements / threads per " "block) = ceil(%d/%d) = %d", diff --git a/tensorflow/compiler/xla/service/gpu/launch_dimensions.h b/tensorflow/compiler/xla/service/gpu/launch_dimensions.h index 1a5a9d618e4..1472141a80e 100644 --- a/tensorflow/compiler/xla/service/gpu/launch_dimensions.h +++ b/tensorflow/compiler/xla/service/gpu/launch_dimensions.h @@ -29,24 +29,37 @@ namespace gpu { // number of threads per block. class LaunchDimensions { public: + struct Dim3D { + int64 x, y, z; + }; + // The default constructor creates a launch dimension that indicate // single-threaded execution. - LaunchDimensions() : block_count_(1), threads_per_block_(1) {} + LaunchDimensions() + : block_counts_({1, 1, 1}), thread_counts_per_block_({1, 1, 1}) {} - LaunchDimensions(int64 block_count, int64 threads_per_block) - : block_count_(block_count), threads_per_block_(threads_per_block) {} + LaunchDimensions(int64 block_x_count, int64 thread_x_count_per_block) + : block_counts_({block_x_count, 1, 1}), + thread_counts_per_block_({thread_x_count_per_block, 1, 1}) {} - bool IsSinglethreaded() const { - return block_count_ == 1 && threads_per_block_ == 1; + LaunchDimensions(const Dim3D& block_counts, + const Dim3D& thread_counts_per_block) + : block_counts_(block_counts), + thread_counts_per_block_(thread_counts_per_block) {} + + Dim3D block_counts() const { return block_counts_; } + + Dim3D thread_counts_per_block() const { return thread_counts_per_block_; } + + int64 launch_bound() const { + return block_counts_.x * thread_counts_per_block_.x * block_counts_.y * + thread_counts_per_block_.y * block_counts_.z * + thread_counts_per_block_.z; } - int64 block_count() const { return block_count_; } - int64 threads_per_block() const { return threads_per_block_; } - int64 launch_bound() const { return block_count() * threads_per_block(); } - private: - int64 block_count_; - int64 threads_per_block_; + Dim3D block_counts_; + Dim3D thread_counts_per_block_; }; std::ostream& operator<<(std::ostream& out, @@ -54,7 +67,8 @@ std::ostream& operator<<(std::ostream& out, LaunchDimensions CalculateLaunchDimensions(const Shape& shape, GpuDeviceInfo gpu_device_info, - int unroll_factor = 1); + int unroll_factor = 1, + bool few_waves = false); } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD index c3ef02a04f2..eb6291172fe 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD @@ -1,3 +1,4 @@ +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test") package( diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc index 04af67a70b9..36b676565b5 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc @@ -149,7 +149,7 @@ std::unique_ptr GetTargetMachine( } llvm::TargetOptions target_options = - llvm::codegen::InitTargetOptionsFromCodeGenFlags(); + llvm::codegen::InitTargetOptionsFromCodeGenFlags(llvm::Triple()); // Set the verbose assembly options. target_options.MCOptions.AsmVerbose = false; diff --git a/tensorflow/compiler/xla/service/gpu/memset_thunk.h b/tensorflow/compiler/xla/service/gpu/memset_thunk.h index 8a1890a0769..fb18b7041b7 100644 --- a/tensorflow/compiler/xla/service/gpu/memset_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/memset_thunk.h @@ -55,7 +55,7 @@ class Memset32BitValueThunk : public Thunk { Status ExecuteOnStream(const ExecuteParams& params) override; private: - uint32 value_; + const uint32 value_; const BufferAllocation::Slice dest_; }; diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc index fcbd9e760c6..fa73ac261f8 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_reachability.h" +#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/platform/types.h" @@ -104,6 +105,13 @@ HloInstruction* SelectPreferredFusionCandidate( std::vector GetProducerConsumerMultiOutputFusionCandidates( const HloInstruction* producer, const HloReachabilityMap& reachability) { std::vector fusion_candidates; + // If there is only one user, and it is not a multi-output fusion node, this + // fusion possibility was already considered and rejected by the FusionMerger + // pass. No need to try again! + if (producer->user_count() == 1 && + !producer->users()[0]->IsMultiOutputFusion()) { + return fusion_candidates; + } for (HloInstruction* consumer : producer->users()) { VLOG(3) << "Looking at producer " << producer->name() << " and its consumer " << consumer->name(); @@ -141,6 +149,16 @@ std::vector GetProducerConsumerMultiOutputFusionCandidates( << " would be too large of a fusion."; continue; } + // Make sure the emitter can codegen the fusion op efficiently. We currently + // can have exponential time/memory requirements for emitting certain fusion + // ops, in which case we don't want to fuse. + // TODO(b/119692968): Remove this once fixed in the emitter. + if (FusedIrEmitter::IsFusedIrEmitterInefficient(consumer, producer)) { + VLOG(3) << "Fusion of " << producer->name() << " into " + << consumer->name() + << " would result in overly large code duplication."; + continue; + } fusion_candidates.push_back(consumer); } return fusion_candidates; diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc index 4d269665b42..6cb66290a9a 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc @@ -798,6 +798,86 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionDUS) { // Check that we don't fuse too many reductions together. TEST_F(MultiOutputFusionTest, SharedMemoryBudget) { auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"( + fused_computation0 { + p0 = f32[64,64] parameter(0) + p1 = f32[64,64] parameter(1) + p2 = f32[] parameter(2) + add = f32[64,64] add(p0, p1) + ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0}, + to_apply=scalar_add_computation + } + fused_computation1 { + p0 = f32[64,64] parameter(0) + p1 = f32[64,64] parameter(1) + p2 = f32[] parameter(2) + add = f32[64,64] add(p0, p1) + ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0}, + to_apply=scalar_add_computation + } + fused_computation2 { + p0 = f32[64,64] parameter(0) + p1 = f32[64,64] parameter(1) + p2 = f32[] parameter(2) + add = f32[64,64] add(p0, p1) + ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0}, + to_apply=scalar_add_computation + } + fused_computation3 { + p0 = f32[64,64] parameter(0) + p1 = f32[64,64] parameter(1) + p2 = f32[] parameter(2) + add = f32[64,64] add(p0, p1) + ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0}, + to_apply=scalar_add_computation + } + fused_computation4 { + p0 = f32[64,64] parameter(0) + p1 = f32[64,64] parameter(1) + p2 = f32[] parameter(2) + add = f32[64,64] add(p0, p1) + ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0}, + to_apply=scalar_add_computation + } + fused_computation5 { + p0 = f32[64,64] parameter(0) + p1 = f32[64,64] parameter(1) + p2 = f32[] parameter(2) + add = f32[64,64] add(p0, p1) + ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0}, + to_apply=scalar_add_computation + } + fused_computation6 { + p0 = f32[64,64] parameter(0) + p1 = f32[64,64] parameter(1) + p2 = f32[] parameter(2) + add = f32[64,64] add(p0, p1) + ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0}, + to_apply=scalar_add_computation + } + fused_computation7 { + p0 = f32[64,64] parameter(0) + p1 = f32[64,64] parameter(1) + p2 = f32[] parameter(2) + add = f32[64,64] add(p0, p1) + ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0}, + to_apply=scalar_add_computation + } + fused_computation8 { + p0 = f32[64,64] parameter(0) + p1 = f32[64,64] parameter(1) + p2 = f32[] parameter(2) + add = f32[64,64] add(p0, p1) + ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0}, + to_apply=scalar_add_computation + } + fused_computation9 { + p0 = f32[64,64] parameter(0) + p1 = f32[64,64] parameter(1) + p2 = f32[] parameter(2) + add = f32[64,64] add(p0, p1) + ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0}, + to_apply=scalar_add_computation + } ENTRY computation { zero = f32[] constant(0) param0 = f32[64,64] parameter(0) @@ -810,36 +890,16 @@ TEST_F(MultiOutputFusionTest, SharedMemoryBudget) { param7 = f32[64,64] parameter(7) param8 = f32[64,64] parameter(8) param9 = f32[64,64] parameter(9) - add0 = f32[64,64] add(param0, param1) - add1 = f32[64,64] add(param1, param2) - add2 = f32[64,64] add(param2, param3) - add3 = f32[64,64] add(param3, param4) - add4 = f32[64,64] add(param4, param5) - add5 = f32[64,64] add(param5, param6) - add6 = f32[64,64] add(param6, param7) - add7 = f32[64,64] add(param7, param8) - add8 = f32[64,64] add(param8, param9) - add9 = f32[64,64] add(param9, param0) - out0 = f32[64] reduce(f32[64,64] add0, f32[] zero), dimensions={0}, - to_apply=scalar_add_computation - out1 = f32[64] reduce(f32[64,64] add1, f32[] zero), dimensions={0}, - to_apply=scalar_add_computation - out2 = f32[64] reduce(f32[64,64] add2, f32[] zero), dimensions={0}, - to_apply=scalar_add_computation - out3 = f32[64] reduce(f32[64,64] add3, f32[] zero), dimensions={0}, - to_apply=scalar_add_computation - out4 = f32[64] reduce(f32[64,64] add4, f32[] zero), dimensions={0}, - to_apply=scalar_add_computation - out5 = f32[64] reduce(f32[64,64] add5, f32[] zero), dimensions={0}, - to_apply=scalar_add_computation - out6 = f32[64] reduce(f32[64,64] add6, f32[] zero), dimensions={0}, - to_apply=scalar_add_computation - out7 = f32[64] reduce(f32[64,64] add7, f32[] zero), dimensions={0}, - to_apply=scalar_add_computation - out8 = f32[64] reduce(f32[64,64] add8, f32[] zero), dimensions={0}, - to_apply=scalar_add_computation - out9 = f32[64] reduce(f32[64,64] add9, f32[] zero), dimensions={0}, - to_apply=scalar_add_computation + out0 = f32[64] fusion(param0, param1, zero), kind=kInput, calls=fused_computation0 + out1 = f32[64] fusion(param1, param2, zero), kind=kInput, calls=fused_computation1 + out2 = f32[64] fusion(param2, param3, zero), kind=kInput, calls=fused_computation2 + out3 = f32[64] fusion(param3, param4, zero), kind=kInput, calls=fused_computation3 + out4 = f32[64] fusion(param4, param5, zero), kind=kInput, calls=fused_computation4 + out5 = f32[64] fusion(param5, param6, zero), kind=kInput, calls=fused_computation5 + out6 = f32[64] fusion(param6, param7, zero), kind=kInput, calls=fused_computation6 + out7 = f32[64] fusion(param7, param8, zero), kind=kInput, calls=fused_computation7 + out8 = f32[64] fusion(param8, param9, zero), kind=kInput, calls=fused_computation8 + out9 = f32[64] fusion(param9, param0, zero), kind=kInput, calls=fused_computation9 ROOT out = (f32[64], f32[64], f32[64], f32[64], f32[64], f32[64], f32[64], f32[64], f32[64]) tuple(f32[64] out0, f32[64] out1, f32[64] out2, f32[64] out3, f32[64] out4, f32[64] out5, f32[64] out6, f32[64] out7, f32[64] out8, f32[64] out9) } )")) @@ -849,5 +909,165 @@ TEST_F(MultiOutputFusionTest, SharedMemoryBudget) { EXPECT_EQ(2, CountMultiOutputFusions(module.get())); } +TEST_F(MultiOutputFusionTest, NoFusionToAvoidCodeDuplication) { + auto module = ParseAndReturnVerifiedModule(R"( +HloModule module + +and.reduce_sub_computation { + x = pred[] parameter(0) + y = pred[] parameter(1) + ROOT and = pred[] and(x, y) +} + +fused_computation.1 { + param_4.658 = f32[2,20,256]{2,0,1} parameter(4) + slice.1385 = f32[2,1,256]{2,0,1} slice(param_4.658), slice={[0:2], [11:12], [0:256]} + constant.6847 = s32[] constant(0) + broadcast.4823 = s32[3]{0} broadcast(constant.6847), dimensions={} + param_9.415 = s32[3]{0} parameter(9) + compare.700 = pred[3]{0} compare(broadcast.4823, param_9.415), direction=LE + constant.6846 = pred[] constant(true) + reduce.221 = pred[] reduce(compare.700, constant.6846), dimensions={0}, to_apply=and.reduce_sub_computation + broadcast.2933 = pred[2,1,256]{2,0,1} broadcast(reduce.221), dimensions={} + param_5.528 = f32[2,512]{1,0} parameter(5) + slice.1384 = f32[2,256]{1,0} slice(param_5.528), slice={[0:2], [0:256]} + bitcast.341 = f32[2,1,256]{2,0,1} bitcast(slice.1384) + constant.5418 = f32[] constant(0) + broadcast.3227 = f32[2,1,256]{2,0,1} broadcast(constant.5418), dimensions={} + select.173 = f32[2,1,256]{2,0,1} select(broadcast.2933, bitcast.341, broadcast.3227) + add.573 = f32[2,1,256]{2,0,1} add(slice.1385, select.173) + param_0.299 = s32[] parameter(0) + constant.5157 = s32[] constant(11) + dynamic-update-slice.189 = f32[2,20,256]{2,0,1} dynamic-update-slice(param_4.658, add.573, param_0.299, constant.5157, param_0.299) + slice.1383 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.189), slice={[0:2], [10:11], [0:256]} + constant.6800 = s32[] constant(0) + broadcast.4803 = s32[3]{0} broadcast(constant.6800), dimensions={} + param_8.484 = s32[3]{0} parameter(8) + compare.681 = pred[3]{0} compare(broadcast.4803, param_8.484), direction=LE + constant.6798 = pred[] constant(true) + reduce.203 = pred[] reduce(compare.681, constant.6798), dimensions={0}, to_apply=and.reduce_sub_computation + broadcast.2932 = pred[2,1,256]{2,0,1} broadcast(reduce.203), dimensions={} + param_3.1169 = f32[2,512]{1,0} parameter(3) + slice.1382 = f32[2,256]{1,0} slice(param_3.1169), slice={[0:2], [0:256]} + bitcast.340 = f32[2,1,256]{2,0,1} bitcast(slice.1382) + select.172 = f32[2,1,256]{2,0,1} select(broadcast.2932, bitcast.340, broadcast.3227) + add.572 = f32[2,1,256]{2,0,1} add(slice.1383, select.172) + constant.5154 = s32[] constant(10) + dynamic-update-slice.188 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.189, add.572, param_0.299, constant.5154, param_0.299) + slice.1381 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.188), slice={[0:2], [9:10], [0:256]} + constant.6794 = s32[] constant(0) + broadcast.4801 = s32[3]{0} broadcast(constant.6794), dimensions={} + param_7.478 = s32[3]{0} parameter(7) + compare.679 = pred[3]{0} compare(broadcast.4801, param_7.478), direction=LE + constant.6793 = pred[] constant(true) + reduce.201 = pred[] reduce(compare.679, constant.6793), dimensions={0}, to_apply=and.reduce_sub_computation + broadcast.2930 = pred[2,1,256]{2,0,1} broadcast(reduce.201), dimensions={} + param_2.1685 = f32[2,512]{1,0} parameter(2) + slice.1380 = f32[2,256]{1,0} slice(param_2.1685), slice={[0:2], [0:256]} + bitcast.339 = f32[2,1,256]{2,0,1} bitcast(slice.1380) + select.171 = f32[2,1,256]{2,0,1} select(broadcast.2930, bitcast.339, broadcast.3227) + add.571 = f32[2,1,256]{2,0,1} add(slice.1381, select.171) + constant.5153 = s32[] constant(9) + dynamic-update-slice.187 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.188, add.571, param_0.299, constant.5153, param_0.299) + slice.1379 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.187), slice={[0:2], [8:9], [0:256]} + constant.6788 = s32[] constant(0) + broadcast.4799 = s32[3]{0} broadcast(constant.6788), dimensions={} + param_6.495 = s32[3]{0} parameter(6) + compare.677 = pred[3]{0} compare(broadcast.4799, param_6.495), direction=LE + constant.6786 = pred[] constant(true) + reduce.199 = pred[] reduce(compare.677, constant.6786), dimensions={0}, to_apply=and.reduce_sub_computation + broadcast.2929 = pred[2,1,256]{2,0,1} broadcast(reduce.199), dimensions={} + param_1.1408 = f32[2,512]{1,0} parameter(1) + slice.1378 = f32[2,256]{1,0} slice(param_1.1408), slice={[0:2], [0:256]} + bitcast.338 = f32[2,1,256]{2,0,1} bitcast(slice.1378) + select.170 = f32[2,1,256]{2,0,1} select(broadcast.2929, bitcast.338, broadcast.3227) + add.570 = f32[2,1,256]{2,0,1} add(slice.1379, select.170) + constant.5152 = s32[] constant(8) + ROOT dynamic-update-slice.186 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.187, add.570, param_0.299, constant.5152, param_0.299) +} + +fused_computation.2 { + param_4.655 = f32[2,20,256]{2,0,1} parameter(4) + slice.1369 = f32[2,1,256]{2,0,1} slice(param_4.655), slice={[0:2], [7:8], [0:256]} + param_6.483 = pred[] parameter(6) + broadcast.2927 = pred[2,1,256]{2,0,1} broadcast(param_6.483), dimensions={} + param_5.525 = f32[2,512]{1,0} parameter(5) + slice.1368 = f32[2,256]{1,0} slice(param_5.525), slice={[0:2], [0:256]} + bitcast.333 = f32[2,1,256]{2,0,1} bitcast(slice.1368) + constant.5415 = f32[] constant(0) + broadcast.3225 = f32[2,1,256]{2,0,1} broadcast(constant.5415), dimensions={} + select.161 = f32[2,1,256]{2,0,1} select(broadcast.2927, bitcast.333, broadcast.3225) + add.549 = f32[2,1,256]{2,0,1} add(slice.1369, select.161) + param_0.265 = s32[] parameter(0) + constant.5151 = s32[] constant(7) + dynamic-update-slice.185 = f32[2,20,256]{2,0,1} dynamic-update-slice(param_4.655, add.549, param_0.265, constant.5151, param_0.265) + slice.1367 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.185), slice={[0:2], [6:7], [0:256]} + constant.6782 = s32[] constant(0) + broadcast.4797 = s32[3]{0} broadcast(constant.6782), dimensions={} + param_9.391 = s32[3]{0} parameter(9) + compare.675 = pred[3]{0} compare(broadcast.4797, param_9.391), direction=LE + constant.6781 = pred[] constant(true) + reduce.197 = pred[] reduce(compare.675, constant.6781), dimensions={0}, to_apply=and.reduce_sub_computation + broadcast.2926 = pred[2,1,256]{2,0,1} broadcast(reduce.197), dimensions={} + param_3.1167 = f32[2,512]{1,0} parameter(3) + slice.1366 = f32[2,256]{1,0} slice(param_3.1167), slice={[0:2], [0:256]} + bitcast.332 = f32[2,1,256]{2,0,1} bitcast(slice.1366) + select.160 = f32[2,1,256]{2,0,1} select(broadcast.2926, bitcast.332, broadcast.3225) + add.548 = f32[2,1,256]{2,0,1} add(slice.1367, select.160) + constant.5150 = s32[] constant(6) + dynamic-update-slice.184 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.185, add.548, param_0.265, constant.5150, param_0.265) + slice.1365 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.184), slice={[0:2], [5:6], [0:256]} + constant.6776 = s32[] constant(0) + broadcast.4794 = s32[3]{0} broadcast(constant.6776), dimensions={} + param_8.464 = s32[3]{0} parameter(8) + compare.673 = pred[3]{0} compare(broadcast.4794, param_8.464), direction=LE + constant.6775 = pred[] constant(true) + reduce.195 = pred[] reduce(compare.673, constant.6775), dimensions={0}, to_apply=and.reduce_sub_computation + broadcast.2925 = pred[2,1,256]{2,0,1} broadcast(reduce.195), dimensions={} + param_2.1684 = f32[2,512]{1,0} parameter(2) + slice.1364 = f32[2,256]{1,0} slice(param_2.1684), slice={[0:2], [0:256]} + bitcast.331 = f32[2,1,256]{2,0,1} bitcast(slice.1364) + select.159 = f32[2,1,256]{2,0,1} select(broadcast.2925, bitcast.331, broadcast.3225) + add.547 = f32[2,1,256]{2,0,1} add(slice.1365, select.159) + constant.5149 = s32[] constant(5) + dynamic-update-slice.183 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.184, add.547, param_0.265, constant.5149, param_0.265) + slice.1363 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.183), slice={[0:2], [4:5], [0:256]} + constant.6770 = s32[] constant(0) + broadcast.4792 = s32[3]{0} broadcast(constant.6770), dimensions={} + param_7.458 = s32[3]{0} parameter(7) + compare.671 = pred[3]{0} compare(broadcast.4792, param_7.458), direction=LE + constant.6769 = pred[] constant(true) + reduce.193 = pred[] reduce(compare.671, constant.6769), dimensions={0}, to_apply=and.reduce_sub_computation + broadcast.2924 = pred[2,1,256]{2,0,1} broadcast(reduce.193), dimensions={} + param_1.1405 = f32[2,512]{1,0} parameter(1) + slice.1362 = f32[2,256]{1,0} slice(param_1.1405), slice={[0:2], [0:256]} + bitcast.330 = f32[2,1,256]{2,0,1} bitcast(slice.1362) + select.158 = f32[2,1,256]{2,0,1} select(broadcast.2924, bitcast.330, broadcast.3225) + add.546 = f32[2,1,256]{2,0,1} add(slice.1363, select.158) + constant.5148 = s32[] constant(4) + ROOT dynamic-update-slice.182 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.183, add.546, param_0.265, constant.5148, param_0.265) +} + +ENTRY main { + param_0.0 = s32[] parameter(0) + param_1.0 = f32[2,512]{1,0} parameter(1) + param_2.0 = f32[2,512]{1,0} parameter(2) + param_3.0 = f32[2,512]{1,0} parameter(3) + param_4.0 = f32[2,20,256]{2,1,0} parameter(4) + param_5.0 = f32[2,512]{1,0} parameter(5) + param_6.0 = s32[3]{0} parameter(6) + param_7.0 = s32[3]{0} parameter(7) + param_8.0 = s32[3]{0} parameter(8) + param_9.0 = s32[3]{0} parameter(9) + fusion.1 = f32[2,20,256]{2,0,1} fusion(param_0.0, param_1.0, param_2.0, param_3.0, param_4.0, param_5.0, param_6.0, param_7.0, param_8.0, param_9.0), kind=kLoop, calls=fused_computation.1 + param_10 = pred[] parameter(10) + fusion.2 = f32[2,20,256]{2,0,1} fusion(param_0.0, param_1.0, param_2.0, param_3.0, fusion.1, param_5.0, param_10, param_7.0, param_8.0, param_9.0), kind=kLoop, calls=fused_computation.2 + ROOT root = (f32[2,20,256]{2,0,1}, f32[2,20,256]{2,0,1}) tuple(fusion.1, fusion.2) +} + )") + .ValueOrDie(); + EXPECT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc index 25ab9a7ce6e..b13f71c5a13 100644 --- a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc @@ -514,11 +514,40 @@ void RendezvousNcclAllReduce::CleanupImpl(std::shared_ptr handle, // header. In particular, this stores the thunk's cache of all NcclCliques it's // ever used. This causes those cliques to stay alive as long as the thunk // lives, which is how we avoid expensive reinitialization of NCCL cliques. -struct NcclAllReduceThunk::AuxData { +struct NcclAllReduceConfig::AuxData { tensorflow::mutex mu; absl::flat_hash_set> cliques TF_GUARDED_BY(mu); }; +NcclAllReduceConfig::NcclAllReduceConfig(NcclAllReduceConfig&&) = default; +NcclAllReduceConfig::~NcclAllReduceConfig() = default; + +NcclAllReduceConfig GetNcclAllReduceConfig(const HloInstruction* instr, + int64 replica_count) { + NcclAllReduceConfig config; + config.operand_count = instr->operands().size(); + config.operand_element_type.reserve(config.operand_count); + for (int i = 0; i < config.operand_count; i++) { + config.operand_element_type.push_back( + instr->operand(i)->shape().element_type()); + } + config.replica_count = replica_count; + config.replica_groups = instr->replica_groups(); + auto reduction_kind = MatchReductionComputation(instr->to_apply()); + CHECK(reduction_kind.has_value()); + config.reduction_kind = reduction_kind.value(); + + if (instr->channel_id().has_value()) { + config.collective_op_kind = RendezvousKey::kCrossModule; + config.op_id = instr->channel_id().value(); + } else { + config.collective_op_kind = RendezvousKey::kCrossReplica; + config.op_id = static_cast(instr->GetModule()->unique_id()); + } + config.aux_data = std::make_unique(); + return config; +} + /*static*/ bool NcclAllReduceThunk::CanImplement(const HloInstruction* crs) { auto operands_are_supported = [crs]() { return absl::c_all_of(crs->operands(), [](HloInstruction* operand) { @@ -541,14 +570,12 @@ NcclAllReduceThunk::DevicesWithOpenNcclChannels() { } NcclAllReduceThunk::NcclAllReduceThunk( - ThunkInfo thunk_info, int64 replica_count, + ThunkInfo thunk_info, NcclAllReduceConfig&& config, std::vector buffers) : Thunk(Thunk::kNcclAllReduce, thunk_info), - hlo_instruction_(thunk_info.hlo_instruction), - replica_count_(replica_count), - buffers_(std::move(buffers)), - aux_data_(absl::make_unique()) { - CHECK_EQ(hlo_instruction_->operand_count(), buffers_.size()); + config_(std::move(config)), + buffers_(std::move(buffers)) { + CHECK_EQ(config_.operand_count, buffers_.size()); } // Figures out which devices (named by their replica-ids) are participating in @@ -558,7 +585,6 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) { auto op_profiler = params.profiler->MakeScopedInstructionProfiler(profile_index()); - auto* instr = Cast(hlo_instruction_); int64 local_device_ordinal = params.stream->parent()->device_ordinal(); GlobalDeviceId global_device_id; if (params.gpu_global_device_ids) { @@ -574,10 +600,10 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) { // the same collective group as the caller. TF_ASSIGN_OR_RETURN( std::vector global_participating_replicas, - GetParticipatingReplicas(global_device_id, instr->replica_groups(), - replica_count_, *params.device_assn)); + GetParticipatingReplicas(global_device_id, config_.replica_groups, + config_.replica_count, *params.device_assn)); if (IsGlobalNcclConfig() && - global_participating_replicas.size() != replica_count_) { + global_participating_replicas.size() != config_.replica_count) { return InvalidArgument( "Partial replica groups are not allowed when using NCCL_COMM_ID " "environment configuration."); @@ -605,10 +631,10 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) { } absl::c_sort(global_devices); - // Find or create the rendezvous for this collective operation. - RendezvousKey rendezvous_key = RendezvousKey::FromInstruction( - params.run_id, global_devices, local_devices.size(), hlo_instruction_); - + // Create the rendezvous for this collective operation. + RendezvousKey rendezvous_key(params.run_id, global_devices, + local_devices.size(), config_.collective_op_kind, + config_.op_id); if (VLOG_IS_ON(2)) { std::vector local_participants; for (const auto& entry : local_devices) { @@ -633,15 +659,12 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) { params.buffer_allocations->GetDeviceAddress(buffer.source_buffer); pbuffer.destination_data = params.buffer_allocations->GetDeviceAddress(buffer.destination_buffer); - pbuffer.primitive_type = - hlo_instruction_->operand(i)->shape().element_type(); + pbuffer.primitive_type = config_.operand_element_type[i]; participant.buffers.push_back(pbuffer); } participant.local_devices = std::move(local_devices); participant.nccl_unique_id_callback = params.nccl_unique_id_callback; - auto reduction_kind = MatchReductionComputation(hlo_instruction_->to_apply()); - CHECK(reduction_kind.has_value()); - participant.reduction_kind = *reduction_kind; + participant.reduction_kind = config_.reduction_kind; auto rendezvous_factory = [](const RendezvousKey& k) { return absl::make_unique(k); @@ -658,13 +681,11 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) { // Keep the clique we used alive for as long as this Thunk lives. Creating // new NCCL cliques is expensive, and this is how we avoid thrashing them. { - tensorflow::mutex_lock lock(aux_data_->mu); - aux_data_->cliques.insert(std::move(clique)); + tensorflow::mutex_lock lock(config_.aux_data->mu); + config_.aux_data->cliques.insert(std::move(clique)); } return Status::OK(); } -NcclAllReduceThunk::~NcclAllReduceThunk() {} - } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h index cbd4fd3aa51..20e4adef7b1 100644 --- a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h @@ -18,11 +18,13 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/collective_ops_utils.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/types.h" @@ -30,6 +32,30 @@ limitations under the License. namespace xla { namespace gpu { +struct NcclAllReduceConfig { + int64 operand_count; + std::vector operand_element_type; + int64 replica_count; + std::vector replica_groups; + ReductionKind reduction_kind; + RendezvousKey::CollectiveOpKind collective_op_kind; + int64 op_id; + + NcclAllReduceConfig() = default; + NcclAllReduceConfig(NcclAllReduceConfig &&); + ~NcclAllReduceConfig(); + + // Extra data stored in NcclAllReduceThunk whose types we don't want exposed + // in the header file. (This is mainly because the implementation of + // NcclAllReduceThunk is different depending on whether CUDA is enabled in the + // build, and we don't want to expose *that* mess in the header.) + struct AuxData; + std::unique_ptr aux_data; +}; + +NcclAllReduceConfig GetNcclAllReduceConfig(const HloInstruction *instr, + int64 replica_count); + // Thunk that performs a NCCL-based All-Reduce among CUDA GPU-based replicas. class NcclAllReduceThunk : public Thunk { public: @@ -56,9 +82,8 @@ class NcclAllReduceThunk : public Thunk { BufferAllocation::Slice source_buffer; BufferAllocation::Slice destination_buffer; }; - NcclAllReduceThunk(ThunkInfo thunk_info, int64 replica_count, + NcclAllReduceThunk(ThunkInfo thunk_info, NcclAllReduceConfig &&config, std::vector buffers); - ~NcclAllReduceThunk() override; Status ExecuteOnStream(const ExecuteParams& params) override; @@ -67,16 +92,8 @@ class NcclAllReduceThunk : public Thunk { static bool CanImplement(const HloInstruction* crs); private: - // Extra data stored in NcclAllReduceThunk whose types we don't want exposed - // in the header file. (This is mainly because the implementation of - // NcclAllReduceThunk is different depending on whether CUDA is enabled in the - // build, and we don't want to expose *that* mess in the header.) - struct AuxData; - - const HloInstruction* hlo_instruction_; - const int64 replica_count_; + const NcclAllReduceConfig config_; const std::vector buffers_; - std::unique_ptr aux_data_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index eefa4661d37..77c54e48a70 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -198,6 +198,42 @@ absl::optional CanShareBufferHint(const HloInstruction* user, return absl::nullopt; } +// Try to load ptx from files defined in the FLAGS. If successful, return true. +bool MaybeLoadPtxFromFile(const HloModule* module, std::string* ptx) { + // If the xla_gpu_ptx_file options is set, be explicit when a file is used + // and warn when a file is not used to ease catching typo in filename. + std::string prefix = xla::FilenameFor(*module, "", *ptx); + std::string matched_filename; + for (const string& full_filename : + module->config().debug_options().xla_gpu_ptx_file()) { + // To ease comparing many PTX versions, accept different suffixes then + // the original filename. + auto filename = tensorflow::io::Basename(full_filename); + if (absl::StartsWith(filename, prefix)) { + matched_filename = full_filename; + VLOG(0) << "RunBackend() - Will load PTX from file: " << full_filename; + break; + } + } + if (module->config().debug_options().xla_gpu_ptx_file().size() > 0 && + matched_filename.empty()) { + VLOG(0) << "RunBackend() - For module with prefix '" << prefix + << "', we did not found a PTX file to load."; + } + + if (!matched_filename.empty()) { + std::ifstream ifs(matched_filename, std::ifstream::in); + *ptx = std::string(std::istreambuf_iterator(ifs), + std::istreambuf_iterator()); + CHECK(!ptx->empty()) << "Empty or non existing PTX file: " + << matched_filename; + return true; + } + return false; +} + +} // namespace + // Prints a warning if the ptx->sass JIT in the driver has known bugs. // // Using such a driver only a problem if we fail to use ptxas to compile our ptx @@ -238,42 +274,6 @@ void WarnIfBadDriverJITVersion() { }); } -// Try to load ptx from files defined in the FLAGS. If successful, return true. -bool MaybeLoadPtxFromFile(const HloModule* module, std::string* ptx) { - // If the xla_gpu_ptx_file options is set, be explicit when a file is used - // and warn when a file is not used to ease catching typo in filename. - std::string prefix = xla::FilenameFor(*module, "", *ptx); - std::string matched_filename; - for (const string& full_filename : - module->config().debug_options().xla_gpu_ptx_file()) { - // To ease comparing many PTX versions, accept different suffixes then - // the original filename. - auto filename = tensorflow::io::Basename(full_filename); - if (absl::StartsWith(filename, prefix)) { - matched_filename = full_filename; - VLOG(0) << "RunBackend() - Will load PTX from file: " << full_filename; - break; - } - } - if (module->config().debug_options().xla_gpu_ptx_file().size() > 0 && - matched_filename.empty()) { - VLOG(0) << "RunBackend() - For module with prefix '" << prefix - << "', we did not found a PTX file to load."; - } - - if (!matched_filename.empty()) { - std::ifstream ifs(matched_filename, std::ifstream::in); - *ptx = std::string(std::istreambuf_iterator(ifs), - std::istreambuf_iterator()); - CHECK(!ptx->empty()) << "Empty or non existing PTX file: " - << matched_filename; - return true; - } - return false; -} - -} // namespace - NVPTXCompiler::NVPTXCompiler() : GpuCompiler(stream_executor::cuda::kCudaPlatformId, nvptx::kTargetTriple, nvptx::kDataLayout) {} @@ -415,7 +415,9 @@ std::vector NVPTXCompiler::CompileGpuAsmOrGetCachedResult( "using $PATH.", hlo_module_config); } - } else { + } else if (maybe_cubin.status().code() != + tensorflow::error::Code::UNIMPLEMENTED) { + // If unimplemented is returned, we fallback to the driver. LOG(FATAL) << "ptxas returned an error during compilation of ptx " "to sass: '" << maybe_cubin.status() << "' " diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h index e69be947522..3e19b35af19 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h @@ -30,6 +30,8 @@ limitations under the License. namespace xla { namespace gpu { +void WarnIfBadDriverJITVersion(); + // NVPTXCompiler generates efficient GPU executables for NVPTX target. class NVPTXCompiler : public GpuCompiler { public: diff --git a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc index 83066a4addf..6eef1b9f0b9 100644 --- a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc @@ -14,26 +14,34 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/gpu/outfeed_thunk.h" + #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/outfeed_manager.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { namespace gpu { -OutfeedThunk::OutfeedThunk(ThunkInfo thunk_info, +OutfeedConfig GetOutfeedConfig(const HloInstruction* instr) { + OutfeedConfig config; + config.input_shape = instr->operand(0)->shape(); + return config; +} + +OutfeedThunk::OutfeedThunk(ThunkInfo thunk_info, OutfeedConfig&& config, ShapeTree outfeed_slices) : Thunk(Kind::kOutfeed, thunk_info), - hlo_instruction_(thunk_info.hlo_instruction), + config_(std::move(config)), outfeed_slices_(std::move(outfeed_slices)) {} Status OutfeedThunk::ExecuteOnStream(const ExecuteParams& params) { auto& stream = *params.stream; auto& buffer_allocations = *params.buffer_allocations; - VLOG(2) << "Outfeeding from GPU: " << hlo_instruction_->ToString(); + VLOG(2) << "Outfeeding from GPU"; auto op_profiler = params.profiler->MakeScopedInstructionProfiler(profile_index()); @@ -42,13 +50,12 @@ Status OutfeedThunk::ExecuteOnStream(const ExecuteParams& params) { outfeed_manager->BlockingGetNextDestination(); // Nothing to be done for empty tuples. - if (ShapeUtil::IsEmptyTuple(hlo_instruction_->operand(0)->shape())) { + if (ShapeUtil::IsEmptyTuple(config_.input_shape)) { return Status::OK(); } - CHECK(ShapeUtil::Compatible(hlo_instruction_->operand(0)->shape(), - outfeed_buffers->shape())) + CHECK(ShapeUtil::Compatible(config_.input_shape, outfeed_buffers->shape())) << "XLA program outfeed request of shape " - << hlo_instruction_->operand(0)->shape().ToString() + << config_.input_shape.ToString() << " did not match the runtime's outfeed buffer of shape " << outfeed_buffers->shape().ToString(); diff --git a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.h b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.h index 9174e605783..60c64858ee7 100644 --- a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.h @@ -25,6 +25,12 @@ limitations under the License. namespace xla { namespace gpu { +struct OutfeedConfig { + Shape input_shape; +}; + +OutfeedConfig GetOutfeedConfig(const HloInstruction* instr); + // A thunk that outfeeds data. Data must be already resident on the host. This // thunk performs a host to device copy from the buffer allocated for the // outfeed op to the host location. @@ -32,7 +38,7 @@ class OutfeedThunk : public Thunk { public: // Constructs a OutfeedThunk that copies data to the host-side // outfeed queue from the buffers in the given shape tree. - OutfeedThunk(ThunkInfo thunk_info, + OutfeedThunk(ThunkInfo thunk_info, OutfeedConfig&& config, ShapeTree outfeed_slices); OutfeedThunk(const OutfeedThunk&) = delete; @@ -41,7 +47,7 @@ class OutfeedThunk : public Thunk { Status ExecuteOnStream(const ExecuteParams& params) override; private: - const HloInstruction* hlo_instruction_; + const OutfeedConfig config_; const ShapeTree outfeed_slices_; }; diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc index f9937ba77de..45c4f25d8e8 100644 --- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc @@ -23,6 +23,7 @@ limitations under the License. #include "llvm/IR/Intrinsics.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/gpu/target_util.h" +#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -58,7 +59,8 @@ ParallelLoopEmitter::ParallelLoopEmitter( std::vector ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, - llvm::Type* index_type) { + llvm::Type* index_type, + llvm::Value* base_index) { // Emit the following code in LLVM IR: // linear_index = blockIdx.x * blockDim.x + threadIdx.x; // if (linear_index < num_elements) { @@ -75,7 +77,7 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, std::vector array_indices; llvm::Value* block_id = EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockIdx, {}, {}, b_); - llvm_ir::AddRangeMetadata(0, launch_dimensions_.block_count(), + llvm_ir::AddRangeMetadata(0, launch_dimensions_.block_counts().x, static_cast(block_id)); block_id = b_->CreateZExtOrTrunc(block_id, index_type, "block_id"); @@ -85,16 +87,17 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, // %ntid.x is currently specified as 1024. llvm::Value* thread_id = EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdx, {}, {}, b_); - llvm_ir::AddRangeMetadata(0, launch_dimensions_.threads_per_block(), + llvm_ir::AddRangeMetadata(0, launch_dimensions_.thread_counts_per_block().x, static_cast(thread_id)); thread_id = b_->CreateZExtOrTrunc(thread_id, index_type, "thread_id"); llvm::Value* linear_index_base = b_->CreateAdd( - b_->CreateMul(block_id, - llvm::ConstantInt::get( - index_type, launch_dimensions_.threads_per_block()), - "", - /*HasNUW=*/true, /*HasNSW=*/true), + b_->CreateMul( + block_id, + llvm::ConstantInt::get( + index_type, launch_dimensions_.thread_counts_per_block().x), + "", + /*HasNUW=*/true, /*HasNSW=*/true), thread_id, "linear_index", /*HasNUW=*/true, /*HasNSW=*/true); // Add an @llvm.assume(linear_index < threads_per_block * num_blocks). @@ -109,9 +112,9 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, llvm::Intrinsic::assume, {b_->CreateICmpULT( linear_index_base, - llvm::ConstantInt::get(index_type, - launch_dimensions_.threads_per_block() * - launch_dimensions_.block_count()), + llvm::ConstantInt::get( + index_type, launch_dimensions_.thread_counts_per_block().x * + launch_dimensions_.block_counts().x), "linear_index_in_range")}, {}, b_); @@ -121,6 +124,12 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, "linear_index_base", /*HasNUW=*/true, /*HasNSW=*/true); } + if (base_index != nullptr) { + linear_index_base = + b_->CreateAdd(linear_index_base, base_index, "linear_index_plus_base", + /*HasNUW=*/true, /*HasNSW=*/true); + } + array_indices.emplace_back(linear_index_base, shape_, b_); for (int i = 1; i < unroll_factor_; ++i) { llvm::Value* linear_index = @@ -146,5 +155,43 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, return array_indices; } +Status ParallelLoopEmitter::EmitLoop(absl::string_view loop_name, + llvm::Type* index_type) { + if (index_type == nullptr) { + index_type = b_->getInt64Ty(); + } + int64 total_threads = launch_dimensions_.launch_bound(); + int64 num_elements = ShapeUtil::ElementsIn(shape_); + // If all the elements are handled by the current threads, no need + // to add a loop inside the kernel. + if (total_threads * unroll_factor_ >= num_elements) { + VLOG(1) << "ParallelLoopEmitter::EmitLoop fallback"; + return LoopEmitter::EmitLoop(loop_name, index_type); + } + + KernelSupportLibrary ksl(b_, llvm_ir::UnrollMode::kDefaultUnroll); + auto constant = [&](int64 val) { + return llvm::ConstantInt::get(index_type, val); + }; + + TF_RETURN_IF_ERROR(ksl.ForWithStatus( + "loop", constant(0), constant(num_elements), + constant(total_threads * unroll_factor_), [&](llvm::Value* base_indvar) { + for (const llvm_ir::IrArray::Index& array_index : + EmitIndexAndSetExitBasicBlock(loop_name, index_type, + base_indvar)) { + TF_RETURN_IF_ERROR(body_emitter_(array_index)); + } + return Status::OK(); + })); + + // Set the insertion point of b_ to the loop exit, so that + // code emitted for later instructions will be correctly placed. + if (exit_bb_ != nullptr) { + b_->SetInsertPoint(exit_bb_); + } + return Status::OK(); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h index 0a6b5430b23..5e142ec3832 100644 --- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h @@ -57,7 +57,11 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter { ~ParallelLoopEmitter() override = default; std::vector EmitIndexAndSetExitBasicBlock( - absl::string_view loop_name, llvm::Type* index_type) override; + absl::string_view loop_name, llvm::Type* index_type, + llvm::Value* base_index) override; + + Status EmitLoop(absl::string_view loop_name = "", + llvm::Type* index_type = nullptr); private: // The thread and block dimension to parallelize the loop on. diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc index d7468a31377..7293b1485fc 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc @@ -209,16 +209,18 @@ StatusOr> CreateKernel( Status ExecuteKernelOnStream(const se::KernelBase& kernel, absl::Span args, - int64 threads_per_block, int64 block_count, - se::Stream* stream) { + const LaunchDimensions& dims, se::Stream* stream) { static constexpr int kKernelArgsLimit = 1024; auto kernel_args = absl::make_unique>(); for (const se::DeviceMemoryBase& buf : args) { kernel_args->add_device_memory_argument(buf); } - return stream->parent()->Launch(stream, se::ThreadDim(threads_per_block), - se::BlockDim(block_count), kernel, - *kernel_args); + LaunchDimensions::Dim3D thread_counts = dims.thread_counts_per_block(); + LaunchDimensions::Dim3D block_counts = dims.block_counts(); + return stream->parent()->Launch( + stream, se::ThreadDim(thread_counts.x, thread_counts.y, thread_counts.z), + se::BlockDim(block_counts.x, block_counts.y, block_counts.z), kernel, + *kernel_args); } se::GpuAsmOpts PtxOptsFromConfig(const HloModuleConfig& hlo_module_config) { @@ -317,5 +319,35 @@ void InitializeBuffer(se::Stream* stream, PrimitiveType buffer_type, } } +StatusOr GetDNNConvKindFromCudnnConvKind( + CudnnConvKind kind) { + switch (kind) { + case CudnnConvKind::kBackwardFilter: + return se::dnn::BACKWARD_FILTER; + case CudnnConvKind::kBackwardInput: + return se::dnn::BACKWARD_DATA; + case CudnnConvKind::kForward: + return se::dnn::FORWARD; + default: + break; + } + return InternalError("Unexpected convolution kind"); +} + +StatusOr GetDNNDataTypeFromPrimitiveType( + PrimitiveType type) { + switch (type) { + case F16: + return se::dnn::ToDataType::value; + case F32: + return se::dnn::ToDataType::value; + case F64: + return se::dnn::ToDataType::value; + default: + break; + } + return InternalError("Unsupported convolution datatype"); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.h b/tensorflow/compiler/xla/service/gpu/stream_executor_util.h index 0a5e0e93a51..2b58496e05c 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_executor_util.h +++ b/tensorflow/compiler/xla/service/gpu/stream_executor_util.h @@ -19,6 +19,8 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/layout.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -71,8 +73,7 @@ StatusOr> CreateKernel( // Runs loaded kernel on the stream with the provided arguments. Status ExecuteKernelOnStream(const se::KernelBase& kernel, absl::Span args, - int64 threads_per_block, int64 block_count, - se::Stream* stream); + const LaunchDimensions& dims, se::Stream* stream); // Create GpuAsmOpts out of HloModuleConfig. se::GpuAsmOpts PtxOptsFromConfig(const HloModuleConfig& hlo_module_config); @@ -86,6 +87,10 @@ se::GpuAsmOpts PtxOptsFromConfig(const HloModuleConfig& hlo_module_config); void InitializeBuffer(se::Stream* stream, PrimitiveType buffer_type, int64* rng_state, se::DeviceMemoryBase buffer); +StatusOr GetDNNConvKindFromCudnnConvKind( + CudnnConvKind kind); +StatusOr GetDNNDataTypeFromPrimitiveType(PrimitiveType type); + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD index 809b277317f..681e025ba1f 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD @@ -4,6 +4,8 @@ # TODO(jlebar): None of these tests actually use the GPU, so they should not # need to run on machines with GPUs present. +load("//tensorflow:tensorflow.bzl", "filegroup") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test") load( @@ -219,6 +221,28 @@ tf_cc_test( ], ) +tf_cc_test( + name = "parallel_reduction_test", + srcs = [ + "parallel_reduction_test.cc", + ], + tags = tf_cuda_tests_tags() + ["no_rocm"], + deps = [ + ":gpu_codegen_test", + "//tensorflow/compiler/xla/service:gpu_plugin", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/service/gpu:gpu_executable", + "//tensorflow/compiler/xla/tests:filecheck", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:llvm_irgen_test_base", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + tf_cc_test( name = "gpu_copy_test", srcs = ["gpu_copy_test.cc"], @@ -375,6 +399,8 @@ tf_cc_test( ":gpu_codegen_test", "//tensorflow/compiler/xla/service:hlo_module_config", "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/service/gpu:gpu_fusible", + "//tensorflow/compiler/xla/service/gpu:instruction_fusion", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/core:test", "//tensorflow/core:test_main", @@ -528,8 +554,15 @@ filegroup( # Binary with only the thunks dialect registered, for testing purposes. tf_cc_binary( name = "xla-thunks-opt", + srcs = ["xla_thunks_opt.cc"], deps = [ - "//tensorflow/compiler/mlir:tf_mlir_opt_main", - "//tensorflow/compiler/xla/service/gpu:xla_thunks_dialect_registration", + "//tensorflow/compiler/mlir:init_mlir", + "//tensorflow/compiler/mlir/hlo:hlo_dialect_registration", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/xla/service/gpu:xla_thunks_ops", + "//tensorflow/core:lib", + "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", + "@llvm-project//mlir:MlirOptLib", + "@llvm-project//mlir:Shape", ], ) diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_fusion_test.cc index 674b436a8e3..811705d2b17 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_fusion_test.cc @@ -15,6 +15,8 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h" +#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" @@ -54,6 +56,37 @@ TEST_F(GpuFusionTest, FusedReshape) { )"); } +// Check that we limit the number of operands to fusions we create. +TEST_F(GpuFusionTest, FusedBiggerThenThresholdButDoNotChangeTheFusionl) { + constexpr int64 kNumParams = kMaxOperandsAndOutputsPerFusion + 1; + + // Compute + // p0 + p1 + p2 + ... + pn, + // Use so many parameters that they do not fit into one fusion. + auto module = CreateNewVerifiedModule(); + HloComputation::Builder b(TestName()); + Shape input_shape = ShapeUtil::MakeShape(F32, {10, 100}); + Shape slice_shape = ShapeUtil::MakeShape(F32, {10, 2}); + Shape concat_shape = ShapeUtil::MakeShape(F32, {10, 2 * kNumParams}); + HloInstruction* input = + b.AddInstruction(HloInstruction::CreateParameter(0, input_shape, "p")); + + std::vector slice_params; + for (int64 i = 0; i < kNumParams; ++i) { + slice_params.push_back(b.AddInstruction(HloInstruction::CreateSlice( + slice_shape, input, {0, 0}, {10, 2}, {1, 1}))); + } + b.AddInstruction( + HloInstruction::CreateConcatenate(concat_shape, slice_params, 1)); + module->AddEntryComputation(b.Build()); + EXPECT_TRUE(GpuInstructionFusion(false).Run(module.get()).ValueOrDie()); + EXPECT_TRUE(module->entry_computation()->root_instruction()->opcode() == + HloOpcode::kFusion); + for (HloInstruction* instr : module->entry_computation()->instructions()) { + EXPECT_TRUE(instr->opcode() != HloOpcode::kSlice); + } +} + } // namespace } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc index d1bece038e0..6ed378adfeb 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc @@ -833,7 +833,7 @@ TEST_F(GpuKernelTilingTest, RowReductionCorrectShmemUsage) { )"; auto hlo_module = ParseAndReturnVerifiedModule(kHloString).ValueOrDie(); auto expected_ir = R"( -; CHECK: shared_cache_{{[0-9]*}} = private addrspace({{[0-9]*}}) global [1 x [32 x float]] +; CHECK: shared_cache_{{[0-9]*}} = private unnamed_addr addrspace({{[0-9]*}}) global [1 x [32 x float]] )"; CompileAndVerifyIr(std::move(hlo_module), expected_ir, /*match_optimized_ir=*/true); diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc index 2f139563b4a..200829efddb 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc @@ -148,8 +148,8 @@ TEST_F(GpuUnrollingTest, DisabledUnrollUnfusedSine) { HloModule test_module ENTRY SineFunc { - p0 = f32[160000]{0} parameter(0) - ROOT s = f32[160000]{0} sine(p0) + p0 = f32[1600000]{0} parameter(0) + ROOT s = f32[1600000]{0} sine(p0) })"; auto hlo_module = ParseAndReturnVerifiedModule(kUnfusedAddModule, config).ValueOrDie(); @@ -182,8 +182,8 @@ TEST_F(GpuUnrollingTest, DisabledUnrollUnfusedCosine) { HloModule test_module ENTRY SineFunc { - p0 = f32[160000]{0} parameter(0) - ROOT s = f32[160000]{0} cosine(p0) + p0 = f32[1600000]{0} parameter(0) + ROOT s = f32[1600000]{0} cosine(p0) })"; auto hlo_module = ParseAndReturnVerifiedModule(kUnfusedAddModule, config).ValueOrDie(); @@ -216,8 +216,8 @@ TEST_F(GpuUnrollingTest, DisabledUnrollUnfusedPower) { HloModule test_module ENTRY SineFunc { - p0 = f32[160000]{0} parameter(0) - ROOT s = f32[160000]{0} power(p0, p0) + p0 = f32[1600000]{0} parameter(0) + ROOT s = f32[1600000]{0} power(p0, p0) })"; auto hlo_module = ParseAndReturnVerifiedModule(kUnfusedAddModule, config).ValueOrDie(); @@ -241,8 +241,8 @@ TEST_F(GpuUnrollingTest, DisabledUnrollUnfusedAtan2) { HloModule test_module ENTRY SineFunc { - p0 = f32[160000]{0} parameter(0) - ROOT s = f32[160000]{0} atan2(p0, p0) + p0 = f32[16000000]{0} parameter(0) + ROOT s = f32[16000000]{0} atan2(p0, p0) })"; auto hlo_module = ParseAndReturnVerifiedModule(kUnfusedAddModule, config).ValueOrDie(); diff --git a/tensorflow/compiler/xla/service/gpu/tests/parallel_reduction_test.cc b/tensorflow/compiler/xla/service/gpu/tests/parallel_reduction_test.cc new file mode 100644 index 00000000000..06e547dfe34 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/parallel_reduction_test.cc @@ -0,0 +1,190 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/tests/filecheck.h" + +namespace xla { +namespace gpu { + +namespace { + +class ParallelReductionTest : public GpuCodegenTest { + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest(); + // The test contains a MOF fusion and the XLA optimizer passes + // don't like this. + debug_options.set_xla_disable_all_hlo_passes(true); + return debug_options; + } +}; + +TEST_F(ParallelReductionTest, TwoParallelReductions) { + const char* hlo_text = R"( +HloModule TwoParallelReductions + +%add_f32 { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +%fused_computation { + %param0 = f32[1024] parameter(0) + %param1 = f32[1024] parameter(1) + %constant0 = f32[] constant(0) + %reduce1 = f32[] reduce(%param0, %constant0), dimensions={0}, to_apply=%add_f32 + %reduce2 = f32[] reduce(%param1, %constant0), dimensions={0}, to_apply=%add_f32 + ROOT %tuple = (f32[], f32[]) tuple(%reduce1, %reduce2) +} + +ENTRY %cluster { + %param0 = f32[1024] parameter(0) + %param1 = f32[1024] parameter(1) + ROOT %fusion = (f32[], f32[]) + fusion(%param0, %param1), kind=kInput, calls=%fused_computation +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, + ParseAndReturnVerifiedModule(hlo_text)); + CompileAndVerifyIr(std::move(hlo_module), + R"( +CHECK: reduce-group-0 +CHECK: reduce-group-1 +CHECK-NOT: reduce-group-2 +)", + /*match_optimized_ir=*/false); + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); +} + +TEST_F(ParallelReductionTest, ManyParallelReductions) { + std::unique_ptr module = CreateNewVerifiedModule(); + // Simply use a number not too large to avoid long compilation time + // and not too small for meaningful test. + const size_t num_reduces = 32; + + HloComputation* reduce_computation; + { + auto embedded_builder = HloComputation::Builder("add"); + HloInstruction* lhs = + embedded_builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {}), "lhs")); + HloInstruction* rhs = + embedded_builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {}), "rhs")); + embedded_builder.AddInstruction( + HloInstruction::CreateBinary(lhs->shape(), HloOpcode::kAdd, lhs, rhs)); + reduce_computation = + module->AddEmbeddedComputation(embedded_builder.Build()); + } + + Shape input_shape = ShapeUtil::MakeShape(F32, {1024}); + Shape output_shape = ShapeUtil::MakeShape(F32, {}); + HloComputation* fusion_computation; + { + auto fusion_builder = HloComputation::Builder("fusion_computation"); + std::vector outputs; + HloInstruction* constant = fusion_builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + for (size_t i = 0; i < num_reduces; ++i) { + HloInstruction* param = fusion_builder.AddInstruction( + HloInstruction::CreateParameter(i, input_shape, "param")); + HloInstruction* output = + fusion_builder.AddInstruction(HloInstruction::CreateReduce( + output_shape, param, constant, {0}, reduce_computation)); + outputs.push_back(output); + } + fusion_builder.AddInstruction(HloInstruction::CreateTuple(outputs)); + fusion_computation = module->AddEmbeddedComputation(fusion_builder.Build()); + } + + HloComputation::Builder b(TestName()); + std::vector entry_params; + std::vector output_shapes; + for (size_t i = 0; i < num_reduces; ++i) { + HloInstruction* param = b.AddInstruction( + HloInstruction::CreateParameter(i, input_shape, "param")); + entry_params.push_back(param); + output_shapes.push_back(output_shape); + } + b.AddInstruction(HloInstruction::CreateFusion( + ShapeUtil::MakeTupleShape(output_shapes), + HloInstruction::FusionKind::kInput, entry_params, fusion_computation)); + module->AddEntryComputation(b.Build()); + + EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{1e-5, 1e-5})); +} + +TEST_F(ParallelReductionTest, ThreeReductionGroups) { + const char* hlo_text = R"( +HloModule ThreeReductionGroups + +%add_f32 { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +%fused_computation { + %param0 = f32[1024,128] parameter(0) + %param1 = f32[1024,128] parameter(1) + %param2 = f32[1024,128] parameter(2) + %constant0 = f32[] constant(0) + // %mul0, %reduce0, and %reduce1 should go into a group. + %broadcast0 = f32[1024,128] broadcast(%constant0), dimensions={} + %mul0 = f32[1024,128] multiply(param0, broadcast0) + %reduce0 = f32[128] reduce(%mul0, %constant0), dimensions={0}, to_apply=%add_f32 + %reduce1 = f32[128] reduce(%param0, %constant0), dimensions={0}, to_apply=%add_f32 + // %reduce2 and %reduce3 should go into another group. + %reduce2 = f32[128] reduce(%param1, %constant0), dimensions={0}, to_apply=%add_f32 + %reduce3 = f32[128] reduce(%param1, %constant0), dimensions={0}, to_apply=%add_f32 + // %reduce4 and %mul2 should go into the other group, although broadcast0 is + // reused. + %mul1 = f32[1024,128] multiply(param2, broadcast0) + %reduce4 = f32[128] reduce(%mul1, %constant0), dimensions={0}, to_apply=%add_f32 + %mul2 = f32[1024,128] multiply(param2, param2) + ROOT %tuple = + (f32[1024, 128], f32[128], f32[128], f32[128], f32[128], f32[128], f32[1024, 128]) + tuple(%mul2, %reduce0, %reduce4, %reduce3, %reduce2, %reduce1, %mul0) +} + +ENTRY %cluster { + %param0 = f32[1024,128] parameter(0) + %param1 = f32[1024,128] parameter(1) + %param2 = f32[1024,128] parameter(2) + ROOT %fusion = + (f32[1024, 128], f32[128], f32[128], f32[128], f32[128], f32[128], f32[1024, 128]) + fusion(%param0, %param1, %param2), kind=kInput, calls=%fused_computation +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, + ParseAndReturnVerifiedModule(hlo_text)); + CompileAndVerifyIr(std::move(hlo_module), + R"( +CHECK: reduce-group-0 +CHECK: reduce-group-1 +CHECK: reduce-group-2 +CHECK-NOT: reduce-group-3 +)", + /*match_optimized_ir=*/false); + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tests/reduction_vectorization_test.cc b/tensorflow/compiler/xla/service/gpu/tests/reduction_vectorization_test.cc index 215c2e627ae..5f97452ff71 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/reduction_vectorization_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/reduction_vectorization_test.cc @@ -336,8 +336,17 @@ ENTRY %cluster { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr optimized_module, ParseAndReturnVerifiedModule(hlo_text)); - CompileAndOptionallyVerifyPtx(std::move(optimized_module), - R"( + const se::DeviceDescription& device_description = + backend().default_stream_executor()->GetDeviceDescription(); + int cc_major = 0, cc_minor = 0; + device_description.cuda_compute_capability(&cc_major, &cc_minor); + + string expected; + if (cc_major < 6) { + // We do not vectorize for GPU before Pascal. + expected = "CHECK-NOT: ld.global.nc.v2.f32"; + } else { + expected = R"( CHECK: ld.global.nc.v2.f32 CHECK: st.global.v2.f32 CHECK: st.global.v2.f32 @@ -350,7 +359,9 @@ CHECK: st.global.v2.f32 CHECK: ld.global.nc.v2.f32 CHECK: st.global.v2.f32 CHECK: st.global.v2.f32 -)"); +)"; + } + CompileAndOptionallyVerifyPtx(std::move(optimized_module), expected); EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); } diff --git a/tensorflow/compiler/xla/service/gpu/tests/scatter.hlo b/tensorflow/compiler/xla/service/gpu/tests/scatter.hlo index c9e7daeb3bc..f625abe6612 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/scatter.hlo +++ b/tensorflow/compiler/xla/service/gpu/tests/scatter.hlo @@ -1,6 +1,6 @@ // RUN: hlo_to_llvm_ir %s | FileCheck %s -// CHECK-LABEL: define void @scatter_TensorFlowScatterV1(i8* noalias align 64 dereferenceable(36) %alloc0, i8* noalias align 16 dereferenceable(36) %alloc1, i8* noalias align 16 dereferenceable(24) %alloc2, i8* noalias align 16 dereferenceable(8) %alloc3) { +// CHECK-LABEL: define void @scatter_TensorFlowScatterV1(i8* noalias align 16 dereferenceable(36) %alloc0, i8* noalias align 16 dereferenceable(24) %alloc1, i8* noalias align 16 dereferenceable(8) %alloc2) { // CHECK: entry: // CHECK: %[[VAL_32:.*]] = alloca i32, align 4 // CHECK: %[[VAL_0:.*]] = getelementptr inbounds i8, i8* %[[VAL_1:.*]], i64 0 @@ -43,8 +43,8 @@ // CHECK: store atomic i32 %[[VAL_36]], i32* %[[VAL_31]] unordered, align 4 // CHECK: br label %[[VAL_23]] // CHECK: !nvvm.annotations = !{!0, !1} -// CHECK: !0 = !{void (i8*, i8*, i8*, i8*)* @scatter_TensorFlowScatterV1, !"kernel", i32 1} -// CHECK: !1 = !{void (i8*, i8*, i8*, i8*)* @scatter_TensorFlowScatterV1, !"reqntidx", i32 6} +// CHECK: !0 = !{void (i8*, i8*, i8*)* @scatter_TensorFlowScatterV1, !"kernel", i32 1} +// CHECK: !1 = !{void (i8*, i8*, i8*)* @scatter_TensorFlowScatterV1, !"reqntidx", i32 6} // CHECK: !2 = !{i32 0, i32 1} // CHECK: !3 = !{i32 0, i32 6} // CHECK: !4 = !{} @@ -72,7 +72,7 @@ ENTRY main { // ----- -// CHECK-LABEL: define void @scatter_ScatterIntoScalar(i8* noalias align 64 dereferenceable(4) %alloc0, i8* noalias align 16 dereferenceable(4) %alloc1, i8* noalias align 16 dereferenceable(4) %alloc2, i8* noalias align 16 %alloc3) { +// CHECK-LABEL: define void @scatter_ScatterIntoScalar(i8* noalias align 16 dereferenceable(4) %alloc0, i8* noalias align 16 dereferenceable(4) %alloc1, i8* noalias align 16 %alloc2) { // CHECK: entry: // CHECK: %[[VAL_60:.*]] = alloca i32, align 4 // CHECK: %[[VAL_37:.*]] = getelementptr inbounds i8, i8* %[[VAL_38:.*]], i64 0 @@ -104,8 +104,8 @@ ENTRY main { // CHECK: store atomic i32 %[[VAL_62]], i32* %[[VAL_39]] unordered, align 4 // CHECK: br label %[[VAL_57]] // CHECK: !nvvm.annotations = !{!0, !1} -// CHECK: !0 = !{void (i8*, i8*, i8*, i8*)* @scatter_ScatterIntoScalar, !"kernel", i32 1} -// CHECK: !1 = !{void (i8*, i8*, i8*, i8*)* @scatter_ScatterIntoScalar, !"reqntidx", i32 1} +// CHECK: !0 = !{void (i8*, i8*, i8*)* @scatter_ScatterIntoScalar, !"kernel", i32 1} +// CHECK: !1 = !{void (i8*, i8*, i8*)* @scatter_ScatterIntoScalar, !"reqntidx", i32 1} // CHECK: !2 = !{i32 0, i32 1} // CHECK: !3 = !{} @@ -131,7 +131,7 @@ ENTRY main { // ----- -// CHECK-LABEL: define void @scatter_TensorFlowScatter_Mul(i8* noalias align 64 dereferenceable(36) %alloc0, i8* noalias align 16 dereferenceable(36) %alloc1, i8* noalias align 16 dereferenceable(24) %alloc2, i8* noalias align 16 dereferenceable(8) %alloc3) { +// CHECK-LABEL: define void @scatter_TensorFlowScatter_Mul(i8* noalias align 16 dereferenceable(36) %alloc0, i8* noalias align 16 dereferenceable(24) %alloc1, i8* noalias align 16 dereferenceable(8) %alloc2) { // CHECK: %[[VAL_63:.*]] = alloca i32, align 4 // CHECK: %[[VAL_64:.*]] = alloca i32, align 4 // CHECK: %[[VAL_98:.*]] = alloca i32, align 4 @@ -188,8 +188,8 @@ ENTRY main { // CHECK: %[[VAL_109:.*]] = extractvalue { i32, i1 } %[[VAL_107]], 1 // CHECK: br i1 %[[VAL_109]], label %[[VAL_96]], label %[[VAL_104]] // CHECK: !nvvm.annotations = !{!0, !1} -// CHECK: !0 = !{void (i8*, i8*, i8*, i8*)* @scatter_TensorFlowScatter_Mul, !"kernel", i32 1} -// CHECK: !1 = !{void (i8*, i8*, i8*, i8*)* @scatter_TensorFlowScatter_Mul, !"reqntidx", i32 6} +// CHECK: !0 = !{void (i8*, i8*, i8*)* @scatter_TensorFlowScatter_Mul, !"kernel", i32 1} +// CHECK: !1 = !{void (i8*, i8*, i8*)* @scatter_TensorFlowScatter_Mul, !"reqntidx", i32 6} // CHECK: !2 = !{i32 0, i32 1} // CHECK: !3 = !{i32 0, i32 6} // CHECK: !4 = !{} @@ -216,7 +216,7 @@ ENTRY main { // ----- -// CHECK-LABEL: define void @scatter_ScalarUpdate(i8* noalias align 64 dereferenceable(16) %alloc0, i8* noalias align 16 dereferenceable(16) %alloc1, i8* noalias align 16 dereferenceable(4) %alloc2, i8* noalias align 16 dereferenceable(4) %alloc3) { +// CHECK-LABEL: define void @scatter_ScalarUpdate(i8* noalias align 16 dereferenceable(16) %alloc0, i8* noalias align 16 dereferenceable(4) %alloc1, i8* noalias align 16 dereferenceable(4) %alloc2) { // CHECK: entry: // CHECK: %[[VAL_146:.*]] = alloca i32, align 4 // CHECK: %[[VAL_118:.*]] = getelementptr inbounds i8, i8* %[[VAL_119:.*]], i64 0 @@ -253,8 +253,8 @@ ENTRY main { // CHECK: store atomic i32 %[[VAL_148]], i32* %[[VAL_145]] unordered, align 4 // CHECK: br label %[[VAL_138]] // CHECK: !nvvm.annotations = !{!0, !1} -// CHECK: !0 = !{void (i8*, i8*, i8*, i8*)* @scatter_ScalarUpdate, !"kernel", i32 1} -// CHECK: !1 = !{void (i8*, i8*, i8*, i8*)* @scatter_ScalarUpdate, !"reqntidx", i32 1} +// CHECK: !0 = !{void (i8*, i8*, i8*)* @scatter_ScalarUpdate, !"kernel", i32 1} +// CHECK: !1 = !{void (i8*, i8*, i8*)* @scatter_ScalarUpdate, !"reqntidx", i32 1} // CHECK: !2 = !{i32 0, i32 1} // CHECK: !3 = !{} diff --git a/tensorflow/compiler/xla/service/gpu/tests/xla_thunks_opt.cc b/tensorflow/compiler/xla/service/gpu/tests/xla_thunks_opt.cc new file mode 100644 index 00000000000..97c3b3a5bde --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/xla_thunks_opt.cc @@ -0,0 +1,39 @@ +/* Copyright 2020 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project +#include "mlir/InitAllDialects.h" // from @llvm-project +#include "mlir/InitAllPasses.h" // from @llvm-project +#include "mlir/Support/MlirOptMain.h" // from @llvm-project +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/register.h" +#include "tensorflow/compiler/mlir/init_mlir.h" +#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" +#include "tensorflow/compiler/xla/service/gpu/ir/xla_thunks_ops.h" +#include "tensorflow/core/platform/init_main.h" + +int main(int argc, char **argv) { + tensorflow::InitMlir y(&argc, &argv); + + mlir::registerAllPasses(); + + mlir::DialectRegistry registry; + mlir::registerAllDialects(registry); + mlir::RegisterAllTensorFlowDialects(registry); + mlir::mhlo::registerAllMhloDialects(registry); + registry.insert(); + registry.insert(); + return failed( + mlir::MlirOptMain(argc, argv, "XLA-Thunk pass driver\n", registry)); +} diff --git a/tensorflow/compiler/xla/service/gpu/thunk.h b/tensorflow/compiler/xla/service/gpu/thunk.h index 7a9fedec629..64b685db379 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk.h +++ b/tensorflow/compiler/xla/service/gpu/thunk.h @@ -69,10 +69,6 @@ class Thunk { }; struct ThunkInfo { - // Optional. It's only used by subclasses which haven't been migrated away - // from HloInstructions. Once the migration is done, Thunks should be fully - // serializable. - const HloInstruction* hlo_instruction = nullptr; absl::optional profile_index; std::string profile_annotation; }; diff --git a/tensorflow/compiler/xla/service/gpu/thunk_emitter.cc b/tensorflow/compiler/xla/service/gpu/thunk_emitter.cc index 690d0c9de56..4c6c5bb846d 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/thunk_emitter.cc @@ -19,9 +19,11 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" #include "tensorflow/compiler/xla/service/gpu/copy_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_runner.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h" #include "tensorflow/compiler/xla/service/gpu/fft_thunk.h" #include "tensorflow/compiler/xla/service/gpu/gemm_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h" #include "tensorflow/compiler/xla/service/gpu/infeed_thunk.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/gpu/outfeed_thunk.h" @@ -37,6 +39,65 @@ limitations under the License. namespace xla { namespace gpu { +namespace { +void CheckBatchNormInputOutputPrimitivetypeAreValid(const HloInstruction* hlo) { + // All input and output statistics variables must be F32. Also, the last + // operand for CudnnBatchNormForwardInference, CudnnBatchNormForwardTraining, + // and CudnnBatchNormBackward is the feature_index which must be S64. + // The allowed types for non-statistics variables are as follows: + // CudnnBatchNormForwardInference: + // operand[0]: {half, float} + // out[0]: {half, float} + // CudnnBatchNormForwardTraining: + // operand[0]: {half, float} + // out[0]: {half, float} + // CudnnBatchNormBackward: + // operand[0]: {half, float} + // operand[4]: {half, float} + // out[0]: {half, float} + // Note non-statistics inputs and outputs mentioned above should be of the + // same type. + + // Check Inputs. + int64 num_operands = hlo->operand_count(); + PrimitiveType operand_primitive_type = + hlo->operand(0)->shape().element_type(); + CHECK(operand_primitive_type == F16 || operand_primitive_type == F32) + << "Not yet implemented"; + + for (int i = 1; i < num_operands - 2; i++) { + if (hlo->custom_call_target() == kCudnnBatchNormBackwardCallTarget && + i == 4) { + // The first operand to batchnorm grad is the input and the 4th operand is + // the grad_output, both of which can be Eigen::half. + CHECK_EQ(hlo->operand(i)->shape().element_type(), operand_primitive_type) + << "Invalid datatype"; + continue; + } + CHECK_EQ(hlo->operand(i)->shape().element_type(), F32) + << "Not yet implemented"; + } + + // The last operand is the feature index which must be int64. + CHECK_EQ(hlo->operand(num_operands - 1)->shape().element_type(), S64) + << "Not yet implemented"; + + // Check Outputs. + if (hlo->shape().IsTuple()) { + CHECK_EQ(hlo->shape().tuple_shapes(0).element_type(), + operand_primitive_type) + << "Invalid datatype"; + + for (int j = 1; j < hlo->shape().tuple_shapes_size(); j++) { + CHECK_EQ(hlo->shape().tuple_shapes(j).element_type(), F32) + << "Not yet implemented"; + } + } else { + CHECK_EQ(hlo->shape().element_type(), operand_primitive_type) + << "Invalid datatype"; + } +} +} // namespace std::unique_ptr ThunkEmitter::BuildFftThunk(const HloInstruction* inst) { const HloInstruction* operand = inst->operand(0); return absl::make_unique( @@ -72,15 +133,14 @@ std::unique_ptr ThunkEmitter::BuildTriangularSolveThunk( std::unique_ptr ThunkEmitter::BuildGemmThunk( const HloInstruction* inst) { - auto config_or = inst->backend_config(); - GemmBackendConfig gemm_config = std::move(config_or.ValueOrDie()); + GpuGemmConfig config = GetGpuGemmConfig(inst); const HloInstruction* lhs = inst->operand(0); const HloInstruction* rhs = inst->operand(1); // The bias is passed inside the output buffer. If those buffers are shared // we can just use it, otherwise copy the bias values into the output buffer // first. - if (gemm_config.beta() != 0.0) { + if (config.backend_config.beta() != 0.0) { const HloInstruction* bias = inst->operand(2); CHECK_EQ(bias->shape(), inst->shape()); if (GetAllocationSlice(*bias) != GetAllocationSlice(*inst)) { @@ -91,22 +151,22 @@ std::unique_ptr ThunkEmitter::BuildGemmThunk( /*destination_buffer=*/GetAllocationSlice(*inst), /*mem_size=*/ShapeUtil::ByteSizeOf(inst->shape()))); thunks.push_back(absl::make_unique( - context_->GetThunkInfo(inst), + context_->GetThunkInfo(inst), std::move(config), GetAllocationSlice(*lhs), // The buffer assigned to LHS. GetAllocationSlice(*rhs), // The buffer assigned to RHS. GetAllocationSlice(*inst), // The output buffer. - /*implements_whole_instruction=*/false, std::move(gemm_config))); + /*implements_whole_instruction=*/false)); return absl::make_unique(context_->GetThunkInfo(inst), std::move(thunks)); } } return absl::make_unique( - context_->GetThunkInfo(inst), + context_->GetThunkInfo(inst), std::move(config), GetAllocationSlice(*lhs), // The buffer assigned to LHS. GetAllocationSlice(*rhs), // The buffer assigned to RHS. GetAllocationSlice(*inst), // The output buffer. - /*implements_whole_instruction=*/true, std::move(gemm_config)); + /*implements_whole_instruction=*/true); } std::unique_ptr ThunkEmitter::BuildInfeedThunk( @@ -133,8 +193,9 @@ std::unique_ptr ThunkEmitter::BuildOutfeedThunk( *slice = status_or_slice.ValueOrDie(); } }); + OutfeedConfig config = GetOutfeedConfig(inst); return absl::make_unique(context_->GetThunkInfo(inst), - std::move(slices)); + std::move(config), std::move(slices)); } Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) { @@ -154,16 +215,20 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) { CHECK(feature_index->IsConstant()); int64 feature_index_value = feature_index->literal().Get({}); + CHECK_EQ(custom_call->shape().tuple_shapes_size(), 3); + CHECK(LayoutUtil::LayoutsInShapesEqual(custom_call->shape().tuple_shapes(0), + custom_call->operand(0)->shape())); + CheckBatchNormInputOutputPrimitivetypeAreValid(custom_call); + CudnnBatchNormConfig config = GetCudnnBatchNormConfig( + custom_call, epsilon_value, feature_index_value); AddThunkToThunkSequence( absl::make_unique( - context_->GetThunkInfo(custom_call), + context_->GetThunkInfo(custom_call), std::move(config), /*operand=*/GetAllocationSlice(*custom_call->operand(0)), /*scale=*/GetAllocationSlice(*custom_call->operand(1)), /*offset=*/GetAllocationSlice(*custom_call->operand(2)), /*mean=*/GetAllocationSlice(*custom_call->operand(3)), /*variance=*/GetAllocationSlice(*custom_call->operand(4)), - /*epsilon=*/epsilon_value, - /*feature_index=*/feature_index_value, /*output=*/GetAllocationSlice(*custom_call))); return Status::OK(); } @@ -183,14 +248,14 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) { auto output_data = GetAllocationSlice(*custom_call, {0}); auto output_mean = GetAllocationSlice(*custom_call, {1}); auto output_inv_stddev = GetAllocationSlice(*custom_call, {2}); + CudnnBatchNormConfig config = GetCudnnBatchNormConfig( + custom_call, epsilon_value, feature_index_value); AddThunkToThunkSequence( absl::make_unique( - context_->GetThunkInfo(custom_call), + context_->GetThunkInfo(custom_call), std::move(config), /*operand=*/GetAllocationSlice(*custom_call->operand(0)), /*scale=*/GetAllocationSlice(*custom_call->operand(1)), /*offset=*/GetAllocationSlice(*custom_call->operand(2)), - /*epsilon=*/epsilon_value, - /*feature_index=*/feature_index_value, /*output_data=*/output_data, /*output_mean=*/output_mean, /*output_inv_stddev=*/output_inv_stddev, @@ -212,15 +277,22 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) { auto output_grad_data = GetAllocationSlice(*custom_call, {0}); auto output_grad_scale = GetAllocationSlice(*custom_call, {1}); auto output_grad_offset = GetAllocationSlice(*custom_call, {2}); + CHECK_EQ(custom_call->shape().tuple_shapes_size(), 3); + CHECK(LayoutUtil::LayoutsInShapesEqual(custom_call->shape().tuple_shapes(0), + custom_call->operand(0)->shape())); + CHECK(LayoutUtil::LayoutsInShapesEqual(custom_call->shape().tuple_shapes(0), + custom_call->operand(4)->shape())); + CheckBatchNormInputOutputPrimitivetypeAreValid(custom_call); + + CudnnBatchNormConfig config = GetCudnnBatchNormConfig( + custom_call, epsilon_value, feature_index_value); AddThunkToThunkSequence(absl::make_unique( - context_->GetThunkInfo(custom_call), + context_->GetThunkInfo(custom_call), std::move(config), /*operand=*/GetAllocationSlice(*custom_call->operand(0)), /*scale=*/GetAllocationSlice(*custom_call->operand(1)), /*mean=*/GetAllocationSlice(*custom_call->operand(2)), /*inv_stddev=*/GetAllocationSlice(*custom_call->operand(3)), /*grad_output=*/GetAllocationSlice(*custom_call->operand(4)), - /*epsilon=*/epsilon_value, - /*feature_index=*/feature_index_value, /*output_grad_data=*/output_grad_data, /*output_grad_scale=*/output_grad_scale, /*output_grad_offset=*/output_grad_offset, @@ -238,9 +310,13 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) { auto conv_result_slice = GetAllocationSlice(*custom_call, {0}); auto scratch_slice = GetAllocationSlice(*custom_call, {1}); + TF_ASSIGN_OR_RETURN( + GpuConvConfig config, + GetGpuConvConfig(Cast(custom_call))); AddThunkToThunkSequence(absl::make_unique( - context_->GetThunkInfo(custom_call), std::move(operand_slices), - conv_result_slice, scratch_slice, tuple_result_slice)); + context_->GetThunkInfo(custom_call), std::move(config), + std::move(operand_slices), conv_result_slice, scratch_slice, + tuple_result_slice)); return Status::OK(); } @@ -310,11 +386,26 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) { return slices; }; std::vector> operand_slices; - for (const auto* operand : custom_call->operands()) { + for (int64 i = 0; i < custom_call->operand_count(); i++) { + const auto* operand = custom_call->operand(i); operand_slices.push_back(get_slices_for_instr(operand)); + const auto& s1 = operand_slices.back().shape(); + const auto& s2 = operand->shape(); + CHECK(ShapeUtil::Equal(s1, s2)) << absl::StreamFormat( + "Shape mismatch between operand shape and " + "slice shape for operand %d: %s vs %s", + i, s1.ToString(), s2.ToString()); } ShapeTree result_slices = get_slices_for_instr(custom_call); + CHECK(ShapeUtil::Equal(custom_call->shape(), result_slices.shape())) + << absl::StreamFormat( + "Shape mismatch between instr->shape() and " + "result_slices.shape(): " + "%s vs %s.", + custom_call->shape().ToString(), + result_slices.shape().ToString()); + AddThunkToThunkSequence(absl::make_unique( context_->GetThunkInfo(custom_call), call_target, std::move(operand_slices), std::move(result_slices), @@ -385,7 +476,6 @@ Thunk::ThunkInfo ThunkEmitter::EmissionContext::GetThunkInfo( const HloInstruction* hlo) const { CHECK(hlo); Thunk::ThunkInfo info; - info.hlo_instruction = hlo; info.profile_annotation = absl::StrFormat( "Thunk:#hlo_op=%s,hlo_module=%s#", hlo->name(), hlo->GetModule()->name()); return info; diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.cc b/tensorflow/compiler/xla/service/gpu/while_thunk.cc index 792479df4ac..6397ad3bee0 100644 --- a/tensorflow/compiler/xla/service/gpu/while_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/while_thunk.cc @@ -27,9 +27,10 @@ WhileThunk::WhileThunk( ThunkInfo thunk_info, const BufferAllocation::Slice& condition_result_buffer_index, std::unique_ptr condition_thunk_sequence, - std::unique_ptr body_thunk_sequence) + std::unique_ptr body_thunk_sequence, + absl::optional condition_profile_index, + absl::optional body_profile_index) : Thunk(Kind::kWhile, thunk_info), - hlo_instruction_(thunk_info.hlo_instruction), condition_result_buffer_index_(condition_result_buffer_index), // Pass nullptr as the HloInstruction* to the condition_thunk_sequence_ // and body_thunk_sequence_ constructors because these SequentialThunks @@ -38,7 +39,9 @@ WhileThunk::WhileThunk( condition_thunk_sequence_(absl::make_unique( ThunkInfo(), std::move(*condition_thunk_sequence))), body_thunk_sequence_(absl::make_unique( - ThunkInfo(), std::move(*body_thunk_sequence))) {} + ThunkInfo(), std::move(*body_thunk_sequence))), + condition_profile_index_(condition_profile_index), + body_profile_index_(body_profile_index) {} Status WhileThunk::Initialize(const GpuExecutable& executable, se::StreamExecutor* executor) { @@ -62,7 +65,7 @@ Status WhileThunk::ExecuteOnStream(const ExecuteParams& params) { profiler.StartHloComputation(); VLOG(3) << "Executing condition computation"; TF_RETURN_IF_ERROR(condition_thunk_sequence_->ExecuteOnStream(params)); - profiler.FinishHloComputation(hlo_instruction_->while_condition()); + profiler.FinishHloComputation(condition_profile_index_); // Copy the result of condition computation and break the loop if 'false'. bool condition_result; @@ -86,7 +89,7 @@ Status WhileThunk::ExecuteOnStream(const ExecuteParams& params) { // Invoke thunk sequence for while 'body' computation, and pass on // 'profiler' to measure the timing of the thunks in 'body_thunk_sequence_'. TF_RETURN_IF_ERROR(body_thunk_sequence_->ExecuteOnStream(params)); - profiler.FinishHloComputation(hlo_instruction_->while_body()); + profiler.FinishHloComputation(body_profile_index_); } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.h b/tensorflow/compiler/xla/service/gpu/while_thunk.h index 707bac15bb2..707edbdc192 100644 --- a/tensorflow/compiler/xla/service/gpu/while_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/while_thunk.h @@ -42,7 +42,9 @@ class WhileThunk : public Thunk { WhileThunk(ThunkInfo thunk_info, const BufferAllocation::Slice& condition_result_buffer_index, std::unique_ptr condition_thunk_sequence, - std::unique_ptr body_thunk_sequence); + std::unique_ptr body_thunk_sequence, + absl::optional condition_profile_index, + absl::optional body_profile_index); WhileThunk(const WhileThunk&) = delete; WhileThunk& operator=(const WhileThunk&) = delete; @@ -51,10 +53,11 @@ class WhileThunk : public Thunk { Status ExecuteOnStream(const ExecuteParams& params) override; private: - const HloInstruction* hlo_instruction_; const BufferAllocation::Slice condition_result_buffer_index_; std::unique_ptr condition_thunk_sequence_; std::unique_ptr body_thunk_sequence_; + const absl::optional condition_profile_index_; + const absl::optional body_profile_index_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index c3a7b3a5c14..ac94b2e1d24 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -35,7 +35,7 @@ import "tensorflow/compiler/xla/xla_data.proto"; option cc_enable_arenas = true; // Serialization of HloInstruction. -// Next ID: 74 +// Next ID: 75 message HloInstructionProto { reserved 10; reserved "parameter_name"; @@ -232,6 +232,11 @@ message HloInstructionProto { // kCustomCall. bool custom_call_has_side_effect = 65; + // A list of CustomCallOutputOperandAliasing pairs that specifies aliasing + // buffers between output and operands for kCustomCall. + repeated xla.CustomCallOutputOperandAliasing + custom_call_output_operand_aliasing = 74; + // The delta value for kRngGetAndUpdateState. int64 delta = 66; diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc index 384ae272dc1..cf09ddeec27 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc @@ -308,6 +308,39 @@ class BufferValueMap { } } + void ComputeInPlaceOperationAliasedBuffers( + const HloValue& value, std::vector* aliased_buffers) { + VLOG(3) << "Compute aliases for in-place operations (e.g. " + "kDynamicUpdateSlice and kScatter)"; + for (const HloPosition& position : value.positions()) { + HloInstruction* instruction = position.instruction; + for (const auto& operand_and_output_index : + HloDataflowAnalysis::GetInPlaceInputOutputPairs(instruction)) { + if (position.index == operand_and_output_index.second) { + const HloUse& operand = operand_and_output_index.first; + const HloValue& operand_value = dataflow_.GetUniqueValueAt( + instruction->operand(operand.operand_number), + operand.operand_index); + VLOG(3) << " operand value " << operand_value.ToShortString() + << " aliases."; + aliased_buffers->push_back(GetBufferForValue(operand_value)); + } + } + } + + for (const HloUse& use : value.uses()) { + for (const auto& operand_and_output_index : + HloDataflowAnalysis::GetInPlaceInputOutputPairs(use.instruction)) { + if (use == operand_and_output_index.first) { + const HloValue& use_value = dataflow_.GetUniqueValueAt( + use.instruction, operand_and_output_index.second); + VLOG(3) << " use value " << use_value.ToShortString() << " aliases."; + aliased_buffers->push_back(GetBufferForValue(use_value)); + } + } + } + } + // Compute and return a vector of buffers that the given value must be // contained in due to HLO aliasing rules. std::vector ComputeAliasedBuffers(const HloValue& value) { @@ -318,6 +351,7 @@ class BufferValueMap { ComputeInputOutputAliasedBuffers(value, &aliased_buffers); ComputeWhileAliasedBuffers(value, &aliased_buffers); ComputeConditionalAliasedBuffers(value, &aliased_buffers); + ComputeInPlaceOperationAliasedBuffers(value, &aliased_buffers); // Uniquify aliased buffers. absl::c_sort(aliased_buffers); aliased_buffers.erase( diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc index 2666cb0872d..5e94f1d173e 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc @@ -1062,6 +1062,118 @@ TEST_F(HloAliasAnalysisTest, MergeBuffersReverse) { analysis.BufferLivesOut(analysis.buffers()[0]); } +TEST_F(HloAliasAnalysisTest, DynamicUpdateSlice) { + Shape shape = ShapeUtil::MakeShape(F32, {8}); + Shape update_shape = ShapeUtil::MakeShape(F32, {4}); + Shape index_shape = ShapeUtil::MakeShape(S32, {}); + auto builder = HloComputation::Builder(TestName()); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param0")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, update_shape, "param1")); + auto param2 = builder.AddInstruction( + HloInstruction::CreateParameter(2, index_shape, "param2")); + auto copy0 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kCopy, param0)); + auto dynamic_update_slice = builder.AddInstruction( + HloInstruction::CreateDynamicUpdateSlice(shape, copy0, param1, {param2})); + + module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); + + HloAliasAnalysis& analysis = RunAnalysis(); + + EXPECT_EQ(analysis.GetUniqueBufferAt(copy0), + analysis.GetUniqueBufferAt(dynamic_update_slice)); +} + +TEST_F(HloAliasAnalysisTest, DynamicUpdateSliceMultiOutputFusion) { + absl::string_view hlo_string = R"( +HloModule Module + +fused_computation { + param0 = f32[1280,1,128] parameter(0) + param1 = f32[1280,1,128] parameter(1) + param2 = f32[1280,1,128] parameter(2) + constant.1 = f32[] constant(0) + broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={} + constant.3 = s32[] constant(0) + add.1 = f32[1280,1,128] add(param0, param0) + dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param1, broadcast.6, constant.3, constant.3, constant.3) + dynamic-update-slice.6 = f32[1280,1,128] dynamic-update-slice(param2, broadcast.6, constant.3, constant.3, constant.3) + ROOT tuple.1 = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) tuple(add.1, dynamic-update-slice.5, dynamic-update-slice.6) +} + +ENTRY main { + param = f32[1280,1,128] parameter(0) + negate0 = f32[1280,1,128] negate(param) + negate1 = f32[1280,1,128] negate(param) + negate2 = f32[1280,1,128] negate(param) + ROOT fusion = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) fusion(negate0, negate1, negate2), kind=kLoop, calls=fused_computation +} +)"; + TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(hlo_string)); + + SCOPED_TRACE(module_->ToString()); + + HloAliasAnalysis& analysis = RunAnalysis(); + LOG(INFO) << analysis.ToString(); + + // Expect negate1 and negate2 to alias with fusion{1} and fusion{2} + // respectively (due to DUS), but not negate0 and fusion{0}. + const HloInstruction* fusion = + module_->entry_computation()->GetInstructionWithName("fusion"); + const HloInstruction* negate0 = + module_->entry_computation()->GetInstructionWithName("negate0"); + const HloInstruction* negate1 = + module_->entry_computation()->GetInstructionWithName("negate1"); + const HloInstruction* negate2 = + module_->entry_computation()->GetInstructionWithName("negate2"); + EXPECT_EQ(analysis.GetUniqueBufferAt(negate1), + analysis.GetUniqueBufferAt(fusion, {1})); + EXPECT_EQ(analysis.GetUniqueBufferAt(negate2), + analysis.GetUniqueBufferAt(fusion, {2})); + EXPECT_NE(analysis.GetUniqueBufferAt(negate0), + analysis.GetUniqueBufferAt(fusion, {0})); +} + +TEST_F(HloAliasAnalysisTest, ChainedDynamicUpdateSliceFusion) { + // CPU and GPU backends may generate fusions with dynamic update slices + // feeding each other. They expect the fusion to not be in-place if that is + // the case. + absl::string_view hlo_string = R"( +HloModule Module + +fused_computation { + param0 = f32[1280,1,128] parameter(0) + constant.1 = f32[] constant(0) + broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={} + constant.3 = s32[] constant(0) + dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param0, broadcast.6, constant.3, constant.3, constant.3) + ROOT dynamic-update-slice.6 = f32[1280,1,128] dynamic-update-slice(dynamic-update-slice.5, broadcast.6, constant.3, constant.3, constant.3) +} + +ENTRY main { + param = f32[1280,1,128] parameter(0) + negate0 = f32[1280,1,128] negate(param) + ROOT fusion = f32[1280,1,128] fusion(negate0), kind=kLoop, calls=fused_computation +} +)"; + TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(hlo_string)); + + SCOPED_TRACE(module_->ToString()); + + HloAliasAnalysis& analysis = RunAnalysis(); + LOG(INFO) << analysis.ToString(); + + const HloInstruction* fusion = + module_->entry_computation()->GetInstructionWithName("fusion"); + const HloInstruction* negate0 = + module_->entry_computation()->GetInstructionWithName("negate0"); + EXPECT_NE(analysis.GetUniqueBufferAt(negate0), + analysis.GetUniqueBufferAt(fusion)); +} + TEST_F(HloAliasAnalysisTest, BitcastInterference) { // A bitcast value simultaneously live with its operand should not cause // interference. diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 14daf680ac9..6323d0903a4 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -93,10 +93,13 @@ HloComputation::HloComputation( } HloInstruction* HloComputation::AddInstruction( - std::unique_ptr instruction) { + std::unique_ptr instruction, const std::string& new_name) { CHECK(instruction->opcode() != HloOpcode::kParameter) << "Parameter instructions cannot be added to a computation after " << "it has been built"; + if (!new_name.empty()) { + instruction->SetAndSanitizeName(new_name); + } return AddInstructionInternal(std::move(instruction)); } @@ -315,6 +318,8 @@ Status HloComputation::RemoveInstructionImpl(HloInstruction* instruction, (*inst_it->second)->set_parent(nullptr); to_be_deleted_.emplace_back(inst_it->second->release()); to_be_deleted_.back()->DetachFromOperandsAndUsers(); + // Clear all operands to avoid Null operands. + to_be_deleted_.back()->RemoveAllOperands(); to_be_deleted_.back()->MarkAsDead(); instructions_.erase(inst_it->second); instruction_iterators_.erase(inst_it); @@ -380,6 +385,9 @@ void HloComputation::ComputeInstructionPostOrder( dfs_stack.push_back(root); while (!dfs_stack.empty()) { const auto current = dfs_stack.back(); + CHECK_EQ(current->parent(), this) + << "Instruction " << current->name() + << " is not in the current computation (" << name() << ")."; auto it = visited->find(current); if (it != visited->end()) { if (it->second == kVisited) { @@ -836,8 +844,9 @@ ProgramShape HloComputation::ComputeProgramShape(bool include_ids) const { return program_shape; } -bool HloComputation::Equal(const HloComputation& other, - bool is_layout_sensitive) const { +bool HloComputation::EqualInternal(const HloComputation& other, + bool is_layout_sensitive, + bool ignore_channel_id_values) const { if (this == &other) { return true; } @@ -855,15 +864,21 @@ bool HloComputation::Equal(const HloComputation& other, continue; } visited.emplace(pair); - // TODO(b/123082518): Avoid recursively invoking == because it may + // TODO(b/123082518): Avoid recursively invoking Equal because it may // cause a stack overflow with deeply nested subcomputations. - bool identical_ignoring_operands = pair.first->Identical( - *pair.second, - [](const HloInstruction*, const HloInstruction*) { return true; }, - [](const HloComputation* a, const HloComputation* b) { - return *a == *b; - }, - is_layout_sensitive); + auto operands_eq = [](const HloInstruction*, const HloInstruction*) { + return true; + }; + auto comp_eq = [&](const HloComputation* a, const HloComputation* b) { + return a->EqualInternal(*b, is_layout_sensitive, + ignore_channel_id_values); + }; + bool identical_ignoring_operands = + ignore_channel_id_values + ? pair.first->IdenticalIgnoringChannelIdValues( + *pair.second, operands_eq, comp_eq, is_layout_sensitive) + : pair.first->Identical(*pair.second, operands_eq, comp_eq, + is_layout_sensitive); if (!identical_ignoring_operands) { return false; } diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index d640007886c..d618a527070 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -122,7 +122,8 @@ class HloComputation { // Add an instruction to the computation. The computation takes ownership of // the instruction. - HloInstruction* AddInstruction(std::unique_ptr instruction); + HloInstruction* AddInstruction(std::unique_ptr instruction, + const std::string& new_name = ""); // Remove the param_no'th parameter from the computation. // Note this is only applicatable to the computation for the fusion @@ -310,7 +311,19 @@ class HloComputation { ProgramShape ComputeProgramShape(bool include_ids = true) const; // Return whether `*this` and `other` are functionally equivalent. - bool Equal(const HloComputation& other, bool is_layout_sensitive) const; + bool Equal(const HloComputation& other, bool is_layout_sensitive) const { + return EqualInternal(other, is_layout_sensitive, + /*ignore_channel_id_values=*/false); + } + + // Same as Equal() but ignores channel ID value mismatches on instructions, as + // long as the two instructions both have channel IDs or neither has a channel + // ID. + bool EqualIgnoringChannelIdValues(const HloComputation& other, + bool is_layout_sensitive) const { + return EqualInternal(other, is_layout_sensitive, + /*ignore_channel_id_values=*/true); + } // Return whether `*this` and `other` are functionally equivalent. bool operator==(const HloComputation& other) const { @@ -489,6 +502,10 @@ class HloComputation { HloInstruction* AddInstructionInternal( std::unique_ptr instruction); + // Internal helper for comparison with different options. + bool EqualInternal(const HloComputation& other, bool is_layout_sensitive, + bool ignore_channel_id_values) const; + // Fuses HLOs in instructions_to_fuse into fusion_instruction. // // Pre-condition: fusion_instruction's opcode is kFusion. diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc index 4ba67888409..4aeeb6d27ac 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc @@ -92,16 +92,17 @@ StatusOr MakeSliceHlo(HloInstruction* operand, StatusOr MakeConvolveHlo( HloInstruction* lhs, HloInstruction* rhs, int64 feature_group_count, - const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, + int64 batch_group_count, const Window& window, + const ConvolutionDimensionNumbers& dimension_numbers, const PrecisionConfig& precision_config) { HloComputation* computation = lhs->parent(); CHECK_EQ(computation, rhs->parent()); TF_ASSIGN_OR_RETURN(Shape convolve_shape, ShapeInference::InferConvolveShape( - lhs->shape(), rhs->shape(), feature_group_count, 1, - window, dimension_numbers)); + lhs->shape(), rhs->shape(), feature_group_count, + batch_group_count, window, dimension_numbers)); return computation->AddInstruction(HloInstruction::CreateConvolve( - convolve_shape, lhs, rhs, feature_group_count, 1, window, + convolve_shape, lhs, rhs, feature_group_count, batch_group_count, window, dimension_numbers, precision_config)); } diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.h b/tensorflow/compiler/xla/service/hlo_creation_utils.h index 2b17ae3d967..53eeeffb858 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.h +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.h @@ -61,7 +61,8 @@ StatusOr MakeSliceHlo(HloInstruction* operand, // containing `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation). StatusOr MakeConvolveHlo( HloInstruction* lhs, HloInstruction* rhs, int64 feature_group_count, - const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, + int64 batch_group_count, const Window& window, + const ConvolutionDimensionNumbers& dimension_numbers, const PrecisionConfig& precision_config); // Creates a transpose HLO instruction and adds it to the computation containing diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index a46d20d5808..bc1063f9d48 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include #include "absl/algorithm/container.h" @@ -32,6 +33,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_value.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -42,7 +44,45 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" namespace xla { +namespace { +// CalculatePostOrderSchedule traverses a module and assign a ordinal to each +// instruction based the postorder dependency. +int64 CalculatePostOrderScheduleHelper( + const HloComputation* comp, int64 start_ordinal, + absl::flat_hash_map* ordinal_map) { + int64 ordinal = start_ordinal; + for (HloInstruction* instruction : comp->MakeInstructionPostOrder()) { + if (instruction->opcode() == HloOpcode::kCall || + instruction->opcode() == HloOpcode::kConditional) { + for (const HloComputation* called_computation : + instruction->called_computations()) { + ordinal = CalculatePostOrderScheduleHelper(called_computation, ordinal, + ordinal_map); + } + } + if (instruction->opcode() == HloOpcode::kWhile) { + ordinal = CalculatePostOrderScheduleHelper(instruction->while_condition(), + ordinal, ordinal_map); + ordinal = CalculatePostOrderScheduleHelper(instruction->while_body(), + ordinal, ordinal_map); + } + // It's possible that in some unit tests the computation graph is not + // flatten (meaning we could have multiple callers for one computation). In + // that case the oridinal_map will see the instruction multiple times. We + // consider that case to be ok as it only shows up in unit tests. + ordinal_map->insert({instruction, ordinal++}); + } + return ordinal; +} +absl::flat_hash_map CalculatePostOrderSchedule( + const HloModule& module) { + absl::flat_hash_map map; + CalculatePostOrderScheduleHelper(module.entry_computation(), 0, &map); + return map; +} + +} // namespace using absl::StrAppend; using absl::StrCat; @@ -392,6 +432,23 @@ bool HloDataflowAnalysis::UpdateSendValueSet(HloInstruction* send) { return changed; } +bool HloDataflowAnalysis::UpdateCustomCallValueSet( + HloInstruction* custom_call) { + CHECK_EQ(custom_call->opcode(), HloOpcode::kCustomCall); + bool changed = false; + for (const auto& aliasing : Cast(custom_call) + ->output_to_operand_aliasing()) { + const HloValueSet& operand_value_set = GetValueSet( + custom_call->operand(aliasing.second.first), aliasing.second.second); + HloValueSet& value_set = GetValueSet(custom_call, aliasing.first); + if (value_set != operand_value_set) { + value_set = operand_value_set; + changed = true; + } + } + return changed; +} + bool HloDataflowAnalysis::UpdateCopyStartValueSet(HloInstruction* copy_start) { CHECK_EQ(copy_start->opcode(), HloOpcode::kCopyStart); bool changed = false; @@ -717,6 +774,8 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet( return UpdateAddDependencyValueSet(instruction); case HloOpcode::kBitcast: return UpdateBitcastValueSet(instruction); + case HloOpcode::kCustomCall: + return UpdateCustomCallValueSet(instruction); case HloOpcode::kSetDimensionSize: return UpdateSetDimensionSizeValueSet(instruction); case HloOpcode::kDomain: @@ -757,27 +816,35 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet( } void HloDataflowAnalysis::Propagate() { - std::queue worklist; + using Work = std::pair; + // Avoid duplicating work by preferring work items early in the post order + // schedule. Intuitively, we start from entry parameters and propagate buffers + // updates throughout the module only once. + std::priority_queue, std::greater> worklist; absl::flat_hash_set workset; - auto add_to_worklist = [&worklist, &workset](HloInstruction* instruction) { + auto priority_map = CalculatePostOrderSchedule(module_); + auto add_to_worklist = [&priority_map, &worklist, + &workset](HloInstruction* instruction) { if (workset.insert(instruction).second) { - worklist.push(instruction); + worklist.emplace(priority_map[instruction], instruction); } }; - for (HloComputation* computation : module_.computations()) { - for (HloInstruction* instruction : computation->instructions()) { + auto comps = module_.MakeComputationPostOrder(); + for (HloComputation* computation : comps) { + for (HloInstruction* instruction : + computation->MakeInstructionPostOrder()) { add_to_worklist(instruction); } } VLOG(1) << "SSA_FORM_: " << ssa_form_; while (!worklist.empty()) { - HloInstruction* instruction = worklist.front(); + HloInstruction* instruction = worklist.top().second; auto add_to_worklist = [&](HloInstruction* todo) { if (workset.insert(todo).second) { VLOG(1) << " Adding todo : " << todo->name(); - worklist.push(todo); + worklist.emplace(priority_map[todo], todo); } }; worklist.pop(); @@ -970,6 +1037,22 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() { define_value_at(/*index=*/{1}); define_value_at(/*index=*/{2}); break; + case HloOpcode::kCustomCall: { + absl::flat_hash_set aliasing_indices; + for (const auto& aliasing : + Cast(instruction) + ->output_to_operand_aliasing()) { + aliasing_indices.insert(aliasing.first); + } + ShapeUtil::ForEachSubshape( + instruction->shape(), + [&](const Shape& /*subshape*/, const ShapeIndex& index) { + if (!aliasing_indices.contains(index)) { + define_value_at(index); + } + }); + break; + } default: define_all_values(); break; @@ -1130,69 +1213,49 @@ bool HloDataflowAnalysis::DoesNotUseOperandBuffer( return true; } -// Given a fusion whose root is a dynamic-update-slice op, determines whether -// the fusion's output buffer can be shared with the buffer of fusion_param, -// which must be a fused parameter of the fusion. -// -// Preconditions: -// -// - fusion's root is a dynamic-update-slice op. -// - fusion_param is a parameter within the fusion. -// -// fusion_param may point to a subelement of the actual parameter instruction if -// the param is a tuple; i.e. fusion_param->index() need not be the empty list. -// -// Returns true if: -// -// * fusion_param is used by the root of dynamic-update-slice as the "base" of -// the update, i.e. the thing being updated, AND -// * all other uses of fusion_param are dynamic-slices that slice the same -// indices as are overwritten in the dynamic-update-slice. -// -// In the case that there are no other uses of fusion_param (last bullet point -// is vacuously true) it's easy to see why an in-place DUS is safe; this is just -// the "natural" implementation of DUS. If there are other users, in-place DUS -// is safe on the assumption that the thread which writes element i of the -// output will be the only one to read element i of fusion_param (via the -// dynamic-slice ops). -static bool CanDoInPlaceDynamicUpdateSlice(HloInstruction* fusion, - const HloValue& fusion_param_value) { - auto* root = - Cast(fusion->fused_expression_root()); - auto* fusion_param = fusion_param_value.instruction(); - CHECK_EQ(fusion_param->opcode(), HloOpcode::kParameter); - CHECK_EQ(fusion_param->parent(), fusion->fused_instructions_computation()); +/*static*/ bool HloDataflowAnalysis::IsInPlaceOperation(HloOpcode opcode) { + return opcode == HloOpcode::kDynamicUpdateSlice || + opcode == HloOpcode::kScatter; +} - // fusion_param must be used by the root as the "base" of the - // dynamic-update-slice. The natural way to check this would be - // - // `if (root->operand(0) != fusion_param)` - // - // but we also have to handle the case where the fusion parameter is - // tuple-shaped and we're considering just one element of that tuple, i.e. - // fusion_param.index() != {}. - if (absl::c_count_if(fusion_param_value.uses(), [&](const HloUse& use) { - return use.instruction == root; - }) != 1) { - return false; +/*static*/ std::vector> +HloDataflowAnalysis::GetInPlaceInputOutputPairs(HloInstruction* instruction) { + if (IsInPlaceOperation(instruction->opcode())) { + return {{HloUse{instruction, 0, {}}, {}}}; + } else if (instruction->opcode() != HloOpcode::kFusion) { + return {}; } - - // All other uses of fusion_param must be dynamic-slices that slice the same - // indices as are overwritten by the dynamic-update-slice. - for (const HloUse& use : fusion_param_value.uses()) { - auto* user = use.instruction; - if (user == root) { - continue; + std::vector> input_output_pairs; + for (auto& indexed_shape : ShapeUtil::GetLeafShapes(instruction->shape())) { + const HloInstruction* hlo_generating_output = + instruction->fused_expression_root(); + for (int64 i = 0; i < indexed_shape.index.size(); ++i) { + if (hlo_generating_output->opcode() == HloOpcode::kTuple) { + hlo_generating_output = + hlo_generating_output->operand(indexed_shape.index[i]); + } else { + CHECK_EQ(i, indexed_shape.index.size() - 1); + } } - // Check that `user` is a dynamic-slice op and has the same slice indices as - // `root`. - auto* ds = DynCast(user); - if (!ds || ds->index_operands() != root->index_operands()) { - return false; + if (IsInPlaceOperation(hlo_generating_output->opcode())) { + ShapeIndex operand_index; + const HloInstruction* fusion_parameter = + hlo_generating_output->operand(0); + while (fusion_parameter->opcode() == HloOpcode::kGetTupleElement) { + operand_index.push_front(fusion_parameter->tuple_index()); + fusion_parameter = fusion_parameter->operand(0); + } + + if (fusion_parameter->opcode() == HloOpcode::kParameter) { + input_output_pairs.emplace_back( + HloUse{instruction, fusion_parameter->parameter_number(), + operand_index}, + indexed_shape.index); + } } } - return true; + return input_output_pairs; } bool HloDataflowAnalysis::CanShareOperandBufferWithUser( @@ -1213,24 +1276,17 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( return false; } - if (user->opcode() == HloOpcode::kFusion) { - // Get the parameter associated with 'operand'; - HloInstruction* fusion_param = - user->fused_parameter(user->operand_index(operand)); - - const HloValue& fusion_param_value = - GetValueDefinedAt(fusion_param, operand_index); - - // TODO(b/80315712): This code is in a bit of a weird intermediate state - // at the moment. The in-place DUS check really needs to be common to all - // backends, so it runs first. Then we run the backend-specific check if - // provided, or go through the target-independent check if not. - // Unfortunately, the notionally "target-independent" path actually contains - // some target-specific code, so we can't run all of it *in addition* to the - // target-specific function, like the interface documentation says. - if (user->fused_expression_root()->opcode() == - HloOpcode::kDynamicUpdateSlice) { - return CanDoInPlaceDynamicUpdateSlice(user, fusion_param_value); + // Must-alias relationship returns true for in-place operations (DUS and DUS + // fusions), regardless of the backend. + for (const auto& operand_and_output_index : + GetInPlaceInputOutputPairs(user)) { + if (operand_and_output_index.second != user_index) { + continue; + } + for (const HloUse& use : GetUniqueValueAt(operand, operand_index).uses()) { + if (use == operand_and_output_index.first) { + return true; + } } } diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h index bec592aeb20..c3aad04023f 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h @@ -49,6 +49,9 @@ class HloDataflowAnalysis { // Infrastructure for passing may-alias hints: HLO passes can populate the // may-alias table. If an empty optional is returned, default rules are used. // + // Must-alias rules (as defined by GetInPlaceInputOutputPairs) cannot be + // overriden using backend-specific overrides. + // // The first parameter of the function should be the instruction, the // second parameter should be an operand of the instruction. The third // parameter should be the output index of the instruction. @@ -160,6 +163,15 @@ class HloDataflowAnalysis { const HloModule& module() const { return module_; } + // Returns true if the operation is an in-place operation and its operand 0 + // must alias with the output. + static bool IsInPlaceOperation(HloOpcode opcode); + + // Returns a vector consisting of the HloUse (operand number and shape index) + // and output shape index of the in-place operations within this HLO. + static std::vector> GetInPlaceInputOutputPairs( + HloInstruction* instruction); + protected: HloDataflowAnalysis(const HloModule& module, bool ssa_form, bool bitcast_defines_value = false, @@ -204,6 +216,7 @@ class HloDataflowAnalysis { bool UpdateCallValueSet(HloInstruction* call); bool UpdateConditionalValueSet(HloInstruction* conditional); bool UpdateCopyValueSet(HloInstruction* copy); + bool UpdateCustomCallValueSet(HloInstruction* custom_call); bool UpdateDomainValueSet(HloInstruction* domain); bool UpdateGetTupleElementValueSet(HloInstruction* gte); bool UpdateParameterValueSet(HloInstruction* parameter); diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index 551ffb52031..1fa6fe95c40 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -2324,36 +2324,6 @@ TEST_F(CanShareOperandBufferWithUserTest, dataflow_analysis_->CanShareOperandBufferWithUser(param, {}, fusion, {})); } -TEST_F(CanShareOperandBufferWithUserTest, DUSWithSliceWithDifferentIndices) { - const char* kModule = R"( - HloModule test - - fused_computation { - p0 = f32[10,20,30] parameter(0) - p1 = s32[] parameter(1) - p2 = s32[] parameter(2) - p3 = s32[] parameter(3) - slice = f32[1,1,30] dynamic-slice(p0, p1, p2, p3), dynamic_slice_sizes={1,1,30} - ROOT dus = f32[10,20,30] dynamic-update-slice(p0, slice, p1, p3, p2) - } - - ENTRY test { - p0 = f32[10,20,30] parameter(0) - p1 = s32[] parameter(1) - p2 = s32[] parameter(2) - p3 = s32[] parameter(3) - ROOT fusion = f32[10,20,30] fusion(p0, p1, p2, p3), kind=kLoop, calls=fused_computation - } - )"; - TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(kModule)); - auto* fusion = module_->entry_computation()->root_instruction(); - auto* param = module_->entry_computation()->parameter_instruction(0); - - RunAnalysis(); - EXPECT_FALSE( - dataflow_analysis_->CanShareOperandBufferWithUser(param, {}, fusion, {})); -} - TEST_F(CanShareOperandBufferWithUserTest, DUSWithSliceWithSameIndices) { const char* kModule = R"( HloModule test diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index d5f0c62adc1..4fb7edd0104 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -1157,7 +1157,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const int64 feature_group_index = out_index[output_z_dim] / output_feature_group_size; - const int64 batch_group_index = out_index[output_z_dim]; + const int64 depthwise_multiplier = + batch_group_count > 1 ? output_z_size / input_batch_size : 1; + const int64 batch_group_index = + out_index[output_z_dim] / depthwise_multiplier; ElementwiseT result_val = static_cast(0); DimensionVector rhs_spatial_index(dnums.kernel_spatial_dimensions_size(), @@ -1218,7 +1221,6 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { feature_group_index * input_feature_group_size + rhs_iz; int64 lhs_linear_index = lhs_linear_spatial_index; - lhs_linear_index += out_index[output_batch_dim] * lhs_dim_multipliers[input_batch_dim]; @@ -1233,7 +1235,6 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { lhs_dim_multipliers[input_batch_dim]; lhs_linear_index += iz * lhs_dim_multipliers[input_z_dim]; - int64 rhs_linear_index = rhs_linear_spatial_index; rhs_linear_index += out_index[output_z_dim] * @@ -2299,8 +2300,6 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { std::vector input_index(operand_shape.dimensions_size()); std::vector update_index(updates_shape.dimensions_size()); - std::vector input_scatter_index_clamped( - operand_shape.dimensions_size()); UpdateScatterIndexToInputIndex update_scatter_index_to_input_index( &scatter->scatter_dimension_numbers(), /*input_shape=*/operand_shape, @@ -2789,7 +2788,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // bound, call `f` with the base index. static void IterateThroughWindow( const Shape& window_shape, const Window& window, const Shape& base_shape, - const absl::Span& window_count_index, + const absl::Span window_count_index, const std::function&)>& f) { const int64 rank = base_shape.rank(); DimensionVector window_index(rank); diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index bb01fdd0e15..41488dcdaaa 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -568,6 +568,19 @@ StatusOr> HloInstruction::CreateFromProto( std::max(static_cast(proto.batch_group_count()), int64{1})); custom_call_instr->set_custom_call_has_side_effect( proto.custom_call_has_side_effect()); + std::vector>> + output_to_operand_aliasing; + for (const auto& aliasing : proto.custom_call_output_operand_aliasing()) { + output_to_operand_aliasing.emplace_back( + ShapeIndex(aliasing.output_shape_index().begin(), + aliasing.output_shape_index().end()), + std::pair{ + aliasing.operand_index(), + ShapeIndex(aliasing.operand_shape_index().begin(), + aliasing.operand_shape_index().end())}); + } + custom_call_instr->set_output_to_operand_aliasing( + std::move(output_to_operand_aliasing)); break; } case HloOpcode::kPad: @@ -1942,6 +1955,56 @@ Status HloInstruction::CopyAllControlDepsFrom(const HloInstruction* inst) { return Status::OK(); } +bool HloInstruction::IdenticalInternal( + const HloInstruction& other, + const std::function& + eq_operands, + const std::function& + eq_computations, + bool layout_sensitive, bool ignore_channel_id_values) const { + // An instruction is always identical to itself. + if (this == &other) { + return true; + } + + // Identical instruction must have the same opcode, shape, and identical + // operands. + if (opcode() != other.opcode()) { + return false; + } + if (!(layout_sensitive ? ShapeUtil::Equal(shape(), other.shape()) + : ShapeUtil::Compatible(shape(), other.shape()))) { + return false; + } + if (operands().size() != other.operands().size()) { + return false; + } + + // Two AllReduces are Identical if they have the same channel_id. + // Their operands don't have to be Identical. + if (!IsCrossModuleAllReduce()) { + // Use an explicit loop rather than ContainerEquals, because copying + // around std::functions may be too expensive in some cases. + for (size_t i = 0; i < operands().size(); ++i) { + if (!eq_operands(operand(i), other.operand(i))) { + return false; + } + } + } + + if (backend_config_ != other.backend_config_) { + return false; + } + + if (ignore_channel_id_values) { + if (auto channel_inst = DynCast(this)) { + return channel_inst->IdenticalSlowPathIgnoringChannelIdValues( + other, eq_computations); + } + } + return IdenticalSlowPath(other, eq_computations); +} + void HloInstruction::AppendOperand(HloInstruction* operand) { if (operand->parent() != nullptr) { DCHECK(!operand->parent()->IsMarkedAsDead(operand)) @@ -3370,6 +3433,11 @@ class HloInstruction::FusionReusesParamElements { // that. value_it = cache->find(&hlo); value_it->second = new_val; + // Fold() minimizes the UseKind value. If it is already minimum, we can + // break the loop early. + if (new_val == UseKind::kReuse) { + break; + } } } return value_it->second; @@ -3991,6 +4059,10 @@ const Shape& HloInstruction::outfeed_shape() const { return Cast(this)->outfeed_shape(); } +Shape* HloInstruction::mutable_outfeed_shape() { + return Cast(this)->mutable_outfeed_shape(); +} + const string& HloInstruction::outfeed_config() const { return Cast(this)->outfeed_config(); } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 7db128b4d34..9675a2f0f0d 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -1122,41 +1122,23 @@ class HloInstruction { const std::function& eq_computations = std::equal_to(), bool layout_sensitive = true) const { - // An instruction is always identical to itself. - if (this == &other) { - return true; - } + return IdenticalInternal(other, eq_operands, eq_computations, + layout_sensitive, + /*ignore_channel_id_values=*/false); + } - // Identical instruction must have the same opcode, shape, and identical - // operands. - if (opcode() != other.opcode()) { - return false; - } - if (!(layout_sensitive ? ShapeUtil::Equal(shape(), other.shape()) - : ShapeUtil::Compatible(shape(), other.shape()))) { - return false; - } - if (operands().size() != other.operands().size()) { - return false; - } - - // Two AllReduces are Identical if they have the same channel_id. - // Their operands don't have to be Identical. - if (!IsCrossModuleAllReduce()) { - // Use an explicit loop rather than ContainerEquals, because copying - // around std::functions may be too expensive in some cases. - for (size_t i = 0; i < operands().size(); ++i) { - if (!eq_operands(operand(i), other.operand(i))) { - return false; - } - } - } - - if (backend_config_ != other.backend_config_) { - return false; - } - - return IdenticalSlowPath(other, eq_computations); + // Same as Identical() but ignores channel ID value mismatches, as long as + // both have channel IDs or neither has a channel ID. + bool IdenticalIgnoringChannelIdValues( + const HloInstruction& other, + const std::function& + eq_operands = std::equal_to(), + const std::function& + eq_computations = std::equal_to(), + bool layout_sensitive = true) const { + return IdenticalInternal(other, eq_operands, eq_computations, + layout_sensitive, + /*ignore_channel_id_values=*/true); } // Generates a hash value of an HLO instruction. Hash considers @@ -1787,6 +1769,9 @@ class HloInstruction { // Returns the shape for the Outfeed instruction. const Shape& outfeed_shape() const; + // Returns the mutable shape for the Outfeed instruction. + Shape* mutable_outfeed_shape(); + // Delegates to HloCollectiveInstruction::replica_groups. const std::vector& replica_groups() const; @@ -1926,6 +1911,8 @@ class HloInstruction { // by factory methods. HloInstruction(HloOpcode opcode, const Shape& shape); + void RemoveAllOperands() { operands_.clear(); } + void RemoveOperandAt(int index) { operands_.erase(operands_.begin() + index); } @@ -1962,6 +1949,14 @@ class HloInstruction { private: friend class HloComputation; + bool IdenticalInternal( + const HloInstruction& other, + const std::function& + eq_operands, + const std::function& + eq_computations, + bool layout_sensitive, bool ignore_channel_id_values) const; + // Implementation for non-common logic of CloneWithNewOperands. virtual std::unique_ptr CloneWithNewOperandsImpl( const Shape& shape, absl::Span new_operands, diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index df225e27aad..45b2d885d8e 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -447,7 +447,10 @@ std::vector HloChannelInstruction::ExtraAttributesToStringImpl( bool HloChannelInstruction::IdenticalSlowPath( const HloInstruction& other, const std::function& - /*eq_computations*/) const { + eq_computations) const { + if (!IdenticalSlowPathIgnoringChannelIdValues(other, eq_computations)) { + return false; + } const auto& casted_other = static_cast(other); return channel_id() == casted_other.channel_id(); } @@ -475,7 +478,7 @@ std::vector HloSendRecvInstruction::ExtraAttributesToStringImpl( return attrs; } -bool HloSendRecvInstruction::IdenticalSlowPath( +bool HloSendRecvInstruction::IdenticalSlowPathIgnoringChannelIdValues( const HloInstruction& other, const std::function& eq_computations) const { @@ -596,13 +599,14 @@ std::vector HloCollectiveInstruction::ExtraAttributesToStringImpl( return result; } -bool HloCollectiveInstruction::IdenticalSlowPath( +bool HloCollectiveInstruction::IdenticalSlowPathIgnoringChannelIdValues( const HloInstruction& other, const std::function& eq_computations) const { const auto& casted_other = static_cast(other); - return HloChannelInstruction::IdenticalSlowPath(other, eq_computations) && + return HloChannelInstruction::IdenticalSlowPathIgnoringChannelIdValues( + other, eq_computations) && constrain_layout() == casted_other.constrain_layout() && absl::c_equal(replica_groups(), casted_other.replica_groups(), [](const ReplicaGroup& a, const ReplicaGroup& b) { @@ -645,12 +649,13 @@ HloInstructionProto HloAllGatherInstruction::ToProto() const { return proto; } -bool HloAllGatherInstruction::IdenticalSlowPath( +bool HloAllGatherInstruction::IdenticalSlowPathIgnoringChannelIdValues( const HloInstruction& other, const std::function& eq_computations) const { const auto& casted_other = static_cast(other); - return HloCollectiveInstruction::IdenticalSlowPath(other, eq_computations) && + return HloCollectiveInstruction::IdenticalSlowPathIgnoringChannelIdValues( + other, eq_computations) && all_gather_dimension_ == casted_other.all_gather_dimension() && use_global_device_ids() == casted_other.use_global_device_ids(); } @@ -691,12 +696,13 @@ std::vector HloAllReduceInstruction::ExtraAttributesToStringImpl( return result; } -bool HloAllReduceInstruction::IdenticalSlowPath( +bool HloAllReduceInstruction::IdenticalSlowPathIgnoringChannelIdValues( const HloInstruction& other, const std::function& eq_computations) const { const auto& casted_other = static_cast(other); - return HloCollectiveInstruction::IdenticalSlowPath(other, eq_computations) && + return HloCollectiveInstruction::IdenticalSlowPathIgnoringChannelIdValues( + other, eq_computations) && constrain_layout() == casted_other.constrain_layout() && use_global_device_ids() == casted_other.use_global_device_ids() && eq_computations(to_apply(), casted_other.to_apply()); @@ -747,12 +753,13 @@ std::vector HloAllToAllInstruction::ExtraAttributesToStringImpl( return result; } -bool HloAllToAllInstruction::IdenticalSlowPath( +bool HloAllToAllInstruction::IdenticalSlowPathIgnoringChannelIdValues( const HloInstruction& other, const std::function& eq_computations) const { const auto& casted_other = static_cast(other); - return HloCollectiveInstruction::IdenticalSlowPath(other, eq_computations) && + return HloCollectiveInstruction::IdenticalSlowPathIgnoringChannelIdValues( + other, eq_computations) && split_dimension_ == casted_other.split_dimension(); } @@ -788,7 +795,7 @@ HloCollectivePermuteInstruction::ExtraAttributesToStringImpl( return result; } -bool HloCollectivePermuteInstruction::IdenticalSlowPath( +bool HloCollectivePermuteInstruction::IdenticalSlowPathIgnoringChannelIdValues( const HloInstruction& other, const std::function& eq_computations) const { @@ -797,7 +804,8 @@ bool HloCollectivePermuteInstruction::IdenticalSlowPath( } const auto& casted_other = static_cast(other); - return HloChannelInstruction::IdenticalSlowPath(other, eq_computations) && + return HloChannelInstruction::IdenticalSlowPathIgnoringChannelIdValues( + other, eq_computations) && absl::c_equal(source_target_pairs(), casted_other.source_target_pairs(), [](const std::pair& a, @@ -2387,6 +2395,16 @@ HloInstructionProto HloCustomCallInstruction::ToProto() const { } } proto.set_custom_call_has_side_effect(custom_call_has_side_effect_); + for (const auto& pair : output_to_operand_aliasing_) { + auto aliasing = proto.add_custom_call_output_operand_aliasing(); + aliasing->set_operand_index(pair.second.first); + for (int64 index : pair.first) { + aliasing->add_output_shape_index(index); + } + for (int64 index : pair.second.second) { + aliasing->add_operand_shape_index(index); + } + } return proto; } @@ -2424,6 +2442,16 @@ std::vector HloCustomCallInstruction::ExtraAttributesToStringImpl( if (custom_call_has_side_effect_) { extra.push_back("custom_call_has_side_effect=true"); } + if (!output_to_operand_aliasing_.empty()) { + std::vector pair_strings; + for (const auto& pair : output_to_operand_aliasing_) { + pair_strings.push_back(StrCat(pair.first.ToString(), ": (", + pair.second.first, ", ", + pair.second.second.ToString(), ")")); + } + extra.push_back(StrCat("output_to_operand_aliasing={", + StrJoin(pair_strings, ", "), "}")); + } return extra; } @@ -2467,6 +2495,10 @@ bool HloCustomCallInstruction::IdenticalSlowPath( casted_other.custom_call_has_side_effect()) { return false; } + if (output_to_operand_aliasing_ != + casted_other.output_to_operand_aliasing()) { + return false; + } // Note: backend_config comparison is done in Identical, which is the // intended/exposed way to compare computations, and so not repeated here. return custom_call_target_ == casted_other.custom_call_target_; @@ -2491,6 +2523,7 @@ HloCustomCallInstruction::CloneWithNewOperandsImpl( cloned->set_feature_group_count(feature_group_count_); cloned->set_batch_group_count(batch_group_count_); cloned->set_custom_call_has_side_effect(custom_call_has_side_effect_); + cloned->set_output_to_operand_aliasing(output_to_operand_aliasing_); return std::move(cloned); } diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index 17368e8b714..88e874347bd 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -244,6 +244,15 @@ class HloChannelInstruction : public HloInstruction { absl::optional channel_id() const { return channel_id_; } void set_channel_id(const absl::optional& channel_id); + // Whether this instruction is identical to `other` except for the values of + // channel IDs, as long as both have channel IDs or neither has a channel ID. + virtual bool IdenticalSlowPathIgnoringChannelIdValues( + const HloInstruction& other, + const std::function& + eq_computations) const { + return channel_id_.has_value() == other.channel_id().has_value(); + } + protected: explicit HloChannelInstruction(HloOpcode opcode, const Shape& shape, const absl::optional& channel_id); @@ -252,10 +261,13 @@ class HloChannelInstruction : public HloInstruction { std::vector ExtraAttributesToStringImpl( const HloPrintOptions& options) const override; + + // Do not override IdenticalSlowPath(). Override + // IdenticalSlowPathIgnoringChannelIdValues() instead. bool IdenticalSlowPath( const HloInstruction& other, const std::function& - eq_computations) const override; + eq_computations) const final; absl::optional channel_id_; }; @@ -275,7 +287,7 @@ class HloSendRecvInstruction : public HloChannelInstruction { private: std::vector ExtraAttributesToStringImpl( const HloPrintOptions& options) const override; - bool IdenticalSlowPath( + bool IdenticalSlowPathIgnoringChannelIdValues( const HloInstruction& other, const std::function& eq_computations) const override; @@ -363,7 +375,7 @@ class HloCollectiveInstruction : public HloChannelInstruction { std::vector ExtraAttributesToStringImpl( const HloPrintOptions& options) const override; - bool IdenticalSlowPath( + bool IdenticalSlowPathIgnoringChannelIdValues( const HloInstruction& other, const std::function& eq_computations) const override; @@ -390,7 +402,7 @@ class HloAllGatherInstruction : public HloCollectiveInstruction { HloInstructionProto ToProto() const override; private: - bool IdenticalSlowPath( + bool IdenticalSlowPathIgnoringChannelIdValues( const HloInstruction& other, const std::function& eq_computations) const override; @@ -434,7 +446,7 @@ class HloAllReduceInstruction : public HloCollectiveInstruction { HloInstructionProto ToProto() const override; private: - bool IdenticalSlowPath( + bool IdenticalSlowPathIgnoringChannelIdValues( const HloInstruction& other, const std::function& eq_computations) const override; @@ -471,7 +483,7 @@ class HloAllToAllInstruction : public HloCollectiveInstruction { HloInstructionProto ToProto() const override; private: - bool IdenticalSlowPath( + bool IdenticalSlowPathIgnoringChannelIdValues( const HloInstruction& other, const std::function& eq_computations) const override; @@ -501,7 +513,7 @@ class HloCollectivePermuteInstruction : public HloChannelInstruction { private: std::vector ExtraAttributesToStringImpl( const HloPrintOptions& options) const override; - bool IdenticalSlowPath( + bool IdenticalSlowPathIgnoringChannelIdValues( const HloInstruction& other, const std::function& eq_computations) const override; @@ -1182,6 +1194,8 @@ class HloOutfeedInstruction : public HloInstruction { absl::string_view outfeed_config); // Returns the shape for the Outfeed instruction. const Shape& outfeed_shape() const { return outfeed_shape_; } + // Returns the mutable shape for the Outfeed instruction. + Shape* mutable_outfeed_shape() { return &outfeed_shape_; } // Returns the config for the Outfeed instruction. const string& outfeed_config() const { return outfeed_config_; } void set_outfeed_config(const string& config) { outfeed_config_ = config; } @@ -1416,6 +1430,20 @@ class HloCustomCallInstruction : public HloInstruction { CHECK(layout_constrained()); return operand_shapes_with_layout_; } + // Gets a list of output/operand buffer pairs that alias each other, where the + // output buffer is represented as a ShapeIndex, and the operand buffer is + // represented as the operand index and the ShapeIndex. By default this list + // is empty. + const std::vector>>& + output_to_operand_aliasing() const { + return output_to_operand_aliasing_; + } + // Sets the list of output/operand buffer pairs that alias each other. + void set_output_to_operand_aliasing( + std::vector>> + aliasing) { + output_to_operand_aliasing_ = std::move(aliasing); + } private: std::vector ExtraAttributesToStringImpl( @@ -1444,6 +1472,10 @@ class HloCustomCallInstruction : public HloInstruction { std::vector operand_shapes_with_layout_; // Whether this custom call has a side-effect. bool custom_call_has_side_effect_; + // A list of output/operand buffer pairs that alias each other. See comment of + // output_to_operand_aliasing(). + std::vector>> + output_to_operand_aliasing_; }; class HloPadInstruction : public HloInstruction { diff --git a/tensorflow/compiler/xla/service/hlo_lexer.cc b/tensorflow/compiler/xla/service/hlo_lexer.cc index 749193a83ef..3c44b390969 100644 --- a/tensorflow/compiler/xla/service/hlo_lexer.cc +++ b/tensorflow/compiler/xla/service/hlo_lexer.cc @@ -387,6 +387,12 @@ TokKind HloLexer::LexNumberOrPattern() { return TokKind::kNegInf; } + static LazyRE2 neg_nan = {"-nan"}; + if (RE2::Consume(&consumable, *neg_nan)) { + current_ptr_ = consumable.begin(); + return TokKind::kNegNan; + } + return TokKind::kError; } @@ -502,6 +508,8 @@ string TokKindToString(TokKind kind) { return "kw_nan"; case TokKind::kw_inf: return "kw_inf"; + case TokKind::kNegNan: + return "kNegNan"; case TokKind::kNegInf: return "kNegInf"; case TokKind::kPrimitiveType: diff --git a/tensorflow/compiler/xla/service/hlo_lexer.h b/tensorflow/compiler/xla/service/hlo_lexer.h index b8c7debaab4..4068ad76581 100644 --- a/tensorflow/compiler/xla/service/hlo_lexer.h +++ b/tensorflow/compiler/xla/service/hlo_lexer.h @@ -65,6 +65,7 @@ enum class TokKind { kw_nan, kw_inf, + kNegNan, // -nan kNegInf, // -inf // Typed tokens. diff --git a/tensorflow/compiler/xla/service/hlo_module_config.cc b/tensorflow/compiler/xla/service/hlo_module_config.cc index eaed707607d..8158d198799 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.cc +++ b/tensorflow/compiler/xla/service/hlo_module_config.cc @@ -51,12 +51,14 @@ string HloModuleConfig::compilation_cache_key() const { string key = absl::StrCat("profiling=", hlo_profiling_enabled()); StrAppend(&key, "::("); std::vector params; - for (const ShapeLayout& param_layout : - entry_computation_layout_->parameter_layouts()) { - params.push_back(param_layout.shape().DebugString()); + if (entry_computation_layout_.has_value()) { + for (const ShapeLayout& param_layout : + entry_computation_layout_->parameter_layouts()) { + params.push_back(param_layout.shape().DebugString()); + } + StrAppend(&key, absl::StrJoin(params, ", "), ") => ", + entry_computation_layout_->result_shape().SerializeAsString()); } - StrAppend(&key, absl::StrJoin(params, ", "), ") => ", - entry_computation_layout_->result_shape().SerializeAsString()); if (seed() != 0) { // TODO(b/32083678): force recompilation to reset global state. static std::atomic counter{0}; diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index e2bbda3a607..d04a7695f3c 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -212,6 +212,7 @@ class HloParserImpl : public HloParser { kEnum, kRandomAlgorithm, kAliasing, + kInstructionAliasing, }; struct AttrConfig { @@ -346,6 +347,12 @@ class HloParserImpl : public HloParser { // fails. bool ParseAliasing(AliasingData* data); + // Parses the per-instruction aliasing information from string `s`, returns + // `false` if it fails. + bool ParseInstructionOutputOperandAliasing( + std::vector>>* + aliasing_output_operand_pairs); + bool ParseShapeIndex(ShapeIndex* out); // Returns true if the current token is the beginning of a shape. @@ -598,6 +605,58 @@ bool HloParserImpl::ParseAliasing(AliasingData* data) { return true; } +bool HloParserImpl::ParseInstructionOutputOperandAliasing( + std::vector>>* + aliasing_output_operand_pairs) { + if (!ParseToken( + TokKind::kLbrace, + "Expects '{' at the start of instruction aliasing description")) { + return false; + } + + while (lexer_.GetKind() != TokKind::kRbrace) { + ShapeIndex out; + if (!ParseShapeIndex(&out)) { + return false; + } + std::string errmsg = + "Expected format: : (, " + ")"; + if (!ParseToken(TokKind::kColon, errmsg)) { + return false; + } + + if (!ParseToken(TokKind::kLparen, errmsg)) { + return false; + } + int64 operand_index; + ParseInt64(&operand_index); + if (!ParseToken(TokKind::kComma, errmsg)) { + return false; + } + ShapeIndex operand_shape_index; + if (!ParseShapeIndex(&operand_shape_index)) { + return false; + } + + aliasing_output_operand_pairs->emplace_back( + out, std::pair{operand_index, operand_shape_index}); + if (!ParseToken(TokKind::kRparen, errmsg)) { + return false; + } + + if (!EatIfPresent(TokKind::kComma)) { + break; + } + } + if (!ParseToken( + TokKind::kRbrace, + "Expects '}' at the end of instruction aliasing description")) { + return false; + } + return true; +} + // ::= 'HloModule' name computations bool HloParserImpl::ParseHloModule(HloModule* module) { if (lexer_.GetKind() != TokKind::kw_HloModule) { @@ -1777,6 +1836,8 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, optional> operand_layout_constraints; optional custom_call_has_side_effect; optional to_apply; + optional>>> + output_to_operand_aliasing; attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString, &custom_call_target}; attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window}; @@ -1792,6 +1853,9 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, &custom_call_has_side_effect}; attrs["to_apply"] = {/*required=*/false, AttrTy::kHloComputation, &to_apply}; + attrs["output_to_operand_aliasing"] = {/*required=*/false, + AttrTy::kInstructionAliasing, + &output_to_operand_aliasing}; if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } @@ -1861,6 +1925,10 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, custom_call_instr->set_custom_call_has_side_effect( *custom_call_has_side_effect); } + if (output_to_operand_aliasing.has_value()) { + custom_call_instr->set_output_to_operand_aliasing( + std::move(*output_to_operand_aliasing)); + } break; } case HloOpcode::kDot: { @@ -2649,6 +2717,7 @@ bool HloParserImpl::ParseDenseLiteral(Literal* literal, const Shape& shape) { case TokKind::kInt: case TokKind::kDecimal: case TokKind::kw_nan: + case TokKind::kNegNan: case TokKind::kw_inf: case TokKind::kNegInf: { add_one_elem_seen(); @@ -3223,6 +3292,19 @@ bool HloParserImpl::ParseAttributeHelper( ->emplace(aliasing_data); return true; } + case AttrTy::kInstructionAliasing: { + std::vector>> + aliasing_output_operand_pairs; + if (!ParseInstructionOutputOperandAliasing( + &aliasing_output_operand_pairs)) { + return false; + } + static_cast>>>*>( + attr_out_ptr) + ->emplace(std::move(aliasing_output_operand_pairs)); + return true; + } } }(); if (!success) { @@ -4293,6 +4375,9 @@ bool HloParserImpl::ParseDouble(double* result) { case TokKind::kw_nan: *result = std::numeric_limits::quiet_NaN(); break; + case TokKind::kNegNan: + *result = -std::numeric_limits::quiet_NaN(); + break; case TokKind::kw_inf: *result = std::numeric_limits::infinity(); break; diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index 620e67c3a2f..3cb9a1c564b 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -991,6 +991,19 @@ ENTRY %CustomCallWithHasSideEffect (p0: (f32[2,2], f32[42,2,3]), p1: f32[123,4]) ROOT %custom-call = (f32[1,2,3]{0,2,1}, f32[1,2,3]{1,2,0}) custom-call((f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", custom_call_has_side_effect=true } +)" +}, +// CustomCallWithAliasing +{ +"CustomCallWithAliasing", +R"(HloModule CustomCallWithAliasing + +ENTRY %CustomCallWithAliasing (p0: (f32[2,2], f32[42,2,3]), p1: f32[123,4]) -> (f32[123,4], f32[2,2], f32[1,2,3]) { + %p0 = (f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) parameter(0) + %p1 = f32[123,4]{0,1} parameter(1) + ROOT %custom-call = (f32[123,4]{0,1}, f32[2,2]{0,1}, f32[1,2,3]{0,1,2}) custom-call((f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", output_to_operand_aliasing={{0}: (1, {}), {1}: (0, {0})} +} + )" }, // Parse c64 literal @@ -2107,6 +2120,19 @@ ENTRY %ShortConstant.v4 () -> f32[67,89] { EXPECT_EQ(result.ValueOrDie()->ToString(HloPrintOptions()), original); } +TEST_F(HloParserTest, NegativeNan) { + const string original = R"(HloModule NegativeNan_module + +ENTRY %NegativeNan () -> bf16[2] { + ROOT %constant = bf16[2]{0} constant({-nan, -nan}) +} + +)"; + auto result = ParseAndReturnUnverifiedModule(original); + EXPECT_EQ(Status::OK(), result.status()); + EXPECT_EQ(result.ValueOrDie()->ToString(HloPrintOptions()), original); +} + TEST_F(HloParserTest, AttributesAnyOrder) { const string original = R"(HloModule any_order_module diff --git a/tensorflow/compiler/xla/service/hlo_pass_fix.h b/tensorflow/compiler/xla/service/hlo_pass_fix.h index a22a394c6a4..1533c53ba45 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_fix.h +++ b/tensorflow/compiler/xla/service/hlo_pass_fix.h @@ -43,11 +43,13 @@ class HloPassFix : public Pass { while (changed_this_iteration) { TF_ASSIGN_OR_RETURN(changed_this_iteration, Pass::Run(module)); changed |= changed_this_iteration; - VLOG(3) << "changed_this_iteration: " << changed_this_iteration; + VLOG(3) << Pass::name() << " iteration " << iteration_count + << " changed_this_iteration: " << changed_this_iteration; ++iteration_count; if (iteration_count == kLimit) { - VLOG(1) << "Unexpectedly high number of iterations in HLO passes, " - "exiting fixed point loop."; + VLOG(1) << "Unexpectedly high number of iterations in HLO passes '" + << Pass::name() << "' for module '" << module->name() + << "'. Exiting fixed point loop."; // Return false in case this is fixed point is nested. return false; } diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc index b07ab10827a..74c385f16bd 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc @@ -58,6 +58,7 @@ StatusOr HloPassPipeline::RunPassesInternal( TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, last_pass_name)); bool changed = false; for (HloPassInterface* pass : passes) { + XLA_SCOPED_LOGGING_TIMER(absl::StrCat("HLO pass: ", pass->name())); absl::string_view pass_name = pass->name(); VLOG(1) << " HLO pass " << pass_name; VLOG(2) << " Module hash " << hlo->Hash(); @@ -69,6 +70,9 @@ StatusOr HloPassPipeline::RunPassesInternal( } TF_ASSIGN_OR_RETURN(bool pass_changed, RunHelper(pass, hlo)); changed |= pass_changed; + if (pass_changed) { + VLOG(3) << " Pass caused changes" << pass->name(); + } TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, pass_name)); last_pass_name = string(pass_name); if (!pass->IsPassPipeline()) { diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index 7f974a618a8..59b1ac31e9b 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -135,12 +135,17 @@ struct Item { // The buffers used by this instruction. BufferIdList buffers_used; + bool is_skip_node = false; + private: friend class InstructionList; // Items are arranged in a doubly linked list. - Item* next; - Item* prev; + Item* next = nullptr; + Item* prev = nullptr; + + Item* prev_skip_node = nullptr; + Item* next_skip_node = nullptr; // List is ordered by position, which can however be duplicated as // new instructions are inserted. See InsertBeforeInstructions @@ -152,11 +157,23 @@ using ItemList = absl::InlinedVector; // Class which maintains an ordered list of instructions with fast insertion // before arbitrary elements. +// +// This is a skip list structure that has two lanes: express lane and slow lane. +// All nodes are presented on the slow lane but a node can be promoted into +// express lane for fast iteration. +// +// In the following case, node 2 and node + 1 are connected via an express lane. +// +--------------------------+----------->: Express lane +// | | +// node1<-> node 2 <-> .. <-> node n <-> node n+1 <->...: Slow lane +// class InstructionList { public: explicit InstructionList(const HloInstructionSequence& order) { int64 position = 0; Item* last = nullptr; + last_skip_node_ = nullptr; + first_skip_node_ = nullptr; for (HloInstruction* inst : order.instructions()) { // Add a new item to the linked list. Item* item = new Item; @@ -198,6 +215,9 @@ class InstructionList { Item* first() const { return first_; } Item* next(Item* item) const { return item->next; } + Item* first_skip_node() const { return first_skip_node_; } + Item* next_skip_node(Item* item) const { return item->next_skip_node; } + // Creates an Item for the given instruction, but doesn't add it to the list. // (Use InsertBeforeInstructions to add the Item to the list.) Item* CreateItem(HloInstruction* inst) { @@ -266,6 +286,27 @@ class InstructionList { return InsertBefore(to_insert, min_position_item); } + // Scan the list and promote nodes to express lane if should_promote(Item) + // returns true; + void PromoteNodesToSkip(std::function should_promote) { + int64 count = 0; + for (auto* item = first(); item != nullptr; item = next(item)) { + if (should_promote(item)) { + count += 1; + if (first_skip_node_ == nullptr) { + first_skip_node_ = item; + } + item->is_skip_node = true; + item->prev_skip_node = last_skip_node_; + if (last_skip_node_ != nullptr) { + last_skip_node_->next_skip_node = item; + } + last_skip_node_ = item; + } + } + VLOG(1) << " Rematerialization has " << count << " items in express lane"; + } + void InsertAfterInstructions(Item* to_insert, absl::Span after_instructions) { VLOG(3) << "InsertAfterInstructions: " << to_insert->instruction->name() @@ -301,6 +342,44 @@ class InstructionList { void InsertBefore(Item* item, Item* before) { VLOG(3) << "InsertBefore: " << item->instruction->name() << " before " << before->instruction->name(); + // Always place new nodes on express lane for the ease of implementation. + item->is_skip_node = true; + // Find the next express node starting from 'before'. Set up the node's + // express pointers. + Item* cursor = before; + while (cursor != nullptr && !cursor->is_skip_node) { + cursor = cursor->next; + } + CHECK(cursor == nullptr || cursor->is_skip_node); + if (cursor == nullptr) { + // + // last_skip_node_<---+ : express lane + // | + // ...<->`item`<-> .. <-> `cursor`(null) : slow lane + // + // Reached the end. Set the prev_express to last_skip_node, and reset + // last_skip. + item->prev_skip_node = last_skip_node_; + item->next_skip_node = nullptr; + last_skip_node_ = item; + } else { + // + // <-+------------+----------------+---------> : express lane + // | | | + // prev_express..<->`item`<-> .. <-> `cursor` <-> ...: slow lane + // + // Reached the next skip node, sets up express pointers accordingly. + CHECK(cursor->is_skip_node); + item->prev_skip_node = cursor->prev_skip_node; + if (item->prev_skip_node != nullptr) { + item->prev_skip_node->next_skip_node = item; + } + item->next_skip_node = cursor; + cursor->prev_skip_node = item; + } + if (first_skip_node_ == cursor) { + first_skip_node_ = item; + } // Insert new item into linked list. item->prev = before->prev; item->next = before; @@ -319,6 +398,12 @@ class InstructionList { Item* first_; + // First skip node of this list. + Item* first_skip_node_; + + // Last skip node of this list. + Item* last_skip_node_; + // Item for each instruction. absl::flat_hash_map item_map_; }; @@ -460,6 +545,15 @@ class MemoryUsageTracker { // values. int64 memory_usage() const { return memory_usage_; } + // + int64 AllocatedSize(Item* item) const { + int64 size = 0; + for (auto buffer_id : item->buffers_defined) { + size += AllocatedSize(buffer_id); + } + return size; + } + // Check invariants of the data structure. This is expensive to call. bool Check() const; @@ -652,7 +746,6 @@ MemoryUsageTracker::MemoryUsageTracker( .CreateFlattenedSet(); absl::flat_hash_map logical_buffer_to_buffer_id; - for (auto* item = instruction_list_.first(); item != nullptr; item = instruction_list_.next(item)) { const HloInstruction* const instruction = item->instruction; @@ -1186,8 +1279,9 @@ MemoryUsageTracker::PickRematerializationCandidates( VLOG(5) << "Picking candidate block with size in [" << min_block_size << ", " << max_block_size << "]"; - for (auto* start_item = instruction_list.first(); start_item != nullptr; - start_item = instruction_list.next(start_item)) { + for (auto* start_item = instruction_list.first_skip_node(); + start_item != nullptr; + start_item = instruction_list.next_skip_node(start_item)) { std::vector block = GetInitialBlock(instruction_list, *this, start_item, min_block_size); if (block.size() < min_block_size) { @@ -1427,12 +1521,13 @@ StatusOr CompressInstruction(MemoryUsageTracker* memory_tracker, << ") to" << compact_shape.ToString(true); HloComputation* computation = best->parent(); - HloInstruction* compressed = computation->AddInstruction( - HloInstruction::CreateUnary(compact_shape, HloOpcode::kCopy, best)); + HloInstruction::CreateUnary(compact_shape, HloOpcode::kCopy, best), + /*new_name=*/best->name() + ".remat_compressed"); HloInstruction* uncompressed = computation->AddInstruction( - HloInstruction::CreateUnary(best->shape(), HloOpcode::kCopy, compressed)); + HloInstruction::CreateUnary(best->shape(), HloOpcode::kCopy, compressed), + /*new_name=*/best->name() + ".remat_uncompressed"); Item* compressed_item = instruction_list->CreateItem(compressed); compressed_item->placed = true; @@ -1566,7 +1661,7 @@ StatusOr HloRematerialization::CalledComputationsMemoryUsage( StatusOr HloRematerialization::RematerializeComputation( HloComputation* computation, HloSchedule* schedule, - int64 memory_limit_bytes) { + int64 memory_limit_bytes, int64 min_remat_size) { VLOG(1) << "Rematerializing computation " << computation->name() << " with limit " << HumanReadableNumBytes(memory_limit_bytes); VLOG(1) << "peak memory usage is " @@ -1577,6 +1672,10 @@ StatusOr HloRematerialization::RematerializeComputation( MemoryUsageTracker memory_tracker( computation, size_function_, compact_shape_function_, *points_to_analysis_, instruction_list, mode_); + + instruction_list.PromoteNodesToSkip([&](Item* item) { + return memory_tracker.AllocatedSize(item) >= min_remat_size; + }); bool changed = false; // If the rematerialization makes the source instruction dead, then the @@ -1622,43 +1721,46 @@ StatusOr HloRematerialization::RematerializeComputation( // single instruction rematerialization is considered first. int min_block_size = 1; int max_block_size = 1; + // Only trigger rematerialization when the memory usage changes. + if (memory_tracker.AllocatedSize(item) + callee_usage > 0) { + while (memory_tracker.memory_usage() + callee_usage > + memory_limit_bytes) { + VLOG(2) << "Over memory limit at instruction " << instruction->name() + << ", using " + << HumanReadableNumBytes(memory_tracker.memory_usage() + + callee_usage) + << ", limit is " << HumanReadableNumBytes(memory_limit_bytes); - while (memory_tracker.memory_usage() + callee_usage > memory_limit_bytes) { - VLOG(2) << "Over memory limit at instruction " << instruction->name() - << ", using " - << HumanReadableNumBytes(memory_tracker.memory_usage() + - callee_usage) - << ", limit is " << HumanReadableNumBytes(memory_limit_bytes); + TF_ASSIGN_OR_RETURN( + InstructionsAdded instructions_added, + RematerializeBestBlock(min_block_size, max_block_size, + &memory_tracker, &instruction_list, + memory_limit_bytes, &rematerializable_map, + &remat_move_instructions)); + net_instructions_added += instructions_added.net_instructions_added; + remat_count += instructions_added.remat_count; - TF_ASSIGN_OR_RETURN(InstructionsAdded instructions_added, - RematerializeBestBlock( - min_block_size, max_block_size, &memory_tracker, - &instruction_list, memory_limit_bytes, - &rematerializable_map, &remat_move_instructions)); - net_instructions_added += instructions_added.net_instructions_added; - remat_count += instructions_added.remat_count; - - VLOG(1) << "memory_usage after rematerialization = " - << HumanReadableNumBytes(memory_tracker.memory_usage()); - if (instructions_added.remat_count == 0) { - // Unable to find a block to rematerialize. - // Consider doubling the block size. - min_block_size = max_block_size + 1; - max_block_size = 2 * max_block_size; - } else { - // Found a valid block. Reset to start looking for single instructions - // again. - max_rematerialized_block_size_ = - std::max(max_rematerialized_block_size_, max_block_size); - changed = true; - min_block_size = 1; - max_block_size = 1; - } - if (max_block_size > block_size_limit_) { - break; + VLOG(1) << "memory_usage after rematerialization = " + << HumanReadableNumBytes(memory_tracker.memory_usage()); + if (instructions_added.remat_count == 0) { + // Unable to find a block to rematerialize. + // Consider doubling the block size. + min_block_size = max_block_size + 1; + max_block_size = 2 * max_block_size; + } else { + // Found a valid block. Reset to start looking for single instructions + // again. + max_rematerialized_block_size_ = + std::max(max_rematerialized_block_size_, max_block_size); + changed = true; + min_block_size = 1; + max_block_size = 1; + } + if (max_block_size > block_size_limit_) { + break; + } } } - const CallSite* callsite = call_graph_node.GetCallSite(instruction); if (callsite != nullptr && callsite->context() == CallContext::kSequential && @@ -1683,10 +1785,12 @@ StatusOr HloRematerialization::RematerializeComputation( TF_ASSIGN_OR_RETURN( bool subcomputation_changed, RematerializeComputation(called_computation, schedule, - subcomputation_memory_limit_bytes)); + subcomputation_memory_limit_bytes, + min_remat_size)); changed |= subcomputation_changed; } } + TF_ASSIGN_OR_RETURN(callee_usage, CalledComputationsMemoryUsage(instruction)); } @@ -1786,14 +1890,12 @@ StatusOr HloRematerialization::Run(HloModule* module) { module_output_size; VLOG(1) << "Peak memory usage of module (before): " << HumanReadableNumBytes(before_peak_memory); - // Subcomputations called by the entry computation will also be // rematerialized. TF_ASSIGN_OR_RETURN( bool changed, RematerializeComputation(module->entry_computation(), &module->schedule(), - adjusted_memory_limit_bytes)); - + adjusted_memory_limit_bytes, min_remat_size_)); // Rematerialization can introduce dead code. This occurs if all uses of an // instruction are replaced with rematerializations of the instruction. @@ -1838,7 +1940,6 @@ StatusOr HloRematerialization::Run(HloModule* module) { HumanReadableNumBytes(memory_limit_bytes_), memory_limit_bytes_, HumanReadableNumBytes(current_peak_memory), current_peak_memory); } - return changed; } diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h index 72221fa8a32..878bb2a8eef 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.h +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h @@ -85,7 +85,8 @@ class HloRematerialization : public HloModulePass { RematerializationSizes* sizes, RematerializationPass pass_location, int block_size_limit, CompactShapeFunction compact_shape_function = nullptr, - RematerializationMode mode = RematerializationMode::kRecomputeAndCompress) + RematerializationMode mode = RematerializationMode::kRecomputeAndCompress, + int64 min_remat_size = 0) : size_function_(size_function), memory_limit_bytes_(memory_limit_bytes), sizes_(sizes), @@ -94,7 +95,8 @@ class HloRematerialization : public HloModulePass { compact_shape_function_(compact_shape_function == nullptr ? DefaultCompactShapeFunction : std::move(compact_shape_function)), - mode_(mode) {} + mode_(mode), + min_remat_size_(min_remat_size) {} ~HloRematerialization() override = default; absl::string_view name() const override { return "rematerialization"; } @@ -114,7 +116,8 @@ class HloRematerialization : public HloModulePass { // and inserted into 'order'. virtual StatusOr RematerializeComputation(HloComputation* computation, HloSchedule* schedule, - int64 memory_limit_bytes); + int64 memory_limit_bytes, + int64 min_remat_size); // Computes and returns the peak memory used by the given computation. The // peak memory is the maximum total size of all live HLO instruction values at @@ -185,6 +188,8 @@ class HloRematerialization : public HloModulePass { int max_rematerialized_block_size_ = 0; RematerializationMode mode_; + + int64 min_remat_size_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc index 5176e2f99e5..35f39e9a342 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc @@ -41,7 +41,8 @@ using ::testing::_; class HloRematerializationTest : public RematerializationTestBase { protected: StatusOr RunHloRematerialization(int64 memory_limit_bytes, - HloModule* module) { + HloModule* module, + int64 min_remat_size = 0) { TF_EXPECT_OK(verifier().Run(module).status()); HloMemoryScheduler scheduler( [](const BufferValue& buffer) { return ByteSizeOf(buffer.shape()); }, @@ -51,7 +52,9 @@ class HloRematerializationTest : public RematerializationTestBase { ByteSizeOf, memory_limit_bytes, /*sizes=*/nullptr, HloRematerialization::RematerializationPass::kPreFusion, - /*block_size_limit=*/1); + /*block_size_limit=*/1, nullptr, + HloRematerialization::RematerializationMode::kRecomputeAndCompress, + min_remat_size); return remat.Run(module); } }; @@ -96,6 +99,26 @@ TEST_F(HloRematerializationTest, SingleComputation) { remat_bcast); } +// Test rematerialization of a single computation that contains nodes that +// doesn't contain node worth using remat. +TEST_F(HloRematerializationTest, SingleComputationNoWorthRemat) { + auto module = CreateNewVerifiedModule(); + HloComputation* computation = + module->AddEntryComputation(MakeRematerializableComputation()); + + // Find and save the original broadcast instruction which should be + // rematerialized. + const HloInstruction* slice = computation->root_instruction(); + ASSERT_THAT(slice, op::Slice(op::Concatenate(op::Broadcast(_), _))); + + // Set the minimum remat size to 14KiB, meaning no nodes should be remat. + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/14 * 1024, module.get(), + /*min_remat_size=*/14 * 1024)); + EXPECT_FALSE(changed); +} + // Test rematerialization of a single computation produced by // MakeRematerializableComputation but with a sufficiently high memory limit // such that no instructions are rematerialized. @@ -577,17 +600,67 @@ class CompressingRematerializationTest : public RematerializationTestBase { } StatusOr RunHloRematerialization(int64 memory_limit_bytes, - HloModule* module) { + HloModule* module, + int64 min_remat_size = 0) { TF_EXPECT_OK(verifier().Run(module).status()); HloRematerialization remat( ShapeSizePadMinorTo64, memory_limit_bytes, /*sizes=*/nullptr, HloRematerialization::RematerializationPass::kPreFusion, - /*block_size_limit=*/1, ChooseCompactLayoutForShape); + /*block_size_limit=*/1, ChooseCompactLayoutForShape, + HloRematerialization::RematerializationMode::kCompressOnly, + min_remat_size); return remat.Run(module); } }; +// Test rematerialization only remats big buffer that pass certain limits. +TEST_F(CompressingRematerializationTest, OnlyRematBigBuffer) { + const string& hlo_string = R"( +HloModule fusion, is_scheduled=true + +%add_float { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %x, f32[] %y) +} + +ENTRY %entry { + %param.0 = f32[] parameter(0) + %constant = f32[] constant(0) + %broadcast.0 = f32[64,2]{1,0} broadcast(f32[] %param.0), dimensions={} + %broadcast.1 = f32[10,2]{1,0} broadcast(f32[] %param.0), dimensions={} + %negate = f32[64,2]{1,0} negate(f32[64,2]{1,0} broadcast.0) + %reduce.0 = f32[] reduce(f32[64,2]{1,0} %negate, f32[] %constant), dimensions={1, 0}, to_apply=%add_float + %reduce.1 = f32[] reduce(f32[64,2]{1,0} %broadcast.0, f32[] %constant), dimensions={1, 0}, to_apply=%add_float + %reduce.2 = f32[] reduce(f32[10,2]{1,0} %broadcast.1, f32[] %constant), dimensions={1, 0}, to_apply=%add_float + %add = f32[] add(f32[] %reduce.0, f32[] %reduce.1) + ROOT %add.2 = f32[] add(f32[] %add, f32[] %reduce.2) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + // Only rematerialize buffers which have shaep f32[64, 2]. Buffers with shape + // f32[10, 2] are ignored. + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( + /*memory_limit_bytes=*/30 * 1024, + module.get(), 10 * 1024)); + EXPECT_TRUE(changed); + HloInstruction* broadcast = + module->entry_computation()->GetInstructionWithName("broadcast.0"); + HloInstruction* broadcast_2 = + module->entry_computation()->GetInstructionWithName("broadcast.1"); + HloInstruction* reduce = + module->entry_computation()->GetInstructionWithName("reduce.1"); + HloInstruction* reduce_2 = + module->entry_computation()->GetInstructionWithName("reduce.2"); + EXPECT_THAT(reduce, + op::Reduce(op::Copy(op::Copy(broadcast)), op::Constant())); + EXPECT_THAT(reduce_2, op::Reduce(broadcast_2, op::Constant())); +} + // Test rematerialization of a single instruction. TEST_F(CompressingRematerializationTest, SingleRemat) { const string& hlo_string = R"( diff --git a/tensorflow/compiler/xla/service/hlo_replication_analysis.cc b/tensorflow/compiler/xla/service/hlo_replication_analysis.cc index dec119d8aba..2d1edbefd97 100644 --- a/tensorflow/compiler/xla/service/hlo_replication_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_replication_analysis.cc @@ -129,6 +129,13 @@ bool DetermineHloInstructionIsReplicated( return true; } + if (hlo->opcode() == HloOpcode::kCustomCall && + (hlo->custom_call_target() == "X64SplitLow" || + hlo->custom_call_target() == "X64SplitHigh" || + hlo->custom_call_target() == "X64Combine")) { + return all_operands_replicated(hlo); + } + if (hlo->IsElementwise() || // hlo->opcode() == HloOpcode::kConcatenate || // hlo->opcode() == HloOpcode::kConvolution || // diff --git a/tensorflow/compiler/xla/service/hlo_replication_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_replication_analysis_test.cc index c2d86e808c2..cc0f4c86f4d 100644 --- a/tensorflow/compiler/xla/service/hlo_replication_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_replication_analysis_test.cc @@ -501,6 +501,36 @@ ENTRY entry { FindInstruction(module.get(), "conditional"), {1})); } +TEST_F(HloReplicationAnalysisTest, X64SplitCombine) { + const string module_str = R"( +HloModule SimpleTupleSelect + +ENTRY entry { + param = (f64[]) parameter(0) + gte = f64[] get-tuple-element(param), index=0 + param-low = f32[] custom-call(gte), custom_call_target="X64SplitLow" + param-high = f32[] custom-call(gte), custom_call_target="X64SplitHigh" + ROOT result-combine = f64[] custom-call(param-low, param-high), custom_call_target="X64Combine" +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(module_str)); + auto param = module->entry_computation()->parameter_instruction(0); + param->set_parameter_replicated_at_leaf_buffers(absl::Span{true}); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr analysis, + HloReplicationAnalysis::Run( + module.get(), /*cross_partition_spmd=*/false)); + EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "gte"), {})); + EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "param-low"), {})); + EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "param-high"), {})); + EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "result-combine"), {})); +} + TEST_F(HloReplicationAnalysisTest, SimpleTupleSelect) { const string module_str = R"( HloModule SimpleTupleSelect diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index 3a5e7ca6f40..0d71c6d49ed 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -211,8 +211,7 @@ static std::vector ExecutionInputsFromScopedShapedBuffers( *buffer_tree.mutable_element(index) = execution_input_buffer; } }); - execution_inputs.emplace_back(std::move(buffer_tree), - input_buffer.on_host_shape()); + execution_inputs.emplace_back(std::move(buffer_tree)); } return execution_inputs; } diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index 4244cdaceea..977f6ee8ea6 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -56,6 +56,13 @@ HloSharding HloSharding::PartialTile( HloSharding HloSharding::PartialTile( const Array& tile_assignment_last_dim_replicate) { + if (tile_assignment_last_dim_replicate.dimensions().back() == 1) { + auto new_tile_dims = tile_assignment_last_dim_replicate.dimensions(); + new_tile_dims.pop_back(); + auto fully_tiled = tile_assignment_last_dim_replicate; + fully_tiled.Reshape(new_tile_dims); + return HloSharding(fully_tiled); + } std::vector> sorted_groups( tile_assignment_last_dim_replicate.num_elements() / tile_assignment_last_dim_replicate.dimensions().back()); diff --git a/tensorflow/compiler/xla/service/hlo_sharding_util.cc b/tensorflow/compiler/xla/service/hlo_sharding_util.cc index e1e506b2892..18f76c5253b 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_util.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_util.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/array.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" @@ -347,13 +348,21 @@ HloSharding GatherOutputSharding(const HloSharding& index_sharding, index_dim++; } } + + if (index_sharding.ReplicateOnLastTileDim()) { + output_tile_assignment_dims.push_back( + index_sharding.tile_assignment().dimensions().back()); + } + Array new_tile_assignment = index_sharding.tile_assignment(); if (new_tile_assignment.num_elements() != Product(output_tile_assignment_dims)) { return HloSharding::Replicate(); } new_tile_assignment.Reshape(output_tile_assignment_dims); - return HloSharding::Tile(new_tile_assignment); + return index_sharding.ReplicateOnLastTileDim() + ? HloSharding::PartialTile(new_tile_assignment) + : HloSharding::Tile(new_tile_assignment); } HloSharding GatherIndexSharding(const HloSharding& output_sharding, @@ -379,13 +388,20 @@ HloSharding GatherIndexSharding(const HloSharding& output_sharding, index_tile_assignment_dims.begin() + dnums.index_vector_dim(), 1); } + if (output_sharding.ReplicateOnLastTileDim()) { + index_tile_assignment_dims.push_back( + output_sharding.tile_assignment().dimensions().back()); + } + Array new_tile_assignment = output_sharding.tile_assignment(); if (new_tile_assignment.num_elements() != Product(index_tile_assignment_dims)) { return HloSharding::Replicate(); } new_tile_assignment.Reshape(index_tile_assignment_dims); - return HloSharding::Tile(new_tile_assignment); + return output_sharding.ReplicateOnLastTileDim() + ? HloSharding::PartialTile(new_tile_assignment) + : HloSharding::Tile(new_tile_assignment); } HloSharding GatherEffectiveOutputSharding(const HloInstruction& hlo) { @@ -455,13 +471,19 @@ HloSharding ScatterIndexSharding(const HloSharding& data_sharding, if (index_tile_assignment_dims.size() < hlo->operand(1)->shape().rank()) { index_tile_assignment_dims.push_back(1); } + if (data_sharding.ReplicateOnLastTileDim()) { + index_tile_assignment_dims.push_back( + data_sharding.tile_assignment().dimensions().back()); + } Array new_tile_assignment = data_sharding.tile_assignment(); if (new_tile_assignment.num_elements() != Product(index_tile_assignment_dims)) { return HloSharding::Replicate(); } new_tile_assignment.Reshape(index_tile_assignment_dims); - return HloSharding::Tile(new_tile_assignment); + return data_sharding.ReplicateOnLastTileDim() + ? HloSharding::PartialTile(new_tile_assignment) + : HloSharding::Tile(new_tile_assignment); } HloSharding ScatterDataSharding(const HloSharding& index_sharding, @@ -481,13 +503,19 @@ HloSharding ScatterDataSharding(const HloSharding& index_sharding, index_dim++; } } + if (index_sharding.ReplicateOnLastTileDim()) { + data_tile_assignment_dims.push_back( + index_sharding.tile_assignment().dimensions().back()); + } Array new_tile_assignment = index_sharding.tile_assignment(); if (new_tile_assignment.num_elements() != Product(data_tile_assignment_dims)) { return HloSharding::Replicate(); } new_tile_assignment.Reshape(data_tile_assignment_dims); - return HloSharding::Tile(new_tile_assignment); + return index_sharding.ReplicateOnLastTileDim() + ? HloSharding::PartialTile(new_tile_assignment) + : HloSharding::Tile(new_tile_assignment); } HloSharding ScatterEffectiveIndexSharding(const HloSharding& index_sharding, @@ -614,9 +642,15 @@ absl::optional PassthroughOperandToGatherOutputOrScatterUpdate( } passthrough_tile[offset_dim] = dim_partitions; } + if (operand_sharding.ReplicateOnLastTileDim()) { + passthrough_tile.push_back( + operand_sharding.tile_assignment().dimensions().back()); + } Array tile_assignment = operand_sharding.tile_assignment(); tile_assignment.Reshape(passthrough_tile); - return HloSharding::Tile(tile_assignment); + return operand_sharding.ReplicateOnLastTileDim() + ? HloSharding::PartialTile(tile_assignment) + : HloSharding::Tile(tile_assignment); } // Inverse of PassthroughOperandToGatherOutputOrScatterUpdate. @@ -650,12 +684,19 @@ absl::optional PassthroughGatherOutputOrScatterUpdateToOperand( } passthrough_tile[i] = dim_partitions; } + + if (update_or_gather_sharding.ReplicateOnLastTileDim()) { + passthrough_tile.push_back( + update_or_gather_sharding.tile_assignment().dimensions().back()); + } Array tile_assignment = update_or_gather_sharding.tile_assignment(); if (tile_assignment.num_elements() != Product(passthrough_tile)) { return absl::nullopt; } tile_assignment.Reshape(passthrough_tile); - return HloSharding::Tile(tile_assignment); + return update_or_gather_sharding.ReplicateOnLastTileDim() + ? HloSharding::PartialTile(tile_assignment) + : HloSharding::Tile(tile_assignment); } } // namespace @@ -776,29 +817,51 @@ IdentityValueAndHloOpcodeForScatterReduceComputation( "add/or/multiply/add/min/max"); } -std::vector DevicesForSharding( - const HloSharding& sharding, const std::vector& available_devices) { - std::vector devices; - if (sharding.IsReplicated()) { - for (int64 d : available_devices) { - if (!HloSharding::IsReservedDevice(d)) { - devices.push_back(d); - } +namespace { + +void DevicesForShardingInternal( + const HloSharding& sharding, + const absl::flat_hash_set& available_devices, + absl::flat_hash_set* used) { + if (sharding.IsTuple()) { + for (const auto& subsharding : sharding.tuple_elements()) { + DevicesForShardingInternal(subsharding, available_devices, used); } - return devices; + return; } - for (int64 i : available_devices) { - if (sharding.UsesDevice(i)) { - devices.push_back(i); + if (sharding.IsReplicated()) { + for (int64 device : available_devices) { + if (!HloSharding::IsReservedDevice(device)) { + used->insert(device); + } + } + return; + } + + DCHECK(std::all_of( + sharding.tile_assignment().begin(), sharding.tile_assignment().end(), + [&](int64 device) { return available_devices.contains(device); })); + sharding.tile_assignment().Each([&](absl::Span /*indices*/, + int64 device) { used->insert(device); }); +} + +} // namespace + +std::vector DevicesForSharding( + const HloSharding& sharding, const std::vector& available_devices) { + absl::flat_hash_set available_set; + for (int64 device : available_devices) { + available_set.insert(device); + } + absl::flat_hash_set used_set; + DevicesForShardingInternal(sharding, available_set, &used_set); + std::vector devices; + for (int64 device : available_devices) { + if (used_set.contains(device)) { + devices.push_back(device); } } - DCHECK(std::all_of(sharding.tile_assignment().begin(), - sharding.tile_assignment().end(), [&](int64 device) { - return std::find(available_devices.begin(), - available_devices.end(), - device) != available_devices.end(); - })); return devices; } diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 0af2a45bfc7..4be0c5259cc 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -801,6 +801,28 @@ Status ShapeVerifier::HandleCustomCall(HloInstruction* instruction) { TF_RET_CHECK(LayoutUtil::HasLayout(operand_shape_with_layout)); } } + for (const auto& pair : custom_call->output_to_operand_aliasing()) { + TF_RET_CHECK(pair.second.first < custom_call->operand_count()) + << "Invalid aliasing operand index."; + TF_RET_CHECK(ShapeUtil::IndexIsValid( + custom_call->operand(pair.second.first)->shape(), pair.second.second)) + << "Invalid aliasing operand shape index."; + TF_RET_CHECK(ShapeUtil::IndexIsValid(custom_call->shape(), pair.first)) + << "Invalid aliasing output shape index."; + const Shape& output_subshape = + ShapeUtil::GetSubshape(custom_call->shape(), pair.first); + const Shape& operand_subshape = ShapeUtil::GetSubshape( + custom_call->operand(pair.second.first)->shape(), pair.second.second); + if (layout_sensitive_) { + TF_RET_CHECK(operand_subshape == output_subshape) + << "Different aliasing shapes: " << operand_subshape.ToString() + << " vs " << output_subshape.ToString(); + } else { + TF_RET_CHECK(ShapeUtil::Compatible(output_subshape, operand_subshape)) + << "Different aliasing shapes: " << operand_subshape.ToString() + << " vs " << output_subshape.ToString(); + } + } return Status::OK(); } @@ -1037,7 +1059,7 @@ namespace { // inputs. Status CheckMixedPrecisionOperands(const HloInstruction* instruction) { switch (instruction->opcode()) { - // White list the following opcodes for mixed-precision check, because + // Allow-list the following opcodes for mixed-precision check, because // they involve data pass through or grouping via tuples, where the // precisions of buffers can be different. case HloOpcode::kCall: @@ -1160,6 +1182,7 @@ Status ShapeVerifier::CheckShape(const HloInstruction* instruction, case HloOpcode::kCopyDone: case HloOpcode::kCopyStart: case HloOpcode::kCustomCall: + case HloOpcode::kDynamicUpdateSlice: case HloOpcode::kGetTupleElement: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc index 1f71c9586d5..0df30166a1c 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc @@ -494,6 +494,28 @@ TEST_F(HloVerifierTest, ScalarIndexDynamicUpdateSlice) { ASSERT_TRUE(status.ok()); } +TEST_F(HloVerifierTestAllowMixedPrecision, DynamicUpdateSliceMixedPrecision) { + const char* const kDynamicUpdateSliceMixedPrecision = R"( + HloModule kDynamicUpdateSliceMixedPrecision + + ENTRY %entry (parameter.0: f32[32,511,2048], parameter.1: bf16[32,511,512], parameter.2: s32[], parameter.3: s32[], parameter.4: s32[]) -> bf16[32,511,2048] { + %parameter.0 = f32[32,511,2048] parameter(0) + %parameter.1 = bf16[32,511,512] parameter(1) + %parameter.2 = s32[] parameter(2) + %parameter.3 = s32[] parameter(3) + %parameter.4 = s32[] parameter(4) + ROOT %dus = bf16[32,511,2048] dynamic-update-slice(f32[32,511,2048] %parameter.0, bf16[32,511,512] %parameter.1, s32[] %parameter.2, s32[] %parameter.3, s32[] %parameter.4) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule( + kDynamicUpdateSliceMixedPrecision)); + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("Expected instruction to have shape equal to " + "f32[32,511,2048], actual shape is bf16[32,511,2048]")); +} + TEST_F(HloVerifierTestLayoutSensitive, AddWithLayoutChangeNotAllowed) { TF_ASSERT_OK_AND_ASSIGN( auto module, ParseAndReturnUnverifiedModule(kAddWithLayoutChangeHlo)); diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index b290b1bd68b..11472f55792 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -516,11 +516,12 @@ StatusOr InstructionFusion::Run(HloModule* module) { continue; } - VLOG(5) << "Considering fusion of: " << instruction->ToString(); std::vector& sorted_operand_numbers = next_entry.second; for (int64 i : sorted_operand_numbers) { HloInstruction* operand = instruction->mutable_operand(i); + VLOG(5) << "Considering fusion of: " << instruction->ToString() + << " with operand " << operand->name(); if (!operand->IsFusible()) { VLOG(3) << "Operand (" << operand->ToString() << ") is not fusible"; @@ -601,6 +602,9 @@ StatusOr InstructionFusion::Run(HloModule* module) { VLOG(1) << FusionConfigToString(*fusion_config); module->set_config(module_config); } + + reachability_.reset(); + VLOG(1) << "Fusion count: " << fuse_count; return changed; @@ -710,4 +714,23 @@ HloInstruction::FusionKind InstructionFusion::ChooseKind( return HloInstruction::FusionKind::kLoop; } +bool InstructionFusion::ReusesOperandElements(const HloInstruction* consumer, + int64 operand_index) { + auto operand = consumer->operand(operand_index); + auto it = reused_fusion_operands_.find(consumer); + if (it != reused_fusion_operands_.end() && it->second.contains(operand)) { + return true; + } + bool reuses = consumer->ReusesOperandElements(operand_index); + // If a parameter was reused, we can cache this information. Fusion + // computations only ever grow, so it becomes more likely that a parameter is + // reused, but a reused parameter will never become *not* reused. + if (reuses) { + // We cache the operand corresponding to the fusion parameter, because the + // parameter pointers would be invalidated after the next fusion. + reused_fusion_operands_[consumer].insert(operand); + } + return reuses; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h index 90d9da48e33..d51bf700371 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.h +++ b/tensorflow/compiler/xla/service/instruction_fusion.h @@ -1,4 +1,3 @@ -#include "absl/container/flat_hash_map.h" /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); @@ -20,6 +19,8 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/fusion_queue.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -138,6 +139,11 @@ class InstructionFusion : public HloModulePass { return config_collection_mode_; } + // Returns whether 'consumer' may reuse elements of its `operand_index`th + // operand. + bool ReusesOperandElements(const HloInstruction* consumer, + int64 operand_index); + private: // The set of producers whose consumers we cannot fuse into. using HloInstructionSet = std::unordered_set; @@ -172,6 +178,11 @@ class InstructionFusion : public HloModulePass { // Configuration mode. FusionConfigCollection config_collection_mode_; + // Caches which operands are reused inside fusion computations. + absl::flat_hash_map> + reused_fusion_operands_; + TF_DISALLOW_COPY_AND_ASSIGN(InstructionFusion); }; diff --git a/tensorflow/compiler/xla/service/interpreter/BUILD b/tensorflow/compiler/xla/service/interpreter/BUILD index 3444d4cae42..c134b7ba6a6 100644 --- a/tensorflow/compiler/xla/service/interpreter/BUILD +++ b/tensorflow/compiler/xla/service/interpreter/BUILD @@ -1,3 +1,4 @@ +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load( "//tensorflow/core/platform:build_config_root.bzl", "if_static", @@ -52,6 +53,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_subcomputation_unification", "//tensorflow/compiler/xla/service:layout_assignment", "//tensorflow/compiler/xla/service:map_inliner", + "//tensorflow/compiler/xla/service:qr_expander", "//tensorflow/compiler/xla/service:reshape_mover", "//tensorflow/compiler/xla/service:triangular_solve_expander", "//tensorflow/compiler/xla/service:while_loop_simplifier", @@ -119,9 +121,9 @@ cc_library( "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core/platform:macros", "//tensorflow/core/platform:mutex", + "//tensorflow/core/platform:stream_executor_no_cuda", "//tensorflow/core/platform:types", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:span", diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc index a059482d832..3f3e74dbb62 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.cc +++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/interpreter/executable.h" #include "tensorflow/compiler/xla/service/layout_assignment.h" #include "tensorflow/compiler/xla/service/map_inliner.h" +#include "tensorflow/compiler/xla/service/qr_expander.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/service/triangular_solve_expander.h" #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" @@ -82,6 +83,7 @@ Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) { pipeline.AddPass(); pipeline.AddPass(); + pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass( diff --git a/tensorflow/compiler/xla/service/interpreter/executable_base.cc b/tensorflow/compiler/xla/service/interpreter/executable_base.cc index 745750bffe1..00998994c0a 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable_base.cc +++ b/tensorflow/compiler/xla/service/interpreter/executable_base.cc @@ -56,7 +56,7 @@ StatusOr InterpreterExecutableBase::ExecuteAsyncOnStream( } for (auto& argument : arguments) { const ShapeTree& buffers = argument.Buffers(); - argument_buffers.push_back(ShapedBuffer(buffers.shape(), buffers.shape(), + argument_buffers.push_back(ShapedBuffer(buffers.shape(), /*platform=*/nullptr, /*device_ordinal=*/device_ordinal)); auto in_it = buffers.begin(); diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD index 67bfb7da20a..9940b032558 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/BUILD +++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD @@ -1,6 +1,8 @@ # Description: # Libraries for helping construct LLVM IR for XLA backends. +load("//tensorflow:tensorflow.bzl", "filegroup") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load( "//tensorflow:tensorflow.bzl", "tf_cc_test", @@ -156,10 +158,10 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/service:elemental_ir_emitter", + "//tensorflow/compiler/xla/service:fusion_node_indexing_evaluation", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:lib", "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Core", diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc index 0371ce71874..f8514a6cba3 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc @@ -17,13 +17,13 @@ limitations under the License. #include #include +#include -#include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" +#include "tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -44,37 +44,32 @@ using llvm_ir::IrArray; Status FusedIrEmitter::DefaultAction(const HloInstruction* hlo) { indexed_generators_[hlo] = [=](const IrArray::Index& index) -> StatusOr { - if (llvm::Value* generated_value = FindOrDefault( - generated_value_cache_[hlo], index.multidim(), nullptr)) { - llvm::BasicBlock* generated_value_bb = nullptr; - if (auto* generated_instruction = - llvm::dyn_cast(generated_value)) { - generated_value_bb = generated_instruction->getParent(); - } - // Ideally, we should be able to reuse the cached generated value if it - // dominates the current insertion block. However, the check for dominance - // can be expensive and unreliable when the function is being constructed. - // - // It's also worth experimenting what if we don't do caching at all. - // LLVM's CSE or GVN should be able to easily merge common subexpressions - // that would be regenerated without caching. But this might increase the - // JIT compilation time. - if (generated_value_bb == nullptr || - generated_value_bb == b_->GetInsertBlock()) { + auto cache = generated_value_cache_.find(hlo); + if (cache != generated_value_cache_.end()) { + auto key = std::make_pair(b_->GetInsertBlock(), index.multidim()); + if (llvm::Value* generated_value = + FindOrDefault(cache->second, key, nullptr)) { + VLOG(3) << "The cached generated value is reused."; + return generated_value; + } + auto null_key = std::make_pair(nullptr, index.multidim()); + if (llvm::Value* generated_value = + FindOrDefault(cache->second, null_key, nullptr)) { VLOG(3) << "The cached generated value is reused."; return generated_value; } - VLOG(3) << "The cached generated value can't be reused, because it is in " - "a different BB (" - << generated_value_bb->getName().str() - << ") from the current insertion block (" - << b_->GetInsertBlock()->getName().str() << ")."; } TF_ASSIGN_OR_RETURN(llvm::Value* const generated_value, elemental_emitter_->MakeElementGenerator( hlo, indexed_generators_)(index)); - generated_value_cache_[hlo][index.multidim()] = generated_value; + llvm::BasicBlock* generated_value_bb = nullptr; + if (auto* generated_instruction = + llvm::dyn_cast(generated_value)) { + generated_value_bb = generated_instruction->getParent(); + } + generated_value_cache_[hlo][std::make_pair( + generated_value_bb, index.multidim())] = generated_value; return generated_value; }; return Status::OK(); @@ -214,99 +209,15 @@ bool FusedIrEmitter::IsFusedIrEmitterInefficient( if (consumer->opcode() != HloOpcode::kFusion) { return false; } - // Collects for each instruction in the fusion node from which (indirect) - // users newly created index values are passed. Roughly speaking, we reuse - // index values if the shapes are equal when ignoring the element type (we may - // reuse also if the shape change is a bitcast, but we don't consider that - // here). By ignoring potential reuses our estimate whether the fusion emitter - // is inefficient is a bit more conservative than necessary. - absl::flat_hash_map> - indexing_users; - // Stores the number of different index accesses for each instruction in the - // fusion node. The fusion emitter caches access with the same index, so this - // value indicates how many times a specific instruction will be emitted. - absl::flat_hash_map index_usage_count; - index_usage_count[consumer] = 1; - - auto evaluate_fusion_computation = [&indexing_users, &index_usage_count]( - const HloInstruction* fusion) { - auto postorder = - fusion->fused_instructions_computation()->MakeInstructionPostOrder(); - std::reverse(postorder.begin(), postorder.end()); - for (const auto* instruction : postorder) { - if (instruction->opcode() == HloOpcode::kParameter) { - continue; - } - int64& total = index_usage_count[instruction]; - if (indexing_users[instruction].empty()) { - total = index_usage_count[fusion]; - } else { - total = 0; - for (const auto* user : indexing_users[instruction]) { - int64 weight = 1; - // Concatenate is special: the index differs for each operand, so - // in the worst case we have to deal with as many index values as - // the number of operands of Concatenate. By considering the worst - // case, we are more conservative than necessary regarding - // refusing to fuse. - if (user->opcode() == HloOpcode::kConcatenate) { - weight = user->operand_count(); - } - total += index_usage_count[user] * weight; - } - } - for (const auto* operand : instruction->operands()) { - // For simplicity we assume that all shape and layout changing - // operations except Transposes invalidate index reuse. Transposes are - // special: although they are shape changing, we can reuse the - // multi-dimensional index for the operand by permuting it. - if (instruction->opcode() == HloOpcode::kTranspose || - Shape::Equal().IgnoreElementType()(operand->shape(), - instruction->shape())) { - // If the index is reused, it means the operand gets index values - // from the same set of (indirect) users as 'instruction' itself. - indexing_users[operand].insert(indexing_users[instruction].begin(), - indexing_users[instruction].end()); - } else { - // If the index is not reused, it means 'instruction' computes a - // new index derived from the index it gets. - indexing_users[operand].insert(instruction); - } - } - } - }; - evaluate_fusion_computation(consumer); - - // Also account for the 'producer' if it would be fused. Find the operand it - // corresponds to. - for (int64 operand_num = 0; operand_num < consumer->operand_count(); - ++operand_num) { - if (consumer->operand(operand_num) == producer) { - auto instruction = consumer->fused_parameter(operand_num); - int64& total = index_usage_count[producer]; - total = 0; - for (const auto* user : indexing_users[instruction]) { - total += index_usage_count[user]; - } - break; - } + FusionNodeIndexingEvaluation eval_consumer(consumer); + if (producer->opcode() != HloOpcode::kFusion) { + return eval_consumer.CodeDuplicationTooHigh(producer); } - - // If 'producer' is a fusion node as well, also evaluate it. - if (producer->opcode() == HloOpcode::kFusion) { - evaluate_fusion_computation(producer); - } - - // Sum up the total number of emitted ops. - int64 total = 0; - for (const auto& entry : index_usage_count) { - total += entry.second; - } - - // Check that the code duplication has at most a factor of 15 (where 15 is an - // arbitrary constant that seems to work). - return total > 15 * index_usage_count.size(); + // If 'producer' is a fusion node as well, also evaluate it. Pass the + // evaluated duplication of the fusion node if it is merged into consumer. + FusionNodeIndexingEvaluation eval_producer( + producer, eval_consumer.EvaluateEmittedInstructions(producer)); + return eval_producer.MaxCodeDuplicationTooHigh(); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h index d13b0262180..e19e970cb24 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include "absl/container/flat_hash_map.h" #include "absl/types/optional.h" @@ -153,9 +154,10 @@ class FusedIrEmitter : public ConstDfsHloVisitorWithDefault { // Cache of generated values, lest we regenerate an element of a node with // multiple outgoing edges - absl::flat_hash_map< - const HloInstruction*, - absl::flat_hash_map, llvm::Value*>> + absl::flat_hash_map>, + llvm::Value*>> generated_value_cache_; }; diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc index b6b3b2dd8b3..9d7f06f4f68 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc @@ -130,8 +130,14 @@ IrArray::Index LoopEmitter::EmitDynamicIndex(ForLoopNest* loop_nest, } std::vector LoopEmitter::EmitIndexAndSetExitBasicBlock( - absl::string_view loop_name, llvm::Type* index_type) { + absl::string_view loop_name, llvm::Type* index_type, + llvm::Value* base_index) { CHECK_NE(index_type, nullptr); + CHECK_EQ(base_index, nullptr) + << "XLA CPU implementation of" + << " LoopEmitter::EmitIndexAndSetExitBasicBlock doesn't support" + << " base_index, but it was requested."; + if (ShapeUtil::IsScalar(shape_)) { // No loop needed, so set exit_bb_ to nullptr. exit_bb_ = nullptr; @@ -164,7 +170,8 @@ Status LoopEmitter::EmitLoop(absl::string_view loop_name, } for (const IrArray::Index& array_index : - EmitIndexAndSetExitBasicBlock(loop_name, index_type)) { + EmitIndexAndSetExitBasicBlock(loop_name, index_type, + /*base_index*/ nullptr)) { TF_RETURN_IF_ERROR(body_emitter_(array_index)); } diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h index 008205a642a..a356741f74b 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h @@ -71,11 +71,13 @@ class LoopEmitter { // specifies the element, will return multiple indices if the loop is // unrolled. std::vector EmitIndexAndSetExitBasicBlock() { - return EmitIndexAndSetExitBasicBlock(/*loop_name=*/"", b_->getInt64Ty()); + return EmitIndexAndSetExitBasicBlock(/*loop_name=*/"", b_->getInt64Ty(), + /*base_index*/ nullptr); } virtual std::vector EmitIndexAndSetExitBasicBlock( - absl::string_view loop_name, llvm::Type* index_type); + absl::string_view loop_name, llvm::Type* index_type, + llvm::Value* base_index); // Emits a complete loop nest for every element in the given shape. Status EmitLoop(absl::string_view loop_name = "", diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.cc b/tensorflow/compiler/xla/service/memory_space_assignment.cc index c53f2c19695..5b133a521e3 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment.cc @@ -29,6 +29,105 @@ const HeapSimulator::Chunk kDummyChunk{-1, -1}; // pow(kWhileExecutionCount, nesting_level) times. const int kWhileExecutionCount = 5; +bool LooksLikeAnActivation(const HloInstruction* inst) { + for (HloInstruction* user : inst->users()) { + switch (user->opcode()) { + case HloOpcode::kConvolution: + case HloOpcode::kDot: + if (user->operand(0) == inst) { + return true; + } + break; + case HloOpcode::kGather: + if (user->operand(1) == inst) { + return true; + } + break; + case HloOpcode::kFusion: + for (int i = 0; i < user->operand_count(); ++i) { + if (user->operand(i) == inst && + LooksLikeAnActivation(user->fused_parameter(i))) { + return true; + } + } + break; + case HloOpcode::kBitcast: + return LooksLikeAnActivation(user); + default: + return true; + } + } + return false; +} + +bool IsCrossProgramPrefetchCandidate( + const HloValue& value, const MemorySpaceAssignment::Options& options) { + return value.instruction()->parent() == + value.instruction()->GetModule()->entry_computation() && + value.instruction()->opcode() == HloOpcode::kParameter && + (!value.shape().has_layout() || + value.shape().layout().memory_space() != + options.alternate_memory_space) && + value.index().size() == 1 && value.shape().IsArray() && + !value.uses().empty() && + options.size_fn(value) <= options.max_size_in_bytes && + absl::c_all_of(value.uses(), [&](const HloUse& use) { + const HloInstruction* inst = + use.instruction->operand(use.operand_number); + + // Skip the LooksLikeAnActivation test since we're testing the + // parent GTE and its children below. + if (inst->opcode() == HloOpcode::kBitcast && + inst->operand(0)->opcode() == HloOpcode::kGetTupleElement && + inst->operand(0)->operand(0)->opcode() == + HloOpcode::kParameter) { + return true; + } + + return inst->opcode() == HloOpcode::kGetTupleElement && + !LooksLikeAnActivation(inst); + }); +} + +absl::optional +FindCrossProgramPrefetchCandidate( + const HloAliasAnalysis& alias_analysis, const HloLiveRange& hlo_live_range, + const MemorySpaceAssignment::Options& options) { + std::vector candidates; + for (const HloBuffer& buffer : alias_analysis.buffers()) { + CHECK_GE(buffer.values().size(), 1); + const HloValue* value = buffer.values().at(0); + if (IsCrossProgramPrefetchCandidate(*value, options)) { + MemorySpaceAssignment::BufferInterval interval; + interval.buffer = value; + interval.size = options.size_fn(*value); + interval.start = 0; + interval.end = hlo_live_range.schedule_end_time(); + interval.need_allocation = true; + interval.colocations = {++buffer.values().begin(), buffer.values().end()}; + candidates.emplace_back(interval); + } + } + + // The buffer_interval_compare ought to do a good job picking the most + // appropriate buffer to cross program prefetch, but empirically, it makes + // worse choices than just picking the largest buffer. + // TODO(b/152421603): Investigate. + auto size_compare = [](const auto& x, const auto& y) { + return x.size < y.size; + }; + auto& compare = options.default_cross_program_prefetch_heuristic && + options.buffer_interval_compare + ? *options.buffer_interval_compare + : size_compare; + + auto best_candidate = absl::c_max_element(candidates, compare); + if (best_candidate == candidates.end()) { + return absl::nullopt; + } + return *best_candidate; +} + } // namespace /*static*/ StatusOr> @@ -64,12 +163,16 @@ float MemorySpaceAssignmentCostAnalysis::GetAlternateMemoryBenefit( while_nest_multiplier = it->second; } else { while_nest_multiplier = tensorflow::MathUtil::IPow( - kWhileExecutionCount, CalculateWhileLoopNestLevel(&instruction)); + kWhileExecutionCount, + CalculateComputationNestLevel(&instruction, + /*while_only=*/true)); cache->while_nest_multiplier[&instruction] = while_nest_multiplier; } } else { while_nest_multiplier = tensorflow::MathUtil::IPow( - kWhileExecutionCount, CalculateWhileLoopNestLevel(&instruction)); + kWhileExecutionCount, + CalculateComputationNestLevel(&instruction, + /*while_only=*/true)); } return (elapsed_time_due_to_memory - elapsed_time_due_to_alternate_mem) * while_nest_multiplier; @@ -119,18 +222,14 @@ float MemorySpaceAssignmentCostAnalysis::GetMemoryBoundedness( } } - // Get performance slowdown in seconds of prefetching current BufferInterval - // causing to other BufferIntervals. - float alternate_mem_slowdown = - GetInstructionElapsedDueToMemorySlowdown(interval.size); - - // Divide by the size of the buffer to prioritize smaller buffers that will - // give the largest alternate memory benefit. - return (alternate_mem_benefit - alternate_mem_slowdown) / interval.size; + // Penalize larger buffers by dividing the benefit by the square root of the + // size. Empirically, we observed this resulted in better performance compared + // to dividing by the size. + return alternate_mem_benefit / std::sqrt(interval.size); } -int MemorySpaceAssignmentCostAnalysis::CalculateWhileLoopNestLevel( - const HloInstruction* instruction) const { +int MemorySpaceAssignmentCostAnalysis::CalculateComputationNestLevel( + const HloInstruction* instruction, bool while_only) const { int nest_level = 0; const HloComputation* computation = instruction->parent(); while (!computation->IsEntryComputation()) { @@ -138,7 +237,7 @@ int MemorySpaceAssignmentCostAnalysis::CalculateWhileLoopNestLevel( auto callsites = node.caller_callsites(); CHECK_EQ(callsites.size(), 1) << "The module is not flattened!"; auto callsite = callsites[0]; - if (callsite.instruction()->opcode() == HloOpcode::kWhile) { + if (!while_only || callsite.instruction()->opcode() == HloOpcode::kWhile) { ++nest_level; } computation = callsite.instruction()->parent(); @@ -284,6 +383,8 @@ CostAnalysisPrefetchIntervalPicker::CostAnalysisPrefetchIntervalPicker( float preferred_async_copy_to_overlap_ratio) : while_nest_level_( cost_analysis.hlo_live_range().instruction_schedule().size(), 0), + computation_nest_level_( + cost_analysis.hlo_live_range().instruction_schedule().size(), 0), cost_analysis_(cost_analysis), min_async_copy_to_overlap_ratio_(min_async_copy_to_overlap_ratio), max_async_copy_to_overlap_ratio_(max_async_copy_to_overlap_ratio), @@ -307,9 +408,12 @@ CostAnalysisPrefetchIntervalPicker::CostAnalysisPrefetchIntervalPicker( instructions_elapsed_time.resize(logical_time + 1, 0.0); while_nest_level_.resize(logical_time + 1, 0); } - int nest_level = cost_analysis_.CalculateWhileLoopNestLevel( - instruction_and_logical_time.first); - while_nest_level_[logical_time] = nest_level; + int while_nest_level = cost_analysis_.CalculateComputationNestLevel( + instruction_and_logical_time.first, /*while_only=*/true); + while_nest_level_[logical_time] = while_nest_level; + int computation_nest_level = cost_analysis_.CalculateComputationNestLevel( + instruction_and_logical_time.first, /*while_only=*/false); + computation_nest_level_[logical_time] = computation_nest_level; if (instruction->opcode() == HloOpcode::kWhile || instruction->opcode() == HloOpcode::kConditional) { continue; @@ -317,8 +421,8 @@ CostAnalysisPrefetchIntervalPicker::CostAnalysisPrefetchIntervalPicker( float elapsed_time = cost_analysis_.GetInstructionElapsed( *instruction_and_logical_time.first); instructions_elapsed_time[logical_time] = - elapsed_time * - tensorflow::MathUtil::IPow(kWhileExecutionCount, nest_level); + elapsed_time * tensorflow::MathUtil::IPow(kWhileExecutionCount, + while_nest_level); } // As an optimization, create a cumulative sum vector of elapsed time. float cumsum = 0.0; @@ -388,14 +492,14 @@ int64 CostAnalysisPrefetchIntervalPicker::LatestPrefetchStartTime( /*output_in_alternate_mem=*/false); inst_elapsed_reduction = elapsed_time - elapsed_time_in_alternate_mem; } - int end_nest_level = while_nest_level_[end_time]; + int end_nest_level = computation_nest_level_[end_time]; // Find the latest time we're allowed to start prefetching. float min_interval = min_async_copy_to_overlap_ratio_ * async_copy_elapsed; int latest_prefetch_time; for (latest_prefetch_time = end_time - 1; latest_prefetch_time >= start_time && - (while_nest_level_[latest_prefetch_time] != end_nest_level || + (computation_nest_level_[latest_prefetch_time] != end_nest_level || min_interval > GetLogicalIntervalElapsed(latest_prefetch_time, end_time) + inst_elapsed_reduction); @@ -416,13 +520,13 @@ int64 CostAnalysisPrefetchIntervalPicker::PreferredPrefetchStartTime( preferred_async_copy_to_overlap_ratio_ * async_copy_elapsed; float best_interval = GetLogicalIntervalElapsed(earliest_prefetch_start_time, prefetch_end_time); - int end_nest_level = while_nest_level_[prefetch_end_time]; + int end_nest_level = computation_nest_level_[prefetch_end_time]; for (int64 prefetch_start_time = earliest_prefetch_start_time + 1; prefetch_start_time <= latest_prefetch_start_time; ++prefetch_start_time) { float interval = GetLogicalIntervalElapsed(prefetch_start_time, prefetch_end_time); - if (while_nest_level_[prefetch_start_time] == end_nest_level && + if (computation_nest_level_[prefetch_start_time] == end_nest_level && std::abs(preferred_interval - interval) < std::abs(preferred_interval - best_interval)) { best_interval = interval; @@ -436,10 +540,11 @@ int64 CostAnalysisPrefetchIntervalPicker::LatestPrefetchEndTime( int64 original_prefetch_end_time, int64 proposed_prefetch_end_time) const { // Iterate towards the beginning until we find a suitable end time that is the // same while nest level as the original prefetch end time. - int64 original_nest_level = while_nest_level_[original_prefetch_end_time]; + int64 original_nest_level = + computation_nest_level_[original_prefetch_end_time]; int64 new_prefetch_end_time; for (new_prefetch_end_time = proposed_prefetch_end_time; - while_nest_level_[new_prefetch_end_time] != original_nest_level; + computation_nest_level_[new_prefetch_end_time] != original_nest_level; --new_prefetch_end_time) { } return new_prefetch_end_time; @@ -460,7 +565,7 @@ void CostAnalysisPrefetchIntervalPicker::Begin(const HloUse& use, /*output_in_alternate_mem=*/false); inst_elapsed_reduction_ = elapsed_time - elapsed_time_in_alternate_mem; end_logical_time_ = end_time; - int end_nest_level = while_nest_level_[end_logical_time_]; + int end_nest_level = computation_nest_level_[end_logical_time_]; // Find the latest time we're allowed to start prefetching. float min_interval = min_async_copy_to_overlap_ratio_ * async_copy_elapsed_; @@ -472,7 +577,7 @@ void CostAnalysisPrefetchIntervalPicker::Begin(const HloUse& use, max_overlap_multiplier_ * async_copy_elapsed_; for (earliest_prefetch_time_ = start_time; earliest_prefetch_time_ <= end_logical_time_ && - (while_nest_level_[earliest_prefetch_time_] != end_nest_level || + (computation_nest_level_[earliest_prefetch_time_] != end_nest_level || max_interval < GetLogicalIntervalElapsed(earliest_prefetch_time_, end_logical_time_)); ++earliest_prefetch_time_) { @@ -510,8 +615,8 @@ int64 CostAnalysisPrefetchIntervalPicker::Next() { if (using_increasing_prefetch_time_iterator_) { int64 prefetch_time = increasing_prefetch_time_iterator_++; while (increasing_prefetch_time_iterator_ <= latest_prefetch_time_ && - while_nest_level_[increasing_prefetch_time_iterator_] != - while_nest_level_[end_logical_time_]) { + computation_nest_level_[increasing_prefetch_time_iterator_] != + computation_nest_level_[end_logical_time_]) { ++increasing_prefetch_time_iterator_; } if (decreasing_prefetch_time_iterator_ >= earliest_prefetch_time_) { @@ -521,8 +626,8 @@ int64 CostAnalysisPrefetchIntervalPicker::Next() { } else { int64 prefetch_time = decreasing_prefetch_time_iterator_--; while (decreasing_prefetch_time_iterator_ >= earliest_prefetch_time_ && - while_nest_level_[decreasing_prefetch_time_iterator_] != - while_nest_level_[end_logical_time_]) { + computation_nest_level_[decreasing_prefetch_time_iterator_] != + computation_nest_level_[end_logical_time_]) { --decreasing_prefetch_time_iterator_; } if (increasing_prefetch_time_iterator_ <= latest_prefetch_time_) { @@ -566,11 +671,11 @@ float CostAnalysisPrefetchIntervalPicker::GetLogicalIntervalElapsed( // Since elapsed_time_cumsum_ is already weighed by the while loop nesting // level, normalize the elapsed time by dividing with the nesting factor of // the interval (start and end times). - int interval_nest_level = GetMinWhileNestLevel(start_time, end_time); + int interval_while_nest_level = GetMinWhileNestLevel(start_time, end_time); return (elapsed_time_cumsum_[end_time - 1] - elapsed_time_cumsum_[start_time]) / tensorflow::MathUtil::IPow(kWhileExecutionCount, - interval_nest_level); + interval_while_nest_level); } std::string CostAnalysisPrefetchIntervalPicker::ToDebugString() const { @@ -713,12 +818,13 @@ void AlternateMemoryBestFitHeap::CreateAllocationValues( } void AlternateMemoryBestFitHeap::FindAliases( - std::vector* allocation_values) const { + std::vector* allocation_values, + bool skip_values_with_no_uses) const { absl::flat_hash_map values_by_defining_inst; for (AllocationValue& value : *allocation_values) { // Skip the value if it doesn't have any uses. - if (value.uses().empty()) { + if (value.uses().empty() && skip_values_with_no_uses) { continue; } CHECK_EQ(values_by_defining_inst.count(value.defining_instruction()), 0); @@ -985,6 +1091,17 @@ void AlternateMemoryBestFitHeap::DumpDebugStringsIfEnabled() const { } HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { + if (options_.enable_cross_program_prefetch) { + absl::optional + prefetch_candidate = FindCrossProgramPrefetchCandidate( + alias_analysis_, hlo_live_range_, options_); + if (prefetch_candidate) { + HloModule* module = + prefetch_candidate->buffer->instruction()->GetModule(); + AllocateCrossProgramPrefetchBuffer(module, prefetch_candidate); + } + } + std::vector sorted_buffer_intervals = GetSortedBufferIntervals(); @@ -1036,6 +1153,12 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { continue; } + if (interval.size > available_heap_size()) { + VLOG(3) << "Skip " << interval.buffer->ToShortString() + << " because the buffer is larger than the heap size."; + continue; + } + auto colocated_intervals = GetSortedColocatedIntervals(interval); if (AreIntervalsReservedInAlternateMemory(colocated_intervals)) { @@ -1084,6 +1207,7 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { bool repacked = false; for (int retry_number = 0; retry_number < options_.max_retries; retry_number++) { + AddRequiredAssignmentsForColocatedIntervals(colocated_intervals); bool final_retry = (retry_number == options_.max_retries - 1); options_.prefetch_interval_picker->SetRetryNumber(retry_number); Result result = @@ -1094,7 +1218,8 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { (!final_retry && result_failed_because_of_async_copy(result))) { UncommitPendingChunks(absl::MakeSpan(allocation_values)); VLOG(2) << "Couldn't allocate. Retry number " << retry_number; - } else if (result_is(result, Result::kFailOutOfMemory) && + } else if ((result_is(result, Result::kFailOutOfMemory) || + options_.repack_after_every_allocation) && num_repacks_ < options_.max_repacks && !repacked) { UncommitPendingChunks(absl::MakeSpan(allocation_values)); ++num_repacks_; @@ -1128,10 +1253,9 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { return result_; } -void AlternateMemoryBestFitHeap::CreateAllocationValuesFromColocatedIntervals( +void AlternateMemoryBestFitHeap::AddRequiredAssignmentsForColocatedIntervals( absl::Span - colocated_intervals, - std::vector& allocation_values) { + colocated_intervals) { // TODO(berkin): For now, place the phi values due to conditionals in // default memory. for (const BufferInterval* colocated_interval : colocated_intervals) { @@ -1150,12 +1274,17 @@ void AlternateMemoryBestFitHeap::CreateAllocationValuesFromColocatedIntervals( } } } +} +void AlternateMemoryBestFitHeap::CreateAllocationValuesFromColocatedIntervals( + absl::Span + colocated_intervals, + std::vector& allocation_values) { // Create AllocationValues for all the colocated intervals. for (const auto& colocated_interval : colocated_intervals) { CreateAllocationValues(*colocated_interval, allocation_values); } - FindAliases(&allocation_values); + FindAliases(&allocation_values, /*skip_values_with_no_uses=*/true); } AlternateMemoryBestFitHeap::Result @@ -1166,7 +1295,7 @@ AlternateMemoryBestFitHeap::AllocateAllocationValues( // Data structure to contain the preferred offset for a given computation. // We ensure that the same offset will be allocated outside the while loop // as well as inside the while loop. - absl::flat_hash_map + absl::flat_hash_map preferred_offset_for_computation; Result result = Result::kSuccess; @@ -1174,7 +1303,7 @@ AlternateMemoryBestFitHeap::AllocateAllocationValues( int64 definition_time = instruction_schedule.at(allocation_value.defining_instruction()); - absl::optional preferred_offset; + AliasedOffset* preferred_offset = nullptr; auto preferred_offset_it = preferred_offset_for_computation.find(allocation_value.computation()); if (preferred_offset_it != preferred_offset_for_computation.end()) { @@ -1273,10 +1402,13 @@ AlternateMemoryBestFitHeap::AllocateAllocationValues( } } - // Bitcasts don't define buffers and don't directly consume buffers. Skip - // allocating buffers for bitcast uses. The uses that feed from bitcasts - // will be handled specially. - if (hlo_use.instruction->opcode() != HloOpcode::kBitcast) { + // Bitcasts don't define buffers and don't directly consume buffers. Skip + // allocating buffers for bitcast uses (unless they are the root + // instruction). The uses that feed from bitcasts will be handled + // specially. + if (hlo_use.instruction->opcode() != HloOpcode::kBitcast || + hlo_use.instruction == + hlo_use.instruction->parent()->root_instruction()) { AllocationRequest request; // Rarely, (e.g., when conditional true and false parameters are the // same), definition time can be the time of the conditional and use @@ -1319,7 +1451,7 @@ AlternateMemoryBestFitHeap::AllocateAllocationValues( if (hlo_use.instruction->opcode() == HloOpcode::kWhile && aliased_allocation->memory_space() == MemorySpace::kAlternate) { preferred_offset_for_computation[hlo_use.instruction->while_body()] = - aliased_allocation->chunk().offset; + GetAliasedOffset(*aliased_allocation); } } } @@ -1360,6 +1492,28 @@ absl::optional AsynchronousCopyOrdering::ViolatesOrdering( return absl::nullopt; } +AlternateMemoryBestFitHeap::AliasedOffset* +AlternateMemoryBestFitHeap::GetAliasedOffset( + const MemorySpaceAssignment::Allocation& allocation) { + auto aliased_offset_it = aliased_offset_map_.find(&allocation); + CHECK(aliased_offset_it != aliased_offset_map_.end()); + return aliased_offset_it->second; +} + +void AlternateMemoryBestFitHeap::CreateOrAddToAliasedOffset( + const MemorySpaceAssignment::Allocation& allocation, + AlternateMemoryBestFitHeap::AliasedOffset* aliased_offset) { + CHECK(allocation.memory_space() == MemorySpace::kAlternate); + CHECK(!aliased_offset_map_.contains(&allocation)); + if (!aliased_offset) { + aliased_offsets_.push_back({allocation.chunk().offset}); + aliased_offset = &aliased_offsets_.back(); + } + CHECK_EQ(allocation.chunk().offset, aliased_offset->offset); + CHECK(aliased_offset->allocations.insert(&allocation).second); + aliased_offset_map_[&allocation] = aliased_offset; +} + /*static*/ MemorySpaceAssignment::Allocation* AlternateMemoryBestFitHeap::GetLiveAllocationAt( const MemorySpaceAssignment::AllocationSequence& allocations, int64 time) { @@ -1435,10 +1589,11 @@ void AlternateMemoryBestFitHeap::AllocateCrossProgramPrefetchBuffer( AddAsyncCopy(*allocations.back(), MemorySpace::kAlternate, chunk_candidate.chunk, prefetch_candidate->start, cross_program_prefetch_end_time, latest_prefetch_time, - &allocations, + &allocations, /*aliased_offset=*/nullptr, /*is_cross_program_prefetch=*/true); absl::c_for_each(uses, [&](auto& use) { allocations.back()->AddUse(use); }); - int64 cross_program_prefetch_offset = allocations.back()->chunk().offset; + AliasedOffset* cross_program_prefetch_offset = + GetAliasedOffset(*allocations.back()); if (free_buffer) { VLOG(2) << "Adding an end-of-program prefetch for freed " @@ -1446,8 +1601,10 @@ void AlternateMemoryBestFitHeap::AllocateCrossProgramPrefetchBuffer( AddAsyncCopy(*allocations.front(), MemorySpace::kAlternate, chunk_candidate.chunk, end_of_program_prefetch_start_time, end_of_program_prefetch_end_time, - end_of_program_prefetch_end_time, &allocations); - CHECK_EQ(cross_program_prefetch_offset, allocations.back()->chunk().offset); + end_of_program_prefetch_end_time, &allocations, + cross_program_prefetch_offset); + CHECK_EQ(cross_program_prefetch_offset->offset, + allocations.back()->chunk().offset); } for (auto& allocation : allocations) { @@ -1477,7 +1634,7 @@ void AlternateMemoryBestFitHeap::AllocateCrossProgramPrefetchBuffer( ClearPendingChunks(); } -absl::optional +absl::optional AlternateMemoryBestFitHeap::RequiredMemoryAssignmentAt(const HloValue* buffer, int64 time) const { auto required_assignment_it = required_assignments_.find(buffer); @@ -1495,7 +1652,7 @@ AlternateMemoryBestFitHeap::RequiredMemoryAssignmentAt(const HloValue* buffer, return required_assignment_at_time; } -absl::optional +absl::optional AlternateMemoryBestFitHeap::AliasedRequiredAssignmentForUse( const AllocationValue::Use& use) const { absl::optional required_assignment; @@ -1521,26 +1678,26 @@ AlternateMemoryBestFitHeap::AliasedRequiredAssignmentForUse( void AlternateMemoryBestFitHeap::AddAliasedRequiredAssignment( const HloInstruction* instruction, ShapeIndex index, const MemorySpaceAssignment::Allocation* aliased_allocation) { - absl::optional chunk; + AliasedOffset* offset = nullptr; if (aliased_allocation->memory_space() == MemorySpace::kAlternate) { - chunk = aliased_allocation->chunk(); + offset = GetAliasedOffset(*aliased_allocation); } AddRequiredAssignment(instruction, index, aliased_allocation->memory_space(), - chunk); + offset); } void AlternateMemoryBestFitHeap::AddRequiredAssignment( const HloValue* value, const HloInstruction* instruction, MemorySpaceAssignment::MemorySpace memory_space, int64 time, - absl::optional chunk) { + AliasedOffset* offset) { // Check for existing required assignment at this time and make sure it is the // same as this if there is one. auto existing_required_assignment = RequiredMemoryAssignmentAt(value, time); if (existing_required_assignment) { CHECK(memory_space == existing_required_assignment->memory_space) << "inst = " << instruction->ToString() << " at " << time; - CHECK((!chunk && !existing_required_assignment->chunk) || - chunk->offset == existing_required_assignment->chunk->offset); + CHECK((!offset && !existing_required_assignment->offset) || + offset == existing_required_assignment->offset); VLOG(3) << "Not adding required assignment because there is one already: " << value->ToShortString() << " at " << time << " at " << (memory_space == MemorySpace::kDefault ? "def" : "alt"); @@ -1548,7 +1705,7 @@ void AlternateMemoryBestFitHeap::AddRequiredAssignment( VLOG(3) << "Adding required assignment: " << value->ToShortString() << " at " << time << " at " << (memory_space == MemorySpace::kDefault ? "def" : "alt"); - RequiredMemoryAssignment required_assignment{memory_space, time, chunk}; + RequiredMemoryAssignment required_assignment{memory_space, time, offset}; required_assignments_[value].push_back(required_assignment); pending_required_assignments_.push_back({value, required_assignment}); } @@ -1556,13 +1713,13 @@ void AlternateMemoryBestFitHeap::AddRequiredAssignment( void AlternateMemoryBestFitHeap::AddRequiredAssignment( const HloInstruction* instruction, ShapeIndex index, - MemorySpace memory_space, absl::optional chunk) { + MemorySpace memory_space, AliasedOffset* offset) { const HloValue* value = &alias_analysis_.dataflow_analysis().GetUniqueValueAt(instruction, index); int64 instruction_time = hlo_live_range_.instruction_schedule().at(instruction); AddRequiredAssignment(value, instruction, memory_space, instruction_time, - chunk); + offset); } void AlternateMemoryBestFitHeap::AddInputAndOutputRequiredAssignments() { @@ -1711,8 +1868,8 @@ void AlternateMemoryBestFitHeap::UncommitPendingChunks( ? "def" : "alt") << " time = " << required_assignment.time << " off = " - << (required_assignment.chunk ? required_assignment.chunk->offset - : -1); + << (required_assignment.offset ? required_assignment.offset->offset + : -1); for (auto it = required_assignment_vector.begin(); it != required_assignment_vector.end(); ++it) { if (*it == value_and_required_assignment.second) { @@ -1726,7 +1883,8 @@ void AlternateMemoryBestFitHeap::UncommitPendingChunks( void AlternateMemoryBestFitHeap::FinalizeAllocations( absl::Span allocation_values) { - absl::flat_hash_map> + absl::flat_hash_map> colocation_map; for (AllocationValue& allocation_value : allocation_values) { for (auto& allocation : *allocation_value.allocation_sequence()) { @@ -1736,12 +1894,12 @@ void AlternateMemoryBestFitHeap::FinalizeAllocations( MemorySpaceAssignment::Allocation* inserted_allocation = allocations_->back().get(); if (inserted_allocation->memory_space() == MemorySpace::kAlternate) { - colocation_map[inserted_allocation->chunk().offset].push_back( + colocation_map[GetAliasedOffset(*inserted_allocation)].push_back( inserted_allocation); } } } - // Assume allocations that received the same offset need to be colocated. + // The allocations that have the same AliasedOffset need to be colocated. // Export these to repack_allocation_blocks_ so that we can repack them to // reduce fragmentation. for (auto& colocation : colocation_map) { @@ -1768,6 +1926,8 @@ void AlternateMemoryBestFitHeap::ClearPendingChunks() { pending_chunks_.clear(); pending_async_copies_.clear(); pending_required_assignments_.clear(); + aliased_offset_map_.clear(); + aliased_offsets_.clear(); } void AlternateMemoryBestFitHeap::AddToPendingChunks( @@ -1843,15 +2003,25 @@ AlternateMemoryBestFitHeap::Result AlternateMemoryBestFitHeap::AllocateSegment( const auto& prev_allocation = allocation_sequence->back(); CHECK(prev_allocation->memory_space() == required_assignment_at_start->memory_space); - CHECK_EQ(prev_allocation->chunk().offset, - required_assignment_at_start->chunk->offset); + CHECK_EQ(GetAliasedOffset(*prev_allocation), + required_assignment_at_start->offset); prev_allocation->Extend(request.start_time); } else { + absl::optional aliased_chunk = absl::nullopt; + if (required_assignment_at_start->memory_space == + MemorySpace::kAlternate) { + aliased_chunk = + Chunk{required_assignment_at_start->offset->offset, request.size}; + } allocation_sequence->push_back( absl::make_unique( defining_position, required_assignment_at_start->memory_space, - required_assignment_at_start->chunk, request.start_time, - request.start_time)); + aliased_chunk, request.start_time, request.start_time)); + if (required_assignment_at_start->memory_space == + MemorySpace::kAlternate) { + CreateOrAddToAliasedOffset(*allocation_sequence->back(), + required_assignment_at_start->offset); + } } } @@ -1935,7 +2105,7 @@ void AlternateMemoryBestFitHeap::AddAsyncCopy( MemorySpace memory_space, absl::optional chunk, int64 start_time, int64 end_time, int64 copy_done_schedule_before_time, MemorySpaceAssignment::AllocationSequence* allocations, - bool is_cross_program_prefetch) { + AliasedOffset* aliased_offset, bool is_cross_program_prefetch) { VLOG(3) << "Copy to " << (memory_space == MemorySpaceAssignment::MemorySpace::kDefault ? "default" @@ -1957,6 +2127,7 @@ void AlternateMemoryBestFitHeap::AddAsyncCopy( prefetch_interval_tree_.Add(start_time, copy_done_schedule_before_time, kDummyChunk); async_copy_ordering_.AddCopy(pending_async_copies_.back()); + CreateOrAddToAliasedOffset(*allocations->back(), aliased_offset); } else { eviction_interval_tree_.Add(start_time, copy_done_schedule_before_time, kDummyChunk); @@ -2033,9 +2204,9 @@ AlternateMemoryBestFitHeap::AllocateInAlternateMemoryNoCopy( alternate_mem_interval.start = request.start_time; // Prefer the offset that was previously used for the previous allocation. - absl::optional preferred_offset; + AliasedOffset* preferred_offset = nullptr; if (prev_allocation != nullptr) { - preferred_offset = prev_allocation->chunk().offset; + preferred_offset = GetAliasedOffset(*prev_allocation); // If there is a previous allocation, set the start time one after the end // of the previous allocation's end. alternate_mem_interval.start = prev_allocation->end_time() + 1; @@ -2045,13 +2216,13 @@ AlternateMemoryBestFitHeap::AllocateInAlternateMemoryNoCopy( // Sanity check that if there is a preferred offset provided in the request, // it matches with the previous allocation. CHECK(!preferred_offset || request.preferred_offset == preferred_offset) - << "preferred_offset = " << *preferred_offset - << ", request.preferred_offset = " << *request.preferred_offset; + << "preferred_offset = " << preferred_offset->offset + << ", request.preferred_offset = " << request.preferred_offset->offset; preferred_offset = request.preferred_offset; } VLOG(3) << "We can eliminate copy to alternate memory. Preferred offset = " - << (preferred_offset ? *preferred_offset : -1); + << (preferred_offset ? preferred_offset->offset : -1); // In case there are additional uses after this use, we rely on the last use // time to try to reserve a chunk in the heap simulator. This is to prevent // the following scenario: @@ -2099,6 +2270,9 @@ AlternateMemoryBestFitHeap::AllocateInAlternateMemoryNoCopy( absl::make_unique( defining_position, MemorySpace::kAlternate, chunk_candidate->chunk, request.start_time, request.end_time)); + CreateOrAddToAliasedOffset( + *request.allocation_value->allocation_sequence()->back(), + preferred_offset); } request.allocation_value->allocation_sequence()->back()->AddUse( request.use->hlo_use); @@ -2162,7 +2336,8 @@ AlternateMemoryBestFitHeap::Result AlternateMemoryBestFitHeap::Evict( AddAsyncCopy(*prev_allocation, MemorySpace::kDefault, /*chunk=*/absl::nullopt, eviction_start_time, prev_allocation->end_time(), eviction_end_time, - request.allocation_value->allocation_sequence()); + request.allocation_value->allocation_sequence(), + /*aliased_offset=*/nullptr); } else { if (eviction_violates_outstanding_copies) { VLOG(3) << "This violates the maximum async copies."; @@ -2180,7 +2355,8 @@ AlternateMemoryBestFitHeap::Result AlternateMemoryBestFitHeap::Evict( VLOG(3) << "Eviction successful."; AddAsyncCopy(*prev_allocation, MemorySpace::kDefault, /*chunk=*/absl::nullopt, time, time + 1, time + 1, - request.allocation_value->allocation_sequence()); + request.allocation_value->allocation_sequence(), + /*aliased_offset=*/nullptr); eviction_scheduled = true; break; } @@ -2332,7 +2508,8 @@ AlternateMemoryBestFitHeap::Result AlternateMemoryBestFitHeap::Prefetch( AddAsyncCopy(prev_allocation_in_default_mem, MemorySpace::kAlternate, chunk_candidate->chunk, alternate_mem_interval.start, request.end_time, prefetch_end_time, - request.allocation_value->allocation_sequence()); + request.allocation_value->allocation_sequence(), + request.preferred_offset); request.allocation_value->allocation_sequence()->back()->AddUse( request.use->hlo_use); @@ -2351,7 +2528,7 @@ AlternateMemoryBestFitHeap::Result AlternateMemoryBestFitHeap::Prefetch( absl::optional AlternateMemoryBestFitHeap::FindBestChunkCandidate( - const AllocationRequest& request, absl::optional preferred_offset, + const AllocationRequest& request, const AliasedOffset* preferred_offset, BufferInterval* alternate_mem_interval) const { int64 end_time = request.end_time; if (!preferred_offset) { @@ -2397,8 +2574,8 @@ AlternateMemoryBestFitHeap::FindBestChunkCandidate( // only. alternate_mem_interval->end = end_time; ChunkCandidate chunk_candidate = - FindChunkCandidate(*alternate_mem_interval, *preferred_offset); - if (chunk_candidate.chunk.offset == *preferred_offset) { + FindChunkCandidate(*alternate_mem_interval, preferred_offset->offset); + if (chunk_candidate.chunk.offset == preferred_offset->offset) { return chunk_candidate; } return absl::nullopt; @@ -2457,107 +2634,6 @@ MemorySpaceAssignment::GetMemoryBoundednessBufferIntervalCompare( }; } -namespace { - -bool LooksLikeAnActivation(const HloInstruction* inst) { - for (HloInstruction* user : inst->users()) { - switch (user->opcode()) { - case HloOpcode::kConvolution: - case HloOpcode::kDot: - if (user->operand(0) == inst) { - return true; - } - break; - case HloOpcode::kGather: - if (user->operand(1) == inst) { - return true; - } - break; - case HloOpcode::kFusion: - for (int i = 0; i < user->operand_count(); ++i) { - if (user->operand(i) == inst && - LooksLikeAnActivation(user->fused_parameter(i))) { - return true; - } - } - break; - case HloOpcode::kBitcast: - return LooksLikeAnActivation(user); - default: - return true; - } - } - return false; -} - -bool IsCrossProgramPrefetchCandidate( - const HloValue& value, const MemorySpaceAssignment::Options& options) { - return value.instruction()->parent() == - value.instruction()->GetModule()->entry_computation() && - value.instruction()->opcode() == HloOpcode::kParameter && - (!value.shape().has_layout() || - value.shape().layout().memory_space() != - options.alternate_memory_space) && - value.index().size() == 1 && value.shape().IsArray() && - !value.uses().empty() && - options.size_fn(value) <= options.max_size_in_bytes && - absl::c_all_of(value.uses(), [&](const HloUse& use) { - const HloInstruction* inst = - use.instruction->operand(use.operand_number); - - // Skip the LooksLikeAnActivation test since we're testing the - // parent GTE and its children below. - if (inst->opcode() == HloOpcode::kBitcast && - inst->operand(0)->opcode() == HloOpcode::kGetTupleElement && - inst->operand(0)->operand(0)->opcode() == - HloOpcode::kParameter) { - return true; - } - - return inst->opcode() == HloOpcode::kGetTupleElement && - !LooksLikeAnActivation(inst); - }); -} - -absl::optional -FindCrossProgramPrefetchCandidate( - const HloAliasAnalysis& alias_analysis, const HloLiveRange& hlo_live_range, - const MemorySpaceAssignment::Options& options) { - std::vector candidates; - for (const HloBuffer& buffer : alias_analysis.buffers()) { - CHECK_GE(buffer.values().size(), 1); - const HloValue* value = buffer.values().at(0); - if (IsCrossProgramPrefetchCandidate(*value, options)) { - MemorySpaceAssignment::BufferInterval interval; - interval.buffer = value; - interval.size = options.size_fn(*value); - interval.start = 0; - interval.end = hlo_live_range.schedule_end_time(); - interval.need_allocation = true; - interval.colocations = {++buffer.values().begin(), buffer.values().end()}; - candidates.emplace_back(interval); - } - } - - // The buffer_interval_compare ought to do a good job picking the most - // appropriate buffer to cross program prefetch, but empirically, it makes - // worse choices than just picking the largest buffer. - // TODO(b/152421603): Investigate. - auto size_compare = [](const auto& x, const auto& y) { - return x.size < y.size; - }; - auto& compare = options.default_cross_program_prefetch_heuristic && - options.buffer_interval_compare - ? *options.buffer_interval_compare - : size_compare; - - auto best_candidate = absl::c_max_element(candidates, compare); - if (best_candidate == candidates.end()) { - return absl::nullopt; - } - return *best_candidate; -} -} // namespace /*static*/ StatusOr> MemorySpaceAssignment::Run(HloModule* module, @@ -2608,13 +2684,6 @@ Status MemorySpaceAssignment::FindAllocationSequence( auto algorithm = absl::make_unique( &allocations_, options_, alias_analysis, hlo_live_range); - if (options_.enable_cross_program_prefetch) { - absl::optional - prefetch_candiate = FindCrossProgramPrefetchCandidate( - alias_analysis, hlo_live_range, options_); - algorithm->AllocateCrossProgramPrefetchBuffer(module_, prefetch_candiate); - } - HeapSimulator::Options heap_simulator_options; heap_simulator_options.may_reuse_operand_buffers = false; TF_RETURN_IF_ERROR(HeapSimulator::Run(std::move(algorithm), *module_, @@ -2747,15 +2816,21 @@ HloInstruction* MemorySpaceAssignment::Allocation::AddGetTupleElements() { } std::string MemorySpaceAssignment::Allocation::ToString() const { - return absl::StrCat("Allocation in ", - memory_space_ == MemorySpace::kDefault ? "def" : "alt", - " defined at ", defining_position_.ToString()); + std::string memory_space_str = "def"; + if (memory_space_ == MemorySpace::kAlternate) { + memory_space_str = absl::StrCat("alt (off: ", chunk_->offset, ")"); + } + return absl::StrCat("Allocation in ", memory_space_str, " defined at ", + defining_position_.ToString()); } std::string MemorySpaceAssignment::CopyAllocation::ToString() const { - return absl::StrCat("Copy Allocation in ", - memory_space_ == MemorySpace::kDefault ? "def" : "alt", - " from ", prev_allocation_.ToString()); + std::string memory_space_str = "def"; + if (memory_space_ == MemorySpace::kAlternate) { + memory_space_str = absl::StrCat("alt (off: ", chunk_->offset, ")"); + } + return absl::StrCat("Copy Allocation in ", memory_space_str, " from ", + prev_allocation_.ToString()); } Status MemorySpaceAssignment::CopyAllocation::Process( @@ -3285,6 +3360,7 @@ Status MemorySpaceAssignment::VerifyAndExportHeapSimulatorTrace() { last_use_instruction, parameter_time, last_use_time, absl::StrCat(indent_string, " "))); } else { + last_use_time = std::min(last_use_time, end_time); TF_RETURN_IF_ERROR(add_allocation_and_verify( parameter_time, last_use_time, chunk, value)); } @@ -3303,12 +3379,13 @@ Status MemorySpaceAssignment::VerifyAndExportHeapSimulatorTrace() { TF_RETURN_IF_ERROR(split_conditional_buffer( last_use_instruction, time_bound.start, time_bound.end, " ")); } else if (!value->uses().empty()) { + last_use_time = std::min(last_use_time, time_bound.end); VLOG(3) << " buffer: " << buffer.ToString() << " value: " << value->ToShortString() << ": (" - << time_bound.start << ", " << time_bound.end + << time_bound.start << ", " << last_use_time << ") off: " << chunk.offset << ", size: " << chunk.size; TF_RETURN_IF_ERROR(add_allocation_and_verify( - time_bound.start, time_bound.end, chunk, value)); + time_bound.start, last_use_time, chunk, value)); } } } diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.h b/tensorflow/compiler/xla/service/memory_space_assignment.h index 04737663424..cb459c68be1 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.h +++ b/tensorflow/compiler/xla/service/memory_space_assignment.h @@ -149,9 +149,11 @@ class MemorySpaceAssignmentCostAnalysis { int64 GetScheduleEndTime() const; - // Returns the number of nested while loop levels this instruction resides in. - // 0 means it is not in a while loop. - int CalculateWhileLoopNestLevel(const HloInstruction* instruction) const; + // Returns the number of nested computation levels this instruction resides + // in. If while_only is true, it returns the while loop nest level and 0 + // means the instruction is not in a while loop. + int CalculateComputationNestLevel(const HloInstruction* instruction, + bool while_only) const; const HloLiveRange& hlo_live_range() const { return *hlo_live_range_; } @@ -360,6 +362,7 @@ class CostAnalysisPrefetchIntervalPicker : public PrefetchIntervalPicker { // (in cumulative sum) and while nesting level. std::vector elapsed_time_cumsum_; std::vector while_nest_level_; + std::vector computation_nest_level_; // Maintain the index of the most recent (before this instruction) nest level // change in order to efficiently determine the minimum nest level in an // interval. @@ -376,7 +379,7 @@ class CostAnalysisPrefetchIntervalPicker : public PrefetchIntervalPicker { int64 end_logical_time_; int64 earliest_prefetch_time_; int64 latest_prefetch_time_; - bool using_increasing_prefetch_time_iterator_; + bool using_increasing_prefetch_time_iterator_ = true; int64 increasing_prefetch_time_iterator_; int64 decreasing_prefetch_time_iterator_; }; @@ -459,6 +462,9 @@ class MemorySpaceAssignment { // max_repacks is greater than 0. MemorySpaceAssignmentRepacker* repacker = nullptr; + // This is only useful for testing, repack after every allocation. + bool repack_after_every_allocation = false; + // If true, tries allocating buffers across (e.g., before and inside a while // loop body) sequential calls (kWhile, kCall, and kConditional). bool allocate_across_sequential_calls = false; @@ -728,6 +734,16 @@ class MemorySpaceAssignment { // All the positions where this use aliases with. The aliased positions // must get the same allocation. std::vector aliases; + + bool operator==(const Use& other) const { + return hlo_use == other.hlo_use && time == other.time && + aliases == other.aliases; + } + + template + friend H AbslHashValue(H h, const Use& s) { + return H::combine(std::move(h), s.hlo_use, s.time, s.aliases); + } }; AllocationValue(const HloValue* value, const HloPosition& position, @@ -823,6 +839,8 @@ class MemorySpaceAssignment { AllocationSequence allocations_; + HloModule* module() { return module_; } + private: // Process calls Process methods of the allocations after the allocations have // been finalized. @@ -871,29 +889,6 @@ class MemorySpaceAssignment { absl::flat_hash_map> schedule_before_; }; -// This struct contains mandatory memory assignments at a given time. E.g., an -// input's required memory assignment time would correspond to the definition -// time of the parameter instruction, and an output's time would correspond to -// the time of last use. -struct RequiredMemoryAssignment { - MemorySpaceAssignment::MemorySpace memory_space; - int64 time; - absl::optional chunk; - - bool equals_ignoring_time(const RequiredMemoryAssignment& other) const { - return memory_space == other.memory_space && chunk == other.chunk; - } - - bool operator==(const RequiredMemoryAssignment& other) const { - return memory_space == other.memory_space && time == other.time && - chunk == other.chunk; - } - - bool operator!=(const RequiredMemoryAssignment& other) const { - return !(*this == other); - } -}; - // A struct representing an asynchronous copy with its logical start and end // time and its destination memory space. struct AsynchronousCopy { @@ -972,6 +967,38 @@ class AlternateMemoryBestFitHeap HeapSimulator::Result Finish() override; + protected: + // Given a buffer interval, returns the colocated intervals. Unlike the + // similar GlobalDecreasingSizeBestFitHeap::GetTransitiveColocations, it + // returns the colocated intervals sorted by scheduled time. + std::vector GetSortedColocatedIntervals( + const BufferInterval& interval) const; + + // Given a BufferInterval, creates AllocationValue objects and corresponding + // AllocationSequences and appends them into allocation_sequence_list_. + void CreateAllocationValues( + const BufferInterval& buffer_interval, + std::vector& allocation_values) const; + + // Given colocated intervals, populates allocation_values with the + // corresponding AllocationValue objects. + void CreateAllocationValuesFromColocatedIntervals( + absl::Span + colocated_intervals, + std::vector& allocation_values); + + // Go through all the uses in the AllocationValues and find the aliasing + // positions. + void FindAliases(std::vector* allocation_values, + bool skip_values_with_no_uses) const; + + MemorySpaceAssignment::AllocationSequence* allocations() { + return allocations_; + } + const MemorySpaceAssignment::Options& options() { return options_; } + const HloAliasAnalysis& alias_analysis() { return alias_analysis_; } + const HloLiveRange& hlo_live_range() { return hlo_live_range_; } + private: // We inherit AllocationBlock struct to attach the Allocation information to // make importing repacked offsets easier. @@ -980,6 +1007,13 @@ class AlternateMemoryBestFitHeap MemorySpaceAssignment::Allocation* allocation; }; + // A data structure we use to associate Allocation objects that are aliased + // and must get the same offset. + struct AliasedOffset { + int64 offset; + absl::flat_hash_set allocations; + }; + // An allocation request for a use segment. A use segment is the time segment // between the definition and the first use, and the time segment between the // uses of a buffer. For example, the time between the definition and Use1, is @@ -1007,11 +1041,34 @@ class AlternateMemoryBestFitHeap int64 size; bool allow_no_copy_alternate_mem_allocation; absl::optional earliest_prefetch_time; - absl::optional preferred_offset; + AliasedOffset* preferred_offset; const MemorySpaceAssignment::AllocationValue::Use* use; MemorySpaceAssignment::AllocationValue* allocation_value; }; + // This struct contains mandatory memory assignments at a given time. E.g., an + // input's required memory assignment time would correspond to the definition + // time of the parameter instruction, and an output's time would correspond to + // the time of last use. + struct RequiredMemoryAssignment { + MemorySpaceAssignment::MemorySpace memory_space; + int64 time; + AliasedOffset* offset; + + bool equals_ignoring_time(const RequiredMemoryAssignment& other) const { + return memory_space == other.memory_space && offset == other.offset; + } + + bool operator==(const RequiredMemoryAssignment& other) const { + return memory_space == other.memory_space && time == other.time && + offset == other.offset; + } + + bool operator!=(const RequiredMemoryAssignment& other) const { + return !(*this == other); + } + }; + // Result of an allocation, prefetch, eviction etc. request. The result is // either kSuccess or a bitwise OR of one or more failures. The values are // unique powers of two. To check if a result contains a particular failure, @@ -1068,6 +1125,17 @@ class AlternateMemoryBestFitHeap result_is(result, Result::kFailViolatesAsyncCopyOrdering); } + // Returns the AliasedOffset object associated with the allocation. + AliasedOffset* GetAliasedOffset( + const MemorySpaceAssignment::Allocation& allocation); + + // If aliased_offset is non-null, this method adds the allocation to + // aliased_offset. Otherwise, it creates a new AliasedOffset object and adds + // the allocation to this new AliasedOffset. + void CreateOrAddToAliasedOffset( + const MemorySpaceAssignment::Allocation& allocation, + AliasedOffset* aliased_offset); + // Given an allocation sequence, returns the live allocation at time with a // preference towards allocations in alternate memory. Returns nullptr if no // allocation is alive at that time. @@ -1078,18 +1146,6 @@ class AlternateMemoryBestFitHeap bool IsUseAllowedInAlternateMemory(const AllocationValue& value, const HloUse& use) const; - // Given a BufferInterval, creates AllocationValue objects and corresponding - // AllocationSequences and appends them into allocation_sequence_list_. - void CreateAllocationValues( - const BufferInterval& buffer_interval, - std::vector& allocation_values) const; - - // Given colocated intervals, populates allocation_values with the - // corresponding AllocationValue objects. - void CreateAllocationValuesFromColocatedIntervals( - absl::Span colocated_intervals, - std::vector& allocation_values); - // Finds allocations for allocation values generated from colocated intervals. // All of the allocation values have a must-alias relationship with each // other. Returns either kSuccess if all of the sites could be placed in the @@ -1097,10 +1153,6 @@ class AlternateMemoryBestFitHeap Result AllocateAllocationValues( absl::Span allocation_values); - // Go through all the uses in the AllocationValues and find the aliasing - // positions. - void FindAliases(std::vector* allocation_values) const; - // Finds an allocation for an allocation request for a segment (see the // documentation for AllocationRequest above how a segment is defined). // @@ -1140,7 +1192,7 @@ class AlternateMemoryBestFitHeap // availability if no preferred offset is given, or at the preferred_offset if // it is given. absl::optional FindBestChunkCandidate( - const AllocationRequest& request, absl::optional preferred_offset, + const AllocationRequest& request, const AliasedOffset* preferred_offset, BufferInterval* alternate_mem_interval) const; // Returns the required assignment at a particular time, if available. @@ -1152,6 +1204,11 @@ class AlternateMemoryBestFitHeap absl::optional AliasedRequiredAssignmentForUse( const AllocationValue::Use& use) const; + // Goes through the colocated intervals and adds any required assignment. + void AddRequiredAssignmentsForColocatedIntervals( + absl::Span + colocated_intervals); + // Propagates aliased required assignment for a given position. void AddAliasedRequiredAssignment( const HloInstruction* instruction, ShapeIndex index, @@ -1162,10 +1219,10 @@ class AlternateMemoryBestFitHeap void AddRequiredAssignment(const HloValue* value, const HloInstruction* instruction, MemorySpace memory_space, int64 time, - absl::optional chunk = absl::nullopt); + AliasedOffset* offset = nullptr); void AddRequiredAssignment(const HloInstruction* instruction, ShapeIndex index, MemorySpace memory_space, - absl::optional chunk = absl::nullopt); + AliasedOffset* offset = nullptr); // Adds input and outputs as required assignments. void AddInputAndOutputRequiredAssignments(); @@ -1176,12 +1233,6 @@ class AlternateMemoryBestFitHeap bool AreIntervalsReservedInAlternateMemory( absl::Span colocated_intervals) const; - // Given a buffer interval, returns the colocated intervals. Unlike the - // similar GlobalDecreasingSizeBestFitHeap::GetTransitiveColocations, it - // returns the colocated intervals sorted by scheduled time. - std::vector GetSortedColocatedIntervals( - const BufferInterval& interval) const; - // Since the allocations are recorded to the AllocationSequence, we don't // maintain result_ in GlobalDecreasingSizeBestFitHeap. Override AddToChunkMap // to avoid unnecessarily adding the chunk to the chunk map. @@ -1216,6 +1267,7 @@ class AlternateMemoryBestFitHeap int64 start_time, int64 end_time, int64 copy_done_schedule_before_time, MemorySpaceAssignment::AllocationSequence* allocations, + AliasedOffset* aliased_offset, bool is_cross_program_prefetch = false); // This method is used for committing the chunk candidate but adding it to @@ -1284,6 +1336,11 @@ class AlternateMemoryBestFitHeap std::vector pending_async_copies_; std::vector> pending_required_assignments_; + // The data structure that contains AliasedOffset objects and Allocation to + // AliasedOffset map for efficient lookup. + std::list aliased_offsets_; + absl::flat_hash_map + aliased_offset_map_; // This map contains required memory assignments for HloValues (e.g., input // and outputs). absl::flat_hash_map> diff --git a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc index cc4f740bc25..187076abe8a 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc @@ -232,6 +232,24 @@ class MemorySpaceAssignmentTest : public HloTestBase, return copies; } + int64 GetAlternateMemoryOffset(const PresetAssignments& preset_assignments, + const HloInstruction* instruction, + const ShapeIndex& index = {}) const { + // Returns the offset of the assignment, -1 if it's not in the alternate + // memory. + const HloModule* module = instruction->parent()->parent(); + auto alias_analysis = HloAliasAnalysis::Run(module).ValueOrDie(); + HloBuffer& buffer = alias_analysis->GetUniqueBufferAt(instruction, index); + for (auto& pos_and_chunk : preset_assignments.chunks()) { + for (auto& value : buffer.values()) { + if (pos_and_chunk.first == value->defining_position()) { + return pos_and_chunk.second.offset; + } + } + } + return -1; + } + std::unique_ptr CreateEvictAndPrefetchModule() { HloComputation::Builder builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); @@ -4066,22 +4084,78 @@ TEST_P(MemorySpaceAssignmentTest, MoveCopyDoneEarlier) { find_schedule_index(cos->operand(0))); } +TEST_P(MemorySpaceAssignmentTest, BitcastRoot) { + // Tests against a bug where the root of entry computation is a bitcast + // instruction and it ends up getting an allocation in the alternate memory. + absl::string_view hlo_string = R"( +HloModule primitive_computation_gather.4, is_scheduled=true + +%while_body { + %param.1 = (s32[], f32[3,3,3]) parameter(0) + %get-tuple-element.32 = s32[] get-tuple-element(%param.1), index=0 + %copy.6 = s32[] copy(s32[] %get-tuple-element.32) + %constant.8 = s32[] constant(1) + %add = s32[] add(s32[] %copy.6, s32[] %constant.8) + %get-tuple-element.35 = f32[3,3,3] get-tuple-element(%param.1), index=1 + negate = f32[3,3,3] negate(get-tuple-element.35) + ROOT %tuple.10 = (s32[], f32[3,3,3]) tuple(s32[] %add, f32[3,3,3] negate) +} + +%while_cond { + %param.0 = (s32[], f32[3,3,3]) parameter(0) + %get-tuple-element = s32[] get-tuple-element(%param.0), index=0 + %constant.3 = s32[] constant(3) + ROOT %compare = pred[] compare(s32[] %get-tuple-element, s32[] %constant.3), direction=LT +} + +ENTRY %primitive_computation_gather.4 (parameter.1: f32[3,10,5], parameter.2: s32[3,1]) -> f32[3,3,3] { + %constant.1 = s32[] constant(0) + %copy.11 = s32[] copy(s32[] %constant.1) + %constant = f32[] constant(0) + %broadcast = f32[3,3,3] broadcast(f32[] %constant), dimensions={} + %tuple.8 = (s32[], f32[3,10,5], s32[3,1], f32[3,3,3]) tuple(s32[] %copy.11, f32[3,3,3] %broadcast) + %while = (s32[], f32[3,3,3]) while(%tuple.8), condition=%while_cond, body=%while_body + %get-tuple-element.7 = f32[3,3,3] get-tuple-element(%while), index=1 + ROOT %bitcast.1 = f32[3,3,3] bitcast(f32[3,3,3] %get-tuple-element.7) +} + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + AssignMemorySpace(module.get()); + + const HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_TRUE(!root->shape().has_layout() || + root->shape().layout().memory_space() == kDefaultMemorySpace); +} + // A mock MemorySpaceAssignmentRepacker class that accepst a map of // (start_time,offset) -> new_offset values. Using this map, the repacker // repacks the allocations to the new_offset. class FakeMemorySpaceAssignmentRepacker : public MemorySpaceAssignmentRepacker { public: explicit FakeMemorySpaceAssignmentRepacker( - absl::flat_hash_map, int64>& repack_map) + absl::flat_hash_map, int64>& repack_map, + std::function)> check_fun = nullptr, + bool always_return_modified = false) : MemorySpaceAssignmentRepacker(/*max_size=*/128, /*alignment=*/8), - repack_map_(repack_map) {} + repack_map_(repack_map), + check_fun_(check_fun), + always_return_modified_(always_return_modified) {} StatusOr Repack(absl::Span allocations) override { bool modified = false; for (AllocationBlock* block : allocations) { - VLOG(1) << "Alloc time: [" << block->start_time << ", " << block->end_time - << "] size: " << block->size - << " init offset: " << block->initial_offset; + absl::flat_hash_set colocations; + std::string colocations_str; + for (const AllocationBlock* colocation : block->colocations) { + absl::StrAppend(&colocations_str, colocation->id, ", "); + colocations.insert(colocation->id); + } + VLOG(1) << "Alloc id: " << block->id << " time: [" << block->start_time + << ", " << block->end_time << "] size: " << block->size + << " init offset: " << block->initial_offset << " colocations: {" + << colocations_str << "}"; auto it = repack_map_.find({block->start_time, block->initial_offset}); if (it != repack_map_.end()) { modified = true; @@ -4090,8 +4164,6 @@ class FakeMemorySpaceAssignmentRepacker : public MemorySpaceAssignmentRepacker { block->offset = block->initial_offset; } for (AllocationBlock* colocation : block->colocations) { - VLOG(1) << " [" << colocation->start_time << ", " - << colocation->end_time << "]"; if (it != repack_map_.end()) { colocation->offset = it->second; } else { @@ -4099,13 +4171,18 @@ class FakeMemorySpaceAssignmentRepacker : public MemorySpaceAssignmentRepacker { } } } + if (check_fun_) { + check_fun_(allocations); + } - return modified; + return always_return_modified_ || modified; } private: // A map from (start_time, offset) to new_offset. absl::flat_hash_map, int64> repack_map_; + std::function)> check_fun_; + bool always_return_modified_; }; TEST_P(MemorySpaceAssignmentTest, Repack) { @@ -4229,6 +4306,181 @@ TEST_P(MemorySpaceAssignmentTest, Repack) { EXPECT_EQ(d->shape().layout().memory_space(), kAlternateMemorySpace); } +TEST_P(MemorySpaceAssignmentTest, RepackExportsAliasedOffsets) { + // This test is that we are correctly exporting aliased offsets for repacking. + // In this example, the buffer produced at HLO "a" will be allocated first, + // and will consist of four allocations: + // 1) a produced in the alternate memory (and then evicted to the default + // memory). 2) a prefetched to the alternate memory to be used by q and + // while HLOs. 3) a used within the while loop body. 4) the output of while + // HLO, used by u. + // + // Since a will be allocated first (the test is crafted to prioritize sine + // HLO), all four allocations should get the same (zero) offsets. However, + // while allocations 2, 3, and 4 need to be colocated with each other, + // allocation 1 doesn't need to be colocated with the other three. + absl::string_view hlo_string = R"( + HloModule bug, is_scheduled=true + + while_condition { + param1 = (f32[2,4], f32[2,4]) parameter(0) + ROOT cond = pred[] constant(true) + } + + while_body { + param2 = (f32[2,4], f32[2,4]) parameter(0) + gte2 = f32[2,4] get-tuple-element(param2), index=0 + gte3 = f32[2,4] get-tuple-element(param2), index=1 + add = f32[2,4] add(gte2, gte3) + ROOT tuple2 = (f32[2,4], f32[2,4]) tuple(add, gte3) + } + + ENTRY Entry { + param0 = f32[2,4] parameter(0) + a = f32[2,4] sine(param0) + b = f32[2,4] negate(a) + c = f32[2,4] negate(b) + d = f32[2,4] negate(c) + e = f32[2,4] negate(d) + f = f32[2,4] negate(e) + g = f32[2,4] negate(f) + h = f32[2,4] negate(g) + i = f32[2,4] negate(h) + j = f32[2,4] negate(i) + k = f32[2,4] negate(j) + l = f32[2,4] negate(k) + m = f32[2,4] negate(l) + n = f32[2,4] negate(m) + o = f32[2,4] negate(n) + p = f32[2,4] negate(o) + q = f32[2,4] add(p, a) + tuple = (f32[2,4], f32[2,4]) tuple(q, a) + while = (f32[2,4], f32[2,4]) while(tuple), condition=while_condition, body=while_body + gte0 = f32[2,4] get-tuple-element(while), index=0 + gte1 = f32[2,4] get-tuple-element(while), index=1 + r = f32[2,4] negate(gte0) + s = f32[2,4] negate(r) + t = f32[2,4] negate(s) + constant = f32[] constant(0) + broadcast = f32[8,4] broadcast(constant), dimensions={} + cos = f32[8,4] cosine(broadcast) + u = f32[2,4] add(t, gte1) + v = f32[2,4] add(u, param0) + w = f32[8,4] negate(cos) + ROOT tuple3 = (f32[2,4], f32[8,4]) tuple(v, w) + } + )"; + + MemorySpaceAssignment::BufferIntervalCompare buffer_interval_compare = + [](const MemorySpaceAssignment::BufferInterval& a, + const MemorySpaceAssignment::BufferInterval& b) { + auto get_opcode_priority = [](const HloOpcode& opcode) { + switch (opcode) { + case HloOpcode::kSin: + return 0; + case HloOpcode::kCos: + return 1; + case HloOpcode::kTanh: + return 2; + default: + return 3; + } + }; + + return get_opcode_priority(a.buffer->defining_instruction()->opcode()) < + get_opcode_priority(b.buffer->defining_instruction()->opcode()); + }; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + InstructionCountPrefetchIntervalPicker prefetch_interval_picker(2, 10); + absl::flat_hash_map, int64> repack_map; + + // Expect that of the four separate allocations for the "a" buffer, the first + // and the next three are in separate colocations. + auto check_fun = + [](absl::Span + allocations) { + EXPECT_TRUE(allocations.at(0)->colocations.size() == 1 || + allocations.at(0)->colocations.size() == 3); + EXPECT_EQ(allocations.at(1)->colocations.size(), 3); + EXPECT_EQ(allocations.at(2)->colocations.size(), 3); + EXPECT_TRUE(allocations.at(3)->colocations.size() == 1 || + allocations.at(3)->colocations.size() == 3); + }; + FakeMemorySpaceAssignmentRepacker repacker = + FakeMemorySpaceAssignmentRepacker(repack_map, check_fun); + MemorySpaceAssignment::Options options; + options.max_size_in_bytes = 128; + options.alignment_in_bytes = 8; + options.verify = true; + options.max_repacks = 1; + options.repacker = &repacker; + AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1, + buffer_interval_compare, &prefetch_interval_picker, + options); +} + +TEST_P(MemorySpaceAssignmentTest, + RepackShouldntEraseRequiredAssignmentForConditionalOutput) { + // This is a test case for b/171040271. Repacks erase the required assignments + // (since some required assignments are inserted conditionally based on + // allocation decisions), including the fact that conditional outputs are + // always required to get assignments in the default memory. After repacking, + // this required assignment was never added back, causing conditionals to get + // alternate-memory allocations. + absl::string_view hlo_string = R"( + HloModule CondAllocation, is_scheduled=true + + true_computation { + p0 = (f32[3]) parameter(0) + gte = f32[3] get-tuple-element(p0), index=0 + neg1 = f32[3] negate(gte) + ROOT tuple1 = (f32[3]) tuple(neg1) + } + + false_computation { + p0 = (f32[3]) parameter(0) + gte = f32[3] get-tuple-element(p0), index=0 + neg2 = f32[3] negate(gte) + ROOT tuple2 = (f32[3]) tuple(neg2) + } + + ENTRY entry { + p0 = f32[3] parameter(0) + p1 = pred[] parameter(1) + copy = f32[3] copy(p0) + tuple = (f32[3]) tuple(copy) + conditional = (f32[3]) conditional(p1, tuple, tuple), true_computation=true_computation, false_computation=false_computation + ROOT gte = f32[3] get-tuple-element(conditional), index=0 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + absl::flat_hash_map, int64> repack_map; + FakeMemorySpaceAssignmentRepacker repacker = + FakeMemorySpaceAssignmentRepacker(repack_map, nullptr, + /*always_return_modified=*/true); + MemorySpaceAssignment::Options options; + options.max_size_in_bytes = 128; + options.alignment_in_bytes = 8; + options.verify = true; + options.max_repacks = 10; + options.repacker = &repacker; + options.repack_after_every_allocation = true; + InstructionCountPrefetchIntervalPicker prefetch_interval_picker(2, 10); + AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1, + /*buffer_interval_compare=*/{}, &prefetch_interval_picker, + options); + // Make sure the root of the entry computation is in the default memory space. + EXPECT_EQ(module->entry_computation() + ->root_instruction() + ->shape() + .layout() + .memory_space(), + kDefaultMemorySpace); +} + TEST_P(MemorySpaceAssignmentTest, Determinism) { // Run memory space assignment a few times to make sure every time it compiles // to the same thing. @@ -4244,6 +4496,47 @@ TEST_P(MemorySpaceAssignmentTest, Determinism) { } } +TEST_P(MemorySpaceAssignmentTest, InPlaceOp) { + // Tests that in-place ops like DynamicUpdateSlice get the same allocation as + // its input. + absl::string_view hlo_string = R"( +HloModule Module, is_scheduled=true + +fused_computation { + param0 = f32[2,3] parameter(0) + constant.1 = f32[] constant(0) + broadcast = f32[2,1] broadcast(constant.1), dimensions={} + constant.3 = s32[] constant(0) + ROOT dynamic-update-slice.5 = f32[2,3] dynamic-update-slice(param0, broadcast, constant.3, constant.3) +} + +ENTRY main { + param = f32[2,3] parameter(0) + negate = f32[2,3] negate(param) + fusion = f32[2,3] fusion(negate), kind=kLoop, calls=fused_computation + ROOT add = f32[2,3] add(fusion, fusion) +} + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + auto preset_assignments = AssignMemorySpace(module.get()); + HloInstruction* negate_instruction = + module->entry_computation()->GetInstructionWithName("negate"); + int64 negate_offset = + GetAlternateMemoryOffset(*preset_assignments, negate_instruction); + HloInstruction* fusion_instruction = + module->entry_computation()->GetInstructionWithName("fusion"); + int64 fusion_offset = + GetAlternateMemoryOffset(*preset_assignments, fusion_instruction); + // We expect negate and fusion to get the same offsets. + EXPECT_EQ(negate_offset, fusion_offset); + const bool allocate_across_sequential_calls = GetParam(); + if (allocate_across_sequential_calls) { + EXPECT_NE(negate_offset, -1); + } +} + INSTANTIATE_TEST_SUITE_P(MemorySpaceAssignmentInstantiation, MemorySpaceAssignmentTest, ::testing::Values(false, true)); @@ -4918,5 +5211,75 @@ TEST_F(CostAnalysisPrefetchIntervalPickerTest, NestedWhile) { 4); } +TEST_F(CostAnalysisPrefetchIntervalPickerTest, ConsecutiveConditionals) { + // This is a test for b/170668492, where prefetching for consecutive + // conditionals can cause the prefetch to start in the conditional's + // computation. + absl::string_view hlo_string = R"( + HloModule bug, is_scheduled=true + + true_computation.0 { + p0 = (f32[3]{0}) parameter(0) // 5 + gte = f32[3]{0} get-tuple-element(p0), index=0 // 6 + ROOT neg1 = f32[3]{0} negate(gte) // 7 + } + + false_computation.0 { + p0 = (f32[3]{0}) parameter(0) // 8 + gte = f32[3]{0} get-tuple-element(p0), index=0 // 9 + ROOT neg2 = f32[3]{0} negate(gte) // 10 + } + + true_computation.1 { + p0 = (f32[3]{0}) parameter(0) // 12 + gte = f32[3]{0} get-tuple-element(p0), index=0 // 13 + ROOT neg1 = f32[3]{0} negate(gte) // 14 + } + + false_computation.1 { + p0 = (f32[3]{0}) parameter(0) // 15 + gte = f32[3]{0} get-tuple-element(p0), index=0 // 16 + ROOT neg2 = f32[3]{0} negate(gte) // 17 + } + + ENTRY entry { + p0 = f32[3]{0} parameter(0) // 0 + p1 = f32[3]{0} parameter(1) // 1 + p2 = pred[] parameter(2) // 2 + tuple0 = (f32[3]{0}) tuple(p0) // 3 + tuple1 = (f32[3]{0}) tuple(p1) // 4 + conditional0 = f32[3]{0} conditional(p2, tuple0, tuple0), true_computation=true_computation.0, false_computation=false_computation.0 // 11 + conditional1 = f32[3]{0} conditional(p2, tuple1, tuple1), true_computation=true_computation.1, false_computation=false_computation.1 // 18 + ROOT tuple2 = (f32[3]{0}, f32[3]{0}) tuple(conditional0, conditional1) // 19 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + HloCostAnalysis hlo_cost_analysis(ShapeSize); + TF_ASSERT_OK_AND_ASSIGN(auto cost_analysis, + FakeMemorySpaceAssignmentCostAnalysis::Create( + hlo_cost_analysis, *module)); + CostAnalysisPrefetchIntervalPicker interval_picker( + *cost_analysis, + /*min_async_copy_to_overlap_ratio=*/1.0, + /*max_async_copy_to_overlap_ratio=*/12.0, + /*preferred_async_copy_to_overlap_ratio=*/2.0); + + LOG(INFO) << module->ToString(); + + HloInstruction* conditional1 = + module->entry_computation()->GetInstructionWithName("conditional1"); + const HloUse use{conditional1, /*operand_number=*/1, /*operand_index=*/{0}}; + const Shape& shape = + module->entry_computation()->parameter_instruction(0)->shape(); + + // Expect that the prefetch to start before conditional0's called + // computations. + EXPECT_LT(interval_picker.LatestPrefetchStartTime(shape, /*start_time=*/0, + /*end_time=*/11, &use), + 5); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/memory_space_assignment_utils.cc b/tensorflow/compiler/xla/service/memory_space_assignment_utils.cc index 0c44ae0d766..aad943aaad7 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment_utils.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment_utils.cc @@ -15,6 +15,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/memory_space_assignment_utils.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" + namespace xla { bool MemorySpaceAssignmentUtils::IsValueAllowedInAlternateMemory( @@ -87,6 +90,17 @@ bool MemorySpaceAssignmentUtils::IsValueAllowedInAlternateMemory( return false; } } + if (auto* custom_call = + DynCast(position.instruction)) { + for (const auto& pair : custom_call->output_to_operand_aliasing()) { + if (position.index == pair.first) { + VLOG(4) << "Keeping value " << value->ToShortString() + << " in default mem because it is a custom-call output that " + "aliases an operand buffer."; + return false; + } + } + } } return true; diff --git a/tensorflow/compiler/xla/service/mlir_gpu/BUILD b/tensorflow/compiler/xla/service/mlir_gpu/BUILD index 68bcde4f7ee..4eaed3a12e6 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/BUILD +++ b/tensorflow/compiler/xla/service/mlir_gpu/BUILD @@ -1,6 +1,14 @@ # Description: # MLIR-GPU-specific components in XLA service implementation. +load("//third_party/mlir:tblgen.bzl", "gentbl") + +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "filegroup") + +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "get_compatible_with_cloud") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load( "//tensorflow/core/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured", @@ -41,9 +49,12 @@ cc_library( srcs = ["emission_context.cc"], hdrs = ["emission_context.h"], deps = [ + "//tensorflow/compiler/mlir/hlo", + "//tensorflow/compiler/mlir/hlo:lhlo", "//tensorflow/compiler/xla/service:hlo", "@com_google_absl//absl/strings", "@llvm-project//mlir:IR", + "@llvm-project//mlir:StandardOps", ], ) @@ -65,7 +76,7 @@ cc_library( ":emission_context", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service/gpu:target_constants", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "@llvm-project//llvm:Core", "@llvm-project//mlir:IR", "@llvm-project//mlir:LLVMDialect", @@ -84,7 +95,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@llvm-project//llvm:Core", "@llvm-project//mlir:GPUDialect", - "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", "@llvm-project//mlir:IR", "@llvm-project//mlir:LLVMDialect", "@llvm-project//mlir:LLVMTransforms", @@ -106,7 +117,7 @@ cc_library( "//tensorflow/compiler/xla/service/gpu:stream_executor_util", "//tensorflow/compiler/xla/service/gpu:target_constants", "//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend", - "//tensorflow/core:cuda_libdevice_path", + "//tensorflow/core/platform:cuda_libdevice_path", "//tensorflow/core:lib", "//tensorflow/stream_executor/gpu:asm_compiler", ]), @@ -156,11 +167,21 @@ cc_library( ], ) +gentbl( + name = "passes_inc_gen", + compatible_with = get_compatible_with_cloud(), + tbl_outs = [("-gen-pass-decls -name XlaMlirGpu", "passes.h.inc")], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "passes.td", + td_srcs = ["@llvm-project//mlir:PassBaseTdFiles"], +) + cc_library( name = "passes", srcs = ["passes.cc"], hdrs = ["passes.h"], deps = [ + ":passes_inc_gen", "//tensorflow/compiler/mlir/hlo:lhlo", "@com_google_absl//absl/memory", "@llvm-project//llvm:Support", @@ -170,6 +191,7 @@ cc_library( "@llvm-project//mlir:Pass", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:SCFTransforms", + "@llvm-project//mlir:SideEffects", "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Transforms", ], @@ -182,15 +204,14 @@ cc_library( deps = [ ":passes", "//tensorflow/compiler/mlir/hlo", - "//tensorflow/compiler/mlir/hlo:hlo_dialect_force_registration", "//tensorflow/compiler/mlir/hlo:hlo_legalize_to_lhlo", - "//tensorflow/compiler/mlir/hlo:legalize_tanh_to_approximation", "//tensorflow/compiler/mlir/hlo:legalize_to_linalg", + "//tensorflow/compiler/mlir/hlo:legalize_trigonometric_to_approximation", "//tensorflow/compiler/mlir/hlo:lhlo", - "//tensorflow/compiler/mlir/hlo:lhlo_copy_removal", "//tensorflow/compiler/mlir/hlo:lhlo_fuse_linalg", "//tensorflow/compiler/mlir/hlo:lhlo_legalize_to_affine", "//tensorflow/compiler/mlir/hlo:lhlo_legalize_to_gpu", + "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", @@ -199,6 +220,7 @@ cc_library( "@llvm-project//mlir:CFGTransforms", "@llvm-project//mlir:GPUDialect", "@llvm-project//mlir:GPUToNVVMTransforms", + "@llvm-project//mlir:GPUToROCDLTransforms", "@llvm-project//mlir:GPUTransforms", "@llvm-project//mlir:IR", "@llvm-project//mlir:LLVMDialect", @@ -207,6 +229,7 @@ cc_library( "@llvm-project//mlir:LinalgTransforms", "@llvm-project//mlir:NVVMDialect", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:ROCDLDialect", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:SCFToGPUPass", "@llvm-project//mlir:SCFTransforms", @@ -255,6 +278,25 @@ tf_cc_binary( "//tensorflow/core:lib", "@llvm-project//llvm:Support", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SideEffects", + "@llvm-project//mlir:Support", + ], +) + +tf_cc_binary( + name = "xla-mlir-gpu-opt", + srcs = ["xla_mlir_gpu_opt.cc"], + visibility = ["//tensorflow/compiler/xla/service/mlir_gpu/tests:__subpackages__"], + deps = [ + ":passes", + "//tensorflow/compiler/mlir/hlo:all_passes", + "//tensorflow/compiler/mlir/hlo:hlo_dialect_registration", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:MlirOptLib", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SideEffects", "@llvm-project//mlir:Support", ], ) diff --git a/tensorflow/compiler/xla/service/mlir_gpu/emission_context.cc b/tensorflow/compiler/xla/service/mlir_gpu/emission_context.cc index cb5ea946c1b..06c7ebd1099 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/emission_context.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/emission_context.cc @@ -16,8 +16,11 @@ limitations under the License. #include "tensorflow/compiler/xla/service/mlir_gpu/emission_context.h" #include "absl/strings/substitute.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" namespace xla { @@ -25,7 +28,8 @@ namespace mlir_gpu { EmissionContext::EmissionContext(std::unique_ptr module) : module_(std::move(module)), context_() { - context_.loadAllGloballyRegisteredDialects(); + context_.loadDialect(); error_handler_ = [](const ErrorMap& instructions_with_error, HloModule* module) { std::set computations_with_error; diff --git a/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/BUILD b/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/BUILD index 8f56548ce77..74eef71870e 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/BUILD +++ b/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/BUILD @@ -1,6 +1,8 @@ # Description: # MLIR-GPU-specific convolution in XLA service implementation. +load("//tensorflow:tensorflow.bzl", "filegroup") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test") package( @@ -72,12 +74,14 @@ tf_cc_test( "//tensorflow/core:test_main", "//tensorflow/core/platform:test", "@llvm-project//llvm:Support", + "@llvm-project//mlir:Affine", "@llvm-project//mlir:AffineToStandardTransforms", - "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", "@llvm-project//mlir:CFGTransforms", "@llvm-project//mlir:IR", "@llvm-project//mlir:LLVMTransforms", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Transforms", ], ) diff --git a/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter_test.cc b/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter_test.cc index f7a7decff76..c868d205310 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter_test.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter_test.cc @@ -20,6 +20,8 @@ limitations under the License. #include "llvm/Support/raw_ostream.h" #include "mlir/Conversion/SCFToStandard/SCFToStandard.h" // from @llvm-project #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" // from @llvm-project +#include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Module.h" // from @llvm-project @@ -46,7 +48,7 @@ std::string CompileHloConvAndGetMlir(absl::string_view hlo_text) { hlo_module.entry_computation()->root_instruction(); mlir::MLIRContext context; - context.loadAllGloballyRegisteredDialects(); + context.loadDialect(); mlir::OwningModuleRef mlir_module( mlir::ModuleOp::create(mlir::UnknownLoc::get(&context))); diff --git a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc index 1b2edec7d61..a664a316e13 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc @@ -18,6 +18,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" // from @llvm-project #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" // from @llvm-project +#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h" // from @llvm-project #include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" // from @llvm-project #include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h" // from @llvm-project #include "mlir/Conversion/SCFToStandard/SCFToStandard.h" // from @llvm-project @@ -26,6 +27,7 @@ limitations under the License. #include "mlir/Dialect/GPU/Passes.h" // from @llvm-project #include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project #include "mlir/Dialect/LLVMIR/NVVMDialect.h" // from @llvm-project +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" // from @llvm-project #include "mlir/Dialect/Linalg/Passes.h" // from @llvm-project #include "mlir/Dialect/SCF/Passes.h" // from @llvm-project #include "mlir/Dialect/SCF/Transforms.h" // from @llvm-project @@ -33,11 +35,12 @@ limitations under the License. #include "mlir/IR/Dialect.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project -#include "mlir/Transforms/BufferPlacement.h" // from @llvm-project +#include "mlir/Transforms/Bufferize.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" #include "tensorflow/compiler/xla/service/mlir_gpu/passes.h" #include "tensorflow/compiler/xla/util.h" @@ -46,7 +49,7 @@ namespace mlir_gpu { Status LowerLHLOToGPU(mlir::ModuleOp module, LowerLHLOToGPUOptions options) { mlir::PassManager pm(module.getContext()); - applyPassManagerCLOptions(pm); + tensorflow::applyTensorflowAndCLOptions(pm); // We have to anticipate later unrolling in tiling to make sure that we get // the requested tiling after unrolling. Compute the new tiling here if @@ -71,7 +74,7 @@ Status LowerLHLOToGPU(mlir::ModuleOp module, LowerLHLOToGPUOptions options) { // Next, we can strip the outer fusion operation. pm.addPass(createFusionOpRemoverPass()); // Remove unnecessary LHLO copies. - pm.addPass(::mlir::lmhlo::createLhloCopyRemovalPass()); + pm.addPass(::mlir::createCopyRemovalPass()); // Transform LHLO operations to LinAlg. pm.addPass(::mlir::lmhlo::createLegalizeLhloToLinalgPass()); // Fuse linalg operations. @@ -120,10 +123,8 @@ Status LowerLHLOToGPU(mlir::ModuleOp module, LowerLHLOToGPUOptions options) { // Approximate of requested. if (options.use_approximations) { pm.addNestedPass<::mlir::FuncOp>( - ::mlir::mhlo::createLegalizeTanhToApproximationPass()); + ::mlir::mhlo::createLegalizeTrigonometricToApproximationPass()); } - // Move scalar operations into the launch to ensure smaller signatures. - pm.addPass(createMoveScalarComputationsIntoGpuLaunchPass()); // Take launches to launches with kernels. pm.addPass(::mlir::createGpuKernelOutliningPass()); // Make sure the kernel signature resembled the original function's @@ -179,7 +180,7 @@ class LowerToNVVMPass Status LowerKernelBodiesToNVVM(mlir::ModuleOp module) { // We cannot verify as the signature of the kernel is rewritten. ::mlir::PassManager pm(module.getContext(), /*verifyPasses=*/false); - applyPassManagerCLOptions(pm); + tensorflow::applyTensorflowAndCLOptions(pm); // Rewrite kernel functions to LLVM IR. auto& kernelPm = pm.nest<::mlir::gpu::GPUModuleOp>(); @@ -197,6 +198,85 @@ Status LowerKernelBodiesToNVVM(mlir::ModuleOp module) { return Status::OK(); } +namespace { + +/// A pass that does the final lowering to ROCDL. It collects all the patterns +/// that are currently required, currently mixing std, linalg and gpu. +class LowerToROCDLPass + : public ::mlir::PassWrapper< + LowerToROCDLPass, ::mlir::OperationPass<::mlir::gpu::GPUModuleOp>> { + void getDependentDialects(mlir::DialectRegistry& registry) const override { + registry.insert(); + } + + public: + void runOnOperation() override { + ::mlir::gpu::GPUModuleOp m = getOperation(); + + ::mlir::OwningRewritePatternList patterns; + ::mlir::populateGpuRewritePatterns(m.getContext(), patterns); + ::mlir::applyPatternsAndFoldGreedily(m, patterns); + patterns.clear(); + + ::mlir::LLVMTypeConverter converter(m.getContext()); + ::mlir::populateStdToLLVMConversionPatterns(converter, patterns); + // TODO(b/145824979) Remove linalg once sliceop is in std. + ::mlir::populateLinalgToLLVMConversionPatterns(converter, patterns, + &getContext()); + ::mlir::populateGpuToROCDLConversionPatterns(converter, patterns); + ::mlir::populateAffineToStdConversionPatterns(patterns, m.getContext()); + + ::mlir::ConversionTarget target(getContext()); + target.addIllegalDialect<::mlir::gpu::GPUDialect>(); + target + .addIllegalOp(); + target.addIllegalOp(); + target.addLegalDialect<::mlir::LLVM::LLVMDialect>(); + target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>(); + // TODO(csigg): Remove once we support replacing non-root ops. + target.addLegalOp<::mlir::gpu::GPUModuleOp, ::mlir::gpu::ModuleEndOp, + ::mlir::gpu::YieldOp>(); + if (failed(mlir::applyFullConversion(m, target, patterns))) { + signalPassFailure(); + } + } +}; + +} // namespace + +Status LowerKernelBodiesToROCDL(mlir::ModuleOp module) { + // We cannot verify as the signature of the kernel is rewritten. + ::mlir::PassManager pm(module.getContext(), /*verifyPasses=*/false); + tensorflow::applyTensorflowAndCLOptions(pm); + + auto enable_if_vlog_is_on = [](mlir::Pass*, mlir::Operation*) { + return VLOG_IS_ON(1); + }; + pm.enableIRPrinting(/*shouldPrintBeforePass=*/{}, + /*shouldPrintAfterPass=*/enable_if_vlog_is_on, + /*printModuleScope=*/false, + /*printAfterOnlyOnChange=*/false, + /*out=*/llvm::dbgs()); + + // Rewrite kernel functions to LLVM IR. + auto& kernelPm = pm.nest<::mlir::gpu::GPUModuleOp>(); + kernelPm.addPass(::mlir::createLowerToCFGPass()); + kernelPm.addPass(absl::make_unique()); + + // Some basic cleanup. + kernelPm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass()); + kernelPm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass()); + // Remove all location information to prevent a debug build. + kernelPm.addPass(::mlir::createStripDebugInfoPass()); + + if (failed(pm.run(module))) { + return InternalError("Lowering to ROCDL IR failed."); + } + return Status::OK(); +} + StatusOr ExtractKernelModule(mlir::ModuleOp module) { auto kernelModule = ::mlir::ModuleOp::create(module.getLoc()); // TODO(b/137624192): This also needs to resolve naming conflicts. diff --git a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.h b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.h index bd633bb06cb..290550142ec 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.h +++ b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.h @@ -36,6 +36,8 @@ Status LowerLHLOToGPU(mlir::ModuleOp module, Status LowerKernelBodiesToNVVM(mlir::ModuleOp module); +Status LowerKernelBodiesToROCDL(mlir::ModuleOp module); + StatusOr ExtractKernelModule(mlir::ModuleOp module); } // namespace mlir_gpu diff --git a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler_impl.cc b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler_impl.cc index c7977aa776a..f00f46b83c1 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler_impl.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler_impl.cc @@ -438,7 +438,6 @@ StatusOr> TransformKernelToXlaThunk( // Finally, create the thunk and set the launch dimensions. gpu::Thunk::ThunkInfo info; - info.hlo_instruction = instr; auto thunk = absl::make_unique(info, buffers, kernel.getName().str()); @@ -563,9 +562,20 @@ StatusOr> MlirCompilerImpl::RunBackend( auto ptx, xla::gpu::nvptx::CompileToPtx(llvmModule.get(), GetGpuVersion(stream_exec), config, GetLibdeviceDir(config))); - TF_ASSIGN_OR_RETURN( - auto cubin, se::CompileGpuAsm(stream_exec->device_ordinal(), ptx.c_str(), - gpu::PtxOptsFromConfig(config))); + // Allow to fallback to the driver compilation when ptxas isn't able to + // compile. + StatusOr> maybe_cubin = + se::CompileGpuAsm(stream_exec->device_ordinal(), ptx.c_str(), + gpu::PtxOptsFromConfig(config)); + std::vector cubin; + if (maybe_cubin.ok()) { + cubin = std::move(maybe_cubin).ValueOrDie(); + } else if (maybe_cubin.status().code() == + tensorflow::error::Code::UNIMPLEMENTED) { + xla::gpu::WarnIfBadDriverJITVersion(); + } else { + return maybe_cubin.status(); + } auto thunk_schedule = absl::make_unique( std::make_unique(std::move(thunk_sequence)), @@ -580,7 +590,7 @@ StatusOr> MlirCompilerImpl::RunBackend( return {absl::make_unique( ptx, cubin, GetGpuVersion(stream_exec), std::move(thunk_schedule), emission_context.releaseHloModule(), std::move(buffer_assignment), - nullptr, nullptr)}; + nullptr, nullptr, std::vector())}; } StatusOr>> MlirCompilerImpl::Compile( diff --git a/tensorflow/compiler/xla/service/mlir_gpu/passes.cc b/tensorflow/compiler/xla/service/mlir_gpu/passes.cc index 887f14e90d9..84751bc0507 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/passes.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/passes.cc @@ -24,6 +24,7 @@ limitations under the License. #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project #include "mlir/Transforms/LoopUtils.h" // from @llvm-project #include "mlir/Transforms/RegionUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" @@ -32,8 +33,10 @@ namespace xla { namespace mlir_gpu { namespace { -struct FusionOpRemoverPass - : public mlir::PassWrapper { +#define GEN_PASS_CLASSES +#include "tensorflow/compiler/xla/service/mlir_gpu/passes.h.inc" + +struct FusionOpRemoverPass : FusionOpRemoverPassBase { void runOnFunction() override { getFunction().walk([&](mlir::lmhlo::FusionOp op) { mlir::OpBuilder builder(op); @@ -52,8 +55,22 @@ struct FusionOpRemoverPass } }; -struct StoreForwardingPass - : mlir::PassWrapper { +template +bool HasEffectsOnValue(mlir::Value value, mlir::Operation* op) { + auto mem_effects_interface = + mlir::dyn_cast_or_null(op); + if (!mem_effects_interface) { + return false; + } + llvm::SmallVector effects; + mem_effects_interface.getEffects(effects); + return llvm::any_of(effects, + [op](const mlir::MemoryEffects::EffectInstance& effect) { + return mlir::isa(effect.getEffect()); + }); +} + +struct StoreForwardingPass : StoreForwardingPassBase { mlir::StoreOp findStore(mlir::Operation* op, std::function matches) { // Search from op upwards in the current block. @@ -86,10 +103,9 @@ struct StoreForwardingPass while (auto subviewOp = mlir::dyn_cast_or_null(defOp)) { defOp = subviewOp.source().getDefiningOp(); } - if (auto allocOp = mlir::dyn_cast_or_null(defOp)) { - return allocOp.getOperation(); - } - return nullptr; + return HasEffectsOnValue(memref, defOp) + ? defOp + : nullptr; } // Retrieves AllocOp from the cache or actually looks for it. @@ -100,7 +116,7 @@ struct StoreForwardingPass if (allocOpIt != memrefToAllocOp->end()) { return allocOpIt->second; } - auto allocOp = SearchAllocOp(memref); + mlir::Operation* allocOp = SearchAllocOp(memref); memrefToAllocOp->insert({memref, allocOp}); return allocOp; } @@ -132,7 +148,7 @@ struct StoreForwardingPass }; struct DeadTempBufferRemovalPass - : mlir::PassWrapper { + : DeadTempBufferRemovalPassBase { bool operationConsideredDead(mlir::Operation* op) { for (auto result : op->getResults()) { if (!llvm::all_of(result.getUsers(), [&](mlir::Operation* op) { @@ -168,13 +184,18 @@ struct DeadTempBufferRemovalPass void runOnFunction() override { llvm::SmallVector dead_ops; - getFunction().walk([&](mlir::AllocOp allocOp) { - if (!operationConsideredDead(allocOp)) { + getFunction().walk([&](mlir::Operation* op) { + if (op->getNumResults() != 1 || + !HasEffectsOnValue(op->getResult(0), + op)) { + return; + } + if (!operationConsideredDead(op)) { return; } // TODO(herhut): There should be a generic helper for this. - recursiveErase(allocOp, &dead_ops); + recursiveErase(op, &dead_ops); }); for (auto op : dead_ops) { op->erase(); @@ -182,66 +203,8 @@ struct DeadTempBufferRemovalPass } }; -struct MoveScalarComputationsIntoGpuLaunchPass - : mlir::PassWrapper { - static bool isInliningBeneficiary(mlir::Operation* op) { - return llvm::isa(op); - } - - static bool extractBeneficiaryOps( - mlir::Operation* op, llvm::SmallVectorImpl* ops, - llvm::SetVector args) { - if (!isInliningBeneficiary(op)) { - return false; - } - - ops->push_back(op); - for (auto operand : op->getOperands()) { - // It is an existing arg, keep going. - if (args.count(operand)) { - continue; - } - mlir::Operation* definingOp = operand.getDefiningOp(); - if (!definingOp || !extractBeneficiaryOps(definingOp, ops, args)) { - return false; - } - } - return true; - } - - static void inlineOperationsIntoLaunch(mlir::gpu::LaunchOp launch) { - llvm::SetVector used_above; - mlir::getUsedValuesDefinedAbove(launch.body(), used_above); - mlir::BlockAndValueMapping inlined_map; - for (mlir::Value v : used_above) { - llvm::SmallVector ops_to_move; - mlir::Operation* definingOp = v.getDefiningOp(); - if (definingOp && - extractBeneficiaryOps(definingOp, &ops_to_move, used_above)) { - mlir::OpBuilder b(launch.body()); - for (mlir::Operation* op : llvm::reverse(ops_to_move)) { - auto result = b.clone(*op, inlined_map); - for (auto pair : llvm::zip(op->getResults(), result->getResults())) { - mlir::replaceAllUsesInRegionWith(std::get<0>(pair), - std::get<1>(pair), launch.body()); - } - inlined_map.map(op->getResults(), result->getResults()); - } - } - } - } - - void runOnFunction() override { - mlir::FuncOp fun = getFunction(); - fun.walk( - [](mlir::gpu::LaunchOp launch) { inlineOperationsIntoLaunch(launch); }); - } -}; - struct RewriteKernelSignaturePass - : mlir::PassWrapper { + : RewriteKernelSignaturePassBase { void runOnFunction() override { mlir::FuncOp func = getFunction(); mlir::ModuleOp module = func.getParentOfType(); @@ -349,15 +312,14 @@ struct RewriteKernelSignaturePass } }; -struct MapParallelLoopsPass - : public mlir::PassWrapper { +struct MapParallelLoopsPass : MapParallelLoopsPassBase { void runOnFunction() override { mlir::greedilyMapParallelSCFToGPU(getFunction().getBody()); } }; struct FuseInnerParallelLoopsPass - : public mlir::PassWrapper { + : FuseInnerParallelLoopsPassBase { void runOnFunction() override { getFunction().walk([](mlir::scf::ParallelOp op) { mlir::scf::naivelyFuseParallelOps(op.region()); @@ -366,12 +328,10 @@ struct FuseInnerParallelLoopsPass }; struct ParallelLoopCollapsingToFirstDimPass - : public mlir::PassWrapper> { - void runOnOperation() override { - mlir::Operation* module = getOperation(); - - module->walk([&](mlir::scf::ParallelOp op) { + : ParallelLoopCollapsingToFirstDimPassBase< + ParallelLoopCollapsingToFirstDimPass> { + void runOnFunction() override { + getFunction().walk([&](mlir::scf::ParallelOp op) { unsigned num_loops = op.getNumLoops(); std::vector combinedLoops; combinedLoops.reserve(num_loops); @@ -397,11 +357,6 @@ std::unique_ptr createDeadTempBufferRemovalPass() { return absl::make_unique(); } -std::unique_ptr -createMoveScalarComputationsIntoGpuLaunchPass() { - return absl::make_unique(); -} - std::unique_ptr createRewriteKernelSignaturePass() { return absl::make_unique(); } @@ -414,7 +369,7 @@ std::unique_ptr createMapParallelLoopsPass() { return absl::make_unique(); } -std::unique_ptr> +std::unique_ptr createParallelLoopCollapsingToFirstDimPass() { return absl::make_unique(); } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/passes.h b/tensorflow/compiler/xla/service/mlir_gpu/passes.h index e3840628a2e..832321387c6 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/passes.h +++ b/tensorflow/compiler/xla/service/mlir_gpu/passes.h @@ -37,10 +37,6 @@ std::unique_ptr createStoreForwardingPass(); /// that loads and stores are side-effect free (in bounds, no aliasing, etc.). std::unique_ptr createDeadTempBufferRemovalPass(); -/// Moves scalar computations to the GPULaunchOp body. -std::unique_ptr -createMoveScalarComputationsIntoGpuLaunchPass(); - /// Sorts the operands to the kernel for a deterministic order. First operands /// that are defined by function arguments, followed by operands that are /// returned from the function. This only works for simple functions without @@ -57,9 +53,12 @@ std::unique_ptr createFuseInnerParallelLoopsPass(); std::unique_ptr createMapParallelLoopsPass(); /// Collapses all loop dimension into the first one. -std::unique_ptr> +std::unique_ptr createParallelLoopCollapsingToFirstDimPass(); +#define GEN_PASS_REGISTRATION +#include "tensorflow/compiler/xla/service/mlir_gpu/passes.h.inc" + } // namespace mlir_gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/mlir_gpu/passes.td b/tensorflow/compiler/xla/service/mlir_gpu/passes.td new file mode 100644 index 00000000000..55fe15ad6ff --- /dev/null +++ b/tensorflow/compiler/xla/service/mlir_gpu/passes.td @@ -0,0 +1,91 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_PASSES_TD_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_PASSES_TD_ + +include "mlir/Pass/PassBase.td" + +def FusionOpRemoverPass : FunctionPass<"mlir-gpu-fusion-op-remover"> { + let summary = "Removes lhlo fusion ops by inlining their regions."; + let constructor = "createFusionOpRemoverPass()"; + let description = [{ + Replaces a FusionOp by the operations contained in its region. + }]; +} + +def StoreForwardingPass : FunctionPass<"mlir-gpu-store-forwarding"> { + let summary = "Limited pass to forward stores to loads."; + let constructor = "createStoreForwardingPass()"; + let description = [{ + Replaces a load that immediately follows a store to the same address with + the stored value. + }]; +} + +def DeadTempBufferRemovalPass + : FunctionPass<"mlir-gpu-dead-temp-buffer-removal"> { + let summary = "Removal of dead temp buffers."; + let constructor = "createDeadTempBufferRemovalPass()"; + let description = [{ + Removes temporary buffers that are only written to but never read from or + that are read but the read value is not used. Needs an analysis that proves + that loads and stores are side-effect free (in bounds, no aliasing, etc.). + }]; +} + +def RewriteKernelSignaturePass + : FunctionPass<"mlir-gpu-rewrite-signatures"> { + let summary = "Rewrite kernel signatures to be deterministic."; + let constructor = "createRewriteKernelSignaturePass()"; + let description = [{ + Sorts the operands to the kernel for a deterministic order. First operands + that are defined by function arguments, followed by operands that are + returned from the function. This only works for simple functions without + control flow and can be used in cases where the kernel is extracted and used + independently of the host-side code. + }]; +} + +def MapParallelLoopsPass + : FunctionPass<"mlir-gpu-map-parallel-loops"> { + let summary = "Greedily maps loops to GPU hardware dimensions."; + let constructor = "createMapParallelLoopsPass()"; + let description = [{ + Greedily maps loops to GPU hardware dimensions. + }]; +} + +def FuseInnerParallelLoopsPass + : FunctionPass<"mlir-gpu-fuse-inner-parallel-loops"> { + let summary = "Limited pass to forward stores to loads."; + let constructor = "createFuseInnerParallelLoopsPass()"; + let description = [{ + Directs parallel loop fusion to the inner loops. This cannot be done with + a passmanager alone ATM, as nested pass managers require operations to + be closed from above. + }]; +} + +def ParallelLoopCollapsingToFirstDimPass + : FunctionPass<"mlir-gpu-collapse-parallel-loops"> { + let summary = "Collaps n-dimensional loops into one-dimensional ones."; + let constructor = "createParallelLoopCollapsingToFirstDimPass()"; + let description = [{ + Collapses all loop dimension of a parallel loop into the first one. + }]; +} + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_PASSES_TD_ diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/BUILD b/tensorflow/compiler/xla/service/mlir_gpu/tests/BUILD index 850d5f5a0cf..9bd5e3350fa 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/BUILD @@ -1,3 +1,4 @@ +load("//tensorflow:tensorflow.bzl", "filegroup") load( "//tensorflow/core/platform:build_config_root.bzl", "tf_cuda_tests_tags", @@ -18,13 +19,16 @@ package_group( ) glob_lit_tests( - data = [":test_utilities"], + data = [ + ":test_utilities", + "@llvm-project//mlir:run_lit.sh", + ], default_tags = tf_cuda_tests_tags() + [ "no_pip", "config-cuda-only", "no_rocm", ], - driver = "@llvm-project//mlir:run_lit.sh", + driver = "//tensorflow/compiler/mlir:run_lit.sh", exclude = [ # TODO(b/137624192): Reenable once we can fuse reductions. "fused_reduce.hlo", diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/passes/BUILD b/tensorflow/compiler/xla/service/mlir_gpu/tests/passes/BUILD new file mode 100644 index 00000000000..b1b7de5c4e6 --- /dev/null +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/passes/BUILD @@ -0,0 +1,24 @@ +load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") + +package( + licenses = ["notice"], # Apache 2.0 +) + +glob_lit_tests( + data = [ + ":test_utilities", + "@llvm-project//mlir:run_lit.sh", + ], + driver = "//tensorflow/compiler/mlir:run_lit.sh", + test_file_exts = ["mlir"], +) + +# Bundle together all of the test utilities that are used by tests. +filegroup( + name = "test_utilities", + testonly = True, + data = [ + "//tensorflow/compiler/xla/service/mlir_gpu:xla-mlir-gpu-opt", + "@llvm-project//llvm:FileCheck", + ], +) diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/passes/dead_temp_buffer_removal.mlir b/tensorflow/compiler/xla/service/mlir_gpu/tests/passes/dead_temp_buffer_removal.mlir new file mode 100644 index 00000000000..58132f4ea45 --- /dev/null +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/passes/dead_temp_buffer_removal.mlir @@ -0,0 +1,72 @@ +// RUN: xla-mlir-gpu-opt --mlir-gpu-dead-temp-buffer-removal %s | FileCheck %s + +// CHECK-LABEL: @dead +func @dead() { + // CHECK-NOT: alloc + %0 = alloc() : memref<42xi32> + %c0 = constant 0 : i32 + %c12 = constant 12 : index + // CHECK-NOT: store + store %c0, %0[%c12] : memref<42xi32> + return +} + +// CHECK-LABEL: @dead_alloca +func @dead_alloca() { + // CHECK-NOT: alloca + %0 = alloc() : memref<42xi32> + %c0 = constant 0 : i32 + %c12 = constant 12 : index + // CHECK-NOT: store + store %c0, %0[%c12] : memref<42xi32> + return +} + +// CHECK-LABEL: @dead_load +func @dead_load() { + // CHECK-NOT: alloc + %0 = alloc() : memref<42xi32> + %c0 = constant 0 : i32 + %c12 = constant 12 : index + store %c0, %0[%c12] : memref<42xi32> + %1 = load %0[%c12] : memref<42xi32> + return +} + +// CHECK-LABEL: @used_load +func @used_load() -> i32 { + // CHECK: alloc + %0 = alloc() : memref<42xi32> + %c0 = constant 0 : i32 + %c12 = constant 12 : index + store %c0, %0[%c12] : memref<42xi32> + %1 = load %0[%c12] : memref<42xi32> + return %1 : i32 +} + +// CHECK-LABEL: @dead_subview +func @dead_subview() { + // CHECK-NOT: alloc + %0 = alloc() : memref<42xi32> + %c0 = constant 0 : i32 + %c1 = constant 1 : index + %c4 = constant 4 : index + %c12 = constant 12 : index + store %c0, %0[%c12] : memref<42xi32> + %1 = subview %0[%c12][%c4][%c1] : memref<42xi32> to memref (d0 * s1 + s0)>> + return +} + +// CHECK-LABEL: @used_subview +func @used_subview() -> i32 { + // CHECK: alloc + %0 = alloc() : memref<42xi32> + %c0 = constant 0 : i32 + %c1 = constant 1 : index + %c4 = constant 4 : index + %c12 = constant 12 : index + store %c0, %0[%c12] : memref<42xi32> + %1 = subview %0[%c12][%c4][%c1] : memref<42xi32> to memref (d0 * s1 + s0)>> + %2 = load %1[%c1] : memref (d0 * s1 + s0)>> + return %2 : i32 +} diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/passes/fusion_op_remover.mlir b/tensorflow/compiler/xla/service/mlir_gpu/tests/passes/fusion_op_remover.mlir new file mode 100644 index 00000000000..69ebbbd5a72 --- /dev/null +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/passes/fusion_op_remover.mlir @@ -0,0 +1,20 @@ +// RUN: xla-mlir-gpu-opt --mlir-gpu-fusion-op-remover %s | FileCheck %s + +// CHECK-LABEL: func @fusion_memref +func @fusion_memref(%input1: memref<10xf32>, %input2: memref<10xf32>, + %input3: memref<10xf32>, %out: memref<10xf32>) -> () { + // CHECK-NOT: lmhlo.fusion + "lmhlo.fusion"() ( { + %0 = tensor_load %input1 : memref<10xf32> + %1 = tensor_load %input2 : memref<10xf32> + %2 = "mhlo.add"(%0, %1) {name = "add"} + : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> + %3 = tensor_load %input3 : memref<10xf32> + %4 = "mhlo.multiply"(%2, %3) {name = "multiply"} + : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> + tensor_store %4, %out : memref<10xf32> + // CHECK-NOT: lmhlo.terminator + "lmhlo.terminator"() : () -> () + } ) : () -> () + return +} diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/passes/rewrite_kernel_signatures.mlir b/tensorflow/compiler/xla/service/mlir_gpu/tests/passes/rewrite_kernel_signatures.mlir new file mode 100644 index 00000000000..cff1989f05b --- /dev/null +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/passes/rewrite_kernel_signatures.mlir @@ -0,0 +1,138 @@ +// RUN: xla-mlir-gpu-opt --mlir-gpu-rewrite-signatures %s --split-input-file --verify-diagnostics | FileCheck %s + +module attributes {gpu.container_module} { + +// CHECK-LABEL: @kernel_module +gpu.module @kernel_module { + // CHECK-LABEL: gpu.func @kernel + // CHECK-SAME: %{{.*}}: memref<32xf32>, %{{.*}}: memref<16xf32>, + // CHECK-SAME: %{{.*}}: memref<8xf32> + gpu.func @kernel(%arg0: memref<8xf32>, %arg1: memref<16xf32>, + %arg2: memref<32xf32>) kernel { + gpu.return + } +} + + // CHECK-LABEL: @caller +func @caller(%arg0: memref<32xf32>, %arg1: memref<16xf32>) -> memref<8xf32> { + %cst = constant 8 : index + %res = alloc() : memref<8xf32> + + // CHECK: gpu.launch_func + // CHECK-SAME: index, memref<32xf32>, memref<16xf32>, memref<8xf32>) + "gpu.launch_func"(%cst, %cst, %cst, %cst, %cst, %cst, %res, %arg1, %arg0) + { kernel = @kernel_module::@kernel } + : (index, index, index, index, index, index, + memref<8xf32>, memref<16xf32>, memref<32xf32>) -> () + + return %res : memref<8xf32> +} + +} + +// ----- + +module attributes {gpu.container_module} { + +gpu.module @kernel_module { + // expected-error @+1 {{number of kernel arguments does not match numberof arguments and results of surrounding function}} + gpu.func @kernel(%arg0: memref<16xf32>, %arg1: memref<32xf32>) kernel { + gpu.return + } +} + +func @caller(%arg0: memref<32xf32>, %arg1: memref<16xf32>) -> memref<8xf32> { + %cst = constant 8 : index + %res = alloc() : memref<8xf32> + + "gpu.launch_func"(%cst, %cst, %cst, %cst, %cst, %cst, %arg1, %arg0) + { kernel = @kernel_module::@kernel } + : (index, index, index, index, index, index, + memref<16xf32>, memref<32xf32>) -> () + + return %res : memref<8xf32> +} + +} + +// ----- + +module attributes {gpu.container_module} { + +gpu.module @kernel_module { + // expected-error @+1 {{result 0 of containing function is not an argument to the kernel}} + gpu.func @kernel(%arg0: memref<16xf32>, %arg1: memref<32xf32>, + %arg2: memref<8xf32>) kernel { + gpu.return + } +} + +func @caller(%arg0: memref<32xf32>, %arg1: memref<16xf32>) -> memref<8xf32> { + %cst = constant 8 : index + %res = alloc() : memref<8xf32> + %fake = alloc() : memref<8xf32> + + "gpu.launch_func"(%cst, %cst, %cst, %cst, %cst, %cst, %arg1, %arg0, %fake) + { kernel = @kernel_module::@kernel } + : (index, index, index, index, index, index, + memref<16xf32>, memref<32xf32>, memref<8xf32>) -> () + + return %res : memref<8xf32> +} + +} + +// ----- + +module attributes {gpu.container_module} { + +gpu.module @kernel_module { + // expected-error @+1 {{argument 1 to containing function is not an argument to the kernel}} + gpu.func @kernel(%arg0: memref<16xf32>, %arg1: memref<32xf32>, + %arg2: memref<8xf32>) kernel { + gpu.return + } +} + +func @caller(%arg0: memref<32xf32>, %arg1: memref<16xf32>) -> memref<8xf32> { + %cst = constant 8 : index + %res = alloc() : memref<8xf32> + %fake = alloc() : memref<16xf32> + + "gpu.launch_func"(%cst, %cst, %cst, %cst, %cst, %cst, %fake, %arg0, %res) + { kernel = @kernel_module::@kernel } + : (index, index, index, index, index, index, + memref<16xf32>, memref<32xf32>, memref<8xf32>) -> () + + return %res : memref<8xf32> +} + +} + +// ----- + +module attributes {gpu.container_module} { + +gpu.module @kernel_module { + gpu.func @kernel(%arg0: memref<8xf32>, %arg1: memref<16xf32>, + %arg2: memref<32xf32>) kernel { + gpu.return + } +} + +// expected-error @+1 {{surrounding function has more than one block}} +func @caller(%arg0: memref<32xf32>, %arg1: memref<16xf32>) -> memref<8xf32> { + %cst = constant 8 : index + %res = alloc() : memref<8xf32> + br ^bb1 + + ^bb1: + "gpu.launch_func"(%cst, %cst, %cst, %cst, %cst, %cst, %res, %arg1, %arg0) + { kernel = @kernel_module::@kernel } + : (index, index, index, index, index, index, + memref<8xf32>, memref<16xf32>, memref<32xf32>) -> () + + return %res : memref<8xf32> +} + +} diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/passes/store_forwarding_pass.mlir b/tensorflow/compiler/xla/service/mlir_gpu/tests/passes/store_forwarding_pass.mlir new file mode 100644 index 00000000000..8b993bb56a5 --- /dev/null +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/passes/store_forwarding_pass.mlir @@ -0,0 +1,72 @@ +// RUN: xla-mlir-gpu-opt --mlir-gpu-store-forwarding %s | FileCheck %s + +// CHECK-LABEL: @forward +func @forward() -> f32 { + %0 = alloc() : memref<1024xf32> + %c42 = constant 24 : index + // CHECK: %[[CST:.*]] = constant 1.0 + %c1 = constant 1.0 : f32 + store %c1, %0[%c42] : memref<1024xf32> + // CHECK-NOT: load + %1 = load %0[%c42] : memref<1024xf32> + // CHECK: return %[[CST]] + return %1 : f32 +} + +// CHECK-LABEL: @forward_alloca +func @forward_alloca() -> f32 { + %0 = alloca() : memref<1024xf32> + %c42 = constant 24 : index + // CHECK: %[[CST:.*]] = constant 1.0 + %c1 = constant 1.0 : f32 + store %c1, %0[%c42] : memref<1024xf32> + // CHECK-NOT: load + %1 = load %0[%c42] : memref<1024xf32> + // CHECK: return %[[CST]] + return %1 : f32 +} + +// CHECK-LABEL: @wrong_index +func @wrong_index() -> f32 { + %0 = alloc() : memref<1024xf32> + %c42 = constant 24 : index + %c12 = constant 12 : index + %c1 = constant 1.0 : f32 + store %c1, %0[%c42] : memref<1024xf32> + // CHECK: %[[RES:.*]] = load + %1 = load %0[%c12] : memref<1024xf32> + // CHECK: return %[[RES]] + return %1 : f32 +} + +// CHECK-LABEL: @wrong_memref +func @wrong_memref() -> f32 { + %0 = alloc() : memref<1024xf32> + %1 = alloc() : memref<1024xf32> + %c42 = constant 24 : index + %c1 = constant 1.0 : f32 + store %c1, %0[%c42] : memref<1024xf32> + // CHECK: %[[RES:.*]] = load + %2 = load %1[%c42] : memref<1024xf32> + // CHECK: return %[[RES]] + return %2 : f32 +} + +// CHECK-LABEL: @with_parallel_loop +func @with_parallel_loop() { + %0 = alloc() : memref<1024xf32> + %c0 = constant 0 : index + %c42 = constant 24 : index + %c1 = constant 1 : index + // CHECK: %[[CST:.*]] = constant 1.100000e+01 : f32 + %c11 = constant 1.100000e+01 : f32 + store %c11, %0[%c42] : memref<1024xf32> + // CHECK: scf.parallel + scf.parallel (%i0) = (%c0) to (%c42) step (%c1) { + // CHECK-NOT: load + %1 = load %0[%c42] : memref<1024xf32> + // CHECK-NEXT: store %[[CST]] + store %1, %0[%c0] : memref<1024xf32> + } + return +} diff --git a/tensorflow/compiler/xla/service/mlir_gpu/xla_mlir_gpu_opt.cc b/tensorflow/compiler/xla/service/mlir_gpu/xla_mlir_gpu_opt.cc new file mode 100644 index 00000000000..cbda9a30a07 --- /dev/null +++ b/tensorflow/compiler/xla/service/mlir_gpu/xla_mlir_gpu_opt.cc @@ -0,0 +1,35 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/InitAllDialects.h" // from @llvm-project +#include "mlir/InitAllPasses.h" // from @llvm-project +#include "mlir/Support/MlirOptMain.h" // from @llvm-project +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/register.h" +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/register_passes.h" +#include "tensorflow/compiler/xla/service/mlir_gpu/passes.h" + +int main(int argc, char **argv) { + mlir::registerAllPasses(); + mlir::mhlo::registerAllMhloPasses(); + mlir::lmhlo::registerAllLmhloPasses(); + xla::mlir_gpu::registerXlaMlirGpuPasses(); + + mlir::DialectRegistry registry; + mlir::registerAllDialects(registry); + mlir::mhlo::registerAllMhloDialects(registry); + + return failed(mlir::MlirOptMain( + argc, argv, "XLA mlir gpu backend pass driver\n", registry)); +} diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.cc b/tensorflow/compiler/xla/service/multi_output_fusion.cc index a21cec538d1..c5c2d081686 100644 --- a/tensorflow/compiler/xla/service/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/multi_output_fusion.cc @@ -17,6 +17,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/debug_options_flags.h" +#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -338,6 +339,21 @@ bool MultiOutputFusion::LegalToFuseMainConstraints(HloInstruction* instr1, if (!ShapesCompatibleForFusion(instr1, instr2)) { return false; } + + // If both nodes are in-place operations and they use a common in-place + // operand, we can't fuse these two. + for (const auto& operand_and_output_index1 : + HloDataflowAnalysis::GetInPlaceInputOutputPairs(instr1)) { + const HloInstruction* operand = + instr1->operand(operand_and_output_index1.first.operand_number); + for (const auto& operand_and_output_index2 : + HloDataflowAnalysis::GetInPlaceInputOutputPairs(instr2)) { + if (operand == + instr2->operand(operand_and_output_index2.first.operand_number)) { + return false; + } + } + } return true; } diff --git a/tensorflow/compiler/xla/service/qr_expander.cc b/tensorflow/compiler/xla/service/qr_expander.cc new file mode 100644 index 00000000000..d1b1526ed30 --- /dev/null +++ b/tensorflow/compiler/xla/service/qr_expander.cc @@ -0,0 +1,466 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/qr_expander.h" + +#include +#include + +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/loops.h" +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" +#include "tensorflow/compiler/xla/client/lib/slicing.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace xla { + +namespace { + +std::vector ConcatVectors(absl::Span xs, + absl::Span ys) { + std::vector output; + output.reserve(xs.size() + ys.size()); + std::copy(xs.begin(), xs.end(), std::back_inserter(output)); + std::copy(ys.begin(), ys.end(), std::back_inserter(output)); + return output; +} + +// Computes a Householder reflection of the form: +// H = I - tau v v.T. +// such that +// H . ( x1 ) = ( x1 ) +// ( x2 ) = ( x2 ) +// ( ... ) = ( ... ) +// ( xk ) = ( beta ) +// ( ... ) ( 0 ) +// ( ... ) ( 0 ) +// Unlike the usual formulation, we allow the caller to supply 'k' rather than +// only providing the relevant part of 'x' to maintain XLA's static shape +// invariant. In addition, the implementation supports batching. +// Pseudo-code, without batching: +// alpha = x[k] +// x_copy = np.copy(x) +// x_copy[:k+1] = 0 +// xnorm = norm2(x_copy) +// if xnorm == 0 and np.imag(alpha) == 0: +// beta = alpha +// tau = 0 +// v = np.zeros_like(x) +// else: +// beta = -np.sign(np.real(alpha)) * np.sqrt(alpha * np.conj(alpha) + xnorm) +// if np.issubdtype(x.dtype, np.complexfloating): +// tau = (beta - alpha) / beta +// else: +// tau = (beta - np.real(alpha) / beta) + (-np.imag(alpha) / beta) * 1j +// v = x / (alpha - beta) +// v[k] = 1 +// return (v, tau, beta) +// TODO(phawkins): LAPACK's xLARFG implementation has code for handling +// overflows in the norm/beta calculations. Perhaps do the same here. +Status House(XlaOp x, XlaOp k, absl::Span batch_dims, + const int64 m, XlaOp* v, XlaOp* tau, XlaOp* beta) { + XlaBuilder* const builder = x.builder(); + TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x)); + const PrimitiveType type = x_shape.element_type(); + + std::vector batch_dim_ids(batch_dims.size()); + std::iota(batch_dim_ids.begin(), batch_dim_ids.end(), 0); + const int64 minor_dim = batch_dims.size(); + + XlaOp zero = ScalarLike(x, 0.0); + + // alpha = x[k] + XlaOp alpha = Reshape(DynamicSliceInMinorDims(x, {k}, {1}), batch_dims); + + // Compute x[k+1:] (padded with zeros in elements 0..k) + XlaOp iota = Iota(builder, S32, m); + XlaOp x_after_k = Mul(x, ConvertElementType(Gt(iota, k), type), + /*broadcast_dimensions=*/{minor_dim}); + + XlaOp sigma_is_zero; + if (primitive_util::IsComplexType(type)) { + // sigma = np.dot(x[k+1:], np.conj(x[k+1:])) + // TODO(phawkins): this calculation may be numerically unstable. + auto x_squared = Real(x_after_k * Conj(x_after_k)); + auto sigma = + Reduce(x_squared, ScalarLike(x_squared, 0.0), + CreateScalarAddComputation( + primitive_util::ComplexComponentType(type), builder), + {minor_dim}); + // mu = np.sqrt(x[k]*np.con(x[k]) + sigma) + auto mu = Sqrt(Real(alpha * Conj(alpha)) + sigma); + + sigma_is_zero = Eq(sigma, ScalarLike(sigma, 0)); + sigma_is_zero = And(sigma_is_zero, Eq(Imag(alpha), ScalarLike(sigma, 0))); + + *beta = Select(Lt(Real(alpha), ScalarLike(sigma, 0)), ScalarLike(mu, 1), + ScalarLike(mu, -1)) * + mu; + *beta = Select(sigma_is_zero, Real(alpha), *beta); + *tau = Complex((*beta - Real(alpha)) / *beta, -Imag(alpha) / *beta); + } else { + // sigma = np.dot(x[k+1:], x[k+1:]) + // TODO(phawkins): this calculation may be numerically unstable. + auto sigma = Reduce(x_after_k * x_after_k, zero, + CreateScalarAddComputation(type, builder), {minor_dim}); + // mu = np.sqrt(x[k]*x[k] + sigma) + auto mu = Sqrt(Square(alpha) + sigma); + sigma_is_zero = Eq(sigma, zero); + + XlaOp one = ScalarLike(x, 1.0); + *beta = Select(Lt(alpha, zero), one, -one) * mu; + *beta = Select(sigma_is_zero, alpha, *beta); + *tau = (*beta - alpha) / *beta; + } + *tau = Select(sigma_is_zero, ZerosLike(*tau), *tau); + + auto divisor = + Select(sigma_is_zero, Broadcast(ScalarLike(alpha, 1), batch_dims), + alpha - ConvertElementType(*beta, type)); + + auto e_k = Broadcast(ConvertElementType(Eq(iota, k), type), + std::vector(batch_dims.size(), 1)); + + // Form v as [0, 0, ..., 1] ++ x[k+1:] / divisor + // If sigma is zero, x[k+1:] is zero, so use any non-zero divisor. + *v = e_k + Div(x_after_k, divisor, /*broadcast_dimensions=*/batch_dim_ids); + return Status::OK(); +} + +} // namespace + +// Householder QR decomposition. Algorithm 5.2.1 from Golub and Van +// Loan "Matrix Computations", 4th Edition. This is an unblocked implementation +// used as an inner routine of the blocked implementation. +// Algorithm is adapted slightly so the shapes inside the loop are static, at +// the cost of some redundant computation. Since this is used as an inner block +// kernel, accumulates the Householder transformations (vs, taus) rather than +// the matrix q. +// Equivalent Python code, without batching: +// def qr(a): +// m = a.shape[0] +// n = a.shape[1] +// taus = np.zeros([n]) +// for j in xrange(min(m, n)): +// v, tau, beta = house(a[:, j], j) +// a[:, j+1:] -= np.conj(tau) * np.dot(v[:, np.newaxis], +// np.dot(np.conj(v[np.newaxis, :]), a[:, j+1:])) +// # Form column j explicitly rather than relying on the precision of the +// # Householder update. +// a[j, j] = beta +// a[j+1:, j] = v[j+1:] +// taus[j] = tau +// return (a, taus) +StatusOr QrExpander::QrBlock( + XlaOp a, PrecisionConfig::Precision precision) { + XlaBuilder* builder = a.builder(); + TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); + const int num_dims = a_shape.rank(); + if (num_dims < 2) { + return InvalidArgument("Argument to QR must have rank >= 2; got shape %s", + a_shape.ToString()); + } + PrimitiveType type = a_shape.element_type(); + + const int64 m = ShapeUtil::GetDimension(a_shape, -2); + const int64 n = ShapeUtil::GetDimension(a_shape, -1); + + const int64 num_batch_dims = num_dims - 2; + std::vector batch_dims(num_batch_dims); + for (int i = 0; i < num_batch_dims; ++i) { + batch_dims[i] = ShapeUtil::GetDimension(a_shape, i); + } + + std::vector batch_dim_indices(num_batch_dims); + std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0); + + auto qr_body_fn = [&](XlaOp j, absl::Span values, + XlaBuilder* builder) -> StatusOr> { + auto a = values[0]; + auto taus = values[1]; + + // v, tau, beta = house(a[:, j], j) + auto x = DynamicSliceInMinorDims(a, {j}, {1}); + XlaOp v, tau, beta; + TF_RETURN_IF_ERROR(House(Collapse(x, {num_dims - 2, num_dims - 1}), j, + batch_dims, m, &v, &tau, &beta)); + + const int64 minor_dim = batch_dims.size(); + auto iota_mn = Iota( + builder, ShapeUtil::MakeShape(S32, ConcatVectors(batch_dims, {m, n})), + minor_dim + 1); + + std::vector shape = batch_dims; + shape.push_back(1); + shape.push_back(m); + auto v_broadcast = Reshape(v, shape); + // a[:, j+1:] -= np.conj(tau) * (v[:, np.newaxis] @ + // (np.conj(v[np.newaxis, :]) @ a[:, j+1:])) + // We use masking rather than a loop-variant shape to handle the j+1: + // indexing. + auto vva = BatchDot(MaybeConjugate(v_broadcast, true), + Select(Lt(j, iota_mn), a, ZerosLike(a)), precision); + vva = BatchDot(v_broadcast, true, vva, false, precision); + a = a - Mul(MaybeConjugate(tau, true), vva, + /*broadcast_dimensions=*/batch_dim_indices); + + // a[j, j] = beta + // a[j+1:,j] = v[j+1:] + auto iota = Reshape(Iota(a.builder(), S32, m), {m, 1}); + auto predecessor_mask = ConvertElementType(Lt(iota, j), type); + auto mask = Broadcast(ConvertElementType(Eq(iota, j), type), + std::vector(batch_dims.size(), 1)); + auto successor_mask = Gt(Iota(a.builder(), S32, m), j); + auto new_x = Mul(x, predecessor_mask, + /*broadcast_dimensions=*/{num_dims - 2, num_dims - 1}) + + Mul(ConvertElementType(beta, type), mask, + /*broadcast_dimensions=*/batch_dim_indices); + new_x = Add( + new_x, Select(Broadcast(successor_mask, batch_dims), v, ZerosLike(v)), + /*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {minor_dim})); + // Update a[:,j] + std::vector dim_ids(num_dims); + std::iota(dim_ids.begin(), dim_ids.end(), 0); + new_x = BroadcastInDim(new_x, ConcatVectors(batch_dims, {m, n}), + /*broadcast_dimensions=*/dim_ids); + a = Select(Eq(iota_mn, j), new_x, a); + + // taus[j] = tau + std::vector tau_broadcast_dims(batch_dims.size()); + std::iota(tau_broadcast_dims.begin(), tau_broadcast_dims.end(), 0); + + auto iota_n = + Iota(builder, ShapeUtil::MakeShape(S32, ConcatVectors(batch_dims, {n})), + minor_dim); + auto taus_zeros = ZerosLike(taus); + auto taus_update = Select( + Eq(iota_n, j), + Add(taus_zeros, tau, /*broadcast_dimensions=*/tau_broadcast_dims), + taus_zeros); + taus = taus + taus_update; + return std::vector{a, taus}; + }; + + auto taus = Zeros( + builder, + ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {std::min(m, n)}))); + + TF_ASSIGN_OR_RETURN(auto values, ForEachIndex(std::min(m, n), S32, qr_body_fn, + {a, taus}, "qr", builder)); + + QrResult result; + result.a = values[0]; + result.taus = values[1]; + return result; +} + +// Computes an upper triangular matrix T such that (I - Y @ T @ Y^t) is a +// product of the elementary Householder reflectors given by `vs` and `taus`. +// +// Schreiber, Robert, and Charles Van Loan. "A storage-efficient WY +// representation for products of Householder transformations." SIAM Journal on +// Scientific and Statistical Computing 10.1 (1989): 53-57. +// +// def compact_wy(vs, taus): +// m, n = vs.shape[-2:] +// t = np.eye(n) * -taus +// # We premultiply Y.T @ vs, since we would prefer to compute a single matrix +// # multiplication to many matrix-vector products. +// vtv = -taus[None, :] * np.triu(np.conj(vs.T) @ vs, 1) + np.eye(n) +// for i in range(1, n): +// t[:, i] = scipy.linalg.blas.strmm(t, vtv[:, i]) +// return t +StatusOr QrExpander::CompactWYRepresentation( + PrimitiveType type, absl::Span batch_dims, XlaOp vs, + XlaOp taus, int64 m, int64 n, PrecisionConfig::Precision precision) { + XlaBuilder* builder = vs.builder(); + + std::vector batch_dim_indices(batch_dims.size()); + std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0); + int64 n_index = batch_dims.size() + 1; + + auto body_fn = [&](XlaOp j, absl::Span values, + XlaBuilder* builder) -> StatusOr> { + // w has shape [..., m, n] + auto t = values[0]; + const auto vtv = values[1]; + + // yv has shape [..., n, 1] + auto yv = DynamicSliceInMinorDims(vtv, {j}, {1}); + + // z has shape [..., n, 1] + auto z = BatchDot(t, yv, precision); + + t = DynamicUpdateSliceInMinorDims(t, z, {j}); + + return std::vector{t, vtv}; + }; + + auto tau_scale = BroadcastInDim(-taus, ConcatVectors(batch_dims, {1, n}), + ConcatVectors(batch_dim_indices, {n_index})); + + auto eye = Broadcast(IdentityMatrix(builder, type, n, n), batch_dims); + auto t = eye; + + auto vtv = BatchDot(MaybeConjugate(vs, true), /*transpose_x=*/true, vs, + /*transpose_y=*/false, precision); + vtv = Select(TriangleMask(vtv, 0), ZerosLike(vtv), vtv); + vtv = (vtv + eye) * tau_scale; + + TF_ASSIGN_OR_RETURN(auto values, + ForEachIndex(n, S32, body_fn, {t, vtv}, "wy", builder)); + return values[0]; +} + +// Block Householder QR Factorization. Algorithm 5.2.2 of Golub and van Loan. +// def qr_blocked(a, block_size): +// m = a.shape[0] +// n = a.shape[1] +// q = np.eye(m) +// for i in xrange(0, min(m, n), block_size): +// k = min(block_size, min(m, n) - s) +// (a, taus) = qr(a[i:, i:i+k]) +// y = np.eye(m, n) + np.tril(a, -1) +// t = CompactWYRepresentation(vs, taus, m-i, k) +// a[i:, i+k:] += (y @ np.conj(t.T)) @ (np.conj(y.T) @ a[i:, i+k:]) +// q[:, i:] += (q[:, i:] @ y) @ np.conj((y @ np.conj(t.T)).T) +// return (q, a) +StatusOr QrExpander::BuildQrDecomposition( + XlaOp a, int64 block_size, PrecisionConfig::Precision precision) { + XlaBuilder* builder = a.builder(); + TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); + const int num_dims = a_shape.rank(); + if (num_dims < 2) { + return InvalidArgument("Arguments to QR must have rank >= 2: got shape %s", + a_shape.ToString()); + } + PrimitiveType type = a_shape.element_type(); + + const int64 m = ShapeUtil::GetDimension(a_shape, -2); + const int64 n = ShapeUtil::GetDimension(a_shape, -1); + const int64 p = std::min(m, n); + + if (block_size < 1) { + return InvalidArgument("block_size argument to QR must be >= 1; got %d", + block_size); + } + + const int64 num_batch_dims = num_dims - 2; + std::vector batch_dims(num_batch_dims); + for (int i = 0; i < num_batch_dims; ++i) { + batch_dims[i] = ShapeUtil::GetDimension(a_shape, i); + } + + auto q = Broadcast(IdentityMatrix(builder, type, m, m), batch_dims); + for (int64 i = 0; i < p; i += block_size) { + int64 k = std::min(block_size, p - i); + + auto a_block = SliceInMinorDims(a, {i, i}, {m, i + k}); + TF_ASSIGN_OR_RETURN(auto qr_block, QrBlock(a_block, precision)); + auto y = Add( + IdentityMatrix(builder, type, m - i, k), + Select(TriangleMask(qr_block.a, -1), qr_block.a, ZerosLike(qr_block.a)), + /*broadcast_dimensions=*/{num_dims - 2, num_dims - 1}); + + a = UpdateSliceInMinorDims(a, qr_block.a, {i, i}); + + // Compute the I + Y @ T @ Y^t block representation of a product of + // Householder matrices. + TF_ASSIGN_OR_RETURN( + auto t, CompactWYRepresentation(type, batch_dims, y, qr_block.taus, + m - i, k, precision)); + + // a[i:, i+k:] += (y @ np.conj(t.T)) @ (np.conj(y.T) @ a[i:, i+k:]) + auto yt = BatchDot(y, /*transpose_x=*/false, MaybeConjugate(t, true), + /*transpose_y=*/true, precision); + auto a_panel = SliceInMinorDims(a, {i, i + k}, {m, n}); + auto a_update = + BatchDot(MaybeConjugate(y, true), /*transpose_x=*/true, a_panel, + /*transpose_y=*/false, precision); + a_update = BatchDot(yt, a_update, precision); + a_panel = a_panel + a_update; + a = UpdateSliceInMinorDims(a, a_panel, {i, i + k}); + + // q[:, i:] += (q[:, i:] @ y) @ np.conj((y @ np.conj(t.T)).T) + auto q_panel = SliceInMinorDims(q, {0, i}, {m, m}); + auto q_update = BatchDot(q_panel, y, precision); + q_update = + BatchDot(q_update, /*transpose_x=*/false, MaybeConjugate(yt, true), + /*transpose_y=*/true, precision); + q_panel = q_panel + q_update; + q = UpdateSliceInMinorDims(q, q_panel, {0, i}); + } + + return Tuple(builder, {q, UpperTriangle(a)}); +} + +bool QrExpander::InstructionMatchesPattern(HloInstruction* instruction) { + return instruction->opcode() == HloOpcode::kCustomCall && + instruction->custom_call_target() == "QrDecomposition"; +} + +StatusOr QrExpander::ExpandInstruction( + HloInstruction* instruction) { + const string name = + absl::StrFormat("xla.qr_%s", instruction->operand(0)->shape().ToString()); + + HloModule* module = instruction->parent()->parent(); + + HloComputation*& computation = + computation_cache_.emplace(name, nullptr).first->second; + if (!computation) { + // Builds a new expansion. + // + // TODO(b/62327888): We do something unusual here: we build the computation + // using the XlaBuilder API, which is nominally an XLA client API. We do + // this because the external APIs for building complicated computations + // (XlaBuilder) are much more ergonomic than the internal ones. As it turns + // out, XlaBuilder isn't really a client API—what it does is build a + // HloModuleProto protocol buffer, that we can then deserialize and clone + // into our HloModule. Ideally we would avoid the protocol buffer step; + // that is left as an exercise for future work. + XlaBuilder builder(name); + XlaOp a = Parameter(&builder, 0, instruction->operand(0)->shape(), "a"); + TF_ASSIGN_OR_RETURN( + XlaOp l, BuildQrDecomposition(a, + /*block_size=*/128, + /*precision=*/PrecisionConfig::HIGHEST)); + + TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build(l)); + + TF_ASSIGN_OR_RETURN(ProgramShape program_shape, + xla_computation.GetProgramShape()); + HloModuleConfig config(program_shape); + TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto( + xla_computation.proto(), config)); + HloCloneContext context(module); + computation = + module->DeepCloneComputation(new_module->entry_computation(), &context); + } + + return instruction->parent()->AddInstruction(HloInstruction::CreateCall( + instruction->shape(), instruction->operands(), computation)); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/qr_expander.h b/tensorflow/compiler/xla/service/qr_expander.h new file mode 100644 index 00000000000..669ace39efb --- /dev/null +++ b/tensorflow/compiler/xla/service/qr_expander.h @@ -0,0 +1,61 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_QR_EXPANDER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_QR_EXPANDER_H_ + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/service/op_expander_pass.h" + +namespace xla { + +class QrExpander : public OpExpanderPass { + public: + absl::string_view name() const override { return "qr_expander"; } + + protected: + bool InstructionMatchesPattern(HloInstruction* instruction) override; + + StatusOr ExpandInstruction( + HloInstruction* instruction) override; + + struct QrResult { + // The upper-triangular matrix R, packed together with the lower-triangular + // elementary Householder reflectors `vs` below the diagonal. + XlaOp a; + + // Representation of the Householder matrices I - beta v v.T + XlaOp taus; // Shape: [..., min(m, n)] + }; + + virtual StatusOr QrBlock(XlaOp a, + PrecisionConfig::Precision precision); + + virtual StatusOr CompactWYRepresentation( + PrimitiveType type, absl::Span batch_dims, XlaOp vs, + XlaOp taus, int64 m, int64 n, PrecisionConfig::Precision precision); + + private: + StatusOr BuildQrDecomposition(XlaOp a, int64 block_size, + PrecisionConfig::Precision precision); + + // Mapping from op signatures to existing computations. + absl::flat_hash_map computation_cache_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_QR_EXPANDER_H_ diff --git a/tensorflow/compiler/xla/service/shaped_buffer.cc b/tensorflow/compiler/xla/service/shaped_buffer.cc index 473a9ca7456..67c7896cebd 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.cc +++ b/tensorflow/compiler/xla/service/shaped_buffer.cc @@ -31,13 +31,18 @@ limitations under the License. namespace xla { -ShapedBuffer::ShapedBuffer(Shape on_host_shape, Shape on_device_shape, - const se::Platform* platform, int device_ordinal) - : on_host_shape_(std::move(on_host_shape)), - on_device_shape_(std::move(on_device_shape)), +ShapedBuffer::ShapedBuffer(Shape on_device_shape, const se::Platform* platform, + int device_ordinal) + : on_device_shape_(std::move(on_device_shape)), platform_(platform), device_ordinal_(device_ordinal), - buffers_(&on_device_shape_) {} + buffers_(&on_device_shape_) { + on_host_shape_ = ShapeUtil::DeviceShapeToHostShape(on_device_shape_); +} + +ShapedBuffer::ShapedBuffer(Shape on_host_shape, Shape on_device_shape, + const se::Platform* platform, int device_ordinal) + : ShapedBuffer(on_device_shape, platform, device_ordinal) {} ShapedBuffer::ShapedBuffer(ShapedBuffer&& s) : on_host_shape_(std::move(s.on_host_shape_)), @@ -52,8 +57,8 @@ ShapedBuffer::ShapedBuffer(ShapedBuffer&& s) } ShapedBuffer& ShapedBuffer::operator=(ShapedBuffer&& s) { - on_host_shape_ = std::move(s.on_host_shape_); on_device_shape_ = std::move(s.on_device_shape_); + on_host_shape_ = std::move(s.on_host_shape_); platform_ = s.platform_; device_ordinal_ = s.device_ordinal_; buffers_ = std::move(s.buffers_); @@ -68,12 +73,9 @@ ShapedBuffer::~ShapedBuffer() {} StatusOr ShapedBuffer::SubShapedBuffer( const ShapeIndex& index) const { - TF_ASSIGN_OR_RETURN(const Shape* host_sub_shape, - ShapeUtil::TryGetSubshape(on_host_shape(), index)); TF_ASSIGN_OR_RETURN(const Shape* device_sub_shape, ShapeUtil::TryGetSubshape(on_device_shape(), index)); - ShapedBuffer sub_shaped_buffer(*host_sub_shape, *device_sub_shape, platform_, - device_ordinal_); + ShapedBuffer sub_shaped_buffer(*device_sub_shape, platform_, device_ordinal_); TF_ASSIGN_OR_RETURN(ShapeTree sub_buffers, buffers_.SubShapeTree(index)); sub_shaped_buffer.set_buffers(std::move(sub_buffers)); @@ -88,12 +90,11 @@ void ShapedBuffer::clear() { } string ShapedBuffer::ToString() const { - string s = absl::StrCat( - "ShapedBuffer(", platform_->Name(), ":", device_ordinal(), - "), on-host shape=" + ShapeUtil::HumanStringWithLayout(on_host_shape()), - ", on-device shape=" + - ShapeUtil::HumanStringWithLayout(on_device_shape()), - ":\n"); + string s = + absl::StrCat("ShapedBuffer(", platform_->Name(), ":", device_ordinal(), + "), on-device shape=" + + ShapeUtil::HumanStringWithLayout(on_device_shape()), + ":\n"); ShapeUtil::ForEachSubshape( on_device_shape(), [this, &s](const Shape& subshape, const ShapeIndex& index) { @@ -116,13 +117,19 @@ std::ostream& operator<<(std::ostream& out, const ShapedBuffer& buffer) { return out; } +ScopedShapedBuffer::ScopedShapedBuffer(Shape on_device_shape, + se::DeviceMemoryAllocator* allocator, + int device_ordinal) + : ShapedBuffer(std::move(on_device_shape), allocator->platform(), + device_ordinal), + allocator_(allocator) {} + ScopedShapedBuffer::ScopedShapedBuffer(Shape on_host_shape, Shape on_device_shape, se::DeviceMemoryAllocator* allocator, int device_ordinal) - : ShapedBuffer(std::move(on_host_shape), std::move(on_device_shape), - allocator->platform(), device_ordinal), - allocator_(allocator) {} + : ScopedShapedBuffer(std::move(on_device_shape), allocator, + device_ordinal) {} ScopedShapedBuffer::ScopedShapedBuffer(ShapedBuffer shaped_buffer, se::DeviceMemoryAllocator* allocator) @@ -171,13 +178,11 @@ void ScopedShapedBuffer::Deallocate() { } ScopedShapedBuffer ScopedShapedBuffer::TakeSubTree(ShapeIndexView index) { - const xla::Shape& sub_on_host_shape = - xla::ShapeUtil::GetSubshape(on_host_shape(), {index}); const xla::Shape& sub_on_device_shape = xla::ShapeUtil::GetSubshape(on_device_shape(), {index}); - ScopedShapedBuffer output(sub_on_host_shape, sub_on_device_shape, - memory_allocator(), device_ordinal()); + ScopedShapedBuffer output(sub_on_device_shape, memory_allocator(), + device_ordinal()); auto src_it = buffers().find(index); auto dst_it = output.buffers().begin(); while (dst_it != output.buffers().end()) { diff --git a/tensorflow/compiler/xla/service/shaped_buffer.h b/tensorflow/compiler/xla/service/shaped_buffer.h index 995b0ece7cd..7f1248998a6 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.h +++ b/tensorflow/compiler/xla/service/shaped_buffer.h @@ -43,6 +43,10 @@ class ShapedBuffer { // both the on-host and on-device shape are required. The on-device shape // determines the number of device allocations (DeviceMemoryBase) held by the // ShapedBuffer. + ShapedBuffer(Shape on_device_shape, const se::Platform* platform, + int device_ordinal); + + // TODO(b/170310047): remove this overload. ShapedBuffer(Shape on_host_shape, Shape on_device_shape, const se::Platform* platform, int device_ordinal); @@ -97,14 +101,18 @@ class ShapedBuffer { // Reset the shape of this shaped buffer and underlying buffer structure. // // Precondition: EqualStructure(this->on_device_shape_, on_device_shape). - void set_shapes(const Shape& on_host_shape, const Shape& on_device_shape) { + void set_shapes(const Shape& on_device_shape) { CHECK(ShapeUtil::EqualStructure(on_device_shape, on_device_shape_)) << "Structures are not the same. new: " << on_device_shape << ", old: " << on_device_shape_; - on_host_shape_ = on_host_shape; + on_host_shape_ = ShapeUtil::DeviceShapeToHostShape(on_device_shape); on_device_shape_ = on_device_shape; buffers_.replace_shape_ptr(&on_device_shape_); } + // TODO(b/170310047): remove this overload. + void set_shapes(const Shape& on_host_shape, const Shape& on_device_shape) { + set_shapes(on_device_shape); + } // Returns the underlying ShapeTree containing all the device addresses in the // ShapedBuffer. @@ -119,7 +127,6 @@ class ShapedBuffer { string ToString() const; protected: - // The shape of the data when represented on the host. Shape on_host_shape_; // The shape of the data on the device. @@ -148,6 +155,10 @@ std::ostream& operator<<(std::ostream& out, const ShapedBuffer& buffer); class ScopedShapedBuffer : public ShapedBuffer { public: // Creates a ScopedShapedBuffer with null DeviceMemoryBases at each index. + explicit ScopedShapedBuffer(Shape on_device_shape, + se::DeviceMemoryAllocator* allocator, + int device_ordinal); + // TODO(b/170310047): remove this overload. explicit ScopedShapedBuffer(Shape on_host_shape, Shape on_device_shape, se::DeviceMemoryAllocator* allocator, int device_ordinal); diff --git a/tensorflow/compiler/xla/service/shaped_buffer_test.cc b/tensorflow/compiler/xla/service/shaped_buffer_test.cc index a2c208d62e4..49751d10c5a 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer_test.cc +++ b/tensorflow/compiler/xla/service/shaped_buffer_test.cc @@ -97,12 +97,12 @@ class TestAllocator : public se::DeviceMemoryAllocator { TEST(ScopedShapedBufferTest, TestMoveAssignmentOperator) { Shape s = ShapeUtil::MakeShape(F32, {1}); TestAllocator allocator; - ScopedShapedBuffer sb1(s, s, &allocator, /*device_ordinal=*/0); + ScopedShapedBuffer sb1(s, &allocator, /*device_ordinal=*/0); sb1.set_buffer( allocator.Allocate(/*device_ordinal=*/0, /*size=*/42).ValueOrDie(), /*index=*/{}); - ScopedShapedBuffer sb2(s, s, &allocator, /*device_ordinal=*/1); + ScopedShapedBuffer sb2(s, &allocator, /*device_ordinal=*/1); sb2.set_buffer( allocator.Allocate(/*device_ordinal=*/1, /*size=*/10).ValueOrDie(), /*index=*/{}); @@ -119,7 +119,7 @@ TEST(ScopedShapedBufferTest, TestTakeSubTree) { s = xla::ShapeUtil::MakeTupleShape(std::vector(2, s)); s = xla::ShapeUtil::MakeTupleShape(std::vector(3, s)); - ScopedShapedBuffer sb(s, s, &allocator, /*device_ordinal=*/0); + ScopedShapedBuffer sb(s, &allocator, /*device_ordinal=*/0); sb.buffers().ForEachMutableElement( [&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) { TF_ASSERT_OK_AND_ASSIGN( @@ -156,8 +156,7 @@ TEST(ScopedShapedBufferTest, TestSubShapeTree) { Shape tuple_shape = xla::ShapeUtil::MakeTupleShape({array_shape, array_shape}); TestAllocator allocator; - ScopedShapedBuffer sb(tuple_shape, tuple_shape, &allocator, - /*device_ordinal=*/0); + ScopedShapedBuffer sb(tuple_shape, &allocator, /*device_ordinal=*/0); sb.buffers().ForEachMutableElement( [&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) { TF_ASSERT_OK_AND_ASSIGN( @@ -182,7 +181,7 @@ void BM_TakeSubTree(int iters, int depth, int fan_out) { std::vector shapes(fan_out, shape); shape = xla::ShapeUtil::MakeTupleShape(shapes); } - xla::ScopedShapedBuffer shaped_buffer(shape, shape, /*allocator=*/&allocator, + xla::ScopedShapedBuffer shaped_buffer(shape, /*allocator=*/&allocator, /*device_ordinal=*/0); tensorflow::testing::StartTiming(); for (int i = 0; i < iters; ++i) { diff --git a/tensorflow/compiler/xla/service/sharding_propagation.cc b/tensorflow/compiler/xla/service/sharding_propagation.cc index 7136ce82e25..6524973a08e 100644 --- a/tensorflow/compiler/xla/service/sharding_propagation.cc +++ b/tensorflow/compiler/xla/service/sharding_propagation.cc @@ -246,9 +246,31 @@ bool MaybeImproveInstructionSharding(HloSharding sharding, instruction->set_sharding(std::move(sharding)); return true; } - auto merged = MergeSharding(instruction->sharding(), &sharding, - may_combine_partial_sharding); - if (merged) { + int64 sharding_tiles = sharding.NumTiles(); + if (MergeSharding(instruction->sharding(), &sharding, + may_combine_partial_sharding)) { + // Override existing tiled sharding only when the new sharding is compatible + // with the existing one. This avoids unexpected resharding when `sharding` + // just has more tiles than existing sharding but they are not mergeable. + if (instruction->shape().IsArray() && + !instruction->sharding().IsTileMaximal() && + sharding.NumTiles() == sharding_tiles) { + std::vector diff_dims; + for (int64 i = 0; i < instruction->shape().rank(); ++i) { + if (instruction->sharding().tile_assignment().dim(i) == + sharding.tile_assignment().dim(i)) { + continue; + } + if (instruction->sharding().tile_assignment().dim(i) != 1) { + return false; + } + diff_dims.push_back(i); + } + if (hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( + sharding, diff_dims) != instruction->sharding()) { + return false; + } + } instruction->set_sharding(std::move(sharding)); return true; } @@ -476,7 +498,7 @@ bool SupportSpatialPartitioning(const HloInstruction* instruction, bool InferDotShardingFromOperands( HloInstruction* instruction, - const dot_as_convolution_util::DotGeneralAsConvolutionDimsInfo& dnums, + const dot_as_convolution_util::DotConvolutionDimsInfo& dnums, bool may_combine_partial_sharding) { auto from_operand = [&](int64 operand_index) { auto operand = instruction->operand(operand_index); @@ -543,9 +565,41 @@ bool InferDotShardingFromOperands( bool InferConvolutionShardingFromOperands(HloInstruction* instruction, int64 aggressiveness, bool may_combine_partial_sharding) { - if (auto dot_dims = dot_as_convolution_util::ParseDotGeneralFromConvolution( - instruction)) { - return InferDotShardingFromOperands(instruction, *dot_dims, + auto get_partitions_for_dims = + [&](const HloInstruction* inst, + absl::Span< + const dot_as_convolution_util::DotConvolutionDimsInfo::DimNums> + dims, + int lhs_or_rhs) { + int64 partitions = 1; + if (!inst->has_sharding()) { + return partitions; + } + const auto& sharding = inst->sharding(); + if (sharding.IsTileMaximal()) { + return partitions; + } + for (const auto& dim : dims) { + if (lhs_or_rhs == 0) { + partitions *= sharding.tile_assignment().dim(dim.lhs); + } else { + CHECK_EQ(lhs_or_rhs, 1); + partitions *= sharding.tile_assignment().dim(dim.rhs); + } + } + return partitions; + }; + auto dot_dims = + dot_as_convolution_util::ParseConvolutionDimsInfo(instruction); + const int64 lhs_conv_spatial_partitions = get_partitions_for_dims( + instruction->operand(0), dot_dims.conv_spatial_dims, 0); + const int64 rhs_conv_spatial_partitions = get_partitions_for_dims( + instruction->operand(1), dot_dims.conv_spatial_dims, 1); + if (dot_dims.conv_spatial_dims.empty() || + (lhs_conv_spatial_partitions == 1 && rhs_conv_spatial_partitions == 1 && + instruction->batch_group_count() == 1 && + instruction->feature_group_count() == 1)) { + return InferDotShardingFromOperands(instruction, dot_dims, may_combine_partial_sharding); } const auto& dnums = instruction->convolution_dimension_numbers(); @@ -597,6 +651,10 @@ bool CanPropagateThroughAtAgressiveLevel(const HloInstruction& inst, inst.opcode() != HloOpcode::kReshape) { return false; } + // Broadcast propagation should have at least aggressiveness 2. + if (aggressiveness < 2 && inst.opcode() == HloOpcode::kBroadcast) { + return false; + } return true; } @@ -743,14 +801,18 @@ bool InferShardingFromOperands(HloInstruction* instruction, return changed; } case HloOpcode::kBroadcast: { - const HloInstruction* op = instruction->operand(0); - if (!IsSpatiallyPartitioned(op) || op->sharding().IsReplicated()) { + // Make forward propagation through broadcast low priority to avoid + // resharding after broadcast. + if (aggressiveness < 3) { return false; } - // Heuristic: If an operand is more than 8 times fewer elements than its - // output, do not propagate sharding. - if (ShapeUtil::ElementsIn(instruction->shape()) > - 8 * ShapeUtil::ElementsIn(op->shape())) { + // Do not override existing tile sharding. This is likely from users. + if (IsSpatiallyPartitioned(instruction) && + !instruction->sharding().IsTileMaximal()) { + return false; + } + const HloInstruction* op = instruction->operand(0); + if (!IsSpatiallyPartitioned(op) || op->sharding().IsReplicated()) { return false; } // The output will be tiled along the broadcasted dimension the same way @@ -1031,7 +1093,7 @@ bool InferShardingFromOperands(HloInstruction* instruction, HloSharding InferDotOperandSharding( const HloInstruction* instruction, - const dot_as_convolution_util::DotGeneralAsConvolutionDimsInfo& dnums, + const dot_as_convolution_util::DotConvolutionDimsInfo& dnums, int64 operand_index, bool may_combine_partial_sharding) { auto operand = instruction->operand(operand_index); auto other = instruction->operand(1 - operand_index); @@ -1185,10 +1247,10 @@ absl::optional GetShardingFromUser( return HloSharding::Tile(new_tile_assignment); } case HloOpcode::kConvolution: { - if (auto dot_dims = - dot_as_convolution_util::ParseDotGeneralFromConvolution(&user)) { + auto dot_dims = dot_as_convolution_util::ParseConvolutionDimsInfo(&user); + if (dot_dims.conv_spatial_dims.empty()) { int64 op_idx = user.operand_index(&instruction); - return InferDotOperandSharding(&user, *dot_dims, op_idx, + return InferDotOperandSharding(&user, dot_dims, op_idx, may_combine_partial_sharding); } return absl::nullopt; @@ -1376,6 +1438,9 @@ absl::optional GetShardingFromUser( bool InferShardingFromUsers(HloInstruction* instruction, const ComputationMap& computation_map, int64 aggressiveness, bool is_spmd) { + if (aggressiveness < 2 && instruction->opcode() == HloOpcode::kBroadcast) { + return false; + } if (!SupportSpatialPartitioning(instruction, computation_map, is_spmd)) { return false; } @@ -1657,17 +1722,11 @@ StatusOr ShardingPropagation::Run(HloModule* module) { // indefinitely. int64 iterations = 0; auto run_to_fix_point = [&](int64 aggressiveness) { - absl::flat_hash_set workset; - for (const HloComputation* computation : module->computations()) { - for (const HloInstruction* instruction : computation->instructions()) { - // Remove the instructions where the sharding was provided from the - // outside so we don't modify them. - if (!provided_shardings.contains(instruction)) { - workset.insert(instruction); - } - } - } - while (!workset.empty()) { + absl::flat_hash_set already_inferred_from_operands; + absl::flat_hash_set already_inferred_from_users; + bool changed_last_iter = true; + while (changed_last_iter) { + changed_last_iter = false; int64 inferred_from_operand_counter = 0; int64 inferred_from_user_counter = 0; int64 instruction_counter = 0; @@ -1680,17 +1739,14 @@ StatusOr ShardingPropagation::Run(HloModule* module) { for (const HloInstruction* instruction : instructions) { already_sharded_counter += (instruction->has_sharding() ? 1 : 0); } - - instructions.erase( - std::remove_if(instructions.begin(), instructions.end(), - [&](HloInstruction* instruction) { - return !workset.contains(instruction); - }), - instructions.end()); - // First iterate the HLO graph in post order taking shardings from // operands. for (HloInstruction* instruction : instructions) { + if (already_inferred_from_operands.contains(instruction) || + provided_shardings.contains(instruction)) { + continue; + } + already_inferred_from_operands.insert(instruction); if (InferShardingFromOperands(instruction, computation_map, is_spmd_, aggressiveness)) { ++inferred_from_operand_counter; @@ -1698,31 +1754,37 @@ StatusOr ShardingPropagation::Run(HloModule* module) { VLOG(2) << "Add sharding (forward-pass): " << instruction->ToString(); maybe_computation_propagation(instruction); - for (auto user : instruction->users()) { - if (!provided_shardings.contains(user)) { - workset.insert(user); - } + for (auto operand : instruction->operands()) { + already_inferred_from_users.erase(operand); } - } else { - workset.erase(instruction); + for (auto user : instruction->users()) { + already_inferred_from_operands.erase(user); + } + changed_last_iter = true; } } // Then iterate the HLO graph in reverse post order taking shardings // from users. for (auto it = instructions.rbegin(); it != instructions.rend(); ++it) { + if (already_inferred_from_users.contains(*it) || + provided_shardings.contains(*it)) { + continue; + } + already_inferred_from_users.insert(*it); if (InferShardingFromUsers(*it, computation_map, aggressiveness, is_spmd_)) { ++inferred_from_user_counter; any_changed = true; VLOG(2) << "Add sharding (backward-pass): " << (*it)->ToString(); maybe_computation_propagation(*it); - workset.insert(*it); for (auto operand : (*it)->operands()) { - if (!provided_shardings.contains(operand)) { - workset.insert(operand); - } + already_inferred_from_users.erase(operand); } + for (auto user : (*it)->users()) { + already_inferred_from_operands.erase(user); + } + changed_last_iter = true; } } } @@ -1733,11 +1795,13 @@ StatusOr ShardingPropagation::Run(HloModule* module) { << inferred_from_operand_counter; VLOG(1) << " shardings inferred from users: " << inferred_from_user_counter; + VLOG(1) << " aggressiveness: " << aggressiveness; ++iterations; } }; - run_to_fix_point(0); - run_to_fix_point(1); + for (int64 aggressiveness = 0; aggressiveness < 4; ++aggressiveness) { + run_to_fix_point(aggressiveness); + } VLOG(1) << "Sharding propagation completed after " << iterations << " iterations"; diff --git a/tensorflow/compiler/xla/service/sharding_propagation_test.cc b/tensorflow/compiler/xla/service/sharding_propagation_test.cc index 03c77c2038c..8c4d8fc24ff 100644 --- a/tensorflow/compiler/xla/service/sharding_propagation_test.cc +++ b/tensorflow/compiler/xla/service/sharding_propagation_test.cc @@ -65,22 +65,6 @@ ENTRY %elementwise { op::Sharding("{devices=[1,2,2,1]0,1,2,3}")); } -TEST_F(ShardingPropagationTest, BroadcastForwardPassNoSharding) { - const char* const hlo_string = R"( -HloModule module -ENTRY %broadcast { - %param0 = f32[7,11]{1,0} parameter(0), - sharding={devices=[2,2]0,1,2,3} - %broadcast = f32[5,7,11,13]{3,2,1,0} broadcast(%param0), dimensions={1,2} - ROOT %copy = f32[5,7,11,13]{3,2,1,0} copy(%broadcast) -})"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - TF_ASSERT_OK_AND_ASSIGN(bool changed, - ShardingPropagation().Run(module.get())); - EXPECT_FALSE(changed); -} - // Regression Test for b/129569657. TEST_F(ShardingPropagationTest, BroadcastForwardPass) { const char* const hlo_string = R"( @@ -530,6 +514,26 @@ ENTRY %pad { op::Sharding("{devices=[2,2]0,1,2,3}")); } +TEST_F(ShardingPropagationTest, PartialReplicatedPadForwardPass) { + const char* const hlo_string = R"( +HloModule module +ENTRY %pad { + %input = f32[11,17]{1,0} parameter(0), + sharding={devices=[2,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate} + %pad_value = f32[] parameter(1) + %pad = f32[27,51]{1,0} pad(%input, %pad_value), padding=2_4_1x1_1_2 + ROOT %copy = f32[27,51]{1,0} copy(%pad) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT( + FindInstruction(module.get(), "pad"), + op::Sharding("{devices=[2,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}")); +} + TEST_F(ShardingPropagationTest, ShardedPreferredOverReplicated) { const char* const hlo_string = R"( HloModule module @@ -653,6 +657,25 @@ ENTRY %slice { op::Sharding("{devices=[2,1]0,1}")); } +TEST_F(ShardingPropagationTest, PartialReplicatedStridedSlice) { + const char* const hlo_string = R"( +HloModule module + +ENTRY %slice { + %param = f32[17,13]{1,0} parameter(0), + sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate} + %slice = f32[7,5]{1,0} slice(%param), slice={[1:15:2], [5:10:1]} + ROOT %tuple = (f32[7,5]{1,0}) tuple(%slice) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "slice"), + op::Sharding("{devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}")); +} + TEST_F(ShardingPropagationTest, ReduceWindowBackwardPass) { const char* const hlo_string = R"( HloModule module @@ -1263,6 +1286,38 @@ ENTRY %conv { op::Sharding("{replicated}")); } +TEST_F(ShardingPropagationTest, + ConvolutionFilterIFOFPartitionedInputPartialReplicate) { + const char* const hlo_string = R"( + HloModule module + +ENTRY entry { + %lhs = f32[128,112,112,12] parameter(0) + %lhs.copy = f32[128,112,112,12] copy(f32[128,112,112,12] %lhs), + sharding={devices=[1,1,1,2,2]0,1,2,3 last_tile_dim_replicate} + %rhs = f32[7,7,12,64] parameter(1) + %rhs.copy = f32[7,7,12,64] copy(f32[7,7,12,64] %rhs), + sharding={devices=[1,1,2,2]0,1,2,3} + %conv = f32[128,56,56,64] convolution( + f32[128,112,112,12] %lhs.copy, + f32[7,7,12,64] %rhs.copy), + window={size=7x7 stride=2x2 pad=3_3x3_3}, + dim_labels=b01f_01io->b01f + ROOT %copy = f32[128,56,56,64] copy(conv) +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); + VLOG(1) << module->ToString(); + + EXPECT_TRUE(changed); + EXPECT_THAT( + FindInstruction(module.get(), "conv"), + op::Sharding("{devices=[1,1,1,2,2]0,2,1,3 last_tile_dim_replicate}")); +} + TEST_F(ShardingPropagationTest, ConcatFromUserUnshardedDim) { const char* const hlo_string = R"( HloModule module @@ -1408,11 +1463,11 @@ ENTRY entry { ShardingPropagation().Run(module.get())); EXPECT_TRUE(changed); EXPECT_THAT(FindInstruction(module.get(), "tp"), - op::Sharding("{{devices=[3,1]0,1,2}}")); + op::Sharding("{{devices=[1,2]0,1}}")); EXPECT_THAT(FindInstruction(module.get(), "tgte"), - op::Sharding("{devices=[3,1]0,1,2}")); + op::Sharding("{devices=[1,2]0,1}")); EXPECT_THAT(FindInstruction(module.get(), "ttr"), - op::Sharding("{devices=[1,3]0,1,2}")); + op::Sharding("{devices=[2,1]0,1}")); EXPECT_THAT(FindInstruction(module.get(), "tr"), op::Sharding("{{devices=[1,3]0,1,2}}")); EXPECT_THAT(FindInstruction(module.get(), "fp"), @@ -1774,6 +1829,28 @@ ENTRY entry { op::Sharding("{devices=[2,1]0,1}")); } +TEST_F(ShardingPropagationTest, GatherFromIndex_PartialReplicate) { + const char* hlo_string = R"( +HloModule module + +ENTRY entry { + %input = f32[2,9] parameter(0), sharding={replicated} + %indices = s32[3] parameter(1), + sharding={devices=[2,2]0,1,2,3 last_tile_dim_replicate} + %gather = f32[3,9] gather(%input, %indices), offset_dims={1}, + collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, + slice_sizes={1,9} + ROOT %copy = f32[3,9] copy(%gather) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "gather"), + op::Sharding("{devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}")); +} + TEST_F(ShardingPropagationTest, GatherFromDataOperand) { const char* hlo_string = R"( HloModule module @@ -1795,6 +1872,28 @@ ENTRY entry { op::Sharding("{devices=[1,2]0,1}")); } +TEST_F(ShardingPropagationTest, GatherFromDataOperand_PartialReplicate) { + const char* hlo_string = R"( +HloModule module + +ENTRY entry { + %input = f32[2,9] parameter(0), + sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate} + %indices = s32[3] parameter(1), sharding={replicated} + %gather = f32[3,9] gather(%input, %indices), offset_dims={1}, + collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, + slice_sizes={1,9} + ROOT %copy = f32[3,9] copy(%gather) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "gather"), + op::Sharding("{devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}")); +} + TEST_F(ShardingPropagationTest, GatherToIndex) { const char* hlo_string = R"( HloModule module @@ -1816,6 +1915,28 @@ ENTRY entry { op::Sharding("{devices=[2]0,1}")); } +TEST_F(ShardingPropagationTest, GatherToIndex_PartialReplicate) { + const char* hlo_string = R"( +HloModule module + +ENTRY entry { + %input = f32[2,9] parameter(0), sharding={replicated} + %p1 = s32[3] parameter(1) + %indices = s32[3] copy(%p1) + ROOT %gather = f32[3,9] gather(%input, %indices), offset_dims={1}, + collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, + slice_sizes={1,9}, + sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "indices"), + op::Sharding("{devices=[2,2]0,1,2,3 last_tile_dim_replicate}")); +} + TEST_F(ShardingPropagationTest, GatherToIndex2) { const char* hlo_string = R"( HloModule module @@ -1839,6 +1960,30 @@ ENTRY entry { op::Sharding("{devices=[1,2,1]0,1}")); } +TEST_F(ShardingPropagationTest, GatherToIndex2_PartialReplicate) { + const char* hlo_string = R"( +HloModule module + +ENTRY entry { + %input = bf16[2,4819,4] parameter(0), sharding={replicated} + %p1 = s32[2,1000,2] parameter(1) + %indices = s32[2,1000,2] copy(%p1) + ROOT %gather = bf16[2,1000,4] + gather(bf16[2,4819,4] %input, s32[2,1000,2] %indices), + offset_dims={2}, collapsed_slice_dims={0,1}, + start_index_map={0,1}, index_vector_dim=2, slice_sizes={1,1,4}, + sharding={devices=[1,2,1,2]0,1,2,3 last_tile_dim_replicate} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT( + FindInstruction(module.get(), "indices"), + op::Sharding("{devices=[1,2,1,2]0,1,2,3 last_tile_dim_replicate}")); +} + TEST_F(ShardingPropagationTest, GatherToIndex3) { const char* hlo_string = R"( HloModule module @@ -1883,6 +2028,27 @@ ENTRY entry { op::Sharding("{devices=[1,2]0,1}")); } +TEST_F(ShardingPropagationTest, GatherToDataOperand_PartialReplicate) { + const char* hlo_string = R"( +HloModule module + +ENTRY entry { + %p0 = f32[2,9] parameter(0) + %input = f32[2,9] copy(%p0) + %indices = s32[3] parameter(1), sharding={replicated} + ROOT %gather = f32[3,9] gather(%input, %indices), offset_dims={1}, + collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, + slice_sizes={1,9}, sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "input"), + op::Sharding("{devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}")); +} + TEST_F(ShardingPropagationTest, DataOperandToScatter) { const char* const hlo_string = R"( HloModule module @@ -1914,6 +2080,38 @@ ENTRY entry { op::Sharding("{devices=[1,2]0,1}")); } +TEST_F(ShardingPropagationTest, DataOperandToScatter_PartialReplicate) { + const char* const hlo_string = R"( +HloModule module + +add (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT sum = f32[] add(lhs, rhs) +} + +ENTRY entry { + %input = f32[2,9] parameter(0), + sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate} + %indices = s32[3] parameter(1), sharding={replicated} + %updates = f32[3,9] parameter(2), sharding={replicated} + %scatter = f32[2,9] scatter(%input, %indices, %updates), + to_apply=add, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 + ROOT %copy = f32[2,9] copy(%scatter) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "scatter"), + op::Sharding("{devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}")); +} + TEST_F(ShardingPropagationTest, UpdateOperandToScatter) { const char* const hlo_string = R"( HloModule module @@ -1945,6 +2143,70 @@ ENTRY entry { op::Sharding("{devices=[1,2]0,1}")); } +TEST_F(ShardingPropagationTest, UpdateOperandToScatter_PartialReplicate) { + const char* const hlo_string = R"( +HloModule module + +add (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT sum = f32[] add(lhs, rhs) +} + +ENTRY entry { + %input = f32[2,9] parameter(0), sharding={replicated} + %indices = s32[3] parameter(1), sharding={replicated} + %updates = f32[3,9] parameter(2), + sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate} + %scatter = f32[2,9] scatter(%input, %indices, %updates), + to_apply=add, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 + ROOT %copy = f32[2,9] copy(%scatter) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "scatter"), + op::Sharding("{devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}")); +} + +TEST_F(ShardingPropagationTest, ScatterToDataOperand_PartialReplicate) { + const char* const hlo_string = R"( +HloModule module + +add (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT sum = f32[] add(lhs, rhs) +} + +ENTRY entry { + %p0 = f32[2,9] parameter(0) + %input = f32[2,9] copy(%p0) + %indices = s32[3] parameter(1), sharding={replicated} + %updates = f32[3,9] parameter(2), sharding={replicated} + ROOT %scatter = f32[2,9] scatter(%input, %indices, %updates), + to_apply=add, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1, + sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "input"), + op::Sharding("{devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}")); +} + TEST_F(ShardingPropagationTest, ScatterToDataOperand) { const char* const hlo_string = R"( HloModule module @@ -1976,6 +2238,38 @@ ENTRY entry { op::Sharding("{devices=[1,2]0,1}")); } +TEST_F(ShardingPropagationTest, ScatterToUpdateOperand_PartialReplicate) { + const char* const hlo_string = R"( +HloModule module + +add (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT sum = f32[] add(lhs, rhs) +} + +ENTRY entry { + %input = f32[2,9] parameter(0) + %indices = s32[3] parameter(1), sharding={replicated} + %p2 = f32[3,9] parameter(2) + %updates = f32[3,9] copy(%p2) + ROOT %scatter = f32[2,9] scatter(%input, %indices, %updates), + to_apply=add, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1, + sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "updates"), + op::Sharding("{devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}")); +} + TEST_F(ShardingPropagationTest, ScatterToUpdateOperand) { const char* const hlo_string = R"( HloModule module @@ -2038,6 +2332,38 @@ ENTRY entry { op::Sharding("{devices=[2]0,1}")); } +TEST_F(ShardingPropagationTest, ScatterUpdateToIndex_PartialReplicate) { + const char* const hlo_string = R"( +HloModule module + +add (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT sum = f32[] add(lhs, rhs) +} + +ENTRY entry { + %input = f32[2,9] parameter(0), sharding={replicated} + %p1 = s32[3] parameter(1), sharding={replicated} + %indices = s32[3] copy(%p1) + %updates = f32[3,9] parameter(2), + sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate} + ROOT %scatter = f32[2,9] scatter(%input, %indices, %updates), + to_apply=add, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1, sharding={replicated} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "indices"), + op::Sharding("{devices=[2,2]0,1,2,3 last_tile_dim_replicate}")); +} + TEST_F(ShardingPropagationTest, ScatterIndexToUpdate) { const char* const hlo_string = R"( HloModule module @@ -2069,6 +2395,38 @@ ENTRY entry { op::Sharding("{devices=[2,1]0,1}")); } +TEST_F(ShardingPropagationTest, ScatterIndexToUpdate_PartialReplicate) { + const char* const hlo_string = R"( +HloModule module + +add (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT sum = f32[] add(lhs, rhs) +} + +ENTRY entry { + %input = f32[2,9] parameter(0), sharding={replicated} + %indices = s32[3] parameter(1), + sharding={devices=[2,2]0,1,2,3 last_tile_dim_replicate} + %p2 = f32[3,9] parameter(2), sharding={replicated} + %updates = f32[3,9] copy(%p2) + ROOT %scatter = f32[2,9] scatter(%input, %indices, %updates), + to_apply=add, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1, sharding={replicated} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "updates"), + op::Sharding("{devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}")); +} + TEST_F(ShardingPropagationTest, PartialShardingOnElementwise) { const char* const hlo_string = R"( HloModule module diff --git a/tensorflow/compiler/xla/service/space_to_batch_converter.cc b/tensorflow/compiler/xla/service/space_to_batch_converter.cc new file mode 100644 index 00000000000..47aee8ed5a8 --- /dev/null +++ b/tensorflow/compiler/xla/service/space_to_batch_converter.cc @@ -0,0 +1,478 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/xla/service/space_to_batch_converter.h" + +#include + +#include "absl/memory/memory.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_creation_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/bitmap.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +namespace { + +// ConvolutionVisitor traverses the HLO computation and rewrites Convolution +// operations with small batch counts into convolutions with larger batch +// counts by moving space to batch. +class ConvolutionVisitor : public DfsHloVisitorWithDefault { + public: + // Default visitor action is to do nothing and return OK. + Status DefaultAction(HloInstruction* /*hlo_instruction*/) override { + return Status::OK(); + } + + Status HandleConvolution(HloInstruction* convolution) override; + + // Runs the visitor on a computation. + static bool Run(int64 limit_on_batch_size, HloComputation* computation); + + // Returns whether any convolution ops were rewritten. + const bool changed() const { return changed_; } + + ~ConvolutionVisitor() override = default; + + private: + explicit ConvolutionVisitor(int64 limit_on_batch_size, + HloComputation* computation) + : computation_(computation), limit_on_batch_size_(limit_on_batch_size) {} + + // Current HloComputation instance the ConvolutionVisitor is traversing. + HloComputation* computation_; + + // Whether rewrite has occurred. + bool changed_ = false; + + // Limit on batch size to apply this technique on. + int64 limit_on_batch_size_; +}; + +bool ConvolutionVisitor::Run(int64 limit_on_batch_size, + HloComputation* computation) { + ConvolutionVisitor visitor(limit_on_batch_size, computation); + TF_CHECK_OK(computation->Accept(&visitor)); + return visitor.changed_; +} + +Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { + VLOG(1) << "Handling conv " << convolution->ToString(); + changed_ = false; + + ConvolutionDimensionNumbers dim_numbers = + convolution->convolution_dimension_numbers(); + + // If there are no spatial dims, we return. + if (dim_numbers.input_spatial_dimensions_size() < 1) { + return Status::OK(); + } + + // This is the spatial dimension we choose to spilt. + constexpr int64 kChosenSpatialDim = 0; + constexpr int64 kLowLimitForSplitCount = 4; + constexpr int64 kHighLimitForSplitCount = 24; + + // Batch in batch_group_count has different semantics (it isn't true batch). + // Consider supporting this case in future if needed. + if (convolution->batch_group_count() != 1) { + return Status::OK(); + } + + if (convolution->window().dimensions(kChosenSpatialDim).window_dilation() != + 1) { + return Status::OK(); + } + + // TODO(b/168316428): Support base dilations. + if (convolution->window().dimensions(kChosenSpatialDim).base_dilation() != + 1) { + return Status::OK(); + } + + int64 activations_batch_dim = dim_numbers.input_batch_dimension(); + + const int64 old_batch_size = + convolution->operand(0)->shape().dimensions(activations_batch_dim); + + if (old_batch_size > limit_on_batch_size_) { + return Status::OK(); + } + + auto kernel = convolution->mutable_operand(1); + const auto& kernel_shape = kernel->shape(); + const int64 kernel_spatial_dim_size = kernel_shape.dimensions( + dim_numbers.kernel_spatial_dimensions(kChosenSpatialDim)); + + auto activations = convolution->mutable_operand(0); + + int64 spatial_dimension_to_split = + dim_numbers.input_spatial_dimensions(kChosenSpatialDim); + + const int64 input_dim_size = activations->shape().dimensions( + dim_numbers.input_spatial_dimensions(kChosenSpatialDim)); + + const int64 inherent_low_padding = + convolution->window().dimensions(kChosenSpatialDim).padding_low(); + const int64 inherent_high_padding = + convolution->window().dimensions(kChosenSpatialDim).padding_high(); + const bool inherent_padding_needed = + inherent_low_padding != 0 || inherent_high_padding != 0; + + const int64 stride = + convolution->window().dimensions(kChosenSpatialDim).stride(); + + const int64 spatial_size = + input_dim_size + inherent_low_padding + inherent_high_padding; + VLOG(1) << "spatial size " << spatial_size; + + int64 min_pad_size = INT64_MAX; + int64 num_splits; + // Explore several splitting points; choose one that requires least padding. + // This padding is done so that we can evenly reshape. + for (int64 j = kHighLimitForSplitCount; j >= kLowLimitForSplitCount; j--) { + if (input_dim_size / j < kernel_spatial_dim_size) { + continue; + } + + if (spatial_size < j) { + continue; + } + + const int64 output_offsets = convolution->shape().dimensions( + dim_numbers.output_spatial_dimensions(kChosenSpatialDim)); + const int64 output_offsets_per_split = CeilOfRatio(output_offsets, j); + + const int64 spatial_split_size = output_offsets_per_split * stride; + + // Pad spatial dim + const int64 pad_size = spatial_split_size * j - spatial_size; + if (pad_size >= 0 && pad_size < min_pad_size) { + min_pad_size = pad_size; + num_splits = j; + } + } + + // No suitable split found. + if (min_pad_size == INT64_MAX) { + return Status::OK(); + } + + // By now, we are certain that the space-to-batch transormation is going to + // take place. + + // Create the new convolution dim numbers. + auto new_dim_numbers = dim_numbers; + + // We'd need transposition of activations here such that batch and space dim + // that is being split are adjacent (in that order). + if (spatial_dimension_to_split != activations_batch_dim + 1) { + int64 pushed_counter = 0; + std::vector transpose_dims; + int64 new_batch_dim, new_spatial_dim; + for (int i = 0; i < activations->shape().rank(); ++i) { + if (i == activations_batch_dim) { + continue; + } + if (i == spatial_dimension_to_split) { + new_dim_numbers.set_input_batch_dimension(pushed_counter); + transpose_dims.push_back(activations_batch_dim); + new_batch_dim = pushed_counter; + pushed_counter++; + new_spatial_dim = pushed_counter; + } + + if (i == dim_numbers.input_feature_dimension()) { + new_dim_numbers.set_input_feature_dimension(pushed_counter); + } else { + for (int j = 0; j < dim_numbers.input_spatial_dimensions_size(); ++j) { + if (i == dim_numbers.input_spatial_dimensions(j)) { + new_dim_numbers.set_input_spatial_dimensions(j, pushed_counter); + break; + } + } + } + transpose_dims.push_back(i); + pushed_counter++; + } + + activations_batch_dim = new_batch_dim; + spatial_dimension_to_split = new_spatial_dim; + TF_ASSIGN_OR_RETURN(activations, + MakeTransposeHlo(activations, transpose_dims)); + } + + const int64 output_offsets = convolution->shape().dimensions( + dim_numbers.output_spatial_dimensions(kChosenSpatialDim)); + const int64 output_offsets_per_split = + CeilOfRatio(output_offsets, num_splits); + + const int64 spatial_split_size = output_offsets_per_split * stride; + const int64 slice_size = + (output_offsets_per_split - 1) * stride + kernel_spatial_dim_size; + + VLOG(1) << "spatial_split_size " << spatial_split_size << " stride " + << stride; + + // Pad spatial dim. + const int64 pad_size = spatial_split_size * num_splits - spatial_size; + + VLOG(1) << "spatial_dimension_to_split " << spatial_dimension_to_split + << " num_splits " << num_splits << " kernel_spatial_dim_size " + << kernel_spatial_dim_size; + + // Because we are splitting the spatial dimension, if convolution needed + // padding in the spatial dimension, we materialize it. + if (pad_size != 0 || inherent_padding_needed) { + PaddingConfig padding_config = + MakeNoPaddingConfig(activations->shape().dimensions_size()); + padding_config.mutable_dimensions(spatial_dimension_to_split) + ->set_edge_padding_high(inherent_high_padding + pad_size); + padding_config.mutable_dimensions(spatial_dimension_to_split) + ->set_edge_padding_low(inherent_low_padding); + HloInstruction* padding = + computation_->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(activations->shape().element_type()))); + TF_ASSIGN_OR_RETURN(activations, + MakePadHlo(activations, padding, padding_config)); + } + VLOG(1) << "Initial padded activations shape " + << activations->shape().ToString(); + + // Now we reorganize the activations. E.g. if the shape [B, SPACE] was [1, 16] + // and 4 splits were needed, we first create [4, 4]. Next, to deal with halo + // in the spatial dimension, we first pad that dimension. E.g. if halo size + // was 2, we'd create a shape of [4, 6]. We then flatten the shape such that + // A = [1, 24]. Now, we rotate the flattened 24 dimension left by 2 (with + // -2 low padding and +2 high padding) to create shape B. Then, we select + // between A and B such that halo regions are placed into A at the right + // locations. + + // The benefit of the above mentioned scheme is that it allows for batch + // growth. Here are some examples of the size increases it causes for a 3x3 + // kernel. + // with batch=1, [1,16] -> [4,4] -> [4,6] -> [1,24] growth of 8. + // with batch=2, [2,16] -> [8,4] -> [8,6] -> [1,48] growth of 16. + // with batch=3, [3,16] -> [12,4] -> [12,6] -> [1,72] growth of 24. + + std::vector reshape_dimensions( + activations->shape().dimensions().begin(), + activations->shape().dimensions().end()); + + reshape_dimensions[spatial_dimension_to_split] = spatial_split_size; + reshape_dimensions[activations_batch_dim] = num_splits * old_batch_size; + + TF_ASSIGN_OR_RETURN(HloInstruction * batch_increased_reshape, + MakeReshapeHlo(reshape_dimensions, activations)); + convolution->SetupDerivedInstruction(batch_increased_reshape); + + VLOG(1) << "First reshape done " << batch_increased_reshape->ToString(); + + PaddingConfig padding_config = + MakeNoPaddingConfig(batch_increased_reshape->shape().dimensions_size()); + padding_config.mutable_dimensions(spatial_dimension_to_split) + ->set_edge_padding_high(slice_size - spatial_split_size); + HloInstruction* padding = + computation_->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(batch_increased_reshape->shape().element_type()))); + TF_ASSIGN_OR_RETURN( + HloInstruction * pad_applied, + MakePadHlo(batch_increased_reshape, padding, padding_config)); + + VLOG(1) << "Padding done " << pad_applied->ToString(); + + auto straightened_activations_dims = reshape_dimensions; + straightened_activations_dims[spatial_dimension_to_split] = + num_splits * slice_size; + straightened_activations_dims[activations_batch_dim] = old_batch_size; + + VLOG(1) << "slice_size " << slice_size; + TF_ASSIGN_OR_RETURN( + HloInstruction * straightened_activations, + MakeReshapeHlo(straightened_activations_dims, pad_applied)); + + VLOG(1) << "Straightening done"; + + PaddingConfig rotation_padding_config = + MakeNoPaddingConfig(straightened_activations->shape().dimensions_size()); + rotation_padding_config.mutable_dimensions(spatial_dimension_to_split) + ->set_edge_padding_high(slice_size - spatial_split_size); + rotation_padding_config.mutable_dimensions(spatial_dimension_to_split) + ->set_edge_padding_low(spatial_split_size - slice_size); + HloInstruction* rotation_padding = + computation_->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(straightened_activations->shape().element_type()))); + TF_ASSIGN_OR_RETURN(HloInstruction * rotated_activations, + MakePadHlo(straightened_activations, rotation_padding, + rotation_padding_config)); + convolution->SetupDerivedInstruction(rotated_activations); + + // Build a constant PRED to decide which elements in the split dimension + // are from halo. + tensorflow::core::Bitmap b(num_splits * slice_size); + for (int k = 0; k < num_splits * slice_size; ++k) { + if (k % slice_size < spatial_split_size) { + b.set(k); + } else { + b.clear(k); + } + } + + auto arg_literal = LiteralUtil::CreateR1(b); + HloInstruction* slice_mask = computation_->AddInstruction( + HloInstruction::CreateConstant(std::move(arg_literal))); + + // Broadcast the mask in all dimensions of the activations. + HloInstruction* shape_mask = + MakeBroadcastHlo(slice_mask, {spatial_dimension_to_split}, + straightened_activations->shape().dimensions()); + + VLOG(1) << "Shape mask made " << shape_mask->ToString(); + + TF_ASSIGN_OR_RETURN(HloInstruction * select, + MakeSelectHlo(shape_mask, straightened_activations, + rotated_activations, convolution)); + VLOG(1) << "Select generated" << select->ToString(); + + // Increase batch size for one last time. + std::vector combined_batch_dimensions( + pad_applied->shape().dimensions().begin(), + pad_applied->shape().dimensions().end()); + + combined_batch_dimensions[activations_batch_dim] = + old_batch_size * num_splits; + TF_ASSIGN_OR_RETURN(activations, + MakeReshapeHlo(combined_batch_dimensions, select)); + + VLOG(1) << "Batch merge done " << activations->ToString(); + + // Now, we rewrite the convolution with a larger batch. + const auto& activations_shape = activations->shape(); + const int64 rank = activations_shape.dimensions_size(); + + // We will generate output such that batch is followed by the split spatial + // dimension. + std::vector transpose_dims(convolution->shape().rank()); + int dim_count = 0; + std::map dim_map; + + for (int j = 0; j < dim_numbers.output_spatial_dimensions_size(); ++j) { + if (j == kChosenSpatialDim) { + dim_map[dim_numbers.output_batch_dimension()] = dim_count; + new_dim_numbers.set_output_batch_dimension(dim_count++); + } + dim_map[dim_numbers.output_spatial_dimensions(j)] = dim_count; + new_dim_numbers.set_output_spatial_dimensions(j, dim_count); + dim_count++; + } + + dim_map[dim_numbers.output_feature_dimension()] = dim_count; + new_dim_numbers.set_output_feature_dimension(dim_count); + + int p = 0; + for (const auto& entry : dim_map) { + transpose_dims[p] = entry.second; + p++; + } + + auto new_window = convolution->window(); + new_window.mutable_dimensions(kChosenSpatialDim)->set_padding_high(0); + new_window.mutable_dimensions(kChosenSpatialDim)->set_padding_low(0); + TF_ASSIGN_OR_RETURN( + HloInstruction * new_conv, + MakeConvolveHlo(activations, /*rhs=*/convolution->mutable_operand(1), + convolution->feature_group_count(), + convolution->batch_group_count(), new_window, + new_dim_numbers, convolution->precision_config())); + convolution->SetupDerivedInstruction(new_conv); + + VLOG(1) << "new_conv " << new_conv->ToString(); + + const int64 output_split_spatial_dim = + new_dim_numbers.output_spatial_dimensions(kChosenSpatialDim); + const int64 output_batch_dim = new_dim_numbers.output_batch_dimension(); + + Shape new_shape = new_conv->shape(); + const int64 new_batch_size = new_shape.dimensions(output_batch_dim); + const int64 new_spatial_dim_size = + new_shape.dimensions(output_split_spatial_dim); + + CHECK_EQ(new_batch_size % old_batch_size, 0); + + const int64 output_split_batch_size = new_batch_size / old_batch_size; + + std::vector new_dimensions(new_conv->shape().dimensions().begin(), + new_conv->shape().dimensions().end()); + new_dimensions[output_split_spatial_dim] = + output_split_batch_size * new_spatial_dim_size; + new_dimensions[new_dim_numbers.output_batch_dimension()] = old_batch_size; + + // Reshape the output of the new conv into the old convolutions shape. + TF_ASSIGN_OR_RETURN(HloInstruction * reshape, + MakeReshapeHlo(new_dimensions, new_conv)); + convolution->SetupDerivedInstruction(reshape); + + std::vector start_indices(rank, 0), + end_indices(new_dimensions.begin(), new_dimensions.end()), + strides(rank, 1); + end_indices[output_split_spatial_dim] = convolution->shape().dimensions( + dim_numbers.output_spatial_dimensions(kChosenSpatialDim)); + + // This slicing is getting rid of the padding we added to evenly divide space. + TF_ASSIGN_OR_RETURN( + HloInstruction * output_slice, + MakeSliceHlo(reshape, start_indices, end_indices, strides)); + convolution->SetupDerivedInstruction(output_slice); + + TF_ASSIGN_OR_RETURN(HloInstruction * output_transpose, + MakeTransposeHlo(output_slice, transpose_dims)); + convolution->SetupDerivedInstruction(output_transpose); + + VLOG(1) << "output_transpose " << output_transpose->ToString(); + + changed_ = true; + return computation_->ReplaceInstruction(convolution, output_transpose); +} + +} // namespace + +StatusOr ConvolutionSpaceToBatchConverter::Run(HloModule* module) { + XLA_VLOG_LINES(2, "ConvolutionSpaceToBatchConverter::Run(), before:\n" + + module->ToString()); + bool changed = false; + for (auto* comp : module->MakeNonfusionComputations()) { + if (ConvolutionVisitor::Run(limit_on_batch_size_, comp)) { + changed = true; + } + } + XLA_VLOG_LINES(2, "ConvolutionSpaceToBatchConverter::Run(), after:\n" + + module->ToString()); + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/space_to_batch_converter.h b/tensorflow/compiler/xla/service/space_to_batch_converter.h new file mode 100644 index 00000000000..a92abda0337 --- /dev/null +++ b/tensorflow/compiler/xla/service/space_to_batch_converter.h @@ -0,0 +1,45 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SPACE_TO_BATCH_CONVERTER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_SPACE_TO_BATCH_CONVERTER_H_ + +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/status_macros.h" + +namespace xla { + +// A pass which rewrites convolutions such that space dimension is turned into +// batch. +class ConvolutionSpaceToBatchConverter : public HloModulePass { + public: + explicit ConvolutionSpaceToBatchConverter(int64 limit_on_batch_size = 1) + : limit_on_batch_size_(limit_on_batch_size) {} + + absl::string_view name() const override { + return "convolution-space-to-batch-converter"; + } + + // Run convolution rewriting on the given computation. Returns whether the + // computation was changed. + StatusOr Run(HloModule* module) override; + + int64 limit_on_batch_size_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_SPACE_TO_BATCH_CONVERTER_H_ diff --git a/tensorflow/compiler/xla/service/space_to_batch_converter_test.cc b/tensorflow/compiler/xla/service/space_to_batch_converter_test.cc new file mode 100644 index 00000000000..bbc3882cde9 --- /dev/null +++ b/tensorflow/compiler/xla/service/space_to_batch_converter_test.cc @@ -0,0 +1,147 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/space_to_batch_converter.h" + +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/types.h" + +namespace xla { +namespace { + +using ConvolutionSpaceToBatchConverterTest = HloTestBase; +namespace op = testing::opcode_matchers; + +TEST_F(ConvolutionSpaceToBatchConverterTest, SimpleBatch1) { + string hlo_string = R"( + + HloModule module +ENTRY computation { + %p0 = bf16[1,258,258,32] parameter(0) + %p1 = bf16[3,3,32,32] parameter(1) + ROOT %convolution = bf16[1,256,256,32] convolution(%p0, %p1), window={size=3x3}, + dim_labels=b01f_01io->b01f +} + + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + + auto computation = module->entry_computation(); + ConvolutionSpaceToBatchConverter converter; + ASSERT_TRUE(converter.Run(module.get()).ValueOrDie()); + HloInstruction* root = computation->root_instruction(); + EXPECT_THAT(root, op::Transpose()); + EXPECT_THAT(root->operand(0), op::Slice()); + auto reshape = root->operand(0)->operand(0); + EXPECT_THAT(reshape, op::Reshape()); + EXPECT_THAT(reshape->operand(0), op::Convolution()); + const int64 batch_dim = reshape->operand(0) + ->convolution_dimension_numbers() + .output_batch_dimension(); + // Verify that the transform has increased the batch size. + EXPECT_GT(reshape->operand(0)->shape().dimensions(batch_dim), 1); +} + +TEST_F(ConvolutionSpaceToBatchConverterTest, SimpleBatch2) { + string hlo_string = R"( + HloModule module + ENTRY computation { + %p0 = bf16[2,258,258,32] parameter(0) + %p1 = bf16[3,3,32,32] parameter(1) + ROOT %convolution = bf16[2,256,256,32] convolution(%p0, %p1), window={size=3x3}, + dim_labels=b01f_01io->b01f + } + + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + + ConvolutionSpaceToBatchConverter converter(/*limit_on_batch_size=*/2); + ASSERT_TRUE(converter.Run(module.get()).ValueOrDie()); + auto computation = module->entry_computation(); + HloInstruction* root = computation->root_instruction(); + EXPECT_THAT(root, op::Transpose()); + EXPECT_THAT(root->operand(0), op::Slice()); + auto reshape = root->operand(0)->operand(0); + EXPECT_THAT(reshape, op::Reshape()); + EXPECT_THAT(reshape->operand(0), op::Convolution()); + const int64 batch_dim = reshape->operand(0) + ->convolution_dimension_numbers() + .output_batch_dimension(); + // Verify that the transform has increased the batch size. + EXPECT_GT(reshape->operand(0)->shape().dimensions(batch_dim), 1); +} + +TEST_F(ConvolutionSpaceToBatchConverterTest, Batch4WithStrideAndPad) { + string hlo_string = R"( + HloModule module + ENTRY computation { + %p0 = bf16[4,224,224,3]{3,2,1,0} parameter(0) + %p1 = bf16[7,7,3,64]{3,2,1,0} parameter(1) + + ROOT %convolution.3 = bf16[4,112,112,64]{3,2,1,0} convolution(%p0, %p1), + window={size=7x7 stride=2x2 pad=3_3x3_3}, dim_labels=b01f_01io->b01f + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + + auto computation = module->entry_computation(); + ConvolutionSpaceToBatchConverter converter(/*limit_on_batch_size=*/4); + ASSERT_TRUE(converter.Run(module.get()).ValueOrDie()); + HloInstruction* root = computation->root_instruction(); + EXPECT_THAT(root, op::Transpose()); + EXPECT_THAT(root->operand(0), op::Slice()); + auto reshape = root->operand(0)->operand(0); + EXPECT_THAT(reshape, op::Reshape()); + EXPECT_THAT(reshape->operand(0), op::Convolution()); + const int64 batch_dim = reshape->operand(0) + ->convolution_dimension_numbers() + .output_batch_dimension(); + + EXPECT_GT(reshape->operand(0)->shape().dimensions(batch_dim), 4); +} + +TEST_F(ConvolutionSpaceToBatchConverterTest, Batch1WithKernelDilation) { + string hlo_string = R"( + + HloModule module +ENTRY computation { + %p2 = bf16[1,7,7,128]{3,0,2,1} parameter(0) + %p3 = bf16[1,1,512,128]{3,2,1,0} parameter(1) + ROOT %c = bf16[1,14,14,512]{3,0,2,1} convolution(%p2, %p3), + window={size=1x1 pad=0_1x0_1 lhs_dilate=2x2 rhs_reversal=1x1}, + dim_labels=b01f_01oi->b01f +} + + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + + ConvolutionSpaceToBatchConverter converter; + ASSERT_FALSE(converter.Run(module.get()).ValueOrDie()); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/spmd/BUILD b/tensorflow/compiler/xla/service/spmd/BUILD index d2243d30adf..9ebaaa8242f 100644 --- a/tensorflow/compiler/xla/service/spmd/BUILD +++ b/tensorflow/compiler/xla/service/spmd/BUILD @@ -1,5 +1,6 @@ # Description: SPMD partitioning pass. +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test") package( @@ -23,6 +24,7 @@ cc_library( "spmd_partitioner_util.cc", ], hdrs = [ + "convolution_handler.h", "spmd_partitioner.h", "spmd_partitioner_util.h", ], diff --git a/tensorflow/compiler/xla/service/spmd/convolution_handler.cc b/tensorflow/compiler/xla/service/spmd/convolution_handler.cc index 01d7ea2ff14..0d34c5b62e9 100644 --- a/tensorflow/compiler/xla/service/spmd/convolution_handler.cc +++ b/tensorflow/compiler/xla/service/spmd/convolution_handler.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/xla/service/spmd/convolution_handler.h" + #include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/dot_as_convolution_util.h" @@ -32,24 +34,36 @@ limitations under the License. namespace xla { namespace spmd { + namespace { -// Partition convolution. -StatusOr PartitionConvolution( +// Partition convolution with batch group count. +StatusOr PartitionConvolutionWithBatchGroupCount( PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, - const HloSharding& output_sharding, const Window& conv_window, - HloInstruction* original_hlo, int64 num_partitions, - const SpmdPartitionerOptions& options, HloInstruction* partition_id, - HloModule* module, SpmdBuilder* b); - -// Partition convolution with only paralell dims are tiled -StatusOr PartitionConvolutionWithParallelDimension( - PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, - const HloSharding& output_sharding, const Window& conv_window, - HloInstruction* original_hlo, int64 num_partitions, SpmdBuilder* b) { + const HloSharding& output_sharding, + const std::function( + HloInstruction*, HloInstruction*, SpmdBuilder*, + const Window& conv_window)>& create_sharded_conv, + const Window& conv_window, HloInstruction* original_hlo, + int64 num_partitions, SpmdBuilder* b) { TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution); + if (original_hlo->batch_group_count() == 1 || + original_hlo->batch_group_count() < num_partitions) { + return nullptr; + } const auto& dnums = original_hlo->convolution_dimension_numbers(); + // Only supports batch_group_size equals input_batch_size case. + const int64 input_batch_size = + lhs.base_shape().dimensions(dnums.input_batch_dimension()); + const int64 kernel_output_feature_size = + rhs.base_shape().dimensions(dnums.kernel_output_feature_dimension()); + if (input_batch_size != kernel_output_feature_size || + original_hlo->batch_group_count() != input_batch_size) { + return nullptr; + } + + // Map RHS indices to LHS indices. std::vector rhs_to_lhs_indices(output_base_shape.rank()); rhs_to_lhs_indices[dnums.kernel_output_feature_dimension()] = dnums.input_batch_dimension(); @@ -59,73 +73,149 @@ StatusOr PartitionConvolutionWithParallelDimension( rhs_to_lhs_indices[dnums.kernel_spatial_dimensions(i)] = dnums.input_spatial_dimensions(i); } + + // Map LHS indices to RHS indices. std::vector lhs_to_rhs_indices(output_base_shape.rank()); for (int64 i = 0; i < rhs_to_lhs_indices.size(); ++i) { lhs_to_rhs_indices[rhs_to_lhs_indices[i]] = i; } + + // Map LHS indices to output indices. + std::vector lhs_to_output_indices(lhs.base_shape().rank(), -1); + lhs_to_output_indices[dnums.input_batch_dimension()] = + dnums.output_feature_dimension(); + lhs_to_output_indices[dnums.input_feature_dimension()] = + dnums.output_batch_dimension(); + for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { + lhs_to_output_indices[dnums.input_spatial_dimensions(i)] = + dnums.output_spatial_dimensions(i); + } + + // Align LHS or RHS to other operand if input batch dim or kernel output + // feature dim is partitioned. auto aligned_rhs_sharding = hlo_sharding_util::TransposeSharding(lhs.sharding(), rhs_to_lhs_indices); auto aligned_lhs_sharding = hlo_sharding_util::TransposeSharding(rhs.sharding(), lhs_to_rhs_indices); - // Handling cases where all the partitioned dimensions are parallel - // dimensions. - int64 lhs_parallel_dim_partitions = 1; - int64 rhs_parallel_dim_partitions = 1; - std::vector parallel_spatial_dims; - for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { - int64 lhs_dim = dnums.input_spatial_dimensions(i); - int64 lhs_size = lhs.base_shape().dimensions(lhs_dim); - const auto& wd = conv_window.dimensions(i); - int64 rhs_dim = dnums.kernel_spatial_dimensions(i); - if (dot_as_convolution_util::ConvSpatialDimensionIsParallel(wd, lhs_size)) { - parallel_spatial_dims.emplace_back(i); - lhs_parallel_dim_partitions *= ShardCountAtDim(lhs.sharding(), lhs_dim); - rhs_parallel_dim_partitions *= ShardCountAtDim(rhs.sharding(), rhs_dim); - } - } - bool lhs_partition_dims_are_parallel = - (lhs_parallel_dim_partitions == num_partitions); - bool rhs_partition_dims_are_parallel = - (rhs_parallel_dim_partitions == num_partitions); - - // If there is a parallel dim and all the partitioned dimensions are parallel - // dimensions in either LHS or RHS, simply create partitioned convolutions. - if (parallel_spatial_dims.empty() || ((!lhs_partition_dims_are_parallel) && - (!rhs_partition_dims_are_parallel))) { + bool lhs_batch_dim_is_partitioned = + (ShardCountAtDim(lhs.sharding(), dnums.input_batch_dimension()) == + num_partitions); + bool rhs_output_feature_dim_is_partitioned = + (ShardCountAtDim(rhs.sharding(), + dnums.kernel_output_feature_dimension()) == + num_partitions); + if (!lhs_batch_dim_is_partitioned && !rhs_output_feature_dim_is_partitioned) { return nullptr; } - // Reshard LHS or RHS to partition at parallel dimensions as the other - // operand. - if (lhs_partition_dims_are_parallel) { + // Reshard LHS or RHS to partition at batch dimension or output feature + // dimension as the other operand. + if (lhs_batch_dim_is_partitioned) { + rhs = rhs.Reshard(aligned_rhs_sharding); + } else { + lhs = lhs.Reshard(aligned_lhs_sharding); + } + // Align output sharding after LHS and RHS sharding are consistent. + auto aligned_output_sharding = hlo_sharding_util::TransposeSharding( + lhs.sharding(), lhs_to_output_indices); + + // Create partitioned convolution. + TF_ASSIGN_OR_RETURN( + auto sharded_conv, + create_sharded_conv(lhs.hlo(), rhs.hlo(), b, conv_window)); + sharded_conv->set_sharding(aligned_output_sharding); + return PartitionedHlo(sharded_conv, output_base_shape, lhs.state()) + .Reshard(output_sharding) + .hlo(); +} + +// Partition convolution with feature group count. +StatusOr PartitionConvolutionWithFeatureGroupCount( + PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, + const HloSharding& output_sharding, + const std::function( + HloInstruction*, HloInstruction*, SpmdBuilder*, + const Window& conv_window)>& create_sharded_conv, + const Window& conv_window, HloInstruction* original_hlo, + int64 num_partitions, SpmdBuilder* b) { + TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution); + if (original_hlo->feature_group_count() == 1 || + original_hlo->feature_group_count() < num_partitions) { + return nullptr; + } + + const auto& dnums = original_hlo->convolution_dimension_numbers(); + const int64 input_feature_size = + lhs.base_shape().dimensions(dnums.input_feature_dimension()); + const int64 kernel_output_feature_size = + rhs.base_shape().dimensions(dnums.kernel_output_feature_dimension()); + if (input_feature_size != kernel_output_feature_size || + input_feature_size % original_hlo->feature_group_count() != 0) { + return nullptr; + } + + // Align RHS indices to LHS. + std::vector rhs_to_lhs_indices(output_base_shape.rank()); + rhs_to_lhs_indices[dnums.kernel_output_feature_dimension()] = + dnums.input_feature_dimension(); + rhs_to_lhs_indices[dnums.kernel_input_feature_dimension()] = + dnums.input_batch_dimension(); + for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { + rhs_to_lhs_indices[dnums.kernel_spatial_dimensions(i)] = + dnums.input_spatial_dimensions(i); + } + + // Align LHS indices to RHS. + std::vector lhs_to_rhs_indices(output_base_shape.rank()); + for (int64 i = 0; i < rhs_to_lhs_indices.size(); ++i) { + lhs_to_rhs_indices[rhs_to_lhs_indices[i]] = i; + } + + // Align LHS indices to output. + std::vector lhs_to_output_indices(output_base_shape.rank()); + lhs_to_output_indices[dnums.input_feature_dimension()] = + dnums.output_feature_dimension(); + lhs_to_output_indices[dnums.input_batch_dimension()] = + dnums.output_batch_dimension(); + for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { + lhs_to_output_indices[dnums.input_spatial_dimensions(i)] = + dnums.output_spatial_dimensions(i); + } + + // Align LHS or RHS if input_feature_dim or kernel_output_feature_dim is + // partitioned. + auto aligned_rhs_sharding = + hlo_sharding_util::TransposeSharding(lhs.sharding(), rhs_to_lhs_indices); + auto aligned_lhs_sharding = + hlo_sharding_util::TransposeSharding(rhs.sharding(), lhs_to_rhs_indices); + + bool lhs_feature_dim_is_partitioned = + (ShardCountAtDim(lhs.sharding(), dnums.input_feature_dimension()) == + num_partitions); + bool rhs_output_feature_dim_is_partitioned = + (ShardCountAtDim(rhs.sharding(), + dnums.kernel_output_feature_dimension()) == + num_partitions); + if (!lhs_feature_dim_is_partitioned && + !rhs_output_feature_dim_is_partitioned) { + return nullptr; + } + // Reshard LHS or RHS to partition at input feature dimension or output + // feature dimension as the other operand. + if (lhs_feature_dim_is_partitioned) { rhs = rhs.Reshard(aligned_rhs_sharding); } else { lhs = lhs.Reshard(aligned_lhs_sharding); } - // Get LHS and RHS sharded shape. - auto lhs_shard_shape = MakePartitionedShape(lhs.base_shape(), lhs.sharding()); - auto rhs_shard_shape = MakePartitionedShape(rhs.base_shape(), rhs.sharding()); + // Align output sharding after LHS and RHS sharding are consistent. + auto aligned_output_sharding = hlo_sharding_util::TransposeSharding( + lhs.sharding(), lhs_to_output_indices); - // Update convolution window. - auto new_window = conv_window; - for (const auto& spatial_dim : parallel_spatial_dims) { - auto wd = new_window.mutable_dimensions(spatial_dim); - wd->set_size(lhs_shard_shape.dimensions( - dnums.input_spatial_dimensions(spatial_dim))); - wd->set_stride(std::max(1, wd->size() - 1)); - wd->set_base_dilation(wd->size()); - } TF_ASSIGN_OR_RETURN( - Shape sharded_conv_shape, - ShapeInference::InferConvolveShape( - lhs_shard_shape, rhs_shard_shape, original_hlo->feature_group_count(), - original_hlo->batch_group_count(), new_window, dnums)); - auto sharded_conv = b->AddInstruction(HloInstruction::CreateConvolve( - sharded_conv_shape, lhs.hlo(), rhs.hlo(), - original_hlo->feature_group_count(), original_hlo->batch_group_count(), - new_window, dnums, original_hlo->precision_config())); - sharded_conv->set_sharding(original_hlo->sharding()); + auto sharded_conv, + create_sharded_conv(lhs.hlo(), rhs.hlo(), b, conv_window)); + sharded_conv->set_sharding(aligned_output_sharding); return PartitionedHlo(sharded_conv, output_base_shape, lhs.state()) .Reshard(output_sharding) .hlo(); @@ -136,9 +226,12 @@ StatusOr PartitionConvolutionWithParallelDimension( StatusOr PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS( PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, - const HloSharding& output_sharding, const Window& conv_window, - HloInstruction* original_hlo, HloInstruction* partition_id, - HloModule* module, SpmdBuilder* b) { + const HloSharding& output_sharding, + const std::function( + HloInstruction*, HloInstruction*, SpmdBuilder*, + const Window& conv_window)>& create_sharded_conv, + const Window& conv_window, HloInstruction* original_hlo, + HloInstruction* partition_id, HloModule* module, SpmdBuilder* b) { TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution); TF_RET_CHECK(!lhs.sharding().IsTileMaximal() && !rhs.sharding().IsTileMaximal()); @@ -188,6 +281,22 @@ PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS( rhs = rhs.Reshard(aligned_rhs_sharding).PadWithValue(zero); } + if (original_hlo->feature_group_count() > 1 && + (lhs.sharding().tile_assignment().dim(dnums.input_feature_dimension()) > + 1 || + rhs.sharding().tile_assignment().dim( + dnums.kernel_output_feature_dimension()) > 1)) { + return nullptr; + } + + if (original_hlo->batch_group_count() > 1 && + (lhs.sharding().tile_assignment().dim(dnums.input_batch_dimension()) > + 1 || + rhs.sharding().tile_assignment().dim( + dnums.kernel_output_feature_dimension()) > 1)) { + return nullptr; + } + // Reshard RHS so that each shard computes the partial sum of the full // shape result, and add AllReduce. See HandleConvolutionTiledLhsAndRhs() // that reshards LHS. @@ -214,7 +323,7 @@ PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS( int64 lhs_dimension = dnums.input_spatial_dimensions(i); int64 rhs_dimension = dnums.kernel_spatial_dimensions(i); int64 shard_count = rhs.sharding().tile_assignment().dim(rhs_dimension); - auto wd = conv_window.dimensions(i); + const auto& wd = conv_window.dimensions(i); if (wd.base_dilation() != 1 || wd.window_reversal()) { return nullptr; } @@ -260,7 +369,7 @@ PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS( // Calculate the left and right halo sizes as described in the comments // above. It calculcates the halo sizes with dilation, so we apply // CeilOfRatio({left,right}_halo_size, window_dilation). - auto wd = conv_window.dimensions(i); + const auto& wd = conv_window.dimensions(i); int64 padding_low = wd.padding_low(); int64 padding_high = wd.padding_high(); int64 base = lhs.base_shape().dimensions(lhs_dimension); @@ -387,10 +496,9 @@ PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS( rhs_with_halo = *concat; } - auto conv = b->AddInstruction(HloInstruction::CreateConvolve( - output_base_shape, conv_lhs, rhs_with_halo, - original_hlo->feature_group_count(), original_hlo->batch_group_count(), - new_window, dnums, original_hlo->precision_config())); + TF_ASSIGN_OR_RETURN( + auto conv, create_sharded_conv(conv_lhs, rhs_with_halo, b, new_window)); + auto ar = collective_ops_creator.create_cross_partition_all_reduce( b, conv, MakeBinaryAdd(original_hlo->shape().element_type(), module), {}, (*lhs.state().next_channel_id)++); @@ -405,9 +513,12 @@ PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS( StatusOr PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS( PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, - const HloSharding& output_sharding, const Window& conv_window, - HloInstruction* original_hlo, HloInstruction* partition_id, - HloModule* module, SpmdBuilder* b) { + const HloSharding& output_sharding, + const std::function( + HloInstruction*, HloInstruction*, SpmdBuilder*, + const Window& conv_window)>& create_sharded_conv, + const Window& conv_window, HloInstruction* original_hlo, + HloInstruction* partition_id, HloModule* module, SpmdBuilder* b) { TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution); TF_RET_CHECK(!lhs.sharding().IsTileMaximal() && !rhs.sharding().IsTileMaximal()); @@ -430,7 +541,7 @@ PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS( lhs_to_rhs_indices[rhs_to_lhs_indices[i]] = i; } - Window window = conv_window; + const Window& window = conv_window; std::vector reversed_rhs_dims; for (int64 i = 0; i < window.dimensions_size(); ++i) { if (window.dimensions(i).window_reversal()) { @@ -480,6 +591,21 @@ PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS( rhs.Reshard(aligned_rhs_sharding).PadWithValue(zero, reversed_rhs_dims); } + if (original_hlo->feature_group_count() > 1 && + (lhs.sharding().tile_assignment().dim(dnums.input_feature_dimension()) > + 1 || + rhs.sharding().tile_assignment().dim( + dnums.kernel_output_feature_dimension()) > 1)) { + return nullptr; + } + + if (original_hlo->batch_group_count() > 1 && + (lhs.sharding().tile_assignment().dim(dnums.input_batch_dimension()) > + 1 || + rhs.sharding().tile_assignment().dim( + dnums.kernel_output_feature_dimension()) > 1)) { + return nullptr; + } // Reshard LHS by exchanging halo such that each shard computes the partial // sum of the full shape result, and add AllReduce. // @@ -505,7 +631,7 @@ PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS( int64 lhs_dimension = dnums.input_spatial_dimensions(i); int64 rhs_dimension = dnums.kernel_spatial_dimensions(i); int64 shard_count = lhs.sharding().tile_assignment().dim(lhs_dimension); - auto wd = window.dimensions(i); + const auto& wd = window.dimensions(i); if (wd.base_dilation() != 1) { // TODO(wangtao): support parallel dim if it is replicate here. return nullptr; @@ -540,7 +666,7 @@ PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS( // Calculate the left and right halo sizes as described in the comments // above. - auto wd = window.dimensions(i); + const auto& wd = window.dimensions(i); int64 padding_low = wd.padding_low(); int64 padding_high = wd.padding_high(); int64 base = lhs.base_shape().dimensions(lhs_dimension); @@ -597,11 +723,8 @@ PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS( lhs_with_halo = *concat; } - auto conv = b->AddInstruction(HloInstruction::CreateConvolve( - output_base_shape, lhs_with_halo, rhs.hlo(), - original_hlo->feature_group_count(), original_hlo->batch_group_count(), - new_window, original_hlo->convolution_dimension_numbers(), - original_hlo->precision_config())); + TF_ASSIGN_OR_RETURN( + auto conv, create_sharded_conv(lhs_with_halo, rhs.hlo(), b, new_window)); auto ar = lhs.state().collective_ops_creator.create_cross_partition_all_reduce( b, conv, MakeBinaryAdd(output_base_shape.element_type(), module), {}, @@ -616,8 +739,11 @@ PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS( // RHS. StatusOr PartitionConvolutionTiledOutput( PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, - const HloSharding& output_sharding, const Window& conv_window, - HloInstruction* original_hlo, SpmdBuilder* b) { + const HloSharding& output_sharding, + const std::function( + HloInstruction*, HloInstruction*, SpmdBuilder*, + const Window& conv_window)>& create_sharded_conv, + const Window& conv_window, HloInstruction* original_hlo, SpmdBuilder* b) { TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution); const auto& dnums = original_hlo->convolution_dimension_numbers(); TF_RET_CHECK(!output_sharding.IsTileMaximal()); @@ -668,19 +794,13 @@ StatusOr PartitionConvolutionTiledOutput( resharded_operand_and_window->shard_window.dimensions( dnums.input_spatial_dimensions(i)); } + TF_ASSIGN_OR_RETURN( - Shape sharded_conv_shape, - ShapeInference::InferConvolveShape( - resharded_operand_and_window->sharded_input->shape(), - rhs.hlo()->shape(), original_hlo->feature_group_count(), - original_hlo->batch_group_count(), new_window, dnums)); + auto sharded_conv, + create_sharded_conv(resharded_operand_and_window->sharded_input, + rhs.hlo(), b, new_window)); + auto shard_shape = MakePartitionedShape(output_base_shape, output_sharding); - *sharded_conv_shape.mutable_layout() = shard_shape.layout(); - auto sharded_conv = b->AddInstruction(HloInstruction::CreateConvolve( - sharded_conv_shape, resharded_operand_and_window->sharded_input, - rhs.hlo(), original_hlo->feature_group_count(), - original_hlo->batch_group_count(), new_window, dnums, - original_hlo->precision_config())); if (!resharded_operand_and_window->dynamic_slice_index_on_output .has_value()) { CHECK(ShapeUtil::Compatible(shard_shape, sharded_conv->shape())); @@ -692,132 +812,40 @@ StatusOr PartitionConvolutionTiledOutput( shard_shape.dimensions())); } -StatusOr PartitionConvolutionGroupOnParallelDim( - PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, - const HloSharding& output_sharding, const Window& conv_window, - HloInstruction* original_hlo, const ConvolutionDimsMapping& dims_mapping, - int64 num_partitions, const SpmdPartitionerOptions& options, - HloInstruction* partition_id, HloModule* module, SpmdBuilder* b) { - std::vector lhs_dims; - std::vector rhs_dims; - std::vector output_dims; - auto lhs_sharding_dims_adjusted_to_output = - lhs.sharding().IsReplicated() - ? std::vector(lhs.base_shape().rank(), 1) - : lhs.sharding().tile_assignment().dimensions(); - auto rhs_sharding_dims_adjusted_to_output = - rhs.sharding().IsReplicated() - ? std::vector(rhs.base_shape().rank(), 1) - : rhs.sharding().tile_assignment().dimensions(); - auto output_sharding_dims_adjusted_to_lhs = - output_sharding.tile_assignment().dimensions(); - bool lhs_rhs_dims_matching = true; - for (const auto& dim : dims_mapping.parallel_spatial_dims) { - lhs_dims.push_back(dim.lhs); - rhs_dims.push_back(dim.rhs); - output_dims.push_back(dim.output); - if (lhs_sharding_dims_adjusted_to_output[dim.lhs] != - rhs_sharding_dims_adjusted_to_output[dim.rhs]) { - lhs_rhs_dims_matching = false; - } - lhs_sharding_dims_adjusted_to_output[dim.lhs] = - output_sharding.tile_assignment().dim(dim.output); - rhs_sharding_dims_adjusted_to_output[dim.rhs] = - output_sharding.tile_assignment().dim(dim.output); - output_sharding_dims_adjusted_to_lhs[dim.output] = - lhs.sharding().tile_assignment().dim(dim.lhs); - } - auto lhs_grouped = GroupShardingOnDims(lhs.sharding(), lhs_dims); - auto rhs_grouped = GroupShardingOnDims(rhs.sharding(), rhs_dims); - auto output_grouped = GroupShardingOnDims(output_sharding, output_dims); - if (lhs_rhs_dims_matching) { - if (ShapeUtil::ByteSizeOf(lhs.base_shape()) > - ShapeUtil::ByteSizeOf(rhs.base_shape())) { - rhs_grouped = AlignGroupsWith(std::move(rhs_grouped), lhs_grouped); - rhs = rhs.Reshard(UngroupSharding(rhs_grouped)); - } else { - lhs_grouped = AlignGroupsWith(std::move(lhs_grouped), rhs_grouped); - lhs = lhs.Reshard(UngroupSharding(lhs_grouped)); - } - auto reshaped_output_tiling = output_sharding.tile_assignment(); - reshaped_output_tiling.Reshape(output_sharding_dims_adjusted_to_lhs); - output_grouped = AlignGroupsWith( - GroupShardingOnDims(HloSharding::Tile(reshaped_output_tiling), - output_dims), - lhs_grouped); - } else { - auto reshaped_lhs_tiling = lhs.sharding().tile_assignment(); - reshaped_lhs_tiling.Reshape(lhs_sharding_dims_adjusted_to_output); - lhs_grouped = AlignGroupsWith( - GroupShardingOnDims(HloSharding::Tile(reshaped_lhs_tiling), lhs_dims), - output_grouped); - lhs = lhs.Reshard(UngroupSharding(lhs_grouped)); - auto reshaped_rhs_tiling = rhs.sharding().tile_assignment(); - reshaped_rhs_tiling.Reshape(rhs_sharding_dims_adjusted_to_output); - rhs_grouped = AlignGroupsWith( - GroupShardingOnDims(HloSharding::Tile(reshaped_rhs_tiling), rhs_dims), - output_grouped); - rhs = rhs.Reshard(UngroupSharding(rhs_grouped)); - } - - // Update LHS and RHS sharding and shape. - lhs.hlo()->set_sharding(lhs_grouped.sharding); - rhs.hlo()->set_sharding(rhs_grouped.sharding); - CHECK(lhs.hlo() != rhs.hlo() || lhs_grouped.sharding == rhs_grouped.sharding); - auto per_group_partitioner_state = CreatePerGroupPartitioningState( - lhs.state(), lhs_grouped.device_groups, b); - auto grouped_lhs_base_shape = - GetPerGroupBaseShape(lhs_grouped, lhs.base_shape()); - auto grouped_lhs_shard_shape = - MakePartitionedShape(grouped_lhs_base_shape, lhs.sharding()); - // Update convolution window with the new shape - auto new_window = conv_window; - for (const auto& dim : dims_mapping.parallel_spatial_dims) { - auto wd = new_window.mutable_dimensions(dim.spatial); - wd->set_size(grouped_lhs_shard_shape.dimensions(dim.lhs)); - wd->set_stride(std::max(1, wd->size() - 1)); - wd->set_base_dilation(wd->size()); - } - - auto new_partition_id = - lhs.state().collective_ops_creator.create_partition_id(b); - TF_ASSIGN_OR_RETURN( - auto conv, - PartitionConvolution( - PartitionedHlo(lhs.hlo(), grouped_lhs_base_shape, - per_group_partitioner_state), - PartitionedHlo(rhs.hlo(), - GetPerGroupBaseShape(rhs_grouped, rhs.base_shape()), - per_group_partitioner_state), - GetPerGroupBaseShape(output_grouped, output_base_shape), - output_grouped.sharding, new_window, original_hlo, - num_partitions / output_grouped.device_groups.size(), options, - new_partition_id, module, b)); - // Reset the LHS sharding to the ungrouped one. - lhs.hlo()->set_sharding(UngroupSharding(lhs_grouped)); - rhs.hlo()->set_sharding(UngroupSharding(rhs_grouped)); - conv->set_sharding(UngroupSharding(output_grouped)); - return PartitionedHlo(conv, output_base_shape, lhs.state()) - .Reshard(output_sharding) - .hlo(); -} - // Partition convolution with only one kind of dims partitioned. StatusOr PartitionConvolutionBaseCase( PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, - const HloSharding& output_sharding, const Window& conv_window, - HloInstruction* original_hlo, int64 num_partitions, - const SpmdPartitionerOptions& options, HloInstruction* partition_id, - HloModule* module, SpmdBuilder* b) { + const HloSharding& output_sharding, + const std::function( + HloInstruction*, HloInstruction*, SpmdBuilder*, + const Window& conv_window)>& create_sharded_conv, + const Window& conv_window, HloInstruction* original_hlo, + int64 num_partitions, const SpmdPartitionerOptions& options, + HloInstruction* partition_id, HloModule* module, SpmdBuilder* b) { TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution); - // Case 1: Either RHS or LHS is only partitioned at parallel dimensions. - TF_ASSIGN_OR_RETURN(auto parallel_partitioned_conv, - PartitionConvolutionWithParallelDimension( - lhs, rhs, output_base_shape, output_sharding, - conv_window, original_hlo, num_partitions, b)); - if (parallel_partitioned_conv) { - return parallel_partitioned_conv; + // Case 1: Handle depthwise convolution with batch group count or + // feature group count. + if (original_hlo->batch_group_count() > 1) { + TF_ASSIGN_OR_RETURN( + auto parallel_partitioned_conv, + PartitionConvolutionWithBatchGroupCount( + lhs, rhs, output_base_shape, output_sharding, create_sharded_conv, + conv_window, original_hlo, num_partitions, b)); + if (parallel_partitioned_conv) { + return parallel_partitioned_conv; + } + } + + if (original_hlo->feature_group_count() > 1) { + TF_ASSIGN_OR_RETURN( + auto parallel_partitioned_conv, + PartitionConvolutionWithFeatureGroupCount( + lhs, rhs, output_base_shape, output_sharding, create_sharded_conv, + conv_window, original_hlo, num_partitions, b)); + if (parallel_partitioned_conv) { + return parallel_partitioned_conv; + } } // Case 2: both RHS and LHS are tiled. @@ -830,8 +858,8 @@ StatusOr PartitionConvolutionBaseCase( TF_ASSIGN_OR_RETURN( auto partitioned_conv, PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS( - lhs, rhs, output_base_shape, output_sharding, conv_window, - original_hlo, partition_id, module, b)); + lhs, rhs, output_base_shape, output_sharding, create_sharded_conv, + conv_window, original_hlo, partition_id, module, b)); if (partitioned_conv) { return partitioned_conv; } @@ -839,8 +867,8 @@ StatusOr PartitionConvolutionBaseCase( TF_ASSIGN_OR_RETURN( auto partitioned_conv, PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS( - lhs, rhs, output_base_shape, output_sharding, conv_window, - original_hlo, partition_id, module, b)); + lhs, rhs, output_base_shape, output_sharding, create_sharded_conv, + conv_window, original_hlo, partition_id, module, b)); if (partitioned_conv) { return partitioned_conv; @@ -853,7 +881,7 @@ StatusOr PartitionConvolutionBaseCase( TF_ASSIGN_OR_RETURN(auto partitioned_conv, PartitionConvolutionTiledOutput( lhs, rhs, output_base_shape, output_sharding, - conv_window, original_hlo, b)); + create_sharded_conv, conv_window, original_hlo, b)); if (partitioned_conv) { return partitioned_conv; @@ -862,151 +890,156 @@ StatusOr PartitionConvolutionBaseCase( return nullptr; } +StatusOr> CreateShardedConvConvolution( + const HloInstruction& conv, + const dot_as_convolution_util::DotConvolutionDimsInfo& dot_dnums, + HloInstruction* sharded_lhs_hlo, HloInstruction* sharded_rhs_hlo, + const Window& conv_window) { + CHECK_EQ(conv.opcode(), HloOpcode::kConvolution); + const auto& conv_dnums = conv.convolution_dimension_numbers(); + auto window = conv.window(); + for (const auto& dim : dot_dnums.batch_dims) { + auto wd = window.mutable_dimensions(dim.spatial_dim); + wd->set_size(sharded_lhs_hlo->shape().dimensions( + conv_dnums.input_spatial_dimensions(dim.spatial_dim))); + wd->set_stride(std::max(1, wd->size() - 1)); + wd->set_base_dilation(wd->size()); + } + for (const auto& dim : dot_dnums.contracting_dims) { + if (dim.spatial_dim < 0) { + continue; + } + auto wd = window.mutable_dimensions(dim.spatial_dim); + wd->set_size(sharded_lhs_hlo->shape().dimensions( + conv_dnums.input_spatial_dimensions(dim.spatial_dim))); + } + for (const auto& dim : dot_dnums.rhs_non_contracting_dims) { + if (dim.spatial_dim < 0) { + continue; + } + auto wd = window.mutable_dimensions(dim.spatial_dim); + wd->set_size(sharded_rhs_hlo->shape().dimensions( + conv_dnums.kernel_spatial_dimensions(dim.spatial_dim))); + wd->set_padding_high(wd->size() - 1); + wd->set_padding_low(wd->size() - 1); + } + + for (const auto& dim : dot_dnums.conv_spatial_dims) { + auto wd = window.mutable_dimensions(dim.spatial_dim); + const auto& new_window_dimension = conv_window.dimensions(dim.spatial_dim); + wd->set_size(new_window_dimension.size()); + wd->set_padding_high(new_window_dimension.padding_high()); + wd->set_padding_low(new_window_dimension.padding_low()); + } + + int64 feature_group_count = conv.feature_group_count(); + if (feature_group_count > 1) { + feature_group_count = sharded_lhs_hlo->shape().dimensions( + conv_dnums.input_feature_dimension()) / + sharded_rhs_hlo->shape().dimensions( + conv_dnums.kernel_input_feature_dimension()); + } + + int64 batch_group_count = conv.batch_group_count(); + if (batch_group_count > 1) { + batch_group_count = + sharded_lhs_hlo->shape().dimensions(conv_dnums.input_batch_dimension()); + } + + TF_ASSIGN_OR_RETURN( + Shape sharded_conv_shape, + ShapeInference::InferConvolveShape( + sharded_lhs_hlo->shape(), sharded_rhs_hlo->shape(), + feature_group_count, batch_group_count, window, conv_dnums)); + *sharded_conv_shape.mutable_layout() = conv.shape().layout(); + return HloInstruction::CreateConvolve( + sharded_conv_shape, sharded_lhs_hlo, sharded_rhs_hlo, feature_group_count, + batch_group_count, window, conv_dnums, conv.precision_config()); +} + +} // namespace + // Partition convolution. StatusOr PartitionConvolution( PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, - const HloSharding& output_sharding, const Window& conv_window, - HloInstruction* original_hlo, int64 num_partitions, - const SpmdPartitionerOptions& options, HloInstruction* partition_id, - HloModule* module, SpmdBuilder* b) { + const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping, + const std::function( + HloInstruction*, HloInstruction*, SpmdBuilder*, + const Window& conv_window)>& create_sharded_conv, + const Window& conv_window, HloInstruction* original_hlo, + int64 num_partitions, const SpmdPartitionerOptions& options, + HloInstruction* partition_id, HloModule* module, SpmdBuilder* b) { TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution); - TF_ASSIGN_OR_RETURN( - auto try_partitioned_conv, - PartitionConvolutionBaseCase(lhs, rhs, output_base_shape, output_sharding, - conv_window, original_hlo, num_partitions, - options, partition_id, module, b)); + TF_ASSIGN_OR_RETURN(auto try_partitioned_conv, + PartitionConvolutionBaseCase( + lhs, rhs, output_base_shape, output_sharding, + create_sharded_conv, conv_window, original_hlo, + num_partitions, options, partition_id, module, b)); if (try_partitioned_conv) { return try_partitioned_conv; } - const auto& dnums = original_hlo->convolution_dimension_numbers(); - spmd::ConvolutionDimsMapping mapping; - for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { - int64 lhs_dim = dnums.input_spatial_dimensions(i); - int64 lhs_size = lhs.base_shape().dimensions(lhs_dim); - const auto& wd = original_hlo->window().dimensions(i); - int64 rhs_dim = dnums.kernel_spatial_dimensions(i); - int64 output_dim = dnums.output_spatial_dimensions(i); - if (dot_as_convolution_util::ConvSpatialDimensionIsParallel(wd, lhs_size)) { - mapping.parallel_spatial_dims.emplace_back(); - mapping.parallel_spatial_dims.back().lhs = lhs_dim; - mapping.parallel_spatial_dims.back().rhs = rhs_dim; - mapping.parallel_spatial_dims.back().output = output_dim; - mapping.parallel_spatial_dims.back().spatial = i; - } else { - mapping.non_parallel_spatial_dims.emplace_back(); - mapping.non_parallel_spatial_dims.back().lhs = lhs_dim; - mapping.non_parallel_spatial_dims.back().rhs = rhs_dim; - mapping.non_parallel_spatial_dims.back().output = output_dim; - mapping.non_parallel_spatial_dims.back().spatial = i; - } - } - - // lhs_rhs_or_output: 0 lhs, 1 rhs, 2 output. - auto get_partitions_for_dims = - [&](const HloSharding& sharding, - absl::Span dims, - int lhs_rhs_or_output) { - int64 partitions = 1; - if (sharding.IsTileMaximal()) { - return partitions; - } - for (const auto& dim : dims) { - if (lhs_rhs_or_output == 0) { - partitions *= sharding.tile_assignment().dim(dim.lhs); - } else if (lhs_rhs_or_output == 1) { - partitions *= sharding.tile_assignment().dim(dim.rhs); - } else { - CHECK_EQ(lhs_rhs_or_output, 2); - partitions *= sharding.tile_assignment().dim(dim.output); - } - } - return partitions; - }; - - const int64 lhs_parallel_spatial_partitions = - get_partitions_for_dims(lhs.sharding(), mapping.parallel_spatial_dims, 0); - const int64 rhs_parallel_spatial_partitions = - get_partitions_for_dims(rhs.sharding(), mapping.parallel_spatial_dims, 1); - const int64 output_parallel_spatial_partitions = get_partitions_for_dims( - original_hlo->sharding(), mapping.parallel_spatial_dims, 2); - - // Recursively partition on different types of dimensions. - // - // Case 1: Group partitions by parallel spatial dims. - if (lhs_parallel_spatial_partitions == rhs_parallel_spatial_partitions && - lhs_parallel_spatial_partitions == output_parallel_spatial_partitions && - lhs_parallel_spatial_partitions > 1) { - TF_ASSIGN_OR_RETURN(auto try_partitioned_conv, - PartitionConvolutionGroupOnParallelDim( - lhs, rhs, output_base_shape, output_sharding, - conv_window, original_hlo, mapping, num_partitions, - options, partition_id, module, b)); - if (try_partitioned_conv) { - return try_partitioned_conv; - } - } - return nullptr; } -} // namespace - Status SpmdPartitioningVisitor::HandleConvolution(HloInstruction* hlo) { - auto dot_dnums = dot_as_convolution_util::ParseDotGeneralFromConvolution(hlo); - if (dot_dnums) { - // Use HandleDotHelper() for convs that are actually einsums. - spmd::DotGeneralDimsMapping mapping; - for (const auto& dims : dot_dnums->batch_dims) { - mapping.batch_dims.emplace_back(); - mapping.batch_dims.back().lhs = dims.lhs; - mapping.batch_dims.back().rhs = dims.rhs; - mapping.batch_dims.back().output = dims.output; - } - for (const auto& dims : dot_dnums->contracting_dims) { - mapping.contracting_dims.emplace_back(); - mapping.contracting_dims.back().lhs = dims.lhs; - mapping.contracting_dims.back().rhs = dims.rhs; - mapping.contracting_dims.back().output = dims.output; - } - for (const auto& dims : dot_dnums->lhs_non_contracting_dims) { - mapping.lhs_non_contracting_dims.emplace_back(); - mapping.lhs_non_contracting_dims.back().lhs = dims.lhs; - mapping.lhs_non_contracting_dims.back().rhs = dims.rhs; - mapping.lhs_non_contracting_dims.back().output = dims.output; - } - for (const auto& dims : dot_dnums->rhs_non_contracting_dims) { - mapping.rhs_non_contracting_dims.emplace_back(); - mapping.rhs_non_contracting_dims.back().lhs = dims.lhs; - mapping.rhs_non_contracting_dims.back().rhs = dims.rhs; - mapping.rhs_non_contracting_dims.back().output = dims.output; - } - auto create_sharded_conv = - [&](HloInstruction* lhs_hlo, HloInstruction* rhs_hlo, - spmd::SpmdBuilder* b) -> StatusOr { + auto dims_info = dot_as_convolution_util::ParseConvolutionDimsInfo(hlo); + spmd::DotConvDimsMapping mapping; + for (const auto& dims : dims_info.batch_dims) { + mapping.batch_dims.emplace_back(); + mapping.batch_dims.back().lhs = dims.lhs; + mapping.batch_dims.back().rhs = dims.rhs; + mapping.batch_dims.back().output = dims.output; + mapping.batch_dims.back().spatial = dims.spatial_dim; + } + for (const auto& dims : dims_info.contracting_dims) { + mapping.contracting_dims.emplace_back(); + mapping.contracting_dims.back().lhs = dims.lhs; + mapping.contracting_dims.back().rhs = dims.rhs; + mapping.contracting_dims.back().output = dims.output; + mapping.contracting_dims.back().spatial = dims.spatial_dim; + } + for (const auto& dims : dims_info.lhs_non_contracting_dims) { + mapping.lhs_non_contracting_dims.emplace_back(); + mapping.lhs_non_contracting_dims.back().lhs = dims.lhs; + mapping.lhs_non_contracting_dims.back().rhs = dims.rhs; + mapping.lhs_non_contracting_dims.back().output = dims.output; + mapping.lhs_non_contracting_dims.back().spatial = dims.spatial_dim; + } + for (const auto& dims : dims_info.rhs_non_contracting_dims) { + mapping.rhs_non_contracting_dims.emplace_back(); + mapping.rhs_non_contracting_dims.back().lhs = dims.lhs; + mapping.rhs_non_contracting_dims.back().rhs = dims.rhs; + mapping.rhs_non_contracting_dims.back().output = dims.output; + mapping.rhs_non_contracting_dims.back().spatial = dims.spatial_dim; + } + for (const auto& dims : dims_info.conv_spatial_dims) { + mapping.conv_spatial_dims.emplace_back(); + mapping.conv_spatial_dims.back().lhs = dims.lhs; + mapping.conv_spatial_dims.back().rhs = dims.rhs; + mapping.conv_spatial_dims.back().output = dims.output; + mapping.conv_spatial_dims.back().spatial = dims.spatial_dim; + } + auto create_sharded_conv = + [&](HloInstruction* lhs_hlo, HloInstruction* rhs_hlo, + spmd::SpmdBuilder* b, + const Window& conv_window) -> StatusOr { + if (dims_info.conv_spatial_dims.empty()) { TF_ASSIGN_OR_RETURN( auto sharded_conv, dot_as_convolution_util::CreateShardedConvForDotGeneralConvolution( - *hlo, *dot_dnums, lhs_hlo, rhs_hlo)); + *hlo, dims_info, lhs_hlo, rhs_hlo)); return b->AddInstruction(std::move(sharded_conv)); - }; - return HandleDotHelper(hlo, mapping, create_sharded_conv); - } + } else { + TF_ASSIGN_OR_RETURN(auto sharded_conv, + CreateShardedConvConvolution(*hlo, dims_info, lhs_hlo, + rhs_hlo, conv_window)); + return b->AddInstruction(std::move(sharded_conv)); + } + }; - auto lhs = GetPartitionedHlo(hlo->operand(0)); - auto rhs = GetPartitionedHlo(hlo->operand(1)); - TF_ASSIGN_OR_RETURN( - auto partitioned_conv, - PartitionConvolution(lhs, rhs, hlo->shape(), hlo->sharding(), - hlo->window(), hlo, num_partitions_, options_, - partition_id_, module_, &b_)); - - if (partitioned_conv) { - SetPartitionedHlo(hlo, [&] { return partitioned_conv; }); - return Status::OK(); - } - return DefaultAction(hlo); + return HandleDotHelper(hlo, mapping, create_sharded_conv); } } // namespace spmd diff --git a/tensorflow/compiler/xla/service/spmd/convolution_handler.h b/tensorflow/compiler/xla/service/spmd/convolution_handler.h new file mode 100644 index 00000000000..2d929da54e7 --- /dev/null +++ b/tensorflow/compiler/xla/service/spmd/convolution_handler.h @@ -0,0 +1,42 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_CONVOLUTION_HANDLER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_CONVOLUTION_HANDLER_H_ + +#include "tensorflow/compiler/xla/service/dot_as_convolution_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_sharding.h" +#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h" + +namespace xla { +namespace spmd { + +// Partition convolution. +StatusOr PartitionConvolution( + PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, + const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping, + const std::function( + HloInstruction*, HloInstruction*, SpmdBuilder*, + const Window& conv_window)>& create_sharded_conv, + const Window& conv_window, HloInstruction* original_hlo, + int64 num_partitions, const SpmdPartitionerOptions& options, + HloInstruction* partition_id, HloModule* module, SpmdBuilder* b); + +} // namespace spmd +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_CONVOLUTION_HANDLER_H_ diff --git a/tensorflow/compiler/xla/service/spmd/dot_handler.cc b/tensorflow/compiler/xla/service/spmd/dot_handler.cc index da432965497..45bd79bfc75 100644 --- a/tensorflow/compiler/xla/service/spmd/dot_handler.cc +++ b/tensorflow/compiler/xla/service/spmd/dot_handler.cc @@ -19,15 +19,19 @@ limitations under the License. #include "absl/types/optional.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_sharding.h" #include "tensorflow/compiler/xla/service/hlo_sharding_util.h" #include "tensorflow/compiler/xla/service/shape_inference.h" +#include "tensorflow/compiler/xla/service/spmd/convolution_handler.h" #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h" #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/platform/numbers.h" @@ -36,7 +40,7 @@ namespace xla { namespace spmd { Status SpmdPartitioningVisitor::HandleDot(HloInstruction* hlo) { - DotGeneralDimsMapping mapping; + DotConvDimsMapping mapping; const auto& dnums = hlo->dot_dimension_numbers(); int64 next_output_dim = 0; for (int64 i = 0; i < dnums.lhs_batch_dimensions_size(); ++i) { @@ -71,8 +75,9 @@ Status SpmdPartitioningVisitor::HandleDot(HloInstruction* hlo) { mapping.rhs_non_contracting_dims.back().rhs = i; mapping.rhs_non_contracting_dims.back().output = next_output_dim++; } - auto create_sharded_dot = [&](HloInstruction* l, HloInstruction* r, - SpmdBuilder* b) -> StatusOr { + auto create_sharded_dot = + [&](HloInstruction* l, HloInstruction* r, SpmdBuilder* b, + const Window& conv_window) -> StatusOr { TF_ASSIGN_OR_RETURN( auto sharded_dot_shape, ShapeInference::InferDotOpShape(l->shape(), r->shape(), @@ -86,19 +91,32 @@ Status SpmdPartitioningVisitor::HandleDot(HloInstruction* hlo) { namespace { +std::vector GetAllDevicesInOrder(const HloSharding& sharding) { + CHECK(!sharding.IsTileMaximal()); + std::vector results; + results.reserve(sharding.tile_assignment().num_elements()); + sharding.tile_assignment().Each( + [&](absl::Span /* indices */, int64 device) { + results.push_back(device); + }); + return results; +} + StatusOr PartitionBaseCase( PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, - const HloSharding& output_sharding, - const DotGeneralDimsMapping& dims_mapping, int64 num_partitions, + const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping, + int64 num_partitions, const std::function( - HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot, - HloModule* module, HloInstruction* original_hlo, int64 lhs_batch_partitions, - int64 rhs_batch_partitions, int64 output_batch_partitions, - int64 lhs_contracting_partitions, int64 rhs_contracting_partitions, - int64 lhs_non_contracting_partitions, int64 rhs_non_contracting_partitions, + HloInstruction*, HloInstruction*, SpmdBuilder*, + const Window& conv_window)>& create_sharded_dot, + const Window& conv_window, HloModule* module, HloInstruction* original_hlo, + int64 lhs_batch_partitions, int64 rhs_batch_partitions, + int64 output_batch_partitions, int64 lhs_contracting_partitions, + int64 rhs_contracting_partitions, int64 lhs_non_contracting_partitions, + int64 rhs_non_contracting_partitions, int64 output_lhs_non_contracting_partitions, int64 output_rhs_non_contracting_partitions, - int64 threshold_for_windowed_einsum_mib, SpmdBuilder* b, + const SpmdPartitionerOptions& options, SpmdBuilder* b, std::vector* windowed_dot_general_loops, bool may_reshard_without_detecting_match) { @@ -116,7 +134,7 @@ StatusOr PartitionBaseCase( std::vector output_to_lhs_indices(output_base_shape.rank(), -1); std::vector output_to_rhs_indices(output_base_shape.rank(), -1); auto populate_indices_mapping = - [&](const DotGeneralDimsMapping::DimsMapping& mapping) { + [&](const DotConvDimsMapping::DimsMapping& mapping) { if (mapping.lhs >= 0) { lhs_to_rhs_indices[mapping.lhs] = mapping.rhs; lhs_to_output_indices[mapping.lhs] = mapping.output; @@ -142,6 +160,9 @@ StatusOr PartitionBaseCase( for (const auto& mapping : dims_mapping.rhs_non_contracting_dims) { populate_indices_mapping(mapping); } + for (const auto& mapping : dims_mapping.conv_spatial_dims) { + populate_indices_mapping(mapping); + } auto lhs_sharding_transposed_to_match_rhs = hlo_sharding_util::TransposeShardingWithCollapsedDims( lhs_sharding, lhs_to_rhs_indices, rhs_to_lhs_indices); @@ -166,7 +187,8 @@ StatusOr PartitionBaseCase( if (lhs_batch_partitions == rhs_batch_partitions && rhs_batch_partitions == num_partitions && lhs_sharding_transposed_to_match_rhs == rhs_sharding) { - TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b)); + TF_ASSIGN_OR_RETURN( + auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b, conv_window)); dot->set_sharding(*lhs_sharding_transposed_to_match_output); return PartitionedHlo(dot, output_base_shape, lhs.state()) .Reshard(output_sharding) @@ -192,7 +214,8 @@ StatusOr PartitionBaseCase( } auto resharded_rhs = rhs.Reshard(*lhs_sharding_transposed_to_match_rhs); TF_ASSIGN_OR_RETURN( - auto dot, create_sharded_dot(lhs.hlo(), resharded_rhs.hlo(), b)); + auto dot, + create_sharded_dot(lhs.hlo(), resharded_rhs.hlo(), b, conv_window)); return dot; } // RHS and output are batch partitioned in the same way. @@ -208,7 +231,8 @@ StatusOr PartitionBaseCase( } auto resharded_lhs = lhs.Reshard(*rhs_sharding_transposed_to_match_lhs); TF_ASSIGN_OR_RETURN( - auto dot, create_sharded_dot(resharded_lhs.hlo(), rhs.hlo(), b)); + auto dot, + create_sharded_dot(resharded_lhs.hlo(), rhs.hlo(), b, conv_window)); return dot; } return nullptr; @@ -306,8 +330,8 @@ StatusOr PartitionBaseCase( dot_rhs = slice; } } - TF_ASSIGN_OR_RETURN(auto dot, - create_sharded_dot(dot_lhs, dot_rhs, &body_b)); + TF_ASSIGN_OR_RETURN( + auto dot, create_sharded_dot(dot_lhs, dot_rhs, &body_b, conv_window)); if (windowed_at_contracting_dims) { // Accumulate the partial output to the result buffer. o = body_b.AddInstruction( @@ -408,7 +432,7 @@ StatusOr PartitionBaseCase( if (output_lhs_non_contracting_partitions == num_partitions && output_sharding_transposed_to_match_lhs == lhs_sharding && ShapeSizeInBytes(rhs.base_shape()) >= - threshold_for_windowed_einsum_mib * 1024 * 1024) { + options.threshold_for_windowed_einsum_mib * 1024 * 1024) { if (rhs_contracting_partitions == num_partitions) { return emit_windowed_dot_general(0, 1, true, false); } @@ -422,7 +446,7 @@ StatusOr PartitionBaseCase( if (output_rhs_non_contracting_partitions == num_partitions && output_sharding_transposed_to_match_rhs == rhs_sharding && ShapeSizeInBytes(lhs.base_shape()) >= - threshold_for_windowed_einsum_mib * 1024 * 1024) { + options.threshold_for_windowed_einsum_mib * 1024 * 1024) { if (lhs_contracting_partitions == num_partitions) { return emit_windowed_dot_general(1, 0, true, false); } @@ -461,10 +485,12 @@ StatusOr PartitionBaseCase( rhs = rhs.Reshard(*lhs_sharding_transposed_to_match_rhs).PadWithValue(zero); } - TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b)); + TF_ASSIGN_OR_RETURN( + auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b, conv_window)); auto ar = lhs.state().collective_ops_creator.create_cross_partition_all_reduce( - b, dot, MakeBinaryAdd(output_base_shape.element_type(), module), {}, + b, dot, MakeBinaryAdd(output_base_shape.element_type(), module), + {GetAllDevicesInOrder(lhs.sharding())}, (*lhs.state().next_channel_id)++); ar->set_sharding(HloSharding::Replicate()); return PartitionedHlo(ar, output_base_shape, lhs.state()) @@ -477,8 +503,8 @@ StatusOr PartitionBaseCase( output_lhs_non_contracting_partitions == num_partitions && lhs_sharding_transposed_to_match_output == output_sharding) { auto rhs_replicated = rhs.Reshard(HloSharding::Replicate()).hlo(); - TF_ASSIGN_OR_RETURN(auto dot, - create_sharded_dot(lhs.hlo(), rhs_replicated, b)); + TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(lhs.hlo(), rhs_replicated, + b, conv_window)); return dot; } @@ -487,8 +513,8 @@ StatusOr PartitionBaseCase( output_rhs_non_contracting_partitions == num_partitions && rhs_sharding_transposed_to_match_output == output_sharding) { auto lhs_replicated = lhs.Reshard(HloSharding::Replicate()).hlo(); - TF_ASSIGN_OR_RETURN(auto dot, - create_sharded_dot(lhs_replicated, rhs.hlo(), b)); + TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(lhs_replicated, rhs.hlo(), + b, conv_window)); return dot; } @@ -499,8 +525,9 @@ StatusOr PartitionBaseCase( lhs.Reshard(*output_sharding_transposed_to_match_lhs); auto resharded_rhs = rhs.Reshard(*output_sharding_transposed_to_match_rhs); - TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(resharded_lhs.hlo(), - resharded_rhs.hlo(), b)); + TF_ASSIGN_OR_RETURN( + auto dot, create_sharded_dot(resharded_lhs.hlo(), resharded_rhs.hlo(), + b, conv_window)); return dot; } // Output is partitioned along LHS non-contracting dimensions. @@ -509,8 +536,8 @@ StatusOr PartitionBaseCase( lhs.Reshard(*output_sharding_transposed_to_match_lhs); auto replicated_rhs = rhs.Reshard(HloSharding::Replicate()); TF_ASSIGN_OR_RETURN( - auto dot, - create_sharded_dot(resharded_lhs.hlo(), replicated_rhs.hlo(), b)); + auto dot, create_sharded_dot(resharded_lhs.hlo(), + replicated_rhs.hlo(), b, conv_window)); return dot; } // Output is partitioned along RHS non-contracting dimensions. @@ -518,8 +545,9 @@ StatusOr PartitionBaseCase( auto replicated_lhs = lhs.Reshard(HloSharding::Replicate()); auto resharded_rhs = rhs.Reshard(*output_sharding_transposed_to_match_rhs); - TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(replicated_lhs.hlo(), - resharded_rhs.hlo(), b)); + TF_ASSIGN_OR_RETURN( + auto dot, create_sharded_dot(replicated_lhs.hlo(), + resharded_rhs.hlo(), b, conv_window)); return dot; } } @@ -562,9 +590,11 @@ StatusOr PartitionBaseCase( rhs = rhs.Reshard(*lhs_sharding_transposed_to_match_rhs).PadWithValue(zero); } - TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b)); + TF_ASSIGN_OR_RETURN( + auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b, conv_window)); return lhs.state().collective_ops_creator.create_cross_partition_all_reduce( - b, dot, MakeBinaryAdd(output_base_shape.element_type(), module), {}, + b, dot, MakeBinaryAdd(output_base_shape.element_type(), module), + {GetAllDevicesInOrder(lhs.sharding())}, (*lhs.state().next_channel_id)++); } return nullptr; @@ -572,26 +602,28 @@ StatusOr PartitionBaseCase( StatusOr PartitionDot( PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, - const HloSharding& output_sharding, - const DotGeneralDimsMapping& dims_mapping, int64 num_partitions, + const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping, + int64 num_partitions, const std::function( - HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot, - HloModule* module, HloInstruction* original_hlo, - int64 threshold_for_windowed_einsum_mib, SpmdBuilder* b, + HloInstruction*, HloInstruction*, SpmdBuilder*, + const Window& conv_window)>& create_sharded_dot, + const Window& conv_window, HloModule* module, HloInstruction* original_hlo, + const SpmdPartitionerOptions& options, SpmdBuilder* b, std::vector* windowed_dot_general_loops); StatusOr PartitionDotGroupOnBatch( PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, - const HloSharding& output_sharding, - const DotGeneralDimsMapping& dims_mapping, int64 num_partitions, - int64 lhs_contracting_partitions, int64 rhs_contracting_partitions, - int64 lhs_non_contracting_partitions, int64 rhs_non_contracting_partitions, + const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping, + int64 num_partitions, int64 lhs_contracting_partitions, + int64 rhs_contracting_partitions, int64 lhs_non_contracting_partitions, + int64 rhs_non_contracting_partitions, const std::function( - HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot, - HloModule* module, HloInstruction* original_hlo, + HloInstruction*, HloInstruction*, SpmdBuilder*, + const Window& conv_window)>& create_sharded_dot, + const Window& conv_window, HloModule* module, HloInstruction* original_hlo, bool require_matching_devices_to_group, - int64 threshold_for_windowed_einsum_mib, SpmdBuilder* b, + const SpmdPartitionerOptions& options, SpmdBuilder* b, std::vector* windowed_dot_general_loops) { std::vector> @@ -804,9 +836,8 @@ StatusOr PartitionDotGroupOnBatch( GetPerGroupBaseShape(output_grouped, output_base_shape), output_grouped.sharding, dims_mapping, num_partitions / output_grouped.device_groups.size(), - create_sharded_dot, module, original_hlo, - threshold_for_windowed_einsum_mib, b, - windowed_dot_general_loops)); + create_sharded_dot, conv_window, module, original_hlo, + options, b, windowed_dot_general_loops)); dot->set_sharding(UngroupSharding(output_grouped)); return PartitionedHlo(dot, output_base_shape, lhs.state()) .Reshard(output_sharding) @@ -816,17 +847,18 @@ StatusOr PartitionDotGroupOnBatch( StatusOr PartitionDotGroupOnNonContracting( bool lhs_matching, PartitionedHlo matching, PartitionedHlo other, int64 matching_contracting_partitions, int64 other_contracting_partitions, - absl::Span + absl::Span partitioned_non_contractin_dims, int64 other_non_contracting_partitions, int64 output_other_non_contracting_partitions, const Shape& output_base_shape, const HloSharding& output_sharding, - const DotGeneralDimsMapping& dims_mapping, int64 num_partitions, + const DotConvDimsMapping& dims_mapping, int64 num_partitions, const std::function( - HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot, - HloModule* module, HloInstruction* original_hlo, + HloInstruction*, HloInstruction*, SpmdBuilder*, + const Window& conv_window)>& create_sharded_dot, + const Window& conv_window, HloModule* module, HloInstruction* original_hlo, bool require_matching_devices_to_group, - int64 threshold_for_windowed_einsum_mib, SpmdBuilder* b, + const SpmdPartitionerOptions& options, SpmdBuilder* b, std::vector* windowed_dot_general_loops) { std::vector> @@ -921,7 +953,7 @@ StatusOr PartitionDotGroupOnNonContracting( other.sharding(), {other_group_dims[0]}, {other.sharding().tile_assignment().dimensions().back() / group_count}), - output_grouped); + output_grouped, /*ignore_group_order=*/true); other = other.Reshard(UngroupSharding(grouped)); partially_replicated_other = other.hlo(); top_level_sharding_to_reset.emplace_back(other.hlo(), other.sharding()); @@ -949,25 +981,25 @@ StatusOr PartitionDotGroupOnNonContracting( GetPerGroupBaseShape(output_grouped, output_base_shape), output_grouped.sharding, dims_mapping, num_partitions / matching_grouped.device_groups.size(), - create_sharded_dot, module, original_hlo, - threshold_for_windowed_einsum_mib, b, - windowed_dot_general_loops)); + create_sharded_dot, conv_window, module, original_hlo, + options, b, windowed_dot_general_loops)); return dot; } StatusOr PartitionDotGroupOnContracting( PartitionedHlo lhs, PartitionedHlo rhs, - absl::Span + absl::Span partitioned_contractin_dims, int64 output_batch_partitions, int64 output_lhs_non_contracting_partitions, int64 output_rhs_non_contracting_partitions, const Shape& output_base_shape, - const HloSharding& output_sharding, - const DotGeneralDimsMapping& dims_mapping, int64 num_partitions, + const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping, + int64 num_partitions, const std::function( - HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot, - HloModule* module, HloInstruction* original_hlo, + HloInstruction*, HloInstruction*, SpmdBuilder*, + const Window& conv_window)>& create_sharded_dot, + const Window& conv_window, HloModule* module, HloInstruction* original_hlo, bool require_matching_devices_to_group, - int64 threshold_for_windowed_einsum_mib, SpmdBuilder* b, + const SpmdPartitionerOptions& options, SpmdBuilder* b, std::vector* windowed_dot_general_loops) { std::vector> @@ -1043,7 +1075,8 @@ StatusOr PartitionDotGroupOnContracting( {output_sharding.tile_assignment().num_dimensions() - 1}, {output_sharding.tile_assignment().dimensions().back() / group_count}), - lhs_grouped); + lhs_grouped, + /*ignore_group_order=*/true); outer_output_tmp_sharding = UngroupSharding(grouped); inner_output_sharding = std::move(grouped.sharding); } else { @@ -1088,10 +1121,9 @@ StatusOr PartitionDotGroupOnContracting( PartitionedHlo(rhs.hlo(), GetPerGroupBaseShape(rhs_grouped, rhs.base_shape()), inner_state), - MakePartitionedShape(output_base_shape, outer_output_tmp_sharding), - inner_output_sharding, dims_mapping, num_partitions / group_count, - create_sharded_dot, module, original_hlo, - threshold_for_windowed_einsum_mib, b, windowed_dot_general_loops)); + output_base_shape, inner_output_sharding, dims_mapping, + num_partitions / group_count, create_sharded_dot, conv_window, module, + original_hlo, options, b, windowed_dot_general_loops)); if (!dot) { return nullptr; } @@ -1107,31 +1139,73 @@ StatusOr PartitionDotGroupOnContracting( inverse_grouped.device_groups, b) .collective_ops_creator.create_cross_partition_all_reduce( b, dot, MakeBinaryAdd(output_base_shape.element_type(), module), - {}, (*lhs.state().next_channel_id)++); + {GetAllDevicesInOrder(inverse_grouped.sharding)}, + (*lhs.state().next_channel_id)++); ar->set_sharding(outer_output_tmp_sharding); return PartitionedHlo(ar, output_base_shape, lhs.state()) .Reshard(output_sharding) .hlo(); } +DotConvDimsMapping ConvertDimsMappingWithFeatureGroupCount( + const DotConvDimsMapping& dims_mapping, HloInstruction* original_hlo) { + const auto& dnums = original_hlo->convolution_dimension_numbers(); + DotConvDimsMapping new_dims_mapping; + new_dims_mapping.batch_dims = dims_mapping.batch_dims; + new_dims_mapping.conv_spatial_dims = dims_mapping.conv_spatial_dims; + // Append batch dims. + new_dims_mapping.batch_dims.emplace_back(); + new_dims_mapping.batch_dims.back().lhs = dnums.input_feature_dimension(); + new_dims_mapping.batch_dims.back().rhs = + dnums.kernel_output_feature_dimension(); + new_dims_mapping.batch_dims.back().output = dnums.output_feature_dimension(); + new_dims_mapping.batch_dims.back().spatial = -1; + // Setup non contracting dims. + new_dims_mapping.lhs_non_contracting_dims.emplace_back(); + new_dims_mapping.lhs_non_contracting_dims.back().lhs = + dnums.input_batch_dimension(); + new_dims_mapping.rhs_non_contracting_dims.emplace_back(); + new_dims_mapping.rhs_non_contracting_dims.back().rhs = + dnums.kernel_input_feature_dimension(); + return new_dims_mapping; +} + +DotConvDimsMapping ConvertDimsMappingWithBatchGroupCount( + const DotConvDimsMapping& dims_mapping, HloInstruction* original_hlo) { + const auto& dnums = original_hlo->convolution_dimension_numbers(); + DotConvDimsMapping new_dims_mapping; + new_dims_mapping.batch_dims = dims_mapping.batch_dims; + new_dims_mapping.conv_spatial_dims = dims_mapping.conv_spatial_dims; + new_dims_mapping.contracting_dims = dims_mapping.contracting_dims; + // Append batch dims. + new_dims_mapping.batch_dims.emplace_back(); + new_dims_mapping.batch_dims.back().lhs = dnums.input_batch_dimension(); + new_dims_mapping.batch_dims.back().rhs = + dnums.kernel_output_feature_dimension(); + new_dims_mapping.batch_dims.back().output = dnums.output_feature_dimension(); + new_dims_mapping.batch_dims.back().spatial = -1; + return new_dims_mapping; +} + // Recursive partitioning function. If there are partial dimensions matching in // the operands and output, group the devices and recursively partition the // in-group dot. StatusOr PartitionDot( PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, - const HloSharding& output_sharding, - const DotGeneralDimsMapping& dims_mapping, int64 num_partitions, + const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping, + int64 num_partitions, const std::function( - HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot, - HloModule* module, HloInstruction* original_hlo, + HloInstruction*, HloInstruction*, SpmdBuilder*, + const Window& conv_window)>& create_sharded_dot, + const Window& conv_window, HloModule* module, HloInstruction* original_hlo, bool require_matching_devices_to_group, - int64 threshold_for_windowed_einsum_mib, SpmdBuilder* b, + const SpmdPartitionerOptions& options, SpmdBuilder* b, std::vector* windowed_dot_general_loops) { // lhs_rhs_or_output: 0 lhs, 1 rhs, 2 output. auto get_partitions_for_dims = [&](const HloSharding& sharding, - absl::Span dims, + absl::Span dims, int lhs_rhs_or_output) { int64 partitions = 1; if (sharding.IsTileMaximal()) { @@ -1167,19 +1241,112 @@ StatusOr PartitionDot( output_sharding, dims_mapping.lhs_non_contracting_dims, 2); const int64 output_rhs_non_contracting_partitions = get_partitions_for_dims( output_sharding, dims_mapping.rhs_non_contracting_dims, 2); + const int64 lhs_conv_spatial_partitions = get_partitions_for_dims( + lhs.sharding(), dims_mapping.conv_spatial_dims, 0); + const int64 rhs_conv_spatial_partitions = get_partitions_for_dims( + rhs.sharding(), dims_mapping.conv_spatial_dims, 1); + const int64 output_conv_spatial_partitions = get_partitions_for_dims( + output_sharding, dims_mapping.conv_spatial_dims, 2); // Before we find partial matches along the dimensions, invoke base case again // without may_reshard_without_detecting_match. + + // Try partition the purely spatially-partitioned convolution with convolution + // spatial dimension partitioned or depthwise parallel dimension partitioned. + bool is_conv_spatial_dim_partitioned = + (lhs_conv_spatial_partitions > 1 || rhs_conv_spatial_partitions > 1 || + output_conv_spatial_partitions > 1); + bool is_conv_batch_or_contracting_dim_partitioned = + (lhs_batch_partitions > 1 || rhs_batch_partitions > 1 || + output_batch_partitions > 1 || + (lhs_contracting_partitions > 1 && rhs_contracting_partitions > 1)); + if ((!dims_mapping.conv_spatial_dims.empty() && + is_conv_spatial_dim_partitioned && + !is_conv_batch_or_contracting_dim_partitioned) || + (original_hlo->opcode() == HloOpcode::kConvolution && + (original_hlo->batch_group_count() > 1 || + original_hlo->feature_group_count() > 1))) { + // Partition with kernel_input_feature_dim > 1 and feature_group_count > 1 + // is not supported. + const auto& dnums = original_hlo->convolution_dimension_numbers(); + if (original_hlo->feature_group_count() > 1 && + rhs.hlo()->shape().dimensions(dnums.kernel_input_feature_dimension()) > + 1) { + return nullptr; + } + + TF_ASSIGN_OR_RETURN( + auto partitioned_conv, + PartitionConvolution(lhs, rhs, output_base_shape, output_sharding, + dims_mapping, create_sharded_dot, conv_window, + original_hlo, num_partitions, options, + lhs.state().partition_id, module, b)); + + if (partitioned_conv) { + return partitioned_conv; + } + + // Recursively partition on different types of dimensions for convolution. + // Case 0.a: Group partitions by feature group count. + if (original_hlo->feature_group_count() > 1 || + original_hlo->batch_group_count() > 1) { + DotConvDimsMapping new_dims_mapping; + if (original_hlo->feature_group_count() > 1) { + new_dims_mapping = + ConvertDimsMappingWithFeatureGroupCount(dims_mapping, original_hlo); + } + + if (original_hlo->batch_group_count() > 1) { + new_dims_mapping = + ConvertDimsMappingWithBatchGroupCount(dims_mapping, original_hlo); + } + + const int64 conv_lhs_contracting_partitions = get_partitions_for_dims( + lhs.sharding(), new_dims_mapping.contracting_dims, 0); + const int64 conv_rhs_contracting_partitions = get_partitions_for_dims( + rhs.sharding(), new_dims_mapping.contracting_dims, 1); + const int64 conv_lhs_non_contracting_partitions = get_partitions_for_dims( + lhs.sharding(), new_dims_mapping.lhs_non_contracting_dims, 0); + const int64 conv_rhs_non_contracting_partitions = get_partitions_for_dims( + rhs.sharding(), new_dims_mapping.rhs_non_contracting_dims, 1); + const int64 conv_lhs_batch_partitions = get_partitions_for_dims( + lhs.sharding(), new_dims_mapping.batch_dims, 0); + const int64 conv_rhs_batch_partitions = get_partitions_for_dims( + rhs.sharding(), new_dims_mapping.batch_dims, 1); + const int64 conv_output_batch_partitions = get_partitions_for_dims( + output_sharding, new_dims_mapping.batch_dims, 2); + if ((conv_lhs_batch_partitions == conv_output_batch_partitions || + conv_rhs_batch_partitions == conv_output_batch_partitions) && + conv_output_batch_partitions > 1) { + TF_ASSIGN_OR_RETURN( + auto try_partitioned_conv, + PartitionDotGroupOnBatch( + lhs, rhs, output_base_shape, output_sharding, new_dims_mapping, + num_partitions, conv_lhs_contracting_partitions, + conv_rhs_contracting_partitions, + conv_lhs_non_contracting_partitions, + conv_rhs_non_contracting_partitions, create_sharded_dot, + conv_window, module, original_hlo, + require_matching_devices_to_group, options, b, + windowed_dot_general_loops)); + if (try_partitioned_conv) { + return try_partitioned_conv; + } + } + return nullptr; + } + } + TF_ASSIGN_OR_RETURN( auto try_partitioned_dot, PartitionBaseCase( lhs, rhs, output_base_shape, output_sharding, dims_mapping, - num_partitions, create_sharded_dot, module, original_hlo, + num_partitions, create_sharded_dot, conv_window, module, original_hlo, lhs_batch_partitions, rhs_batch_partitions, output_batch_partitions, lhs_contracting_partitions, rhs_contracting_partitions, lhs_non_contracting_partitions, rhs_non_contracting_partitions, output_lhs_non_contracting_partitions, - output_rhs_non_contracting_partitions, - threshold_for_windowed_einsum_mib, b, windowed_dot_general_loops, + output_rhs_non_contracting_partitions, options, b, + windowed_dot_general_loops, /*may_reshard_without_detecting_match=*/false)); if (try_partitioned_dot) { return try_partitioned_dot; @@ -1197,9 +1364,9 @@ StatusOr PartitionDot( lhs, rhs, output_base_shape, output_sharding, dims_mapping, num_partitions, lhs_contracting_partitions, rhs_contracting_partitions, lhs_non_contracting_partitions, - rhs_non_contracting_partitions, create_sharded_dot, module, - original_hlo, require_matching_devices_to_group, - threshold_for_windowed_einsum_mib, b, windowed_dot_general_loops)); + rhs_non_contracting_partitions, create_sharded_dot, conv_window, + module, original_hlo, require_matching_devices_to_group, options, b, + windowed_dot_general_loops)); if (dot) { return dot; } @@ -1222,7 +1389,6 @@ StatusOr PartitionDot( ShapeUtil::ByteSizeOf(rhs.hlo()->shape()) <= rhs_non_contracting_partitions * ShapeUtil::ByteSizeOf(lhs.hlo()->shape())); - TF_ASSIGN_OR_RETURN( auto dot, PartitionDotGroupOnNonContracting( @@ -1238,9 +1404,9 @@ StatusOr PartitionDot( lhs_matching ? output_rhs_non_contracting_partitions : output_lhs_non_contracting_partitions, output_base_shape, output_sharding, dims_mapping, num_partitions, - create_sharded_dot, module, original_hlo, - require_matching_devices_to_group, - threshold_for_windowed_einsum_mib, b, windowed_dot_general_loops)); + create_sharded_dot, conv_window, module, original_hlo, + require_matching_devices_to_group, options, b, + windowed_dot_general_loops)); if (dot) { return dot; } @@ -1248,7 +1414,7 @@ StatusOr PartitionDot( if (lhs_non_contracting_partitions > 1 && output_lhs_non_contracting_partitions > 1) { // If part of LHS non-contracting dims match output, try them. - std::vector matching_dims; + std::vector matching_dims; for (const auto& dim : dims_mapping.lhs_non_contracting_dims) { int64 lhs_partitions = lhs.sharding().tile_assignment().dim(dim.lhs); if (lhs_partitions > 1 && @@ -1258,16 +1424,15 @@ StatusOr PartitionDot( } if (!matching_dims.empty()) { TF_ASSIGN_OR_RETURN( - auto dot, - PartitionDotGroupOnNonContracting( - /*lhs_matching=*/true, lhs, rhs, lhs_contracting_partitions, - rhs_contracting_partitions, matching_dims, - rhs_non_contracting_partitions, - output_rhs_non_contracting_partitions, output_base_shape, - output_sharding, dims_mapping, num_partitions, create_sharded_dot, - module, original_hlo, require_matching_devices_to_group, - threshold_for_windowed_einsum_mib, b, - windowed_dot_general_loops)); + auto dot, PartitionDotGroupOnNonContracting( + /*lhs_matching=*/true, lhs, rhs, + lhs_contracting_partitions, rhs_contracting_partitions, + matching_dims, rhs_non_contracting_partitions, + output_rhs_non_contracting_partitions, + output_base_shape, output_sharding, dims_mapping, + num_partitions, create_sharded_dot, conv_window, module, + original_hlo, require_matching_devices_to_group, + options, b, windowed_dot_general_loops)); if (dot) { return dot; } @@ -1276,7 +1441,7 @@ StatusOr PartitionDot( if (rhs_non_contracting_partitions > 1 && output_rhs_non_contracting_partitions > 1) { // If part of RHS non-contracting dims match output, try them. - std::vector matching_dims; + std::vector matching_dims; for (const auto& dim : dims_mapping.rhs_non_contracting_dims) { int64 rhs_partitions = rhs.sharding().tile_assignment().dim(dim.rhs); if (rhs_partitions > 1 && @@ -1286,16 +1451,15 @@ StatusOr PartitionDot( } if (!matching_dims.empty()) { TF_ASSIGN_OR_RETURN( - auto dot, - PartitionDotGroupOnNonContracting( - /*lhs_matching=*/false, rhs, lhs, rhs_contracting_partitions, - lhs_contracting_partitions, matching_dims, - lhs_non_contracting_partitions, - output_lhs_non_contracting_partitions, output_base_shape, - output_sharding, dims_mapping, num_partitions, create_sharded_dot, - module, original_hlo, require_matching_devices_to_group, - threshold_for_windowed_einsum_mib, b, - windowed_dot_general_loops)); + auto dot, PartitionDotGroupOnNonContracting( + /*lhs_matching=*/false, rhs, lhs, + rhs_contracting_partitions, lhs_contracting_partitions, + matching_dims, lhs_non_contracting_partitions, + output_lhs_non_contracting_partitions, + output_base_shape, output_sharding, dims_mapping, + num_partitions, create_sharded_dot, conv_window, module, + original_hlo, require_matching_devices_to_group, + options, b, windowed_dot_general_loops)); if (dot) { return dot; } @@ -1312,15 +1476,16 @@ StatusOr PartitionDot( output_lhs_non_contracting_partitions, output_rhs_non_contracting_partitions, output_base_shape, output_sharding, dims_mapping, num_partitions, create_sharded_dot, - module, original_hlo, require_matching_devices_to_group, - threshold_for_windowed_einsum_mib, b, windowed_dot_general_loops)); + conv_window, module, original_hlo, + require_matching_devices_to_group, options, b, + windowed_dot_general_loops)); if (dot) { return dot; } } if (lhs_contracting_partitions > 1 && rhs_contracting_partitions > 1) { // If part of contracting dims match, try them. - std::vector matching_dims; + std::vector matching_dims; for (const auto& dim : dims_mapping.contracting_dims) { int64 lhs_partitions = lhs.sharding().tile_assignment().dim(dim.lhs); if (lhs_partitions > 1 && @@ -1330,15 +1495,14 @@ StatusOr PartitionDot( } if (!matching_dims.empty()) { TF_ASSIGN_OR_RETURN( - auto dot, - PartitionDotGroupOnContracting( - lhs, rhs, matching_dims, output_batch_partitions, - output_lhs_non_contracting_partitions, - output_rhs_non_contracting_partitions, output_base_shape, - output_sharding, dims_mapping, num_partitions, create_sharded_dot, - module, original_hlo, require_matching_devices_to_group, - threshold_for_windowed_einsum_mib, b, - windowed_dot_general_loops)); + auto dot, PartitionDotGroupOnContracting( + lhs, rhs, matching_dims, output_batch_partitions, + output_lhs_non_contracting_partitions, + output_rhs_non_contracting_partitions, + output_base_shape, output_sharding, dims_mapping, + num_partitions, create_sharded_dot, conv_window, module, + original_hlo, require_matching_devices_to_group, + options, b, windowed_dot_general_loops)); if (dot) { return dot; } @@ -1358,8 +1522,8 @@ StatusOr PartitionDot( PartitionDot(PartitionedHlo(lhs.hlo(), lhs.base_shape(), inner_state), PartitionedHlo(rhs.hlo(), rhs.base_shape(), inner_state), output_base_shape, grouped_output.sharding, dims_mapping, - output_sharding.NumTiles(), create_sharded_dot, module, - original_hlo, threshold_for_windowed_einsum_mib, b, + output_sharding.NumTiles(), create_sharded_dot, + conv_window, module, original_hlo, options, b, windowed_dot_general_loops)); if (dot) { return dot; @@ -1372,13 +1536,13 @@ StatusOr PartitionDot( auto dot, PartitionBaseCase( lhs, rhs, output_base_shape, output_sharding, dims_mapping, - num_partitions, create_sharded_dot, module, original_hlo, + num_partitions, create_sharded_dot, conv_window, module, original_hlo, lhs_batch_partitions, rhs_batch_partitions, output_batch_partitions, lhs_contracting_partitions, rhs_contracting_partitions, lhs_non_contracting_partitions, rhs_non_contracting_partitions, output_lhs_non_contracting_partitions, - output_rhs_non_contracting_partitions, - threshold_for_windowed_einsum_mib, b, windowed_dot_general_loops, + output_rhs_non_contracting_partitions, options, b, + windowed_dot_general_loops, /*may_reshard_without_detecting_match=*/true)); if (dot) { return dot; @@ -1388,12 +1552,13 @@ StatusOr PartitionDot( StatusOr PartitionDot( PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, - const HloSharding& output_sharding, - const DotGeneralDimsMapping& dims_mapping, int64 num_partitions, + const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping, + int64 num_partitions, const std::function( - HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot, - HloModule* module, HloInstruction* original_hlo, - int64 threshold_for_windowed_einsum_mib, SpmdBuilder* b, + HloInstruction*, HloInstruction*, SpmdBuilder*, + const Window& conv_window)>& create_sharded_dot, + const Window& conv_window, HloModule* module, HloInstruction* original_hlo, + const SpmdPartitionerOptions& options, SpmdBuilder* b, std::vector* windowed_dot_general_loops) { // First try partitioning without resharding the groups, then try allow @@ -1402,18 +1567,18 @@ StatusOr PartitionDot( TF_ASSIGN_OR_RETURN( auto try_partition, PartitionDot(lhs, rhs, output_base_shape, output_sharding, dims_mapping, - num_partitions, create_sharded_dot, module, original_hlo, - require_matching_devices_to_group, - threshold_for_windowed_einsum_mib, b, - windowed_dot_general_loops)); + num_partitions, create_sharded_dot, conv_window, module, + original_hlo, require_matching_devices_to_group, options, + b, windowed_dot_general_loops)); if (try_partition) { return try_partition; } } // Default action. - TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(lhs.Replicate().hlo(), - rhs.Replicate().hlo(), b)); + TF_ASSIGN_OR_RETURN( + auto dot, create_sharded_dot(lhs.Replicate().hlo(), rhs.Replicate().hlo(), + b, conv_window)); dot->set_sharding(HloSharding::Replicate()); return PartitionedHlo(dot, output_base_shape, lhs.state()) .Reshard(output_sharding) @@ -1423,17 +1588,22 @@ StatusOr PartitionDot( } // namespace Status SpmdPartitioningVisitor::HandleDotHelper( - HloInstruction* hlo, const DotGeneralDimsMapping& dims_mapping, + HloInstruction* hlo, const DotConvDimsMapping& dims_mapping, const std::function( - HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot) { + HloInstruction*, HloInstruction*, SpmdBuilder*, + const Window& conv_window)>& create_sharded_dot) { auto& lhs = GetPartitionedHlo(hlo->operand(0)); auto& rhs = GetPartitionedHlo(hlo->operand(1)); + Window conv_window; + if (hlo->opcode() == HloOpcode::kConvolution) { + conv_window = hlo->window(); + } + TF_ASSIGN_OR_RETURN( auto partitioned_dot, PartitionDot(lhs, rhs, hlo->shape(), hlo->sharding(), dims_mapping, - num_partitions_, create_sharded_dot, module_, hlo, - options_.threshold_for_windowed_einsum_mib, &b_, - &windowed_dot_general_loops_)); + num_partitions_, create_sharded_dot, conv_window, module_, + hlo, options_, &b_, &windowed_dot_general_loops_)); SetPartitionedHlo(hlo, [&] { return partitioned_dot; }); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/spmd/schedule_aware_all_gather_cse.cc b/tensorflow/compiler/xla/service/spmd/schedule_aware_all_gather_cse.cc index cc97d5ebda7..bdc96afba88 100644 --- a/tensorflow/compiler/xla/service/spmd/schedule_aware_all_gather_cse.cc +++ b/tensorflow/compiler/xla/service/spmd/schedule_aware_all_gather_cse.cc @@ -88,7 +88,7 @@ StatusOr RunOnComputation(HloComputation* comp, bool for_replicas, auto& earlier_ags = operand_to_ag[ag->operand(0)]; bool found = false; - int64 lowest_user_h = lowest_user_height(ag); + int64 ag_height = height[ag]; for (auto& eag : earlier_ags) { auto old_channel_id = ag->channel_id(); if (eag->channel_id() && ag->channel_id()) { @@ -100,7 +100,7 @@ StatusOr RunOnComputation(HloComputation* comp, bool for_replicas, } found = true; ag->set_channel_id(old_channel_id); - if (lowest_user_height(eag) > lowest_user_h + distance_threshold) { + if (lowest_user_height(eag) > ag_height + distance_threshold) { eag = ag; continue; } diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc index f16b7bacda3..ceb81330639 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/types/optional.h" @@ -216,6 +217,125 @@ HloInstruction* SpmdBuilder::AddInstruction( if (visiting_hlo_) { instructions_[visiting_hlo_].push_back(hlo); } + if (hlo->opcode() == HloOpcode::kBroadcast) { + for (int64 i = 0; i < hlo->shape().rank(); ++i) { + if (!absl::c_linear_search(hlo->dimensions(), i)) { + broadcast_dims_[hlo].insert(i); + } + } + } + if (hlo->IsElementwise() && hlo->operand_count() > 0) { + absl::flat_hash_set broadcast_dims; + for (int64 i = 0; i < hlo->shape().rank(); ++i) { + broadcast_dims.insert(i); + } + for (int64 i = 0; i < hlo->operand_count(); ++i) { + auto it = broadcast_dims_.find(hlo->operand(i)); + if (it == broadcast_dims_.end()) { + broadcast_dims.clear(); + break; + } + for (int64 i = 0; i < hlo->shape().rank(); ++i) { + if (!it->second.contains(i)) { + broadcast_dims.erase(i); + } + } + } + if (!broadcast_dims.empty()) { + broadcast_dims_[hlo] = std::move(broadcast_dims); + } + } + if (hlo->opcode() == HloOpcode::kTranspose) { + auto it = broadcast_dims_.find(hlo->operand(0)); + if (it != broadcast_dims_.end()) { + absl::flat_hash_set xpose_broadcast_dims; + std::vector reverse_map(hlo->shape().rank()); + for (int64 i = 0; i < reverse_map.size(); ++i) { + reverse_map[hlo->dimensions(i)] = i; + } + for (int64 dim : it->second) { + xpose_broadcast_dims.insert(reverse_map[dim]); + } + broadcast_dims_[hlo] = std::move(xpose_broadcast_dims); + } + } + if (hlo->opcode() == HloOpcode::kReshape && + Product(hlo->shape().dimensions()) > 0) { + auto it = broadcast_dims_.find(hlo->operand(0)); + if (it != broadcast_dims_.end()) { + absl::flat_hash_set reshape_broadcast_dims; + for (int64 i = 0; i < hlo->shape().rank(); ++i) { + reshape_broadcast_dims.insert(i); + } + std::vector before_dim_size_stack; + std::vector after_dim_size_stack; + for (int64 i = hlo->operand(0)->shape().rank() - 1; i >= 0; --i) { + before_dim_size_stack.push_back(hlo->operand(0)->shape().dimensions(i)); + } + for (int64 i = hlo->shape().rank() - 1; i >= 0; --i) { + after_dim_size_stack.push_back(hlo->shape().dimensions(i)); + } + while (!before_dim_size_stack.empty() && !after_dim_size_stack.empty()) { + int64 before_size = before_dim_size_stack.back(); + int64 after_size = after_dim_size_stack.back(); + int64 current_before_dim = + hlo->operand(0)->shape().rank() - before_dim_size_stack.size(); + int64 current_after_dim = + hlo->shape().rank() - after_dim_size_stack.size(); + before_dim_size_stack.pop_back(); + after_dim_size_stack.pop_back(); + if (!it->second.contains(current_before_dim)) { + reshape_broadcast_dims.erase(current_after_dim); + } + if (before_size == after_size) { + continue; + } + if (before_size % after_size == 0) { + // Split dim. + before_dim_size_stack.push_back(before_size / after_size); + } else if (after_size % before_size == 0) { + // Merge dim. + after_dim_size_stack.push_back(after_size / before_size); + } else { + // Other cases, mark all remaining dims as non-broadcast. + for (int64 i = current_after_dim; i < hlo->shape().rank(); ++i) { + reshape_broadcast_dims.erase(i); + } + break; + } + } + if (!before_dim_size_stack.empty() || !after_dim_size_stack.empty()) { + reshape_broadcast_dims.clear(); + } + if (!reshape_broadcast_dims.empty()) { + broadcast_dims_[hlo] = std::move(reshape_broadcast_dims); + } + } + } + if (hlo->opcode() == HloOpcode::kSlice || + hlo->opcode() == HloOpcode::kDynamicSlice) { + auto it = broadcast_dims_.find(hlo->operand(0)); + if (it != broadcast_dims_.end()) { + auto dims = it->second; + broadcast_dims_[hlo] = std::move(dims); + } + } + if (hlo->opcode() == HloOpcode::kPad) { + auto it = broadcast_dims_.find(hlo->operand(0)); + if (it != broadcast_dims_.end()) { + absl::flat_hash_set pad_broadcast_dims; + for (int64 i = 0; i < hlo->shape().rank(); ++i) { + const auto& dim = hlo->padding_config().dimensions(i); + if (dim.edge_padding_low() == 0 && dim.edge_padding_high() == 0 && + dim.interior_padding() == 0 && it->second.contains(i)) { + pad_broadcast_dims.insert(i); + } + } + if (!pad_broadcast_dims.empty()) { + broadcast_dims_[hlo] = std::move(pad_broadcast_dims); + } + } + } return hlo; } @@ -1099,23 +1219,25 @@ PartitionedHlo PartitionedHlo::ReshardWithCollectivePermute( const HloSharding& target) const { CHECK(CanReshardWithCollectivePermute(sharding(), target)) << sharding().ToString() << " to " << target.ToString(); - if (hlo()->opcode() == HloOpcode::kBroadcast) { - // If hlo() is a broadcast, check if data is already the same between - // source/destination pairs. - std::vector new_dims; - for (int64 i = 0; i < hlo()->shape().rank(); ++i) { - if (!absl::c_linear_search(hlo()->dimensions(), i)) { - new_dims.push_back(i); + if (auto broadcast_dims = state_.b->BroadcastDimsForCreatedHlo(hlo())) { + if (!(*broadcast_dims)->empty()) { + // If hlo() has broadcast dims, check if data is already the same between + // source/destination pairs. + std::vector broadcast_dims_vector; + for (int64 i = 0; i < hlo()->shape().rank(); ++i) { + if ((*broadcast_dims)->contains(i)) { + broadcast_dims_vector.push_back(i); + } + } + if (hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( + sharding(), broadcast_dims_vector) == + hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( + target, broadcast_dims_vector)) { + auto copy = state_.b->AddInstruction(HloInstruction::CreateUnary( + hlo()->shape(), HloOpcode::kCopy, hlo())); + copy->set_sharding(target); + return PartitionedHlo(copy, base_shape_, state_); } - } - if (hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(sharding(), - new_dims) == - hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(target, - new_dims)) { - auto copy = state_.b->AddInstruction( - HloInstruction::CreateUnary(hlo()->shape(), HloOpcode::kCopy, hlo())); - copy->set_sharding(target); - return PartitionedHlo(copy, base_shape_, state_); } } std::vector> src_dst_pairs; @@ -1289,7 +1411,7 @@ namespace { // gather/scatter slice size 1. bool GatherScatterOperandPartitionedOnlyOnTrivialSliceDims( const PartitionedHlo& operand, absl::Span index_map, - absl::Span slice_size, int64 num_partitions) { + absl::Span slice_size) { if (operand.sharding().IsTileMaximal()) { return false; } @@ -1300,7 +1422,7 @@ bool GatherScatterOperandPartitionedOnlyOnTrivialSliceDims( operand.sharding().tile_assignment().dim(dim); } } - return trivial_slice_dims_partitions == num_partitions; + return trivial_slice_dims_partitions == operand.sharding().NumTiles(); } // Returns the min and max for the indices (replicated) in a scatter/gather @@ -1451,10 +1573,23 @@ Status SpmdPartitioningVisitor::HandleScatter(HloInstruction* hlo) { update_dim_to_index_dim); CHECK(new_updates_sharding.has_value()); updates = updates.Reshard(*new_updates_sharding); + // Update collective_ops_creator and partition_id for partial replicate. + auto collective_ops_creator = collective_ops_creator_; + auto partition_id = partition_id_; + if (indices.sharding().ReplicateOnLastTileDim()) { + auto sharding_grouped = GroupShardingOnDims( + indices.sharding(), + {indices.sharding().tile_assignment().num_dimensions() - 1}); + auto per_group_partitioner_state = CreatePerGroupPartitioningState( + indices.state(), sharding_grouped.device_groups, &b_); + collective_ops_creator = + per_group_partitioner_state.collective_ops_creator; + partition_id = per_group_partitioner_state.partition_id; + } // To avoid accumulating the initial operand multiple times during // all-reduce, we use identity operands for all non-zero partitions. auto not_partition_zero = b_.AddInstruction(HloInstruction::CreateConvert( - ShapeUtil::MakeScalarShape(PRED), partition_id_)); + ShapeUtil::MakeScalarShape(PRED), partition_id)); not_partition_zero = b_.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::ChangeElementType(identity->shape(), PRED), not_partition_zero, {})); @@ -1465,7 +1600,7 @@ Status SpmdPartitioningVisitor::HandleScatter(HloInstruction* hlo) { auto pscatter = b_.AddInstruction(scatter->CloneWithNewOperands( scatter->shape(), {select_operand, indices.hlo(), updates.hlo()})); auto all_reduce = - collective_ops_creator_.create_cross_partition_all_reduce( + collective_ops_creator.create_cross_partition_all_reduce( &b_, pscatter, scatter->to_apply(), {}, NewChannel()); all_reduce->set_sharding(HloSharding::Replicate()); SetPartitionedHlo(hlo, [&]() { @@ -1495,8 +1630,7 @@ Status SpmdPartitioningVisitor::HandleScatter(HloInstruction* hlo) { return Status::OK(); } if (GatherScatterOperandPartitionedOnlyOnTrivialSliceDims( - operand, scatter_dims_to_operand_dims, slice_size, - num_partitions_) && + operand, scatter_dims_to_operand_dims, slice_size) && ShapeSizeInBytes(updates.base_shape()) < ShapeSizeInBytes(scatter->shape())) { // Operand is sharded on trivial slice dims (update slice size 1). We can @@ -2371,8 +2505,7 @@ Status SpmdPartitioningVisitor::HandleGather(HloInstruction* hlo) { return Status::OK(); } if (GatherScatterOperandPartitionedOnlyOnTrivialSliceDims( - operand, start_index_map, gather->gather_slice_sizes(), - num_partitions_) && + operand, start_index_map, gather->gather_slice_sizes()) && ShapeSizeInBytes(gather->shape()) < ShapeSizeInBytes(gather->operand(0)->shape())) { indices = indices.Reshard(HloSharding::Replicate()); @@ -2434,7 +2567,17 @@ Status SpmdPartitioningVisitor::HandleGather(HloInstruction* hlo) { pgather->shape(), HloOpcode::kSelect, broadcast_filter, CreateZero(pgather->shape(), &b_), pgather)); // Combine from different partitions. - auto ar = collective_ops_creator_.create_cross_partition_all_reduce( + auto collective_ops_creator = collective_ops_creator_; + if (operand.sharding().ReplicateOnLastTileDim()) { + auto sharding_grouped = GroupShardingOnDims( + operand.sharding(), + {operand.sharding().tile_assignment().num_dimensions() - 1}); + auto per_group_partitioner_state = CreatePerGroupPartitioningState( + operand.state(), sharding_grouped.device_groups, &b_); + collective_ops_creator = + per_group_partitioner_state.collective_ops_creator; + } + auto ar = collective_ops_creator.create_cross_partition_all_reduce( &b_, filtered, MakeBinaryAdd(filtered->shape().element_type(), module_), {}, NewChannel()); @@ -2874,18 +3017,37 @@ Status SpmdPartitioningVisitor::HandleRng(HloInstruction* hlo) { } TF_RET_CHECK(!hlo->sharding().IsTileMaximal()); - SetPartitionedHlo(hlo, [&] { - // Replicate the operands and run partitioned Rng on all devices. - std::vector new_operands; - for (int64 i = 0; i < hlo->operand_count(); ++i) { - new_operands.push_back(GetPartitionedHlo(hlo->operand(i)) - .Reshard(HloSharding::Replicate()) - .hlo()); - } - return b_.AddInstruction(HloInstruction::CreateRng( + // Replicate the operands and run partitioned Rng on all devices. + std::vector new_operands; + for (int64 i = 0; i < hlo->operand_count(); ++i) { + new_operands.push_back(GetPartitionedHlo(hlo->operand(i)) + .Reshard(HloSharding::Replicate()) + .hlo()); + } + + if (!hlo->sharding().ReplicateOnLastTileDim()) { + SetPartitionedHlo(hlo, [&] { + return b_.AddInstruction(HloInstruction::CreateRng( + MakePartitionedShape(hlo->shape(), hlo->sharding()), + hlo->random_distribution(), new_operands)); + }); + } else { + std::vector group_dims( + hlo->sharding().tile_assignment().num_dimensions() - 1); + std::iota(group_dims.begin(), group_dims.end(), 0); + auto sharding_grouped = GroupShardingOnDims(hlo->sharding(), group_dims); + auto per_group_state = CreatePerGroupPartitioningState( + MakePartitioningState(), sharding_grouped.device_groups, &b_); + auto rng = b_.AddInstruction(HloInstruction::CreateRng( MakePartitionedShape(hlo->shape(), hlo->sharding()), hlo->random_distribution(), new_operands)); - }); + rng->set_sharding(HloSharding::AssignDevice(0)); + SetPartitionedHlo(hlo, [&]() { + return PartitionedHlo(rng, rng->shape(), per_group_state) + .Replicate() + .hlo(); + }); + } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h index 6447d08be41..86c1a97b0d2 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h @@ -21,6 +21,7 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -74,6 +75,16 @@ class SpmdBuilder : public HloComputation::Builder { HloInstruction* visiting_hlo() const { return visiting_hlo_; } + // Wrapper of queries to broadcast_dims_. + absl::optional*> BroadcastDimsForCreatedHlo( + const HloInstruction* hlo) { + auto it = broadcast_dims_.find(hlo); + if (it == broadcast_dims_.end()) { + return absl::nullopt; + } + return &it->second; + } + private: // Currently visiting instruction. HloInstruction* visiting_hlo_; @@ -81,6 +92,12 @@ class SpmdBuilder : public HloComputation::Builder { // Map from the currently visiting (old) instruction to new instructions // created during SPMD partitioning. HloInstructionMap> instructions_; + + // Maps from each created instruction to a set of dimensions that are from + // broadcasts or elementwise ops over broadcasts. This means elements along + // these dimensions have the same value. + absl::flat_hash_map> + broadcast_dims_; }; // A set of functions that create the cross-partition collective ops. @@ -330,27 +347,11 @@ class PartitionedHlo { PartitioningState state_; }; -struct DotGeneralDimsMapping { +struct DotConvDimsMapping { // The dimension numbers for the operands and output corresponding to a // logical dimension (e.g., batch, contracting, non-contracting). If an // operand or the output doesn't have the logical dimension, it is set to // -1. - struct DimsMapping { - int64 lhs; - int64 rhs; - int64 output; - }; - std::vector batch_dims; - std::vector contracting_dims; - std::vector lhs_non_contracting_dims; - std::vector rhs_non_contracting_dims; -}; - -struct ConvolutionDimsMapping { - // The dimension numbers for the operands and output corresponding to a - // logical dimension (e.g., batch, parallel, non-parallel). If an - // operand or the output doesn't have the logical dimension, it is set to - // -1. struct DimsMapping { int64 lhs; int64 rhs; @@ -358,8 +359,11 @@ struct ConvolutionDimsMapping { // input mapped to index in input_spatial_dimensions(). int64 spatial; }; - std::vector parallel_spatial_dims; - std::vector non_parallel_spatial_dims; + std::vector batch_dims; + std::vector contracting_dims; + std::vector lhs_non_contracting_dims; + std::vector rhs_non_contracting_dims; + std::vector conv_spatial_dims; }; class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault { @@ -403,10 +407,11 @@ class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault { Status HandlePartitionId(HloInstruction* hlo) override; // Implementation of dot partitioning given DotGeneralDimsMapping. - Status HandleDotHelper( - HloInstruction* hlo, const DotGeneralDimsMapping& dims_mapping, - const std::function( - HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot); + Status HandleDotHelper(HloInstruction* hlo, + const DotConvDimsMapping& dims_mapping, + const std::function( + HloInstruction*, HloInstruction*, SpmdBuilder*, + const Window& conv_window)>& create_sharded_dot); // Common handle for elementwise HLOs. Status HandleElementwise(HloInstruction* hlo); diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc index 089c4c339a4..a4dd0e5441b 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc @@ -2003,6 +2003,36 @@ ENTRY entry { EXPECT_THAT(root, op::DynamicSlice(pad, _)); } +TEST_F(SpmdPartitioningTest, PartialReplicatePad) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[11,7] parameter(0), + sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate} + %param1 = f32[] parameter(1), sharding={replicated} + ROOT %pad = f32[27,22] pad(%param0, %param1), padding=2_4_1x2_1_2, + sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + auto root = module->entry_computation()->root_instruction(); + + auto param0 = AllOf(op::Parameter(), op::Shape("f32[11,4]")); + auto after_halo_exchange = + AllOf(op::Shape("f32[11,4]"), + op::DynamicSlice( + AllOf(op::Shape("f32[11,5]"), + op::Concatenate(op::CollectivePermute(op::Slice(param0)), + param0)), + op::Constant(), _)); + auto pad = op::Pad(after_halo_exchange, op::Parameter(1)); + EXPECT_THAT(root, AllOf(op::DynamicSlice(pad, op::Constant(), _), + op::Shape("f32[27,11]"))); +} + TEST_F(SpmdPartitioningTest, SliceAlongNonPartitionedDimension) { const char* const hlo_string = R"( HloModule module @@ -2060,6 +2090,61 @@ ENTRY entry { op::Shape("f32[63,14,126]"))); } +TEST_F(SpmdPartitioningTest, + PartialReplicateSliceAlongNonPartitionedDimension) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[128,14,257] parameter(0), sharding={devices=[1,1,2,2]0,1,2,3 last_tile_dim_replicate} + ROOT %slice = f32[128,11,257] slice(%param0), + slice={[0:128:1], [2:13:1], [0:257:1]}, sharding={devices=[1,1,2,2]0,1,2,3 last_tile_dim_replicate} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto param0 = AllOf(op::Parameter(), op::Shape("f32[128,14,129]")); + EXPECT_THAT(root, AllOf(op::Slice(param0), op::Shape("f32[128,11,129]"))); +} + +TEST_F(SpmdPartitioningTest, PartialReplicateSliceAlongPartitionedDimension) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[128,14,257] parameter(0), sharding={devices=[1,1,2,2]0,1,2,3 last_tile_dim_replicate} + ROOT %slice = f32[63,14,251] slice(%param0), + slice={[2:128:2], [0:14:1], [5:256:1]}, sharding={devices=[1,1,2,2]0,1,2,3 last_tile_dim_replicate} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto param0 = AllOf(op::Parameter(), op::Shape("f32[128,14,129]")); + EXPECT_THAT( + root, + AllOf( + op::Slice(AllOf( + op::DynamicSlice( + AllOf(op::Concatenate( + param0, + AllOf(op::CollectivePermute(op::Slice(param0)), + op::Shape("f32[128,14,2]"))), + op::Shape("f32[128,14,131]")), + op::Constant(), op::Constant(), + op::Add(op::Multiply(op::Reshape(op::DynamicSlice( + op::Constant(), op::PartitionId())), + op::Constant()), + op::Constant())), + op::Shape("f32[128,14,126]"))), + op::Shape("f32[63,14,126]"))); +} + TEST_F(SpmdPartitioningTest, SortAlongNonPartitionedDimension) { const char* const hlo_string = R"( HloModule module @@ -3293,6 +3378,30 @@ ENTRY entry { EXPECT_THAT(root, AllOf(op::Dot(lhs, rhs), op::Shape("f32[24,19648]"))); } +TEST_F(SpmdPartitioningTest, DotPartialDeviceOrder) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[16,256,4096] parameter(0), sharding={devices=[1,1,2,2]1,3,0,2 last_tile_dim_replicate} + %rhs = f32[4096,2048] parameter(1), sharding={devices=[2,2]3,1,2,0} + ROOT %dot = f32[16,256,2048] dot(%lhs, %rhs), + lhs_batch_dims={}, rhs_batch_dims={}, + lhs_contracting_dims={2}, rhs_contracting_dims={0}, + sharding={devices=[1,1,2,2]2,3,0,1 last_tile_dim_replicate} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf(op::Parameter(0), op::Shape("f32[16,256,2048]")); + auto rhs = AllOf(op::Parameter(1), op::Shape("f32[2048,1024]")); + EXPECT_THAT(root, AllOf(op::AllReduce(op::Dot(lhs, rhs)), + op::Shape("f32[16,256,1024]"))); +} + TEST_F(SpmdPartitioningTest, EinsumBatchPartitioned) { const char* const hlo_string = R"( HloModule module @@ -3843,6 +3952,35 @@ ENTRY entry { op::Shape("s32[2]"))); } +TEST_F(SpmdPartitioningTest, PartialReplicatedRng) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = s32[] parameter(0), sharding={replicated} + %rhs = s32[] parameter(1), sharding={replicated} + ROOT %rng = s32[8]{0} rng(%lhs, %rhs), + distribution=rng_uniform, + sharding={devices=[2,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/8)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf(op::Parameter(0), op::Shape("s32[]")); + auto rhs = AllOf(op::Parameter(1), op::Shape("s32[]")); + auto partition_id = + AllOf(op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId())), + op::Shape("u32[]")); + EXPECT_THAT( + root, AllOf(op::AllReduce(op::Select( + op::Broadcast(op::Compare(partition_id, op::Constant())), + op::Rng(lhs, rhs), op::Broadcast(op::Constant()))), + op::Shape("s32[4]"))); +} + TEST_F(SpmdPartitioningTest, DynamicSliceAlongNonPartitionedDimension) { const char* const hlo_string = R"( HloModule module @@ -3920,6 +4058,26 @@ ENTRY entry { op::Shape("f32[3,5]"))); } +TEST_F(SpmdPartitioningTest, PassthroughGather_PartialReplicate) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %input = f32[2,9] parameter(0), + sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate} + %indices = s32[3] parameter(1), sharding={replicated} + ROOT %gather = f32[3,9] gather(%input, %indices), offset_dims={1}, + collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, + slice_sizes={1,9}, sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Gather(op::Parameter(0), op::Parameter(1)), + op::Shape("f32[3,5]"))); +} + TEST_F(SpmdPartitioningTest, IndexPassthroughGather) { const char* const hlo_string = R"( HloModule module @@ -3939,6 +4097,27 @@ ENTRY entry { op::Shape("f32[8,2,2]"))); } +TEST_F(SpmdPartitioningTest, IndexPassthroughGather_PartialReplicate) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %input = f32[2,9,8] parameter(0), sharding={replicated} + %indices = s32[4,2,4] parameter(1), + sharding={devices=[2,1,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate} + ROOT %gather = f32[8,4,4] gather(%input, %indices), offset_dims={0}, + collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=1, + slice_sizes={1,1,8}, + sharding={devices=[1,2,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/8)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Gather(op::Parameter(0), op::Parameter(1)), + op::Shape("f32[8,2,2]"))); +} + TEST_F(SpmdPartitioningTest, GatherPartitionedOnTrivialSliceDims) { const char* const hlo_string = R"( HloModule module @@ -3968,6 +4147,37 @@ ENTRY entry { EXPECT_THAT(root, AllOf(op::AllReduce(masked), op::Shape("f32[2,3,9]"))); } +TEST_F(SpmdPartitioningTest, + GatherPartitionedOnTrivialSliceDims_PartialReplicate) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %input = f32[17,9] parameter(0), + sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate} + %indices = s32[2,3] parameter(1), sharding={replicated} + ROOT %gather = f32[2,3,9] gather(%input, %indices), offset_dims={2}, + collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2, + slice_sizes={1,9}, sharding={replicated} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + auto offset = + op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId())); + auto min = AllOf(op::Broadcast(offset), op::Shape("s32[2,3]")); + auto max = AllOf(op::Broadcast(op::Add(offset, op::Constant())), + op::Shape("s32[2,3]")); + auto clamp = op::Clamp(min, op::Parameter(1), max); + auto gather = op::Gather(op::Parameter(0), op::Subtract(clamp, min)); + auto mask = + op::Or(op::Lt(op::Parameter(1), min), op::Gt(op::Parameter(1), max)); + auto masked = + op::Select(op::Broadcast(mask), op::Broadcast(op::Constant()), gather); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::AllReduce(masked), op::Shape("f32[2,3,9]"))); +} + TEST_F(SpmdPartitioningTest, PassthroughScatter) { const char* const hlo_string = R"( HloModule module @@ -3998,6 +4208,39 @@ ENTRY entry { op::Shape("f32[2,5]"))); } +TEST_F(SpmdPartitioningTest, PassthroughScatter_PartialReplicate) { + const char* const hlo_string = R"( +HloModule module + +add (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT sum = f32[] add(lhs, rhs) +} + +ENTRY entry { + %input = f32[2,9] parameter(0), + sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate} + %indices = s32[3] parameter(1), sharding={replicated} + %updates = f32[3,9] parameter(2), + sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate} + ROOT %scatter = f32[2,9] scatter(%input, %indices, %updates), + to_apply=add, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1, + sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Scatter(op::Parameter(0), op::Parameter(1), + op::Parameter(2)), + op::Shape("f32[2,5]"))); +} + TEST_F(SpmdPartitioningTest, IndexPassthroughScatter) { const char* const hlo_string = R"( HloModule module @@ -4032,6 +4275,42 @@ ENTRY entry { op::Shape("f32[2,9,8]"))); } +TEST_F(SpmdPartitioningTest, IndexPassthroughScatter_PartialReplicate) { + const char* const hlo_string = R"( +HloModule module + +add (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT sum = f32[] add(lhs, rhs) +} + +ENTRY entry { + %input = f32[2,9,8] parameter(0), sharding={replicated} + %indices = s32[4,2,4] parameter(1), + sharding={devices=[2,1,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate} + %updates = f32[4,4,8] parameter(2), + sharding={devices=[2,2,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate} + ROOT %scatter = f32[2,9,8] scatter(%input, %indices, %updates), + to_apply=add, + update_window_dims={2}, + inserted_window_dims={0,1}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=1, sharding={replicated} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/8)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + AllOf(op::AllReduce(op::Scatter( + op::Select(op::Broadcast(op::Convert(op::Reshape())), + op::Broadcast(op::Constant()), op::Parameter(0)), + op::Parameter(1), op::Parameter(2))), + op::Shape("f32[2,9,8]"))); +} + TEST_F(SpmdPartitioningTest, IndexPassthroughScatter_Min) { const char* const hlo_string = R"( HloModule module @@ -4100,6 +4379,43 @@ ENTRY entry { op::Shape("f32[9,9]"))); } +TEST_F(SpmdPartitioningTest, + ScatterPartitionedOnTrivialSliceDims_PartialReplicate) { + const char* const hlo_string = R"( +HloModule module + +add (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT sum = f32[] add(lhs, rhs) +} + +ENTRY entry { + %input = f32[17,9] parameter(0), + sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate} + %indices = s32[2,3] parameter(1), sharding={replicated} + %updates = f32[2,3,9] parameter(2), sharding={replicated} + ROOT %scatter = f32[17,9] scatter(%input, %indices, %updates), + to_apply=add, + update_window_dims={2}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=2, + sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + auto offset = + op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId())); + auto indices = op::Subtract( + op::Parameter(1), AllOf(op::Broadcast(offset), op::Shape("s32[2,3]"))); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, + AllOf(op::Scatter(op::Parameter(0), indices, op::Parameter(2)), + op::Shape("f32[9,9]"))); +} + TEST_F(SpmdPartitioningTest, TiledReversePassthrough) { const char* const hlo_string = R"( HloModule module @@ -5091,6 +5407,733 @@ ENTRY entry { EXPECT_THAT(root, partially_replicated); } +TEST_F(SpmdPartitioningTest, PartitionConvWithBathGroupCount) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[16,801,1,1024] parameter(0) + %lhs.copy = f32[16,801,1,1024] copy(%lhs), + sharding={devices=[1,1,1,2]0,1} + %rhs = f32[16,801,1,1024] parameter(1) + %rhs.copy = f32[16,801,1,1024] copy(%rhs), + sharding={devices=[1,1,1,2]0,1} + ROOT %conv = f32[5,1,1,1024] convolution(%lhs.copy, %rhs.copy), + dim_labels=f01b_i01o->01bf,batch_group_count=1024, + window={size=801x1 pad=2_2x0_0}, + sharding={devices=[1,1,1,2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + + VLOG(1) << module->ToString(); + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(), + op::Constant(), op::Reshape())), + op::Shape("f32[16,801,1,512]")); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(), + op::Constant(), op::Reshape())), + op::Shape("f32[16,801,1,512]")); + EXPECT_THAT(root, + AllOf(op::Convolution(lhs, rhs), op::Shape("f32[5,1,1,512]"))); +} + +TEST_F(SpmdPartitioningTest, PartitionConvWithBathGroupCountRHSAlignWithLHS) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[16,801,1,1024] parameter(0) + %lhs.copy = f32[16,801,1,1024] copy(%lhs), + sharding={devices=[1,1,1,2]0,1} + %rhs = f32[16,801,1,1024] parameter(1) + %rhs.copy = f32[16,801,1,1024] copy(%rhs), + sharding={devices=[1,2,1,1]0,1} + ROOT %conv = f32[5,1,1,1024] convolution(%lhs.copy, %rhs.copy), + dim_labels=f01b_i01o->01bf,batch_group_count=1024, + window={size=801x1 pad=2_2x0_0}, + sharding={devices=[1,1,1,2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(), + op::Constant(), op::Reshape())), + op::Shape("f32[16,801,1,512]")); + auto rhs = AllOf(op::Copy(op::DynamicSlice( + op::Pad(op::Parameter(), op::Constant()), op::Constant(), + op::Reshape(), op::Constant(), op::Constant())), + op::Shape("f32[16,401,1,1024]")); + auto resharded_rhs = AllOf( + op::Slice(op::Reshape(op::Transpose(op::AllToAll(op::Reshape(rhs))))), + op::Shape("f32[16,801,1,512]")); + EXPECT_THAT(root, AllOf(op::Convolution(lhs, resharded_rhs), + op::Shape("f32[5,1,1,512]"))); +} + +TEST_F(SpmdPartitioningTest, PartitionConvWithBathGroupCountLHSAlignWithRHS) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[16,801,1,1024] parameter(0) + %lhs.copy = f32[16,801,1,1024] copy(%lhs), + sharding={devices=[1,2,1,1]0,1} + %rhs = f32[16,801,1,1024] parameter(1) + %rhs.copy = f32[16,801,1,1024] copy(%rhs), + sharding={devices=[1,1,1,2]0,1} + ROOT %conv = f32[5,1,1,1024] convolution(%lhs.copy, %rhs.copy), + dim_labels=f01b_i01o->01bf,batch_group_count=1024, + window={size=801x1 pad=2_2x0_0}, + sharding={devices=[1,1,1,2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + auto root = module->entry_computation()->root_instruction(); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(), + op::Constant(), op::Reshape())), + op::Shape("f32[16,801,1,512]")); + auto lhs = AllOf(op::Copy(op::DynamicSlice( + op::Pad(op::Parameter(), op::Constant()), op::Constant(), + op::Reshape(), op::Constant(), op::Constant())), + op::Shape("f32[16,401,1,1024]")); + auto resharded_lhs = AllOf( + op::Slice(op::Reshape(op::Transpose(op::AllToAll(op::Reshape(lhs))))), + op::Shape("f32[16,801,1,512]")); + EXPECT_THAT(root, AllOf(op::Convolution(resharded_lhs, rhs), + op::Shape("f32[5,1,1,512]"))); +} + +TEST_F(SpmdPartitioningTest, + PartitionConvWithBathGroupCountOutputAlignWithLHS) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[16,801,1,1024] parameter(0) + %lhs.copy = f32[16,801,1,1024] copy(%lhs), + sharding={devices=[1,1,1,2]0,1} + %rhs = f32[16,801,1,1024] parameter(1) + %rhs.copy = f32[16,801,1,1024] copy(%rhs), + sharding={devices=[1,1,1,2]0,1} + ROOT %conv = f32[5,1,1,1024] convolution(%lhs.copy, %rhs.copy), + dim_labels=f01b_i01o->01bf,batch_group_count=1024, + window={size=801x1 pad=2_2x0_0}, + sharding={devices=[2,1,1,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(), + op::Constant(), op::Reshape())), + op::Shape("f32[16,801,1,512]")); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(), + op::Constant(), op::Reshape())), + op::Shape("f32[16,801,1,512]")); + auto conv = AllOf(op::Convolution(lhs, rhs), op::Shape("f32[5,1,1,512]")); + EXPECT_THAT(root, AllOf(op::Reshape(op::Transpose(op::AllToAll( + op::Reshape(op::Pad(conv, op::Constant()))))), + op::Shape("f32[3,1,1,1024]"))); +} + +TEST_F(SpmdPartitioningTest, + PartitionConvWithBathGroupCountOutputAlignWithRHS) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[16,801,1,1024] parameter(0) + %lhs.copy = f32[16,801,1,1024] copy(%lhs), + sharding={devices=[1,2,1,1]0,1} + %rhs = f32[16,801,1,1024] parameter(1) + %rhs.copy = f32[16,801,1,1024] copy(%rhs), + sharding={devices=[1,1,1,2]0,1} + ROOT %conv = f32[5,1,1,1024] convolution(%lhs.copy, %rhs.copy), + dim_labels=f01b_i01o->01bf,batch_group_count=1024, + window={size=801x1 pad=2_2x0_0}, + sharding={devices=[2,1,1,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + auto root = module->entry_computation()->root_instruction(); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(), + op::Constant(), op::Reshape())), + op::Shape("f32[16,801,1,512]")); + auto lhs = AllOf(op::Copy(op::DynamicSlice( + op::Pad(op::Parameter(), op::Constant()), op::Constant(), + op::Reshape(), op::Constant(), op::Constant())), + op::Shape("f32[16,401,1,1024]")); + auto resharded_lhs = AllOf( + op::Slice(op::Reshape(op::Transpose(op::AllToAll(op::Reshape(lhs))))), + op::Shape("f32[16,801,1,512]")); + auto conv = + AllOf(op::Convolution(resharded_lhs, rhs), op::Shape("f32[5,1,1,512]")); + EXPECT_THAT(root, AllOf(op::Reshape(op::Transpose(op::AllToAll( + op::Reshape(op::Pad(conv, op::Constant()))))), + op::Shape("f32[3,1,1,1024]"))); +} + +TEST_F(SpmdPartitioningTest, PartitionConvWithFeatureGroupCount) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[16,801,1,1024] parameter(0) + %lhs.copy = f32[16,801,1,1024] copy(%lhs), + sharding={devices=[1,1,1,2]0,1} + %rhs = f32[5,1,1,1024] parameter(1) + %rhs.copy = f32[5,1,1,1024] copy(%rhs), + sharding={devices=[1,1,1,2]0,1} + ROOT %conv = f32[16,801,1,1024] convolution(%lhs.copy, %rhs.copy), + dim_labels=b01f_01io->b01f,feature_group_count=1024, + window={size=5x1 pad=2_2x0_0}, + sharding={devices=[1,1,1,2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(), + op::Constant(), op::Reshape())), + op::Shape("f32[16,801,1,512]")); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(), + op::Constant(), op::Reshape())), + op::Shape("f32[5,1,1,512]")); + EXPECT_THAT(root, + AllOf(op::Convolution(lhs, rhs), op::Shape("f32[16,801,1,512]"))); +} + +TEST_F(SpmdPartitioningTest, + PartitionConvWithFeatureGroupCountRHSAlignWithLHS) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[16,801,1,1024] parameter(0) + %lhs.copy = f32[16,801,1,1024] copy(%lhs), + sharding={devices=[1,1,1,2]0,1} + %rhs = f32[5,1,1,1024] parameter(1) + %rhs.copy = f32[5,1,1,1024] copy(%rhs), + sharding={devices=[2,1,1,1]0,1} + ROOT %conv = f32[16,801,1,1024] convolution(%lhs.copy, %rhs.copy), + dim_labels=b01f_01io->b01f,feature_group_count=1024, + window={size=5x1 pad=2_2x0_0}, + sharding={devices=[1,1,1,2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(), + op::Constant(), op::Reshape())), + op::Shape("f32[16,801,1,512]")); + auto rhs = AllOf(op::Copy(op::DynamicSlice( + op::Pad(op::Parameter(), op::Constant()), op::Reshape(), + op::Constant(), op::Constant(), op::Constant())), + op::Shape("f32[3,1,1,1024]")); + auto resharded_rhs = AllOf( + op::Slice(op::Reshape(op::Transpose(op::AllToAll(op::Reshape(rhs))))), + op::Shape("f32[5,1,1,512]")); + EXPECT_THAT(root, AllOf(op::Convolution(lhs, resharded_rhs), + op::Shape("f32[16,801,1,512]"))); +} + +TEST_F(SpmdPartitioningTest, + PartitionConvWithFeatureGroupCountLHSAlignWithRHS) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[16,801,1,1024] parameter(0) + %lhs.copy = f32[16,801,1,1024] copy(%lhs), + sharding={devices=[1,2,1,1]0,1} + %rhs = f32[5,1,1,1024] parameter(1) + %rhs.copy = f32[5,1,1,1024] copy(%rhs), + sharding={devices=[1,1,1,2]0,1} + ROOT %conv = f32[16,801,1,1024] convolution(%lhs.copy, %rhs.copy), + dim_labels=b01f_01io->b01f,feature_group_count=1024, + window={size=5x1 pad=2_2x0_0}, + sharding={devices=[1,1,1,2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf(op::Copy(op::DynamicSlice( + op::Pad(op::Parameter(), op::Constant()), op::Constant(), + op::Reshape(), op::Constant(), op::Constant())), + op::Shape("f32[16,401,1,1024]")); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(), + op::Constant(), op::Reshape())), + op::Shape("f32[5,1,1,512]")); + auto resharded_lhs = AllOf( + op::Slice(op::Reshape(op::Transpose(op::AllToAll(op::Reshape(lhs))))), + op::Shape("f32[16,801,1,512]")); + EXPECT_THAT(root, AllOf(op::Convolution(resharded_lhs, rhs), + op::Shape("f32[16,801,1,512]"))); +} + +TEST_F(SpmdPartitioningTest, + PartitionConvWithFeatureGroupCountAlignOuputWithLHS) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[16,801,1,1024] parameter(0) + %lhs.copy = f32[16,801,1,1024] copy(%lhs), + sharding={devices=[1,1,1,2]0,1} + %rhs = f32[5,1,1,1024] parameter(1) + %rhs.copy = f32[5,1,1,1024] copy(%rhs), + sharding={devices=[1,1,1,2]0,1} + ROOT %conv = f32[16,801,1,1024] convolution(%lhs.copy, %rhs.copy), + dim_labels=b01f_01io->b01f,feature_group_count=1024, + window={size=5x1 pad=2_2x0_0}, + sharding={devices=[2,1,1,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(), + op::Constant(), op::Reshape())), + op::Shape("f32[16,801,1,512]")); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(), + op::Constant(), op::Reshape())), + op::Shape("f32[5,1,1,512]")); + auto conv = AllOf(op::Convolution(lhs, rhs), op::Shape("f32[16,801,1,512]")); + EXPECT_THAT(root, + AllOf(op::Reshape(op::Transpose(op::AllToAll(op::Reshape(conv)))), + op::Shape("f32[8,801,1,1024]"))); +} + +TEST_F(SpmdPartitioningTest, + PartitionConvGroupOnFeatureGroupCount_RHSPartialReplicate) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[16,801,1,1024] parameter(0) + %lhs.copy = f32[16,801,1,1024] copy(%lhs), + sharding={devices=[1,2,1,2]0,1,2,3} + %rhs = f32[5,1,1,1024] parameter(1) + %rhs.copy = f32[5,1,1,1024] copy(%rhs), + sharding={devices=[1,1,1,2,2]0,2,1,3 last_tile_dim_replicate} + ROOT %conv = f32[16,801,1,1024] convolution(%lhs.copy, %rhs.copy), + dim_labels=b01f_01io->b01f,feature_group_count=1024, + window={size=5x1 pad=2_2x0_0}, + sharding={devices=[1,2,1,2]0,1,2,3} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf(op::Copy(op::DynamicSlice( + op::Pad(op::Parameter(), op::Constant()), op::Constant(), + op::Reshape(), op::Constant(), op::Reshape())), + op::Shape("f32[16,401,1,512]")); + auto left_halo = AllOf(op::Shape("f32[16,2, 1, 512]"), + op::CollectivePermute(op::Slice(lhs))); + auto right_halo = AllOf(op::Shape("f32[16,2, 1, 512]"), + op::CollectivePermute(op::Slice(lhs))); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(), + op::Constant(), op::Reshape())), + op::Shape("f32[5,1,1,512]")); + EXPECT_THAT( + root, + AllOf(op::Convolution( + op::Select(_, op::Concatenate(left_halo, lhs, right_halo), _), + rhs), + op::Shape("f32[16, 401, 1, 512]"))); +} + +TEST_F(SpmdPartitioningTest, + PartitionConvGroupOnFeatureGroupCount_RHSAlignWithOutput) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[16,801,1,1024] parameter(0) + %lhs.copy = f32[16,801,1,1024] copy(%lhs), + sharding={devices=[1,2,1,2]0,1,2,3} + %rhs = f32[5,1,1,1024] parameter(1), sharding={replicated} + ROOT %conv = f32[16,801,1,1024] convolution(%lhs.copy, %rhs), + dim_labels=b01f_01io->b01f,feature_group_count=1024, + window={size=5x1 pad=2_2x0_0}, + sharding={devices=[1,2,1,2]0,1,2,3} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf(op::Copy(op::DynamicSlice( + op::Pad(op::Parameter(), op::Constant()), op::Constant(), + op::Reshape(), op::Constant(), op::Reshape())), + op::Shape("f32[16,401,1,512]")); + auto left_halo = AllOf(op::Shape("f32[16,2, 1, 512]"), + op::CollectivePermute(op::Slice(lhs))); + auto right_halo = AllOf(op::Shape("f32[16,2, 1, 512]"), + op::CollectivePermute(op::Slice(lhs))); + auto rhs = + AllOf(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(), + op::Constant(), op::Reshape()), + op::Shape("f32[5,1,1,512]")); + EXPECT_THAT( + root, + AllOf(op::Convolution( + op::Select(_, op::Concatenate(left_halo, lhs, right_halo), _), + rhs), + op::Shape("f32[16, 401, 1, 512]"))); +} + +TEST_F(SpmdPartitioningTest, + PartitionConvGroupOnFeatureGroupCount_LHSAlignWithOutput) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[16,801,1,1024] parameter(0) + %lhs.copy = f32[16,801,1,1024] copy(%lhs), + sharding={devices=[2,1,1,1,2]0,1,2,3 last_tile_dim_replicate} + %rhs = f32[5,1,1,1024] parameter(1) + %rhs.copy = f32[5,1,1,1024] copy(%rhs), + sharding={devices=[1,1,1,2,2]0,2,1,3 last_tile_dim_replicate} + ROOT %conv = f32[16,801,1,1024] convolution(%lhs.copy, %rhs.copy), + dim_labels=b01f_01io->b01f,feature_group_count=1024, + window={size=5x1 pad=2_2x0_0}, + sharding={devices=[1,2,1,2]0,1,2,3} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), op::Constant(), + op::Constant(), op::Constant())), + op::Shape("f32[8,801,1,1024]")); + auto resharded_lhs = + AllOf(op::Reshape(op::Transpose(op::AllToAll(op::Reshape( + op::Pad(op::DynamicSlice(lhs, op::Constant(), op::Constant(), + op::Constant(), op::Reshape()), + op::Constant()))))), + op::Shape("f32[16,401,1,512]")); + auto left_halo = AllOf(op::Shape("f32[16,2, 1, 512]"), + op::CollectivePermute(op::Slice(resharded_lhs))); + auto right_halo = AllOf(op::Shape("f32[16,2, 1, 512]"), + op::CollectivePermute(op::Slice(resharded_lhs))); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(), + op::Constant(), op::Reshape())), + op::Shape("f32[5,1,1,512]")); + EXPECT_THAT( + root, + AllOf( + op::Convolution( + op::Select( + _, op::Concatenate(left_halo, resharded_lhs, right_halo), _), + rhs), + op::Shape("f32[16, 401, 1, 512]"))); +} + +TEST_F(SpmdPartitioningTest, PartitionConvGroupOnBatchGroupCount) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[16,801,1,1024] parameter(0) + %lhs.copy = f32[16,801,1,1024] copy(%lhs), + sharding={devices=[1,2,1,2]0,1,2,3} + %rhs = f32[16,801,1,1024] parameter(1) + %rhs.copy = f32[16,801,1,1024] copy(%rhs), + sharding={devices=[1,2,1,2]0,1,2,3} + ROOT %conv = f32[5,1,1,1024] convolution(%lhs.copy, %rhs.copy), + dim_labels=f01b_i01o->01bf,batch_group_count=1024, + window={size=801x1 pad=2_2x0_0}, + sharding={devices=[1,1,1,2,2]0,1,2,3 last_tile_dim_replicate} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Select(_, + op::Copy(op::DynamicSlice( + op::Pad(op::Parameter(), op::Constant()), op::Constant(), + op::Reshape(), op::Constant(), op::Reshape())), + _), + op::Shape("f32[16,401,1,512]")); + auto left_halo = AllOf(op::Shape("f32[16,2, 1, 512]"), + op::CollectivePermute(op::Slice(lhs))); + auto right_halo = AllOf(op::Shape("f32[16,2, 1, 512]"), + op::CollectivePermute(op::Slice(lhs))); + auto rhs = AllOf(op::Copy(op::DynamicSlice( + op::Pad(op::Parameter(), op::Constant()), op::Constant(), + op::Reshape(), op::Constant(), op::Reshape())), + op::Shape("f32[16,401,1,512]")); + auto conv = AllOf(op::Convolution(op::Concatenate(left_halo, lhs, right_halo), + op::Select(_, rhs, _)), + op::Shape("f32[5,1,1,512]")); + EXPECT_THAT(root, AllOf(op::CollectivePermute(op::AllReduce(conv)), + op::Shape("f32[5,1,1,512]"))); +} + +TEST_F(SpmdPartitioningTest, + PartitionConvWithFeatureGroupCountAlignOuputWithRHS) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[16,801,1,1024] parameter(0) + %lhs.copy = f32[16,801,1,1024] copy(%lhs), + sharding={devices=[1,2,1,1]0,1} + %rhs = f32[5,1,1,1024] parameter(1) + %rhs.copy = f32[5,1,1,1024] copy(%rhs), + sharding={devices=[1,1,1,2]0,1} + ROOT %conv = f32[16,801,1,1024] convolution(%lhs.copy, %rhs.copy), + dim_labels=b01f_01io->b01f,feature_group_count=1024, + window={size=5x1 pad=2_2x0_0}, + sharding={devices=[2,1,1,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf(op::Copy(op::DynamicSlice( + op::Pad(op::Parameter(), op::Constant()), op::Constant(), + op::Reshape(), op::Constant(), op::Constant())), + op::Shape("f32[16,401,1,1024]")); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(), + op::Constant(), op::Reshape())), + op::Shape("f32[5,1,1,512]")); + auto resharded_lhs = AllOf( + op::Slice(op::Reshape(op::Transpose(op::AllToAll(op::Reshape(lhs))))), + op::Shape("f32[16,801,1,512]")); + auto conv = AllOf(op::Convolution(resharded_lhs, rhs), + op::Shape("f32[16,801,1,512]")); + EXPECT_THAT(root, + AllOf(op::Reshape(op::Transpose(op::AllToAll(op::Reshape(conv)))), + op::Shape("f32[8,801,1,1024]"))); +} + +TEST_F(SpmdPartitioningTest, PartitionConvWithFeatureGroupCountBackProp) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[16,801,1,1024] parameter(0) + %lhs.copy = f32[16,801,1,1024] copy(%lhs), + sharding={devices=[1,1,1,2]0,1} + %rhs = f32[5,1,1024,1] parameter(1) + %rhs.copy = f32[5,1,1024,1] copy(%rhs), + sharding={devices=[1,1,2,1]0,1} + ROOT %conv = f32[16,801,1,1024] convolution(%lhs.copy, %rhs.copy), + dim_labels=b01f_01oi->b01f,feature_group_count=1024, + window={size=5x1 pad=2_2x0_0 rhs_reversal=1x1}, + sharding={devices=[1,1,1,2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(), + op::Constant(), op::Reshape())), + op::Shape("f32[16,801,1,512]")); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(), + op::Reshape(), op::Constant())), + op::Shape("f32[5,1,512,1]")); + EXPECT_THAT(root, + AllOf(op::Convolution(lhs, rhs), op::Shape("f32[16,801,1,512]"))); +} + +TEST_F(SpmdPartitioningTest, NoReshardOnBroadcastDims) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[2,3] parameter(0) + %param1 = f32[2,3,20] parameter(1) + %br0 = f32[20,2,20,3,20] broadcast(%param0), dimensions={1,3}, sharding={devices=[2,1,2,1,2]0,1,2,3,4,5,6,7} + %br1 = f32[20,2,20,3,20] broadcast(%param1), dimensions={1,3,4}, sharding={devices=[2,1,2,1,2]0,1,2,3,4,5,6,7} + %add = f32[20,2,20,3,20] add(%br0, %br1), sharding={devices=[2,1,2,1,2]0,1,2,3,4,5,6,7} + %reshape = f32[10,4,10,6,20] reshape(%br0), sharding={devices=[2,1,2,1,2]0,1,2,3,4,5,6,7} + %transpose = f32[2,3,20,20,20] transpose(%br0), dimensions={1,3,0,2,4}, sharding={devices=[1,1,2,2,2]0,1,2,3,4,5,6,7} + %copy_add0 = f32[20,2,20,3,20] copy(%add), sharding={devices=[2,1,2,1,2]6,7,2,3,4,5,0,1} + %copy_add1 = f32[20,2,20,3,20] copy(%add), sharding={devices=[2,1,2,1,2]7,6,3,2,5,4,0,1} + %copy_reshape = f32[10,4,10,6,20] copy(%reshape), sharding={devices=[2,1,2,1,2]7,6,3,2,5,4,0,1} + %copy_transpose = f32[2,3,20,20,20] copy(%transpose), sharding={devices=[1,1,2,2,2]7,6,3,2,5,4,0,1} + ROOT %tuple = (f32[20,2,20,3,20], f32[20,2,20,3,20], f32[10,4,10,6,20], f32[2,3,20,20,20]) + tuple(%copy_add0, %copy_add1, %copy_reshape, %copy_transpose), + sharding={{devices=[2,1,2,1,2]6,7,2,3,4,5,0,1},{devices=[2,1,2,1,2]7,6,3,2,5,4,0,1},{devices=[2,1,2,1,2]7,6,3,2,5,4,0,1},{devices=[1,1,2,2,2]7,6,3,2,5,4,0,1}} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/8)); + VLOG(1) << module->ToString(); + auto root = module->entry_computation()->root_instruction(); + // Reshard on copy_add0 only happens on broadcast dims, can be skipped. + auto copy_add0 = + op::Copy(op::Copy(op::Add(op::Broadcast(_), op::Broadcast(_)))); + // Reshard on copy_add1 also happens on non-broadcast dims. + auto copy_add1 = op::Copy( + op::CollectivePermute(op::Add(op::Broadcast(_), op::Broadcast(_)))); + // Reshard on copy_reshape only happens on broadcast dims, can be skipped. + auto copy_reshape = op::Copy(op::Copy(op::Reshape(op::Broadcast(_)))); + // Reshard on copy_transpose only happens on broadcast dims, can be skipped. + auto copy_transpose = op::Copy(op::Copy(op::Transpose(op::Broadcast(_)))); + EXPECT_THAT(root, + op::Tuple(copy_add0, copy_add1, copy_reshape, copy_transpose)); +} + +TEST_F(SpmdPartitioningTest, + ConvolutionFilterIFOFPartitionedInputPartialReplicate) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,112,112,12] parameter(0) + %lhs.copy = f32[128,112,112,12] copy(f32[128,112,112,12] %lhs), + sharding={devices=[1,1,1,2,2]0,1,2,3 last_tile_dim_replicate} + %rhs = f32[7,7,12,64] parameter(1) + %rhs.copy = f32[7,7,12,64] copy(f32[7,7,12,64] %rhs), + sharding={devices=[1,1,2,2]0,1,2,3} + ROOT %conv = f32[128,56,56,64] convolution( + f32[128,112,112,12] %lhs.copy, + f32[7,7,12,64] %rhs.copy), + window={size=7x7 stride=2x2 pad=3_3x3_3}, + dim_labels=b01f_01io->b01f, + sharding={devices=[1,1,1,2,2]0,1,2,3 last_tile_dim_replicate} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(), + op::Constant(), op::Reshape())), + op::Shape("f32[128,112,112,6]")); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(), + op::Reshape(), op::Reshape())), + op::Shape("f32[7,7,6,32]")); + + EXPECT_THAT( + root, + AllOf(op::CollectivePermute(op::AllReduce(op::Convolution(lhs, rhs))), + op::Shape("f32[128,56,56,32]"))); +} + +TEST_F(SpmdPartitioningTest, + ConvolutionInputKernelNonContractingDimPartialReplicate) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,56,56,256] parameter(0) + %lhs.copy = f32[128,56,56,256] copy(%lhs), + sharding={devices=[1,1,1,2,2]0,1,2,3 last_tile_dim_replicate} + %rhs = f32[128,28,28,512] parameter(1) + %rhs.copy = f32[128,28,28,512] copy(%rhs), + sharding={devices=[1,1,1,2,2]0,1,2,3 last_tile_dim_replicate} + ROOT %conv = f32[1,1,256,512] convolution(%lhs.copy, %rhs.copy), + window={size=28x28 pad=0_-1x0_-1 rhs_dilate=2x2}, dim_labels=f01b_i01o->01bf, + sharding={devices=[1,1,2,2]0,1,2,3} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(), + op::Constant(), op::Reshape())), + op::Shape("f32[128,56,56,128]")); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(), + op::Constant(), op::Reshape())), + op::Shape("f32[128,28,28,256]")); + + EXPECT_THAT(root, AllOf(op::Convolution(lhs, op::CollectivePermute(rhs)), + op::Shape("f32[1,1,128,256]"))); +} + +TEST_F(SpmdPartitioningTest, + ConvolutionInputSpatialDimAndFeatureDimParttiioned) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[8,210,210,12] parameter(0) + %lhs.copy = f32[8,210,210,12] copy(f32[8,210,210,12] %lhs), + sharding={devices=[1,2,1,2]0,1,2,3} + %rhs = f32[3,3,12,32] parameter(1) + %rhs.copy = f32[3,3,12,32] copy(f32[3,3,12,32] %rhs), + sharding={devices=[1,1,2,1,2]0,1,2,3 last_tile_dim_replicate} + ROOT %conv = f32[8,210,210,32] convolution( + f32[8,210,210,12] %lhs.copy, + f32[3,3,12,32] %rhs.copy), + window={size=3x3 pad=1_1x1_1}, + dim_labels=b01f_01io->b01f, + sharding={devices=[1,2,1,1,2]0,1,2,3 last_tile_dim_replicate} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Reshape())), + op::Shape("f32[8,105,210,6]")); + auto left_halo = + AllOf(op::CollectivePermute(op::Slice(lhs)), op::Shape("f32[8,1,210,6]")); + auto right_halo = + AllOf(op::CollectivePermute(op::Slice(lhs)), op::Shape("f32[8,1,210,6]")); + auto exchanged_lhs = AllOf( + op::Select(op::And(_, _), op::Concatenate(left_halo, lhs, right_halo), + op::Broadcast(_)), + op::Shape("f32[8,107,210,6]")); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(), + op::Reshape(), op::Constant())), + op::Shape("f32[3,3,6,32]")); + EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution( + exchanged_lhs, op::CollectivePermute(rhs))), + op::Shape("f32[8,105,210,32]"))); +} + } // namespace } // namespace spmd } // namespace xla diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc index 0edbd4f2b8d..f3f3a95ea0a 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc @@ -576,8 +576,8 @@ absl::optional PadFromPartialReplicateShape( int64 max_right_halo_size = right_halo_size_function.MaxInRange(0, src_shard_count - 1); pad_config.mutable_dimensions(dim)->set_edge_padding_high(std::max( - 0LL, padded_dst_shape.dimensions(dim) - - padded_src_shape.dimensions(dim) - max_right_halo_size)); + int64{0}, padded_dst_shape.dimensions(dim) - + padded_src_shape.dimensions(dim) - max_right_halo_size)); auto padded_concat_shape = ShapeInference::InferPadShape( concat->shape(), zero->shape(), pad_config) .ValueOrDie(); diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h index f6f15481b55..4fc193d9622 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h @@ -362,8 +362,8 @@ absl::optional PadFromPartialReplicateShape( // dimensions by dynamic slice. // For example, if partial_sharding is // {devices=[1,2,2]0,1,2,3 last_tile_dim_replicate} -// Target tile dims is {2, 2}, the returned compatible sharding will be -// sharding={devices=[1,2,2]0,2,1,3 last_tile_dim_replicate}. +// Target sharding is {devices=[2,2]0,1,2,3}, the returned compatible sharding +// will be sharding={devices=[2,2]0,2,1,3}. // If patial replicate sharding is not partial replicate or can't reshard to // target_tile_dims by dynamic slice, return absl::nullopt. // If target_sharding is already compatible, returns it. diff --git a/tensorflow/compiler/xla/service/transfer_manager.cc b/tensorflow/compiler/xla/service/transfer_manager.cc index 0fd64209152..913bfed926a 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.cc +++ b/tensorflow/compiler/xla/service/transfer_manager.cc @@ -169,8 +169,7 @@ Status TransferManager::TransferArrayToDeviceAsync( "%d < %d", dest.size(), GetByteSizeRequirement(on_device_shape)); } - ShapedBuffer shaped_buffer(/*on_host_shape=*/literal.shape(), on_device_shape, - stream->parent()->platform(), + ShapedBuffer shaped_buffer(on_device_shape, stream->parent()->platform(), stream->parent()->device_ordinal()); shaped_buffer.set_buffer(dest, /*index=*/{}); return TransferLiteralToDevice(stream, literal, shaped_buffer, @@ -194,8 +193,7 @@ void TransferManager::TransferArrayFromDevice( "%d < %d", source.size(), GetByteSizeRequirement(shape))); } - ShapedBuffer shaped_buffer(/*on_host_shape=*/shape, shape, - stream->parent()->platform(), + ShapedBuffer shaped_buffer(shape, stream->parent()->platform(), stream->parent()->device_ordinal()); shaped_buffer.set_buffer(source, /*index=*/{}); return TransferLiteralFromDevice(stream, shaped_buffer, literal, @@ -406,8 +404,8 @@ StatusOr TransferManager::AllocateScopedShapedBuffer( Shape on_device_shape = HostShapeToDeviceShape(on_host_shape); TF_RET_CHECK(LayoutUtil::HasLayout(on_device_shape)); - ScopedShapedBuffer shaped_buffer(on_host_shape, std::move(on_device_shape), - allocator, device_ordinal); + ScopedShapedBuffer shaped_buffer(std::move(on_device_shape), allocator, + device_ordinal); // Allocate an appropriate sized buffer for each element in the shape // including the tuple pointer arrays. diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h index c0670d26eee..c49d7d899e7 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.h +++ b/tensorflow/compiler/xla/service/transfer_manager.h @@ -51,7 +51,11 @@ class TransferManager { // pre-allocated by the host, e.g. TransferLiteralToDevice, without the user // needing to consider device-specific behaviors. virtual Shape HostShapeToDeviceShape(const Shape& host_shape) const { - return host_shape; + // Strips off any preexisting tiling or memory space information. + // TODO(phawkins): fix clients not to including tiling or memory space + // information in shapes passed to this function and turn this into an + // assertion. + return ShapeUtil::DeviceShapeToHostShape(host_shape); } // Base class for specifying platform specific transfer metadata that can be @@ -189,6 +193,7 @@ class TransferManager { // shapes, and returns static shapes with dynamic shapes updated. // The shape of the buffer also have to be compatible with the host shape and // device shape. + // TODO(b/170310047): remove host_shape. virtual Status ReadDynamicShapes(se::Stream* stream, ShapedBuffer* device_buffer, Shape* host_shape, Shape* device_shape); diff --git a/tensorflow/compiler/xla/service/transpose_folding.cc b/tensorflow/compiler/xla/service/transpose_folding.cc index c9799453939..614dfc4ffe6 100644 --- a/tensorflow/compiler/xla/service/transpose_folding.cc +++ b/tensorflow/compiler/xla/service/transpose_folding.cc @@ -35,17 +35,46 @@ TransposeFolding::OperandIndices CanFoldOperandsIntoDot( const HloInstruction& dot, const TransposeFolding::TransposableGemmOperandsFn& transposable_gemm_operands) { - if (HloOpcode::kDot != dot.opcode() || - dot.dot_dimension_numbers().lhs_batch_dimensions_size() != 0) { + if (HloOpcode::kDot != dot.opcode()) { return {}; } + if (!absl::c_equal(dot.dot_dimension_numbers().lhs_batch_dimensions(), + dot.dot_dimension_numbers().rhs_batch_dimensions())) { + return {}; + } + + int64 num_batch_dims = + dot.dot_dimension_numbers().lhs_batch_dimensions_size(); + int64 expected_rank = 2 + num_batch_dims; + auto is_r2_transpose = [&](const HloInstruction& transpose) { + if (transpose.opcode() != HloOpcode::kTranspose) { + return false; + } + const auto& transpose_dims = transpose.dimensions(); + if (transpose_dims.size() != expected_rank) { + return false; + } + + // Check that the transpose doesn't touch any batch dimensions, but does + // transpose the non-batch ones. + for (int64 i = 0; i != expected_rank; ++i) { + bool is_batch = absl::c_linear_search( + dot.dot_dimension_numbers().lhs_batch_dimensions(), + transpose_dims[i]); + if ((transpose_dims[i] == i) != is_batch) { + return false; + } + } + return true; + }; + TransposeFolding::OperandIndices operand_set; for (int64 i = 0; i < dot.operand_count(); ++i) { auto& operand = *dot.operand(i); - if (operand.IsRank2Transpose()) { + if (is_r2_transpose(operand)) { operand_set.push_back(i); - } else if (operand.shape().rank() != 2) { + } else if (operand.shape().rank() != expected_rank) { return {}; } } @@ -84,25 +113,25 @@ Status FoldTransposeIntoDot(InstructionOperandsPair pair) { HloInstruction* new_lhs = dot->mutable_operand(0); HloInstruction* new_rhs = dot->mutable_operand(1); - CHECK_EQ(new_dim_numbers.lhs_batch_dimensions_size(), 0); - CHECK_EQ(new_dim_numbers.rhs_batch_dimensions_size(), 0); CHECK_EQ(new_dim_numbers.lhs_contracting_dimensions_size(), 1); CHECK_EQ(new_dim_numbers.rhs_contracting_dimensions_size(), 1); for (int64 operand_index : pair.second) { - // We've checked that there aren't any batch dimensions and that the inputs - // are rank 2, and shape inference guarantees that there is exactly one - // contracting dimension. + // We checked that the batch dimensions are not touched by the transpose, + // and shape inference guarantees that there is exactly one contracting + // dimension. if (operand_index == 0) { CHECK_EQ(new_lhs->opcode(), HloOpcode::kTranspose); new_dim_numbers.set_lhs_contracting_dimensions( - 0, 1 - new_dim_numbers.lhs_contracting_dimensions(0)); + 0, + new_lhs->dimensions(new_dim_numbers.lhs_contracting_dimensions(0))); new_lhs = new_lhs->mutable_operand(0); } else { CHECK_EQ(operand_index, 1); CHECK_EQ(new_rhs->opcode(), HloOpcode::kTranspose); new_dim_numbers.set_rhs_contracting_dimensions( - 0, 1 - new_dim_numbers.rhs_contracting_dimensions(0)); + 0, + new_rhs->dimensions(new_dim_numbers.rhs_contracting_dimensions(0))); new_rhs = new_rhs->mutable_operand(0); } } diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc index 8a2112c87dc..3fe69d22e9c 100644 --- a/tensorflow/compiler/xla/service/transpose_folding_test.cc +++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc @@ -42,7 +42,7 @@ namespace { class TransposeFoldingTest : public HloTestBase { protected: - void FoldTranspose(HloModule* module) { + bool FoldTranspose(HloModule* module) { TransposeFolding transpose_folding( [](const HloInstruction& dot, const TransposeFolding::OperandIndices& candidate_operands) { @@ -52,7 +52,9 @@ class TransposeFoldingTest : public HloTestBase { const TransposeFolding::OperandIndices& candidate_operands) { return candidate_operands; }); - EXPECT_IS_OK(transpose_folding.Run(module).status()); + auto folded = transpose_folding.Run(module); + EXPECT_IS_OK(folded.status()); + return *folded; } }; @@ -465,5 +467,81 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeLhs) { new_conv->convolution_dimension_numbers().output_spatial_dimensions(1)); } +TEST_F(TransposeFoldingTest, FoldBatchDotTranspose) { + string hlo_string = R"( +HloModule FoldBatchDotTranspose + +ENTRY entry_computation { + x = f32[7,7,2,3]{3,2,1,0} parameter(0) + y = f32[7,7,2,3]{3,2,1,0} parameter(1) + transpose = f32[7,7,3,2]{3,2,1,0} transpose(y), dimensions={0,1,3,2} + ROOT dot = f32[7,7,2,2]{3,2,1,0} dot(x, transpose), lhs_contracting_dims={3}, + rhs_contracting_dims={2}, lhs_batch_dims={0,1}, rhs_batch_dims={0,1} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + ASSERT_TRUE(FoldTranspose(module.get())); + + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Dot(op::Parameter(0), op::Parameter(1), + /*lhs_contracting_dim=*/3, /*rhs_contracting_dim=*/3)); +} + +TEST_F(TransposeFoldingTest, NoFoldBatchDotTransposeBatch) { + string hlo_string = R"( +HloModule NoFoldBatchDotTransposeBatch + +ENTRY entry_computation { + x = f32[7,7,2,3]{3,2,1,0} parameter(0) + y = f32[7,7,2,3]{3,2,1,0} parameter(1) + transpose = f32[7,7,3,2]{3,2,1,0} transpose(y), dimensions={1,0,3,2} + ROOT dot = f32[7,7,2,2]{3,2,1,0} dot(x, transpose), lhs_contracting_dims={3}, + rhs_contracting_dims={2}, lhs_batch_dims={0,1}, rhs_batch_dims={0,1} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + EXPECT_FALSE(FoldTranspose(module.get())); +} + +TEST_F(TransposeFoldingTest, FoldBatchDotTransposeNonContiguousBatch) { + string hlo_string = R"( +HloModule FoldBatchDotTransposeNonContiguousBatch + +ENTRY entry_computation { + x = f32[7,2,7,3]{3,2,1,0} parameter(0) + y = f32[7,2,7,3]{3,2,1,0} parameter(1) + transpose = f32[7,3,7,2]{3,2,1,0} transpose(y), dimensions={0,3,2,1} + ROOT dot = f32[7,7,2,2]{3,2,1,0} dot(x, transpose), lhs_contracting_dims={3}, + rhs_contracting_dims={1}, lhs_batch_dims={0,2}, rhs_batch_dims={0,2} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + ASSERT_TRUE(FoldTranspose(module.get())); + + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Dot(op::Parameter(0), op::Parameter(1), + /*lhs_contracting_dim=*/3, /*rhs_contracting_dim=*/3)); +} + +TEST_F(TransposeFoldingTest, NoFoldBatchDotTransposeIdentity) { + string hlo_string = R"( +HloModule NoFoldBatchDotTransposeIdentity + +ENTRY entry_computation { + x = f32[7,7,2,3]{3,2,1,0} parameter(0) + y = f32[7,7,3,2]{3,2,1,0} parameter(1) + transpose = f32[7,7,3,2]{3,2,1,0} transpose(y), dimensions={0,1,2,3} + ROOT dot = f32[7,7,2,2]{3,2,1,0} dot(x, transpose), lhs_contracting_dims={3}, + rhs_contracting_dims={2}, lhs_batch_dims={0,1}, rhs_batch_dims={0,1} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + EXPECT_FALSE(FoldTranspose(module.get())); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/triangular_solve_expander.cc b/tensorflow/compiler/xla/service/triangular_solve_expander.cc index d54eb9e78c3..4015c69e3e2 100644 --- a/tensorflow/compiler/xla/service/triangular_solve_expander.cc +++ b/tensorflow/compiler/xla/service/triangular_solve_expander.cc @@ -89,16 +89,23 @@ XlaOp DiagonalBlocks(XlaOp a, int64 block_size) { // The last block might be smaller than the block size, // so we will need to pad it if (n % block_size != 0) { - // Pad with zeros + // Pad with identity matrix. auto last_blocks = SliceInMinorDims(a, {n - n % block_size, n - n % block_size}, {n, n}); PaddingConfig config = MakeNoPaddingConfig(ndims); int64 padding = block_size - n % block_size; - config.mutable_dimensions(ndims - 1)->set_edge_padding_high(padding); config.mutable_dimensions(ndims - 2)->set_edge_padding_high(padding); last_blocks = Pad(last_blocks, Zero(builder, shape.element_type()), config); + auto eye = + IdentityMatrix(builder, shape.element_type(), padding, padding); + config = MakeNoPaddingConfig(ndims); + config.mutable_dimensions(ndims - 2)->set_edge_padding_low(n % + block_size); + eye = Pad(eye, Zero(builder, shape.element_type()), config); + last_blocks = ConcatInDim(builder, {last_blocks, eye}, ndims - 1); + // Add a singleton dimension // i.e. [..., block_size, block_size] -> [..., 1, block_size, block_size] TF_ASSIGN_OR_RETURN(Shape blocks_shape, builder->GetShape(last_blocks)); @@ -121,134 +128,6 @@ XlaOp DiagonalBlocks(XlaOp a, int64 block_size) { }); } -XlaOp InvertDiagonalBlocks(XlaOp diag_blocks, bool lower, bool transpose_a, - bool conjugate_a, - PrecisionConfig::Precision precision) { - XlaBuilder* builder = diag_blocks.builder(); - return builder->ReportErrorOrReturn([&]() -> StatusOr { - // Input is a batch of square lower triangular square matrices. Its shape is - // (..., size, size). We resize this to (num_blocks, size, size). - TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(diag_blocks)); - int64 block_size = ShapeUtil::GetDimension(shape, -1); - int64 num_blocks = ShapeUtil::ElementsIn(shape) / - tensorflow::MathUtil::IPow(block_size, 2); - diag_blocks = Reshape(diag_blocks, {num_blocks, block_size, block_size}); - - // The input must be triangular because we rely on that when doing - // multiplications later on - diag_blocks = Triangle(diag_blocks, /*lower=*/lower); - - // Rescale blocks to be unit triangular, but avoid dividing by - // zero (which can happen if the last block was padded) otherwise it will - // introduce nans which will propagate - auto diags = GetMatrixDiagonal(diag_blocks); - auto ones = FullLike(diags, 1); - diags = Select(Eq(diags, Zero(builder, shape.element_type())), ones, diags); - auto scaled_diag_blocks = Div(diag_blocks, diags, {0, 2}); - - // We can now use the fact that for an upper triangular matrix - // [[L11, 0], [L21, L22]], given the inverses L11' and L22', we have - // L22' = -L22' * L21 * L11'. In our case, L21 is a vector and our blocks - // have been rescaled to be unit triangular, so L22 = L22' = 1. - - // Initialize the output matrix with -1s on the diagonal. We use -1 instead - // of 1 because we cannot do matrix-vector multiplies with variable shapes - // inside of a loop, or do irregularly shaped in-place updates. Hence, - // L21 <- -L22 * L21 * L11 cannot be done naively. Instead, we update the - // entire row i.e. we calculate - // [L21 L22 0] <- -[L21 L22 0] @ diag_blocks([L11', -I, -I]) - // which means [L21 L22 0] <- [-L21 * L11', L22, 0]. - auto identity = - IdentityMatrix(builder, shape.element_type(), block_size, block_size); - auto neg_identity = -identity; - - // The first or last diagonal element should be set to 1 instead of -1 - // though, since we never update it - auto pos_one = Reshape(One(builder, shape.element_type()), {1, 1}); - auto start_index = ConstantR0(builder, (lower) ? 0 : block_size - 1); - auto output_block = - DynamicUpdateSlice(neg_identity, pos_one, - /*start_indices=*/{start_index, start_index}); - - // Broadcast diag([1, -1, -1, ...]) to every block - XlaOp output = Broadcast(output_block, - /*broadcast_sizes=*/{num_blocks}); - - // Now we construct a loop that performs matrix-vector multiplications - // inverting the blocks one row at a time - std::vector tuple_shapes = { - // The loop iteration counter is a scalar, incremented each iteration. - ShapeUtil::MakeShape(S32, {}), - // The output has the shape of A, with one row updated each iteration. - ShapeUtil::MakeShape(shape.element_type(), - {num_blocks, block_size, block_size}), - // The input is a loop invariant. - ShapeUtil::MakeShape(shape.element_type(), - {num_blocks, block_size, block_size})}; - Shape tuple_shape = ShapeUtil::MakeTupleShape(tuple_shapes); - - auto init_i = One(builder, S32); - auto init = Tuple(builder, {init_i, output, scaled_diag_blocks}); - - // Construct the loop condition function. - std::unique_ptr condb = - builder->CreateSubBuilder("InvertDiagCond"); - { - auto i = GetTupleElement( - Parameter(condb.get(), 0, tuple_shape, "InvertDiagCondTuple"), 0); - Lt(i, ConstantR0(condb.get(), block_size)); - } - TF_ASSIGN_OR_RETURN(auto cond, condb->Build()); - - // Construct the loop body function. - std::unique_ptr bodyb = - builder->CreateSubBuilder("InvertDiagBody"); - { - auto input_tuple = - Parameter(bodyb.get(), 0, tuple_shape, "InvertDiagBodyTuple"); - - auto i = GetTupleElement(input_tuple, 0); - auto body_out = GetTupleElement(input_tuple, 1); - auto body_input = GetTupleElement(input_tuple, 2); - - auto zero = ConstantR0(bodyb.get(), 0); - auto j = (lower) ? i : ScalarLike(i, block_size - 1) - i; - auto input_row = - DynamicSlice(body_input, {zero, j, zero}, - /*slice_sizes=*/{num_blocks, 1, block_size}); - - // We want -L21 L11^{-1} - DotDimensionNumbers dnums; - dnums.add_lhs_batch_dimensions(0); - dnums.add_rhs_batch_dimensions(0); - dnums.add_lhs_contracting_dimensions(2); - dnums.add_rhs_contracting_dimensions(1); - PrecisionConfig precision_proto; - precision_proto.add_operand_precision(precision); - precision_proto.add_operand_precision(precision); - auto update = -DotGeneral(input_row, body_out, dnums, &precision_proto); - - body_out = DynamicUpdateSlice(body_out, update, {zero, j, zero}); - - auto next_i = i + ScalarLike(i, 1); - Tuple(bodyb.get(), {next_i, body_out, body_input}); - } - TF_ASSIGN_OR_RETURN(auto body, bodyb->Build()); - - // Construct the While loop and return the result, - // return while_loop(cond_fun, body_fun, init)[1] - auto invert_while = While(cond, body, init); - auto inv_diag_blocks = GetTupleElement(invert_while, 1); - - // Undo the scaling - inv_diag_blocks = Div(inv_diag_blocks, diags, - /*broadcast_dimensions=*/{0, 1}); - - // Reshape back to original batch major dimensions - return Reshape(inv_diag_blocks, AsInt64Slice(shape.dimensions())); - }); -} - XlaOp SolveWithInvertedDiagonalBlocks(XlaOp a, XlaOp b, XlaOp inv_diag_blocks, bool left_side, bool lower, bool transpose_a, bool conjugate_a, @@ -357,10 +236,140 @@ XlaOp SolveWithInvertedDiagonalBlocks(XlaOp a, XlaOp b, XlaOp inv_diag_blocks, }); } -XlaOp BuildTriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, - bool transpose_a, bool conjugate_a, - bool unit_diagonal, int64 block_size, - PrecisionConfig::Precision precision) { +} // namespace + +XlaOp TriangularSolveExpander::InvertDiagonalBlocks( + XlaOp diag_blocks, bool lower_triangular, + PrecisionConfig::Precision precision) { + XlaBuilder* builder = diag_blocks.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + // Input is a batch of square lower triangular square matrices. Its shape is + // (..., size, size). We resize this to (num_blocks, size, size). + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(diag_blocks)); + int64 block_size = ShapeUtil::GetDimension(shape, -1); + int64 num_blocks = ShapeUtil::ElementsIn(shape) / + tensorflow::MathUtil::IPow(block_size, 2); + diag_blocks = Reshape(diag_blocks, {num_blocks, block_size, block_size}); + + // The input must be triangular because we rely on that when doing + // multiplications later on + diag_blocks = Triangle(diag_blocks, /*lower=*/lower_triangular); + + // Rescale blocks to be unit triangular, but avoid dividing by + // zero (which can happen if the last block was padded) otherwise it will + // introduce nans which will propagate + auto diags = GetMatrixDiagonal(diag_blocks); + auto ones = FullLike(diags, 1); + diags = Select(Eq(diags, Zero(builder, shape.element_type())), ones, diags); + auto scaled_diag_blocks = Div(diag_blocks, diags, {0, 2}); + + // We can now use the fact that for an upper triangular matrix + // [[L11, 0], [L21, L22]], given the inverses L11' and L22', we have + // L22' = -L22' * L21 * L11'. In our case, L21 is a vector and our blocks + // have been rescaled to be unit triangular, so L22 = L22' = 1. + + // Initialize the output matrix with -1s on the diagonal. We use -1 instead + // of 1 because we cannot do matrix-vector multiplies with variable shapes + // inside of a loop, or do irregularly shaped in-place updates. Hence, + // L21 <- -L22 * L21 * L11 cannot be done naively. Instead, we update the + // entire row i.e. we calculate + // [L21 L22 0] <- -[L21 L22 0] @ diag_blocks([L11', -I, -I]) + // which means [L21 L22 0] <- [-L21 * L11', L22, 0]. + auto identity = + IdentityMatrix(builder, shape.element_type(), block_size, block_size); + auto neg_identity = -identity; + + // The first or last diagonal element should be set to 1 instead of -1 + // though, since we never update it + auto pos_one = Reshape(One(builder, shape.element_type()), {1, 1}); + auto start_index = + ConstantR0(builder, lower_triangular ? 0 : block_size - 1); + auto output_block = + DynamicUpdateSlice(neg_identity, pos_one, + /*start_indices=*/{start_index, start_index}); + + // Broadcast diag([1, -1, -1, ...]) to every block + XlaOp output = Broadcast(output_block, + /*broadcast_sizes=*/{num_blocks}); + + // Now we construct a loop that performs matrix-vector multiplications + // inverting the blocks one row at a time + std::vector tuple_shapes = { + // The loop iteration counter is a scalar, incremented each iteration. + ShapeUtil::MakeShape(S32, {}), + // The output has the shape of A, with one row updated each iteration. + ShapeUtil::MakeShape(shape.element_type(), + {num_blocks, block_size, block_size}), + // The input is a loop invariant. + ShapeUtil::MakeShape(shape.element_type(), + {num_blocks, block_size, block_size})}; + Shape tuple_shape = ShapeUtil::MakeTupleShape(tuple_shapes); + + auto init_i = One(builder, S32); + auto init = Tuple(builder, {init_i, output, scaled_diag_blocks}); + + // Construct the loop condition function. + std::unique_ptr condb = + builder->CreateSubBuilder("InvertDiagCond"); + { + auto i = GetTupleElement( + Parameter(condb.get(), 0, tuple_shape, "InvertDiagCondTuple"), 0); + Lt(i, ConstantR0(condb.get(), block_size)); + } + TF_ASSIGN_OR_RETURN(auto cond, condb->Build()); + + // Construct the loop body function. + std::unique_ptr bodyb = + builder->CreateSubBuilder("InvertDiagBody"); + { + auto input_tuple = + Parameter(bodyb.get(), 0, tuple_shape, "InvertDiagBodyTuple"); + + auto i = GetTupleElement(input_tuple, 0); + auto body_out = GetTupleElement(input_tuple, 1); + auto body_input = GetTupleElement(input_tuple, 2); + + auto zero = ConstantR0(bodyb.get(), 0); + auto j = lower_triangular ? i : ScalarLike(i, block_size - 1) - i; + auto input_row = + DynamicSlice(body_input, {zero, j, zero}, + /*slice_sizes=*/{num_blocks, 1, block_size}); + + // We want -L21 L11^{-1} + DotDimensionNumbers dnums; + dnums.add_lhs_batch_dimensions(0); + dnums.add_rhs_batch_dimensions(0); + dnums.add_lhs_contracting_dimensions(2); + dnums.add_rhs_contracting_dimensions(1); + PrecisionConfig precision_proto; + precision_proto.add_operand_precision(precision); + precision_proto.add_operand_precision(precision); + auto update = -DotGeneral(input_row, body_out, dnums, &precision_proto); + + body_out = DynamicUpdateSlice(body_out, update, {zero, j, zero}); + + auto next_i = i + ScalarLike(i, 1); + Tuple(bodyb.get(), {next_i, body_out, body_input}); + } + TF_ASSIGN_OR_RETURN(auto body, bodyb->Build()); + + // Construct the While loop and return the result, + // return while_loop(cond_fun, body_fun, init)[1] + auto invert_while = While(cond, body, init); + auto inv_diag_blocks = GetTupleElement(invert_while, 1); + // Undo the scaling + inv_diag_blocks = Div(inv_diag_blocks, diags, + /*broadcast_dimensions=*/{0, 1}); + + // Reshape back to original batch major dimensions + return Reshape(inv_diag_blocks, AsInt64Slice(shape.dimensions())); + }); +} + +XlaOp TriangularSolveExpander::BuildTriangularSolve( + XlaOp a, XlaOp b, bool left_side, bool lower, bool transpose_a, + bool conjugate_a, bool unit_diagonal, int64 block_size, + PrecisionConfig::Precision precision) { XlaBuilder* builder = a.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); @@ -422,6 +431,11 @@ XlaOp BuildTriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, return b; } + // Degenerate case: 1x1 matrices. + if (ShapeUtil::GetDimension(a_shape, -1) == 1) { + return unit_diagonal ? b : Div(b, MaybeConjugate(a, conjugate_a)); + } + // TODO(phawkins): consider pushing triangle masking into // InvertDiagonalBlocks. if (unit_diagonal) { @@ -440,8 +454,7 @@ XlaOp BuildTriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, auto diag_blocks = DiagonalBlocks(a, block_size); // We invert these blocks in parallel using batched matrix-vector products - auto inv_diag_blocks = InvertDiagonalBlocks(diag_blocks, lower, transpose_a, - conjugate_a, precision); + auto inv_diag_blocks = InvertDiagonalBlocks(diag_blocks, lower, precision); // We now find the solution using GEMMs auto x = @@ -452,8 +465,6 @@ XlaOp BuildTriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, }); } -} // namespace - TriangularSolveExpander::TriangularSolveExpander(int64 block_size) : block_size_(block_size) {} diff --git a/tensorflow/compiler/xla/service/triangular_solve_expander.h b/tensorflow/compiler/xla/service/triangular_solve_expander.h index 362e8557229..3f9e58a3246 100644 --- a/tensorflow/compiler/xla/service/triangular_solve_expander.h +++ b/tensorflow/compiler/xla/service/triangular_solve_expander.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_TRIANGULAR_SOLVE_EXPANDER_H_ #include "absl/container/flat_hash_map.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/service/op_expander_pass.h" namespace xla { @@ -35,6 +36,14 @@ class TriangularSolveExpander : public OpExpanderPass { StatusOr ExpandInstruction( HloInstruction* instruction) override; + virtual XlaOp InvertDiagonalBlocks(XlaOp diag_blocks, bool lower_triangular, + PrecisionConfig::Precision precision); + + XlaOp BuildTriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, + bool transpose_a, bool conjugate_a, + bool unit_diagonal, int64 block_size, + PrecisionConfig::Precision precision); + private: // Block size for BuildTriangularSolve const int64 block_size_; diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc index c80123bcd50..785fdecbfa0 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc @@ -37,23 +37,15 @@ namespace m = match; using absl::optional; using hlo_query::ContainsInstrWithOpcode; -// Tries to remove elements in a while loop's tuple that aren't used within the -// loop. -// -// Specifically, if a loop is tuple-shaped, and there exists some element of -// that tuple that is not used by the loop condition and is not used by the loop -// body except to pass it to the next iteration of the loop, then we can remove -// that element from the loop's tuples. -static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { - CHECK_EQ(while_op->opcode(), HloOpcode::kWhile); - - // Don't try this transformation if the while loop isn't removable, since if - // it succeeds ultimately we're going to have to replace the old while loop - // with a new one. - if (!while_op->parent()->IsSafelyRemovable(while_op)) { - VLOG(2) << "Can't remove dead parameters from non-removable while op."; - return false; - } +// This is a utility function that removes the given tuple indices from the +// while loop init, body, and condition. The final shape returned is still the +// same as before. +static StatusOr RemoveDeadTupleIndices( + HloInstruction* while_op, absl::flat_hash_set& used_tuple_indices) { + // Build up maps from the old/new to the new/old tuple indices. + std::vector new_to_old_tuple_idx(used_tuple_indices.begin(), + used_tuple_indices.end()); + absl::c_sort(new_to_old_tuple_idx); HloModule* module = while_op->GetModule(); HloComputation* computation = while_op->parent(); @@ -62,107 +54,8 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { HloComputation* while_body = while_op->while_body(); HloInstruction* while_body_root = while_body->root_instruction(); - if (!while_init->shape().IsTuple()) { - VLOG(2) << "While op's carried value isn't tuple shaped."; - return false; - } - - if (while_body_root->opcode() != HloOpcode::kTuple) { - VLOG(2) << "While body's root is not a tuple(...) instruction."; - return false; - } - auto print_no_metadata = HloPrintOptions().set_print_metadata(false); - // Bail if param0 of while_cond or while_body has users which aren't of type - // get-tuple-element. - for (const HloInstruction* instr : {while_body->parameter_instruction(0), - while_cond->parameter_instruction(0)}) { - for (const HloInstruction* user : instr->users()) { - if (user->opcode() != HloOpcode::kGetTupleElement) { - VLOG(2) << "Cowardly refusing to analyze while loop with " - << instr->ToString(print_no_metadata) - << " used by non-GTE instruction " - << user->ToString(print_no_metadata) << " in computation " - << instr->parent()->name(); - return false; - } - } - } - - const int64 tuple_size = ShapeUtil::TupleElementCount(while_init->shape()); - if (tuple_size == 0) { - VLOG(2) << "Can't remove elements from while loop's tuple -- it's already " - "empty."; - return false; - } - - absl::flat_hash_set used_tuple_indices; - for (HloComputation* comp : {while_body, while_cond}) { - // The HLO verifier ensures that while_input's shape matches while_init's - // shape, which we verified above is a tuple. - HloInstruction* while_input = comp->parameter_instruction(0); - - for (const HloInstruction* user : while_input->users()) { - // This user doesn't count if it's only used by the while body's root, and - // the root places the tuple element into the same index of the tuple as - // it came from. That just amounts to us carrying the variable through - // the loop. - // - // Careful: HloInstruction::operand_index returns the first index the - // operand appears in, but it may appear more than once! - if (user->user_count() == 1 && user->users().front() == while_body_root && - while_body_root->operand_index(user) == user->tuple_index() && - absl::c_count(while_body_root->operands(), user) == 1) { - continue; - } - - used_tuple_indices.insert(user->tuple_index()); - if (used_tuple_indices.size() == tuple_size) { - VLOG(2) << "Loop " << while_op->ToString(print_no_metadata) - << " uses all of its inputs; no simplification possible."; - return false; - } - } - } - - // If a tuple element is not passed unmodified from the while body's param0 - // through to the while body's root, count that element as "used", since - // removing that element would be observable. - for (int64 i = 0; i < while_body_root->operand_count(); ++i) { - if (used_tuple_indices.contains(i)) { - continue; - } - - auto* operand = while_body_root->operand(i); - if (operand->opcode() != HloOpcode::kGetTupleElement || - operand->operand(0) != while_body->parameter_instruction(0) || - operand->tuple_index() != i) { - VLOG(2) << "Tuple index " << i - << " is not passed through loop body unmodified."; - used_tuple_indices.insert(i); - - if (used_tuple_indices.size() == tuple_size) { - VLOG(2) << "Loop " << while_op->ToString(print_no_metadata) - << " uses all of its inputs; no simplification possible."; - return false; - } - } - } - - // If we got here, used_tuple_indices.size() < tuple_size, meaning some - // elements of the loop's tuple aren't used by while_body or while_cond. - CHECK_LT(used_tuple_indices.size(), tuple_size); - - VLOG(1) << "Eliminating " << tuple_size - used_tuple_indices.size() - << " elements from tuple of " - << while_op->ToString(print_no_metadata); - - // Build up maps from the old/new to the new/old tuple indices. - std::vector new_to_old_tuple_idx(used_tuple_indices.begin(), - used_tuple_indices.end()); - absl::c_sort(new_to_old_tuple_idx); - absl::flat_hash_map old_to_new_tuple_idx; for (int64 new_idx = 0; new_idx < new_to_old_tuple_idx.size(); ++new_idx) { int64 old_idx = new_to_old_tuple_idx[new_idx]; @@ -288,6 +181,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { // The tuple simplifier will then simplify this if possible, removing // new_tuple and while_init. std::vector new_tuple_elems; + const int64 tuple_size = ShapeUtil::TupleElementCount(while_init->shape()); for (int64 old_idx = 0; old_idx < tuple_size; ++old_idx) { auto new_tuple_idx_it = old_to_new_tuple_idx.find(old_idx); if (new_tuple_idx_it != old_to_new_tuple_idx.end()) { @@ -305,9 +199,293 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { HloInstruction* new_tuple = computation->AddInstruction(HloInstruction::CreateTuple(new_tuple_elems)); TF_RETURN_IF_ERROR(computation->ReplaceInstruction(while_op, new_tuple)); + + return new_while_op; +} + +// Tries to remove elements in a while loop's tuple that aren't used within the +// loop. +// +// Specifically, if a loop is tuple-shaped, and there exists some element of +// that tuple that is not used by the loop condition and is not used by the loop +// body except to pass it to the next iteration of the loop, then we can remove +// that element from the loop's tuples. +static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { + CHECK_EQ(while_op->opcode(), HloOpcode::kWhile); + + // Don't try this transformation if the while loop isn't removable, since if + // it succeeds ultimately we're going to have to replace the old while loop + // with a new one. + if (!while_op->parent()->IsSafelyRemovable(while_op)) { + VLOG(2) << "Can't remove dead parameters from non-removable while op."; + return false; + } + + HloInstruction* while_init = while_op->mutable_operand(0); + HloComputation* while_cond = while_op->while_condition(); + HloComputation* while_body = while_op->while_body(); + HloInstruction* while_body_root = while_body->root_instruction(); + + if (!while_init->shape().IsTuple()) { + VLOG(2) << "While op's carried value isn't tuple shaped."; + return false; + } + + if (while_body_root->opcode() != HloOpcode::kTuple) { + VLOG(2) << "While body's root is not a tuple(...) instruction."; + return false; + } + + auto print_no_metadata = HloPrintOptions().set_print_metadata(false); + + // Bail if param0 of while_cond or while_body has users which aren't of type + // get-tuple-element. + for (const HloInstruction* instr : {while_body->parameter_instruction(0), + while_cond->parameter_instruction(0)}) { + for (const HloInstruction* user : instr->users()) { + if (user->opcode() != HloOpcode::kGetTupleElement) { + VLOG(2) << "Cowardly refusing to analyze while loop with " + << instr->ToString(print_no_metadata) + << " used by non-GTE instruction " + << user->ToString(print_no_metadata) << " in computation " + << instr->parent()->name(); + return false; + } + } + } + + const int64 tuple_size = ShapeUtil::TupleElementCount(while_init->shape()); + if (tuple_size == 0) { + VLOG(2) << "Can't remove elements from while loop's tuple -- it's already " + "empty."; + return false; + } + + absl::flat_hash_set used_tuple_indices; + for (HloComputation* comp : {while_body, while_cond}) { + // The HLO verifier ensures that while_input's shape matches while_init's + // shape, which we verified above is a tuple. + HloInstruction* while_input = comp->parameter_instruction(0); + + for (const HloInstruction* user : while_input->users()) { + // This user doesn't count if it's only used by the while body's root, and + // the root places the tuple element into the same index of the tuple as + // it came from. That just amounts to us carrying the variable through + // the loop. + // + // Careful: HloInstruction::operand_index returns the first index the + // operand appears in, but it may appear more than once! + if (user->user_count() == 1 && user->users().front() == while_body_root && + while_body_root->operand_index(user) == user->tuple_index() && + absl::c_count(while_body_root->operands(), user) == 1) { + continue; + } + + used_tuple_indices.insert(user->tuple_index()); + if (used_tuple_indices.size() == tuple_size) { + VLOG(2) << "Loop " << while_op->ToString(print_no_metadata) + << " uses all of its inputs; no simplification possible."; + return false; + } + } + } + + // If a tuple element is not passed unmodified from the while body's param0 + // through to the while body's root, count that element as "used", since + // removing that element would be observable. + for (int64 i = 0; i < while_body_root->operand_count(); ++i) { + if (used_tuple_indices.contains(i)) { + continue; + } + + auto* operand = while_body_root->operand(i); + if (operand->opcode() != HloOpcode::kGetTupleElement || + operand->operand(0) != while_body->parameter_instruction(0) || + operand->tuple_index() != i) { + VLOG(2) << "Tuple index " << i + << " is not passed through loop body unmodified."; + used_tuple_indices.insert(i); + + if (used_tuple_indices.size() == tuple_size) { + VLOG(2) << "Loop " << while_op->ToString(print_no_metadata) + << " uses all of its inputs; no simplification possible."; + return false; + } + } + } + + // If we got here, used_tuple_indices.size() < tuple_size, meaning some + // elements of the loop's tuple aren't used by while_body or while_cond. + CHECK_LT(used_tuple_indices.size(), tuple_size); + + VLOG(1) << "Eliminating " << tuple_size - used_tuple_indices.size() + << " elements from tuple of " + << while_op->ToString(print_no_metadata); + + TF_ASSIGN_OR_RETURN(while_op, + RemoveDeadTupleIndices(while_op, used_tuple_indices)); + return true; } +// This is a helper function for TryRemoveRepeatedWhileTupleIndices. It removes +// duplicates by replacing them with tuple_index, followed by a call to +// RemoveDeadTupleIndices. +static StatusOr TryRemoveRepeatedWhileTupleIndicesHelper( + HloInstruction* while_op, const int64 tuple_index, + absl::flat_hash_set& duplicates) { + HloComputation* while_cond = while_op->while_condition(); + HloComputation* while_body = while_op->while_body(); + HloInstruction* while_init = while_op->mutable_operand(0); + + VLOG(2) << "while_init " << while_init->ToString() << " operands " + << while_init->operand_count(); + VLOG(2) << "while_body_root " << while_body->root_instruction()->ToString() + << " operands " << while_body->root_instruction()->operand_count(); + + // Change the loop body and condition such that uses of the duplicates are + // replaced with the original tuple element. + for (HloComputation* comp : {while_body, while_cond}) { + auto new_get = comp->AddInstruction(HloInstruction::CreateGetTupleElement( + comp->parameter_instruction(0)->shape().tuple_shapes(tuple_index), + comp->parameter_instruction(0), tuple_index)); + + std::vector instrs_to_replace; + for (auto* instr : comp->instructions()) { + if (instr->opcode() == HloOpcode::kGetTupleElement && + duplicates.contains(instr->tuple_index()) && + instr->operand(0) == comp->parameter_instruction(0)) { + instrs_to_replace.push_back(instr); + } + } + + for (auto instr : instrs_to_replace) { + TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, new_get)); + } + } + + // We know which tuple indices are useful; i.e, those which aren't duplicates. + absl::flat_hash_set used_tuple_indices; + for (int index = 0; index < while_init->shape().tuple_shapes_size(); + ++index) { + if (!duplicates.count(index)) { + used_tuple_indices.insert(index); + } + } + + // Remove the duplicate tuple elements. + TF_ASSIGN_OR_RETURN(while_op, + RemoveDeadTupleIndices(while_op, used_tuple_indices)); + + return while_op; +} + +// If the while loop init passes the same values to several tuple indices, and +// if the body keeps on passing them through, we can remove the duplicates. +static StatusOr TryRemoveRepeatedWhileTupleIndices( + HloInstruction* while_op) { + CHECK_EQ(while_op->opcode(), HloOpcode::kWhile); + + int index_to_investigate = 0; + // Don't try this transformation if the while loop isn't removable, since if + // it succeeds ultimately we're going to have to replace the old while loop + // with a new one. + if (!while_op->parent()->IsSafelyRemovable(while_op)) { + VLOG(2) << "Can't remove dead parameters from non-removable while op."; + return false; + } + + HloInstruction* while_init = while_op->mutable_operand(0); + HloComputation* while_cond = while_op->while_condition(); + HloComputation* while_body = while_op->while_body(); + HloInstruction* while_body_root = while_body->root_instruction(); + + if (!while_init->shape().IsTuple()) { + VLOG(2) << "While op's carried value isn't tuple shaped."; + return false; + } + + bool changed = false; + while (index_to_investigate < while_init->shape().tuple_shapes_size()) { + if (!while_init->shape().IsTuple() || + while_init->opcode() != HloOpcode::kTuple) { + VLOG(2) << "While op's carried value isn't tuple shaped."; + return false; + } + + if (while_body_root->opcode() != HloOpcode::kTuple) { + VLOG(2) << "While body's root is not a tuple(...) instruction."; + return false; + } + + auto& while_shape = while_init->shape(); + VLOG(2) << "Iterating " << index_to_investigate; + + absl::flat_hash_set duplicates; + auto* pivot_init_elem = while_init->operand(index_to_investigate); + auto* pivot_body_elem = while_body_root->operand(index_to_investigate); + if (pivot_body_elem->opcode() == HloOpcode::kGetTupleElement && + pivot_body_elem->operand(0) == while_body->parameter_instruction(0)) { + if (pivot_body_elem->tuple_index() != index_to_investigate) { + VLOG(2) << "Mismatch between pivot_body_elem->tuple_index() " + << pivot_body_elem->tuple_index() << " index_to_investigate " + << index_to_investigate; + index_to_investigate++; + continue; + } + } else { + index_to_investigate++; + continue; + } + + // Look from index_to_investigate onwards to see if it is repeated. + for (int64 i = index_to_investigate + 1; + i < while_shape.tuple_shapes_size(); ++i) { + auto* init_elem = while_init->operand(i); + auto* body_elem = while_body_root->operand(i); + if (body_elem->opcode() == HloOpcode::kGetTupleElement && + body_elem->operand(0) == while_body->parameter_instruction(0)) { + if (body_elem->tuple_index() != i) { + VLOG(2) << "Mismatch between body_elem->tuple_index() " + << body_elem->tuple_index() << " i " << i; + continue; + } + } else { + continue; + } + + if (pivot_init_elem == init_elem) { + VLOG(2) << "init_elem " << init_elem->ToString() << " pivot_init_elem " + << pivot_init_elem->ToString(); + VLOG(2) << "body_elem " << body_elem->ToString() << " pivot_body_elem " + << pivot_body_elem->ToString(); + duplicates.insert(i); + } + } + + // If duplicates are found, call the helper to remove them. + if (!duplicates.empty()) { + VLOG(2) << "Duplicate found " << duplicates.size() << " pivot_init " + << pivot_init_elem->ToString(); + TF_ASSIGN_OR_RETURN(while_op, + TryRemoveRepeatedWhileTupleIndicesHelper( + while_op, index_to_investigate, duplicates)); + changed = true; + VLOG(2) << "Changed while_op " << while_op->ToString() + << " while_op operand count " << while_op->operand_count(); + // Update the while loop variables so we can continue looking for + // duplicates of a different index. + while_init = while_op->mutable_operand(0); + while_cond = while_op->while_condition(); + while_body = while_op->while_body(); + while_body_root = while_body->root_instruction(); + } + index_to_investigate++; + } + + return changed; +} + // Removes each loop parameter (i.e. member of the while loop tuple) that is a // constant and is the same in the while loop body and the while loop init. static StatusOr TryRemoveConstantParams(HloInstruction* while_op) { @@ -1048,6 +1226,7 @@ StatusOr WhileLoopSimplifier::Run(HloModule* module) { TF_ASSIGN_OR_RETURN(result, TryRemoveWhileLoop(while_op)); changed |= result; + if (result) { // Don't continue simplifying after successfully removing the while loop // -- that would result in use-after-free nastiness. @@ -1067,6 +1246,12 @@ StatusOr WhileLoopSimplifier::Run(HloModule* module) { // successful, meaning that `while_op` is no longer valid after one of these // transformations returns true. + TF_ASSIGN_OR_RETURN(result, TryRemoveRepeatedWhileTupleIndices(while_op)); + changed |= result; + if (result) { + continue; + } + TF_ASSIGN_OR_RETURN(result, TryFlattenNestedTuples(while_op)); changed |= result; if (result) { @@ -1074,6 +1259,7 @@ StatusOr WhileLoopSimplifier::Run(HloModule* module) { } TF_ASSIGN_OR_RETURN(result, TryRemoveDeadWhileParams(while_op)); + changed |= result; if (result) { continue; diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc index d715fb3857a..c93cb5dc347 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc @@ -794,5 +794,51 @@ TEST_F(WhileLoopSimplifierTest, MergeInductionVariables_SkipS16) { .ValueOrDie()); } +TEST_F(WhileLoopSimplifierTest, RemoveRepeatedParams) { + const string hlo_string = R"( + HloModule SwappingTupleElements + + SwappingTupleElements.body { + loop_var = (s32[], s32[], s32[]) parameter(0) + get-tuple-element = s32[] get-tuple-element(loop_var), index=0 + get-tuple-element.1 = s32[] get-tuple-element(loop_var), index=1 + get-tuple-element.2 = s32[] get-tuple-element(loop_var), index=2 + y = s32[] add(get-tuple-element.1, get-tuple-element.2) + ROOT tuple = (s32[], s32[], s32[]) tuple(s32[] get-tuple-element, y, + s32[] get-tuple-element.2) + } + + SwappingTupleElements.always_true { + param = (s32[], s32[], s32[]) parameter(0) + get-tuple-element = s32[] get-tuple-element(param), index=0 + get-tuple-element.1 = s32[] get-tuple-element(param), index=1 + ROOT less-than = pred[] compare(get-tuple-element, get-tuple-element.1), direction=LT + } + + ENTRY SwappingTupleElements { + x = s32[] parameter(0) + y = s32[] parameter(1) + tuple.1 = (s32[], s32[], s32[]) tuple(s32[] x, s32[] y, s32[] x) + ROOT while = (s32[], s32[], s32[]) while(tuple.1), + condition=SwappingTupleElements.always_true, + body=SwappingTupleElements.body + } + )"; + + auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + EXPECT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); + HloInstruction* new_while = FindFirstWhile(m.get()); + Shape new_while_shape = ParseShape("(s32[], s32[])").ValueOrDie(); + EXPECT_TRUE(ShapeUtil::Equal(new_while->shape(), new_while_shape)); + EXPECT_TRUE(ShapeUtil::Equal( + new_while->while_body()->root_instruction()->shape(), new_while_shape)); + EXPECT_TRUE(ShapeUtil::Equal( + new_while->while_body()->parameter_instruction(0)->shape(), + new_while_shape)); + EXPECT_TRUE(ShapeUtil::Equal( + new_while->while_condition()->parameter_instruction(0)->shape(), + new_while_shape)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h index 73bb3327784..b1c96e9becf 100644 --- a/tensorflow/compiler/xla/shape_tree.h +++ b/tensorflow/compiler/xla/shape_tree.h @@ -70,6 +70,8 @@ struct IndexTableEntry { template class ShapeTreeIterator; +template +class ShapeTreeLeafIterator; // A ShapeTree is a recursive data structure which mirrors the structure of a // XLA shape and holds a value of type T for each subshape (i.e. tuple or array) @@ -158,23 +160,25 @@ class ShapeTree { using reverse_iterator = std::reverse_iterator; using const_reverse_iterator = std::reverse_iterator; + using leaf_iterator = + ShapeTreeLeafIterator, + typename std::vector::iterator, + std::pair>; + using const_leaf_iterator = + ShapeTreeLeafIterator, + typename std::vector::const_iterator, + const std::pair>; + using reverse_leaf_iterator = std::reverse_iterator; + using const_reverse_leaf_iterator = + std::reverse_iterator; + // begin/end for iterating over all nodes. - iterator begin() { - return iterator(&nodes_, nodes_.begin(), - /*iterate_leaves_only=*/false); - } - iterator end() { - return iterator(&nodes_, nodes_.end(), - /*iterate_leaves_only=*/false); - } + iterator begin() { return iterator(&nodes_, nodes_.begin()); } + iterator end() { return iterator(&nodes_, nodes_.end()); } const_iterator begin() const { - return const_iterator(&nodes_, nodes_.begin(), - /*iterate_leaves_only=*/false); - } - const_iterator end() const { - return const_iterator(&nodes_, nodes_.end(), - /*iterate_leaves_only=*/false); + return const_iterator(&nodes_, nodes_.begin()); } + const_iterator end() const { return const_iterator(&nodes_, nodes_.end()); } // rbegin/rend for iterating over all nodes in reverse. reverse_iterator rbegin() { return reverse_iterator(end()); } @@ -188,37 +192,33 @@ class ShapeTree { // leaf_begin()/leaf_end() iterates over all leaf nodes (nodes with no // children). - iterator leaf_begin() { - return iterator(&nodes_, nodes_.begin(), - /*iterate_leaves_only=*/true); + leaf_iterator leaf_begin() { return leaf_iterator(&nodes_, nodes_.begin()); } + leaf_iterator leaf_end() { return leaf_iterator(&nodes_, nodes_.end()); } + const_leaf_iterator leaf_begin() const { + return const_leaf_iterator(&nodes_, nodes_.begin()); } - iterator leaf_end() { - return iterator(&nodes_, nodes_.end(), - /*iterate_leaves_only=*/true); - } - const_iterator leaf_begin() const { - return const_iterator(&nodes_, nodes_.begin(), - /*iterate_leaves_only=*/true); - } - const_iterator leaf_end() const { - return const_iterator(&nodes_, nodes_.end(), - /*iterate_leaves_only=*/true); + const_leaf_iterator leaf_end() const { + return const_leaf_iterator(&nodes_, nodes_.end()); } // range-based iterator for leaf_begin()/leaf_end(). - tensorflow::gtl::iterator_range leaves() { + tensorflow::gtl::iterator_range leaves() { return tensorflow::gtl::make_range(leaf_begin(), leaf_end()); } - tensorflow::gtl::iterator_range leaves() const { + tensorflow::gtl::iterator_range leaves() const { return tensorflow::gtl::make_range(leaf_begin(), leaf_end()); } - reverse_iterator leaf_rbegin() { return reverse_iterator(leaf_end()); } - reverse_iterator leaf_rend() { return reverse_iterator(leaf_begin()); } - const_reverse_iterator leaf_rbegin() const { - return const_reverse_iterator(leaf_end()); + reverse_leaf_iterator leaf_rbegin() { + return reverse_leaf_iterator(leaf_end()); } - const_reverse_iterator leaf_rend() const { - return const_reverse_iterator(leaf_begin()); + reverse_leaf_iterator leaf_rend() { + return reverse_leaf_iterator(leaf_begin()); + } + const_reverse_leaf_iterator leaf_rbegin() const { + return const_reverse_leaf_iterator(leaf_end()); + } + const_reverse_leaf_iterator leaf_rend() const { + return const_reverse_leaf_iterator(leaf_begin()); } // Returns an iterator pointing to the given ShapeIndex. @@ -226,12 +226,12 @@ class ShapeTree { iterator find(ShapeIndexView index) { Node* element = Lookup(index); auto element_iter = nodes_.begin() + (element - &nodes_[0]); - return iterator(&nodes_, element_iter, /*iterate_leaves_only=*/false); + return iterator(&nodes_, element_iter); } const_iterator find(ShapeIndexView index) const { - Node* element = Lookup(index); + const Node* element = Lookup(index); auto element_iter = nodes_.cbegin() + (element - &nodes_[0]); - return const_iterator(&nodes_, element_iter, /*iterate_leaves_only=*/false); + return const_iterator(&nodes_, element_iter); } // Returns the number of leaf nodes in the tree. @@ -343,21 +343,11 @@ template class ShapeTreeIterator : public std::iterator { public: - ShapeTreeIterator(ContainerType* nodes, IteratorType node, - bool iterate_leaves_only) - : nodes_(nodes), - node_(std::move(node)), - iterate_leaves_only_(iterate_leaves_only) { - while (iterate_leaves_only && node_ != nodes_->end() && !node_->is_leaf) { - ++node_; - } - } + ShapeTreeIterator(ContainerType* nodes, IteratorType node) + : nodes_(nodes), node_(std::move(node)) {} ShapeTreeIterator& operator++() { ++node_; - while (iterate_leaves_only_ && node_ != nodes_->end() && !node_->is_leaf) { - ++node_; - } return *this; } ShapeTreeIterator operator++(int) { @@ -368,9 +358,6 @@ class ShapeTreeIterator ShapeTreeIterator& operator--() { --node_; - while (iterate_leaves_only_ && node_ > nodes_->begin() && !node_->is_leaf) { - --node_; - } return *this; } ShapeTreeIterator operator--(int) { @@ -385,14 +372,66 @@ class ShapeTreeIterator bool operator!=(const ShapeTreeIterator& other) const { return node_ != other.node_; } - ValueType& operator*() { return node_->data; } - ValueType* operator->() { return &node_->data; } + ValueType& operator*() const { return node_->data; } + ValueType* operator->() const { return &node_->data; } + + private: + ContainerType* nodes_; + IteratorType node_; +}; + +// Internal iterator that performs a pre-order walk of the leaves. This is cheap +// to copy. The iterator value_type is equivalent to a std::pair&, +// similar to std::map. +template +class ShapeTreeLeafIterator + : public std::iterator { + public: + ShapeTreeLeafIterator(ContainerType* nodes, IteratorType node) + : nodes_(nodes), node_(std::move(node)) { + while (node_ != nodes_->end() && !node_->is_leaf) { + ++node_; + } + } + + ShapeTreeLeafIterator& operator++() { + ++node_; + while (node_ != nodes_->end() && !node_->is_leaf) { + ++node_; + } + return *this; + } + ShapeTreeLeafIterator operator++(int) { + auto i = *this; + ++(*this); + return i; + } + + ShapeTreeLeafIterator& operator--() { + --node_; + while (node_ > nodes_->begin() && !node_->is_leaf) { + --node_; + } + return *this; + } + ShapeTreeLeafIterator operator--(int) { + auto i = *this; + --(*this); + return i; + } + + bool operator==(const ShapeTreeLeafIterator& other) const { + return node_ == other.node_; + } + bool operator!=(const ShapeTreeLeafIterator& other) const { + return node_ != other.node_; + } + ValueType& operator*() const { return node_->data; } + ValueType* operator->() const { return &node_->data; } private: ContainerType* nodes_; IteratorType node_; - // True if we should not include interior nodes in our walk. - const bool iterate_leaves_only_; }; template @@ -648,7 +687,9 @@ void ShapeTree::CopySubtreeFrom(const ShapeTree& other, const ShapeIndex& target_base_index) { CHECK(ShapeUtil::Compatible( ShapeUtil::GetSubshape(shape(), target_base_index), - ShapeUtil::GetSubshape(other.shape(), source_base_index))); + ShapeUtil::GetSubshape(other.shape(), source_base_index))) + << ShapeUtil::GetSubshape(shape(), target_base_index) << " vs " + << ShapeUtil::GetSubshape(other.shape(), source_base_index); ForEachMutableElement([this, &other, &source_base_index, &target_base_index]( const ShapeIndex& index, T* data) { // Copy the data element only if index is in the diff --git a/tensorflow/compiler/xla/shape_tree_test.cc b/tensorflow/compiler/xla/shape_tree_test.cc index 2b6c484bc4f..c294355e269 100644 --- a/tensorflow/compiler/xla/shape_tree_test.cc +++ b/tensorflow/compiler/xla/shape_tree_test.cc @@ -485,6 +485,30 @@ TEST_F(ShapeTreeTest, ReverseIterateOrder) { })); } +// Ensures that we can find an element at an index that we know ahead of time to +// be occupied in a 'ShapeTree' via the 'find' API. +TEST_F(ShapeTreeTest, Find) { + ShapeTree t(nested_tuple_shape_, 42); + auto found = t.find({1, 0}); + EXPECT_NE(found, t.end()); + // The found key must be the same key we searched for. + EXPECT_EQ(found->first, ShapeIndex({1, 0})); + // The 'ShapeTree' has 42 at every position. + EXPECT_EQ(found->second, 42); +} + +// Ensures that we can find an element at an index that we know ahead of time to +// be occupied in a 'const ShapeTree' via the 'find' API. +TEST_F(ShapeTreeTest, ConstFind) { + const ShapeTree t(nested_tuple_shape_, 42); + auto found = t.find({1, 0}); + EXPECT_NE(found, t.end()); + // The found key must be the same key we searched for. + EXPECT_EQ(found->first, ShapeIndex({1, 0})); + // The 'ShapeTree' has 42 at every position. + EXPECT_EQ(found->second, 42); +} + TEST_F(ShapeTreeTest, IterateOrderLeaves) { ShapeTree t(nested_tuple_shape_, 42); std::vector v; diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 0833919b124..0c877bf6102 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -1623,4 +1623,14 @@ static Shape MergeDimensions(absl::Span segs, return absl::nullopt; } +Shape ShapeUtil::DeviceShapeToHostShape(Shape s) { + ForEachMutableSubshape(&s, [](Shape* subshape, const ShapeIndex& index) { + if (subshape->IsArray()) { + subshape->mutable_layout()->clear_tiles(); + subshape->mutable_layout()->set_memory_space(Layout::kDefaultMemorySpace); + } + }); + return s; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 3f69a8b0aca..5a5695d32ee 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -783,6 +783,10 @@ class ShapeUtil { static absl::optional> FindTranspose021(const Shape& a, const Shape& b); + // Strips device-specific information, namely tiling and memory-space + // information, from a shape. + static Shape DeviceShapeToHostShape(Shape s); + private: // Validates the shape size is sane. This makes sure it's safe to do // calculations in int64 without overflowing. diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 17444c042e7..98ed49ad76a 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -6,7 +6,14 @@ load( "//tensorflow/core/platform:build_config_root.bzl", "tf_cuda_tests_tags", ) + +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "filegroup") + +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "genrule") load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( default_visibility = [":friends"], @@ -64,9 +71,9 @@ cc_library( hdrs = ["manifest_checking_test.h"], deps = [ ":test_macros_header", - "//tensorflow/core:regexp_internal", "//tensorflow/core:test", "//tensorflow/core/platform:logging", + "//tensorflow/core/platform:regexp", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", ], @@ -164,8 +171,8 @@ cc_library( "//tensorflow/compiler/xla/service:interpreter_plugin", # reference backend "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", + "//tensorflow/core/platform:stream_executor_no_cuda", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", @@ -226,8 +233,8 @@ cc_library( "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", + "//tensorflow/core/platform:stream_executor_no_cuda", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -305,7 +312,7 @@ cc_library( "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "//tensorflow/stream_executor:device_memory_allocator", "//third_party/eigen3", "@com_google_absl//absl/memory", @@ -380,12 +387,7 @@ xla_test( name = "conv_depthwise_backprop_filter_test", timeout = "long", srcs = ["conv_depthwise_backprop_filter_test.cc"], - # these backends do not natively handle batch group counts. - disabled_backends = [ - "gpu", - "cpu", - ], - shard_count = 6, + shard_count = 40, deps = [ ":test_macros_header", "//tensorflow/compiler/xla:execution_options_util", @@ -507,8 +509,8 @@ xla_test( "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/service:stream_pool", "//tensorflow/core:lib", - "//tensorflow/core:regexp_internal", "//tensorflow/core:test", + "//tensorflow/core/platform:regexp", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", @@ -553,8 +555,8 @@ xla_test( "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", + "//tensorflow/core/platform:stream_executor_no_cuda", ], ) @@ -1456,8 +1458,8 @@ xla_test( "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", + "//tensorflow/core/platform:stream_executor_no_cuda", "//tensorflow/stream_executor:device_memory_allocator", ], ) @@ -1913,8 +1915,8 @@ xla_test( "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", + "//tensorflow/core/platform:stream_executor_no_cuda", ], ) @@ -1922,9 +1924,14 @@ xla_test( name = "concat_test", srcs = ["concat_test.cc"], deps = [ + ":client_library_test_base", + ":hlo_test_base", + ":literal_test_util", ":test_macros_header", + ":xla_internal_test_main", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", + "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", @@ -1932,9 +1939,6 @@ xla_test( "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/tests:client_library_test_base", - "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], ) @@ -1952,8 +1956,8 @@ xla_test( "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", + "//tensorflow/core/platform:stream_executor_no_cuda", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", ], @@ -1982,8 +1986,8 @@ xla_test( "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", + "//tensorflow/core/platform:stream_executor_no_cuda", ], ) @@ -2043,8 +2047,8 @@ xla_test( "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", + "//tensorflow/core/platform:stream_executor_no_cuda", ], ) @@ -2398,8 +2402,8 @@ xla_test( "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", + "//tensorflow/core/platform:stream_executor_no_cuda", "//tensorflow/stream_executor:device_memory_allocator", ], ) @@ -2521,8 +2525,8 @@ xla_test( "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/service:stream_pool", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", + "//tensorflow/core/platform:stream_executor_no_cuda", "//tensorflow/stream_executor:device_memory_allocator", ], ) @@ -2681,6 +2685,7 @@ xla_test( xla_test( name = "cholesky_test", srcs = ["cholesky_test.cc"], + real_hardware_only = True, tags = [ "no_rocm", "optonly", @@ -2699,5 +2704,6 @@ xla_test( "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", + "//tensorflow/core/platform:tensor_float_32_utils", ], ) diff --git a/tensorflow/compiler/xla/tests/buffer_donation_test.cc b/tensorflow/compiler/xla/tests/buffer_donation_test.cc index f78083fe2af..7915737178d 100644 --- a/tensorflow/compiler/xla/tests/buffer_donation_test.cc +++ b/tensorflow/compiler/xla/tests/buffer_donation_test.cc @@ -119,8 +119,7 @@ class BufferDonationTest : public HloTestBase { } }); - args.emplace_back( - ExecutionInput(std::move(owned_buffers), argument_literal.shape())); + args.emplace_back(ExecutionInput(std::move(owned_buffers))); } StatusOr output_status = diff --git a/tensorflow/compiler/xla/tests/cholesky_test.cc b/tensorflow/compiler/xla/tests/cholesky_test.cc index e7f5ca5ed8e..4fa28736d4d 100644 --- a/tensorflow/compiler/xla/tests/cholesky_test.cc +++ b/tensorflow/compiler/xla/tests/cholesky_test.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/tensor_float_32_utils.h" namespace xla { namespace { @@ -60,6 +61,44 @@ XLA_TEST_F(CholeskyTest, NonPSDInput) { ErrorSpec(1e-4, 1e-4)); } +XLA_TEST_F(CholeskyTest, NonPSDBatched) { + XlaBuilder builder(TestName()); + + Array3D a_vals({ + { + {10, 0, 0}, + {1, 20, 0}, + {1, 1, 30}, + }, + { + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + }, + }); + + XlaOp a; + auto a_data = CreateR3Parameter(a_vals, 0, "a", &builder, &a); + Cholesky(a, /*lower=*/true); + + float nan = std::numeric_limits::quiet_NaN(); + Array3D expected({ + { + {3.16227766, 0., 0.}, + {0.31622777, 4.4609416, 0.}, + {0.31622777, 0.20175113, 5.46436606}, + }, + { + {nan, nan, nan}, + {nan, nan, nan}, + {nan, nan, nan}, + }, + }); + + ComputeAndCompareR3(&builder, expected, {a_data.get()}, + ErrorSpec(1e-4, 1e-4)); +} + XLA_TEST_F(CholeskyTest, Lower) { XlaBuilder builder(TestName()); @@ -180,7 +219,9 @@ class RandomCholeskyTest : public ClientLibraryTestBase, public ::testing::WithParamInterface {}; -XLA_TEST_P(RandomCholeskyTest, Random) { +XLA_TEST_P(RandomCholeskyTest, Real) { + // Test fails with TensorFloat-32 enabled + tensorflow::enable_tensor_float_32_execution(false); XlaBuilder builder(TestName()); auto test_params = GetParam(); @@ -217,14 +258,65 @@ XLA_TEST_P(RandomCholeskyTest, Random) { ErrorSpec(1e-4, 1e-4)); } +XLA_TEST_P(RandomCholeskyTest, Complex) { + // Test fails with TensorFloat-32 enabled + tensorflow::enable_tensor_float_32_execution(false); + XlaBuilder builder(TestName()); + + auto test_params = GetParam(); + std::vector dimensions = {std::get<0>(test_params), + std::get<1>(test_params), + std::get<1>(test_params)}; + bool lower = std::get<2>(test_params); + Shape shape = ShapeUtil::MakeShape(F32, dimensions); + TF_ASSERT_OK_AND_ASSIGN( + auto literal_real, + LiteralUtil::CreateRandomLiteral(shape, 0.0, 1.0)); + TF_ASSERT_OK_AND_ASSIGN( + auto literal_imag, + LiteralUtil::CreateRandomLiteral(shape, 0.0, 1.0)); + + auto input_real = Parameter(&builder, 0, shape, "input_real"); + auto input_imag = Parameter(&builder, 1, shape, "input_imag"); + auto input = Complex(input_real, input_imag); + // Form a random positive definite matrix. + auto matrix = BatchDot(input, TransposeInMinorDims(Conj(input)), + PrecisionConfig::HIGHEST); + + auto cholesky = Triangle(Cholesky(matrix, lower), lower); + + // Verify that ||matrix - cholesky * cholesky_t||_2 ~= 0 + XlaOp verification; + if (lower) { + verification = BatchDot(cholesky, TransposeInMinorDims(Conj(cholesky)), + PrecisionConfig::HIGHEST); + } else { + verification = BatchDot(TransposeInMinorDims(Conj(cholesky)), cholesky, + PrecisionConfig::HIGHEST); + } + auto delta = matrix - verification; + Reduce(Abs(delta * Conj(delta)), ConstantR0(&builder, 0.0), + CreateScalarAddComputation(F32, &builder), {0, 1, 2}); + + TF_ASSERT_OK_AND_ASSIGN(auto input_data_real, + client_->TransferToServer(literal_real)); + TF_ASSERT_OK_AND_ASSIGN(auto input_data_imag, + client_->TransferToServer(literal_imag)); + ComputeAndCompareR0(&builder, 0.0, + {input_data_real.get(), input_data_imag.get()}, + ErrorSpec(1e-4, 1e-4)); +} + INSTANTIATE_TEST_SUITE_P(RandomCholeskyTestInstance, RandomCholeskyTest, ::testing::Values(CholeskyTestCase{1, 1, true}, CholeskyTestCase{1, 2, true}, CholeskyTestCase{1, 50, true}, CholeskyTestCase{1, 50, false}, + CholeskyTestCase{1, 255, false}, CholeskyTestCase{10, 5, true}, CholeskyTestCase{5, 10, false}, - CholeskyTestCase{2, 20, true})); + CholeskyTestCase{2, 20, true}, + CholeskyTestCase{2, 129, true})); } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index 0e99ede5d01..6acbb7a9cf0 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -605,7 +605,7 @@ XlaOp ClientLibraryTestBase::CreateConstantFromLiteral(const Literal& literal, : LiteralSlice(literal)); } -std::unique_ptr +StatusOr> ClientLibraryTestBase::CreateParameterAndTransferLiteral(int64 parameter_number, const Literal& literal, const string& name, @@ -637,15 +637,14 @@ Literal ClientLibraryTestBase::MaybeConvertLiteralToBfloat16( return literal.Clone(); } -std::unique_ptr +StatusOr> ClientLibraryTestBase::CreateParameterAndTransferLiteral( int64 parameter_number, const Literal& literal, const string& name, const DeviceHandle* device_handle, XlaBuilder* builder, XlaOp* data_handle) { Literal param_literal = MaybeConvertLiteralToBfloat16(literal); - std::unique_ptr data = - client_->TransferToServer(param_literal, device_handle) - .ConsumeValueOrDie(); + TF_ASSIGN_OR_RETURN(auto data, + client_->TransferToServer(param_literal, device_handle)); *data_handle = Parameter(builder, parameter_number, param_literal.shape(), name); return data; diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index 17bb70bdb42..3c9e37b8fa4 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -270,14 +270,14 @@ class ClientLibraryTestBase : public ManifestCheckingTest { // server, then stores into "data_handle" the global handle for that // parameter. When the use_bfloat16 flag is set but the literal has F32 // elements, the literal will be converted to BF16 before being transferred. - std::unique_ptr CreateParameterAndTransferLiteral( + StatusOr> CreateParameterAndTransferLiteral( int64 parameter_number, const Literal& literal, const string& name, XlaBuilder* builder, XlaOp* data_handle); // As above, but the caller can specify the device that the literal is // transferred to. If device_handle is nullptr, the literal will be // transferred to the default device. - std::unique_ptr CreateParameterAndTransferLiteral( + StatusOr> CreateParameterAndTransferLiteral( int64 parameter_number, const Literal& literal, const string& name, const DeviceHandle* device_handle, XlaBuilder* builder, XlaOp* data_handle); diff --git a/tensorflow/compiler/xla/tests/concat_test.cc b/tensorflow/compiler/xla/tests/concat_test.cc index 4f5b525a342..9df83e30ad4 100644 --- a/tensorflow/compiler/xla/tests/concat_test.cc +++ b/tensorflow/compiler/xla/tests/concat_test.cc @@ -21,11 +21,13 @@ limitations under the License. #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/core/platform/test.h" @@ -34,6 +36,7 @@ namespace xla { namespace { using ConcatTest = ClientLibraryTestBase; +using ConcatTestHlo = HloTestBase; using ::testing::HasSubstr; // Concatenate expects at least one argument. @@ -518,6 +521,250 @@ XLA_TEST_F(ConcatTest, ConcatDeeplyNested) { ComputeAndCompareR1(&builder, expected, {a_data.get()}); } +// TODO(b/169314478): Enable the test when the slow compilation is fixed. +XLA_TEST_F(ConcatTestHlo, DISABLED_ConcatWithBitcast) { + auto module = ParseAndReturnVerifiedModule(R"( +HloModule jit_broken.874 + +primitive_computation_add.866 { + parameter.867 = f32[] parameter(0) + parameter.868 = f32[] parameter(1) + ROOT add.869 = f32[] add(parameter.867, parameter.868) +} + +ENTRY jit_broken.874 { + parameter.38 = f32[4,2]{1,0} parameter(0) + reshape.723 = f32[4,2,1]{2,1,0} reshape(parameter.38) + reshape.724 = f32[4,2,1]{2,1,0} reshape(parameter.38) + concatenate.42 = f32[4,2,2]{2,1,0} concatenate(reshape.723, reshape.724), dimensions={2} + slice.351 = f32[4,1,2]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [0:2]} + reshape.1058 = f32[4,2]{1,0} reshape(slice.351) + slice.352 = f32[4,1]{1,0} slice(reshape.1058), slice={[0:4], [1:2]} + reshape.1059 = f32[4]{0} reshape(slice.352) + slice.353 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [1:2]} + reshape.1060 = f32[4]{0} reshape(slice.353) + add.124 = f32[4]{0} add(reshape.1059, reshape.1060) + slice.354 = f32[4,1]{1,0} slice(reshape.1058), slice={[0:4], [0:1]} + reshape.1061 = f32[4]{0} reshape(slice.354) + slice.379 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [0:1]} + reshape.1062 = f32[4]{0} reshape(slice.379) + add.89 = f32[4]{0} add(reshape.1061, reshape.1062) + subtract.126 = f32[4]{0} subtract(add.124, add.89) + is-finite.127 = pred[4]{0} is-finite(subtract.126) + not.128 = pred[4]{0} not(is-finite.127) + abs.129 = f32[4]{0} abs(subtract.126) + constant.130 = f32[] constant(inf) + broadcast.131 = f32[4]{0} broadcast(constant.130), dimensions={} + compare.132 = pred[4]{0} compare(abs.129, broadcast.131), direction=EQ, type=UNSIGNED + not.133 = pred[4]{0} not(compare.132) + and.134 = pred[4]{0} and(not.128, not.133) + add.135 = f32[4]{0} add(add.124, add.89) + maximum.125 = f32[4]{0} maximum(add.124, add.89) + abs.136 = f32[4]{0} abs(subtract.126) + negate.137 = f32[4]{0} negate(abs.136) + exponential.138 = f32[4]{0} exponential(negate.137) + log-plus-one.139 = f32[4]{0} log-plus-one(exponential.138) + add.140 = f32[4]{0} add(maximum.125, log-plus-one.139) + select.141 = f32[4]{0} select(and.134, add.135, add.140) + slice.356 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [1:2]} + reshape.1064 = f32[4]{0} reshape(slice.356) + add.214 = f32[4]{0} add(select.141, reshape.1064) + slice.380 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [0:1]} + reshape.1066 = f32[4]{0} reshape(slice.380) + add.179 = f32[4]{0} add(select.141, reshape.1066) + subtract.216 = f32[4]{0} subtract(add.214, add.179) + is-finite.217 = pred[4]{0} is-finite(subtract.216) + not.218 = pred[4]{0} not(is-finite.217) + abs.219 = f32[4]{0} abs(subtract.216) + constant.220 = f32[] constant(inf) + broadcast.221 = f32[4]{0} broadcast(constant.220), dimensions={} + compare.222 = pred[4]{0} compare(abs.219, broadcast.221), direction=EQ, type=UNSIGNED + not.223 = pred[4]{0} not(compare.222) + and.224 = pred[4]{0} and(not.218, not.223) + add.225 = f32[4]{0} add(add.214, add.179) + maximum.215 = f32[4]{0} maximum(add.214, add.179) + abs.226 = f32[4]{0} abs(subtract.216) + negate.227 = f32[4]{0} negate(abs.226) + exponential.228 = f32[4]{0} exponential(negate.227) + log-plus-one.229 = f32[4]{0} log-plus-one(exponential.228) + add.230 = f32[4]{0} add(maximum.215, log-plus-one.229) + select.231 = f32[4]{0} select(and.224, add.225, add.230) + slice.359 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [1:2]} + reshape.1068 = f32[4]{0} reshape(slice.359) + add.304 = f32[4]{0} add(select.231, reshape.1068) + slice.381 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [0:1]} + reshape.1070 = f32[4]{0} reshape(slice.381) + add.269 = f32[4]{0} add(select.231, reshape.1070) + subtract.306 = f32[4]{0} subtract(add.304, add.269) + is-finite.307 = pred[4]{0} is-finite(subtract.306) + not.308 = pred[4]{0} not(is-finite.307) + abs.309 = f32[4]{0} abs(subtract.306) + constant.310 = f32[] constant(inf) + broadcast.311 = f32[4]{0} broadcast(constant.310), dimensions={} + compare.312 = pred[4]{0} compare(abs.309, broadcast.311), direction=EQ, type=UNSIGNED + not.313 = pred[4]{0} not(compare.312) + and.314 = pred[4]{0} and(not.308, not.313) + add.315 = f32[4]{0} add(add.304, add.269) + maximum.305 = f32[4]{0} maximum(add.304, add.269) + abs.316 = f32[4]{0} abs(subtract.306) + negate.317 = f32[4]{0} negate(abs.316) + exponential.318 = f32[4]{0} exponential(negate.317) + log-plus-one.319 = f32[4]{0} log-plus-one(exponential.318) + add.320 = f32[4]{0} add(maximum.305, log-plus-one.319) + select.321 = f32[4]{0} select(and.314, add.315, add.320) + slice.362 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [1:2]} + reshape.1072 = f32[4]{0} reshape(slice.362) + add.394 = f32[4]{0} add(select.321, reshape.1072) + slice.382 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [0:1]} + reshape.1074 = f32[4]{0} reshape(slice.382) + add.359 = f32[4]{0} add(select.321, reshape.1074) + subtract.396 = f32[4]{0} subtract(add.394, add.359) + is-finite.397 = pred[4]{0} is-finite(subtract.396) + not.398 = pred[4]{0} not(is-finite.397) + abs.399 = f32[4]{0} abs(subtract.396) + constant.400 = f32[] constant(inf) + broadcast.401 = f32[4]{0} broadcast(constant.400), dimensions={} + compare.402 = pred[4]{0} compare(abs.399, broadcast.401), direction=EQ, type=UNSIGNED + not.403 = pred[4]{0} not(compare.402) + and.404 = pred[4]{0} and(not.398, not.403) + add.405 = f32[4]{0} add(add.394, add.359) + maximum.395 = f32[4]{0} maximum(add.394, add.359) + abs.406 = f32[4]{0} abs(subtract.396) + negate.407 = f32[4]{0} negate(abs.406) + exponential.408 = f32[4]{0} exponential(negate.407) + log-plus-one.409 = f32[4]{0} log-plus-one(exponential.408) + add.410 = f32[4]{0} add(maximum.395, log-plus-one.409) + select.411 = f32[4]{0} select(and.404, add.405, add.410) + slice.365 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [1:2]} + reshape.1076 = f32[4]{0} reshape(slice.365) + add.484 = f32[4]{0} add(select.411, reshape.1076) + slice.383 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [0:1]} + reshape.1078 = f32[4]{0} reshape(slice.383) + add.449 = f32[4]{0} add(select.411, reshape.1078) + subtract.486 = f32[4]{0} subtract(add.484, add.449) + is-finite.487 = pred[4]{0} is-finite(subtract.486) + not.488 = pred[4]{0} not(is-finite.487) + abs.489 = f32[4]{0} abs(subtract.486) + constant.490 = f32[] constant(inf) + broadcast.491 = f32[4]{0} broadcast(constant.490), dimensions={} + compare.492 = pred[4]{0} compare(abs.489, broadcast.491), direction=EQ, type=UNSIGNED + not.493 = pred[4]{0} not(compare.492) + and.494 = pred[4]{0} and(not.488, not.493) + add.495 = f32[4]{0} add(add.484, add.449) + maximum.485 = f32[4]{0} maximum(add.484, add.449) + abs.496 = f32[4]{0} abs(subtract.486) + negate.497 = f32[4]{0} negate(abs.496) + exponential.498 = f32[4]{0} exponential(negate.497) + log-plus-one.499 = f32[4]{0} log-plus-one(exponential.498) + add.500 = f32[4]{0} add(maximum.485, log-plus-one.499) + select.501 = f32[4]{0} select(and.494, add.495, add.500) + slice.368 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [1:2]} + reshape.1080 = f32[4]{0} reshape(slice.368) + add.574 = f32[4]{0} add(select.501, reshape.1080) + slice.384 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [0:1]} + reshape.1082 = f32[4]{0} reshape(slice.384) + add.539 = f32[4]{0} add(select.501, reshape.1082) + subtract.576 = f32[4]{0} subtract(add.574, add.539) + is-finite.577 = pred[4]{0} is-finite(subtract.576) + not.578 = pred[4]{0} not(is-finite.577) + abs.579 = f32[4]{0} abs(subtract.576) + constant.580 = f32[] constant(inf) + broadcast.581 = f32[4]{0} broadcast(constant.580), dimensions={} + compare.582 = pred[4]{0} compare(abs.579, broadcast.581), direction=EQ, type=UNSIGNED + not.583 = pred[4]{0} not(compare.582) + and.584 = pred[4]{0} and(not.578, not.583) + add.585 = f32[4]{0} add(add.574, add.539) + maximum.575 = f32[4]{0} maximum(add.574, add.539) + abs.586 = f32[4]{0} abs(subtract.576) + negate.587 = f32[4]{0} negate(abs.586) + exponential.588 = f32[4]{0} exponential(negate.587) + log-plus-one.589 = f32[4]{0} log-plus-one(exponential.588) + add.590 = f32[4]{0} add(maximum.575, log-plus-one.589) + select.591 = f32[4]{0} select(and.584, add.585, add.590) + slice.371 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [1:2]} + reshape.1084 = f32[4]{0} reshape(slice.371) + add.664 = f32[4]{0} add(select.591, reshape.1084) + slice.385 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [0:1]} + reshape.1086 = f32[4]{0} reshape(slice.385) + add.629 = f32[4]{0} add(select.591, reshape.1086) + subtract.666 = f32[4]{0} subtract(add.664, add.629) + is-finite.667 = pred[4]{0} is-finite(subtract.666) + not.668 = pred[4]{0} not(is-finite.667) + abs.669 = f32[4]{0} abs(subtract.666) + constant.670 = f32[] constant(inf) + broadcast.671 = f32[4]{0} broadcast(constant.670), dimensions={} + compare.672 = pred[4]{0} compare(abs.669, broadcast.671), direction=EQ, type=UNSIGNED + not.673 = pred[4]{0} not(compare.672) + and.674 = pred[4]{0} and(not.668, not.673) + add.675 = f32[4]{0} add(add.664, add.629) + maximum.665 = f32[4]{0} maximum(add.664, add.629) + abs.676 = f32[4]{0} abs(subtract.666) + negate.677 = f32[4]{0} negate(abs.676) + exponential.678 = f32[4]{0} exponential(negate.677) + log-plus-one.679 = f32[4]{0} log-plus-one(exponential.678) + add.680 = f32[4]{0} add(maximum.665, log-plus-one.679) + select.681 = f32[4]{0} select(and.674, add.675, add.680) + slice.374 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [1:2]} + reshape.1088 = f32[4]{0} reshape(slice.374) + add.754 = f32[4]{0} add(select.681, reshape.1088) + slice.386 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [0:1]} + reshape.1090 = f32[4]{0} reshape(slice.386) + add.719 = f32[4]{0} add(select.681, reshape.1090) + subtract.756 = f32[4]{0} subtract(add.754, add.719) + is-finite.757 = pred[4]{0} is-finite(subtract.756) + not.758 = pred[4]{0} not(is-finite.757) + abs.759 = f32[4]{0} abs(subtract.756) + constant.760 = f32[] constant(inf) + broadcast.761 = f32[4]{0} broadcast(constant.760), dimensions={} + compare.762 = pred[4]{0} compare(abs.759, broadcast.761), direction=EQ, type=UNSIGNED + not.763 = pred[4]{0} not(compare.762) + and.764 = pred[4]{0} and(not.758, not.763) + add.765 = f32[4]{0} add(add.754, add.719) + maximum.755 = f32[4]{0} maximum(add.754, add.719) + abs.766 = f32[4]{0} abs(subtract.756) + negate.767 = f32[4]{0} negate(abs.766) + exponential.768 = f32[4]{0} exponential(negate.767) + log-plus-one.769 = f32[4]{0} log-plus-one(exponential.768) + add.770 = f32[4]{0} add(maximum.755, log-plus-one.769) + select.771 = f32[4]{0} select(and.764, add.765, add.770) + slice.377 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [1:2]} + reshape.1092 = f32[4]{0} reshape(slice.377) + add.844 = f32[4]{0} add(select.771, reshape.1092) + slice.387 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [0:1]} + reshape.1094 = f32[4]{0} reshape(slice.387) + add.809 = f32[4]{0} add(select.771, reshape.1094) + subtract.846 = f32[4]{0} subtract(add.844, add.809) + is-finite.847 = pred[4]{0} is-finite(subtract.846) + not.848 = pred[4]{0} not(is-finite.847) + abs.849 = f32[4]{0} abs(subtract.846) + constant.850 = f32[] constant(inf) + broadcast.851 = f32[4]{0} broadcast(constant.850), dimensions={} + compare.852 = pred[4]{0} compare(abs.849, broadcast.851), direction=EQ, type=UNSIGNED + not.853 = pred[4]{0} not(compare.852) + and.854 = pred[4]{0} and(not.848, not.853) + add.855 = f32[4]{0} add(add.844, add.809) + maximum.845 = f32[4]{0} maximum(add.844, add.809) + abs.856 = f32[4]{0} abs(subtract.846) + negate.857 = f32[4]{0} negate(abs.856) + exponential.858 = f32[4]{0} exponential(negate.857) + log-plus-one.859 = f32[4]{0} log-plus-one(exponential.858) + add.860 = f32[4]{0} add(maximum.845, log-plus-one.859) + select.861 = f32[4]{0} select(and.854, add.855, add.860) + constant.865 = f32[] constant(0) + reduce.2 = f32[] reduce(select.861, constant.865), dimensions={0}, to_apply=primitive_computation_add.866 + reduce.3 = f32[] reduce(select.861, constant.865), dimensions={0}, to_apply=primitive_computation_add.866 + add.77 = f32[] add(reduce.2, reduce.3) + constant.719 = f32[] constant(0.125) + multiply = f32[] multiply(add.77, constant.719) + ROOT tuple.873 = (f32[]) tuple(multiply) +})") + .ConsumeValueOrDie(); + auto input_array = absl::make_unique>(4, 2); + input_array->FillUnique(1.0f); + auto input = LiteralUtil::CreateR2FromArray2D(*input_array); + EXPECT_TRUE(RunAndCompare(std::move(module), {&input}, absl::nullopt)); +} + // Describes a binary rank-2 concatenation test. struct R2BinarySpec { int64 lhs_dim0; @@ -578,7 +825,7 @@ XLA_TEST_F(ConcatTest, ConcatOperandsOfSameOperand) { {x_data.get(), y_data.get()}, ErrorSpec(1e-4)); } -// Test that the HLO optimization to replace a concat of a bradcasted scalar +// Test that the HLO optimization to replace a concat of a broadcasted scalar // produces the correct result in rank 1. XLA_TEST_F(ConcatTest, ConcatBroadcastArgument) { auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {}); @@ -604,7 +851,7 @@ XLA_TEST_F(ConcatTest, ConcatBroadcastArgument) { {x_data.get(), y_data.get(), z_data.get()}, ErrorSpec(1e-4)); } -// Test that the HLO optimization to replace a concat of a bradcasted scalar +// Test that the HLO optimization to replace a concat of a broadcasted scalar // produces the correct result in rank 3 with both high and low padding in // different dimensions. XLA_TEST_F(ConcatTest, ConcatBroadcastArgumentR3) { diff --git a/tensorflow/compiler/xla/tests/conv_depthwise_backprop_filter_test.cc b/tensorflow/compiler/xla/tests/conv_depthwise_backprop_filter_test.cc index ff7e7955876..4a7070a32f3 100644 --- a/tensorflow/compiler/xla/tests/conv_depthwise_backprop_filter_test.cc +++ b/tensorflow/compiler/xla/tests/conv_depthwise_backprop_filter_test.cc @@ -45,13 +45,20 @@ class BatchGroupedConvolution2DTest public ::testing::WithParamInterface< ::testing::tuple> {}; -static std::vector GetConv2DTestCases() { +class BatchGroupedConvolution2DDepthTest + : public HloTestBase, + public ::testing::WithParamInterface< + ::testing::tuple> {}; + +static std::vector GetConv2DTestCases( + bool use_depth_multiplier) { std::vector config_set; std::vector> config_options = { - {8, 5, 3, 2}, {4, 5, 5, 2}, {8, 7, 4, 128}, - {16, 20, 20, 256}, {256, 7, 5, 4}, {256, 6, 6, 4}, - {256, 8, 8, 512}, {64, 7, 7, 960}, {64, 14, 14, 576}}; + {129, 10, 3, 2}, {4, 3, 3, 258}, {8, 4, 2, 128}, + {8, 3, 2, 256}, {256, 7, 5, 4}, {128, 6, 6, 4}, + {32, 5, 2, 129}, {16, 4, 3, 2}, {16, 3, 2, 64}}; + int64 counter = 2; for (auto option : config_options) { int64 feature = option[3]; int64 activation_size = option[1]; @@ -65,10 +72,16 @@ static std::vector GetConv2DTestCases() { config.activation_dims = {batch, activation_size, activation_size, feature}; - config.kernel_dims = {batch, kernel_size, kernel_size, feature}; - + const int64 depthwise_multiplier = use_depth_multiplier ? counter++ : 1; + config.kernel_dims = {batch, kernel_size, kernel_size, + feature * depthwise_multiplier}; + // Don't let the counter grow too much, else the compute demand will grow. + if (counter == 4) { + counter = 2; + } int64 output_space_size = 3 + activation_size - kernel_size; - config.output_dims = {output_space_size, output_space_size, feature, 1}; + config.output_dims = {output_space_size, output_space_size, + feature * depthwise_multiplier, 1}; config.activation_and_kernel_layout = {0, 3, 1, 2}; config.output_layout = {2, 3, 0, 1}; @@ -123,11 +136,13 @@ string BatchGroupedConvolution2DTestDataToString( } string BuildHloTextBatchGroupedConvolution2D( - const BatchGroupedConvolution2DSpec& spec, bool use_bfloat16) { + const BatchGroupedConvolution2DSpec& spec, bool use_bfloat16, + bool scheduled = false) { const string data_type = GetFloatDataType(use_bfloat16); + const string scheduled_tag = scheduled ? ",is_scheduled=true" : ""; return absl::StrFormat( R"( - HloModule TensorFlowDepthwiseConv, is_scheduled=true + HloModule TensorFlowDepthwiseConv %s ENTRY main { activation = %s[%s]{%s} parameter(0) @@ -137,7 +152,7 @@ string BuildHloTextBatchGroupedConvolution2D( batch_group_count=%d } )", - data_type, absl::StrJoin(spec.activation_dims, ","), + scheduled_tag, data_type, absl::StrJoin(spec.activation_dims, ","), absl::StrJoin(spec.activation_and_kernel_layout, ","), data_type, absl::StrJoin(spec.kernel_dims, ","), absl::StrJoin(spec.activation_and_kernel_layout, ","), data_type, @@ -161,23 +176,26 @@ XLA_TEST_P(BatchGroupedConvolution2DTest, DoIt) { } #endif - const string hlo_text = - BuildHloTextBatchGroupedConvolution2D(spec, use_bfloat16); + const string hlo_text = BuildHloTextBatchGroupedConvolution2D( + spec, use_bfloat16, /*scheduled=*/false); - EXPECT_TRUE(RunAndCompareNoHloPasses( - hlo_text, ErrorSpec{0.01, 0.01}, [](HloModule* module) -> Status { - BFloat16MixedPrecisionRemoval remover; - TF_RETURN_IF_ERROR(remover.Run(module).status()); - Despecializer despecializer; - return despecializer.Run(module).status(); - })); + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{0.01, 0.01})); } INSTANTIATE_TEST_CASE_P( BatchGroupedConvolution2DTestWithRandomIndices, BatchGroupedConvolution2DTest, - ::testing::Combine(::testing::ValuesIn(GetConv2DTestCases()), - ::testing::Bool()), + ::testing::Combine( + ::testing::ValuesIn(GetConv2DTestCases(/*use_depth_multiplier=*/false)), + ::testing::Bool()), + BatchGroupedConvolution2DTestDataToString); + +INSTANTIATE_TEST_CASE_P( + BatchGroupedConvolution2DDepthMultiplierTestWithRandomIndices, + BatchGroupedConvolution2DTest, + ::testing::Combine( + ::testing::ValuesIn(GetConv2DTestCases(/*use_depth_multiplier=*/true)), + ::testing::Bool()), BatchGroupedConvolution2DTestDataToString); } // namespace diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index 60ba27b2050..e06e2972f1c 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -69,12 +69,14 @@ XLA_TEST_F(DotOperationTest, DotOfInputTupleElem) { XlaBuilder builder(TestName()); XlaOp param; - auto param_data = CreateParameterAndTransferLiteral( - 0, - LiteralUtil::MakeTupleFromSlices( - {LiteralUtil::CreateR2({{1, 2}, {3, 4}}), - LiteralUtil::CreateR2({{5, 6}, {7, 8}})}), - "arg0", &builder, ¶m); + TF_ASSERT_OK_AND_ASSIGN( + auto param_data, + CreateParameterAndTransferLiteral( + 0, + LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR2({{1, 2}, {3, 4}}), + LiteralUtil::CreateR2({{5, 6}, {7, 8}})}), + "arg0", &builder, ¶m)); auto lhs = GetTupleElement(param, 0); auto rhs = GetTupleElement(param, 1); Dot(lhs, rhs); diff --git a/tensorflow/compiler/xla/tests/exhaustive_binary_16_bit_test.cc b/tensorflow/compiler/xla/tests/exhaustive_binary_16_bit_test.cc index dca8e31e792..f884bb9c0e0 100644 --- a/tensorflow/compiler/xla/tests/exhaustive_binary_16_bit_test.cc +++ b/tensorflow/compiler/xla/tests/exhaustive_binary_16_bit_test.cc @@ -123,7 +123,7 @@ BINARY_TEST_16BIT(Min, { }) // TODO(bixia): Pow fails with bfloat16 on CPU. -BINARY_TEST_16BIT(DISABLED_ON_CPU(Pow), { +BINARY_TEST_16BIT(DISABLED_ON_GPU(DISABLED_ON_CPU(Pow)), { // See b/162664705. known_incorrect_fn_ = [](int64 val) { Eigen::bfloat16 f; diff --git a/tensorflow/compiler/xla/tests/exhaustive_binary_test_f32_f64.cc b/tensorflow/compiler/xla/tests/exhaustive_binary_test_f32_f64.cc index 14d3b343b6c..c6feedf9e7f 100644 --- a/tensorflow/compiler/xla/tests/exhaustive_binary_test_f32_f64.cc +++ b/tensorflow/compiler/xla/tests/exhaustive_binary_test_f32_f64.cc @@ -114,6 +114,10 @@ BINARY_TEST_FLOAT_32(Min, { // // TODO(bixia): Need to investigate the failure on CPU and file bugs. BINARY_TEST_FLOAT_32(DISABLED_ON_CPU(AbsComplex), { + // TODO(timshen): see b/162664705. + known_incorrect_fn_ = [this](int64 val) { + return std::isnan(this->ConvertValue(val)); + }; auto host_abs_complex = [](float x, float y) { return std::abs(std::complex(x, y)); }; @@ -198,6 +202,10 @@ BINARY_TEST_FLOAT_64(Min, { // TODO(bixia): Need to investigate the failure on CPU and file bugs. BINARY_TEST_FLOAT_64(DISABLED_ON_CPU(AbsComplex), { + // TODO(timshen): see b/162664705. + known_incorrect_fn_ = [this](int64 val) { + return std::isnan(this->ConvertValue(val)); + }; auto host_abs_complex = [](double x, double y) { return std::abs(std::complex(x, y)); }; diff --git a/tensorflow/compiler/xla/tests/exhaustive_unary_test_complex.cc b/tensorflow/compiler/xla/tests/exhaustive_unary_test_complex.cc index b361bf94a6d..6a638d2106f 100644 --- a/tensorflow/compiler/xla/tests/exhaustive_unary_test_complex.cc +++ b/tensorflow/compiler/xla/tests/exhaustive_unary_test_complex.cc @@ -97,6 +97,10 @@ using ExhaustiveC128UnaryTest = ExhaustiveComplexUnaryTestBase; // TODO(b/138578594): Enable the test for the CPU backend after fixing the bug. UNARY_TEST_COMPLEX_64(DISABLED_ON_CPU(Log), { + // TODO(timshen): see b/162664705. + known_incorrect_fn_ = [this](int64 val) { + return std::isnan(this->ConvertValue(val)); + }; Run(Log, [](complex64 x) { return std::log(x); }); }) diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index 663e7d81006..6c062deb363 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -414,6 +414,47 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( : ::testing::AssertionFailure() << output.status().error_message(); } +::testing::AssertionResult HloTestBase::RunReplicated(string_view hlo_string, + bool run_hlo_passes, + int64 num_replicas, + string backend_config) { + auto module_or_status = + ParseAndReturnVerifiedModule(hlo_string, num_replicas); + if (!module_or_status.ok()) { + return ::testing::AssertionFailure() + << "Error while parsing HLO text format: " + << module_or_status.status().ToString(); + } + + std::unique_ptr module = std::move(module_or_status.ValueOrDie()); + const auto& fake_arguments = + MakeFakeArguments(module.get()).ConsumeValueOrDie(); + std::vector fake_argument_ptrs; + absl::c_transform( + fake_arguments, std::back_inserter(fake_argument_ptrs), + [](const Literal& literal) { return const_cast(&literal); }); + + if (!backend_config.empty()) { + // Set backend configuration if it is given. + HloInstruction* instruction = + module->entry_computation()->root_instruction(); + instruction->set_raw_backend_config_string(backend_config); + } + + HloRunner::ReplicatedExecuteOptions options; + options.num_replicas = num_replicas; + options.run_hlo_passes = run_hlo_passes; + options.use_threads = true; + for (auto argument : fake_argument_ptrs) { + options.arguments.push_back(argument); + } + auto output = test_runner_.ExecuteReplicated(std::move(module), options); + + return output.ok() + ? ::testing::AssertionSuccess() + : ::testing::AssertionFailure() << output.status().error_message(); +} + ::testing::AssertionResult HloTestBase::RunMultipleTimes( string_view hlo_string, bool run_hlo_passes, std::vector* profiles, string backend_config, diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index fc680e39682..e15c1dd5f55 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -234,6 +234,11 @@ class HloTestBase : public ManifestCheckingTest { ExecutionProfile* profile = nullptr, string backend_config = "") TF_MUST_USE_RESULT; + // Executes an hlo module with fake inputs on multiple replicas. + ::testing::AssertionResult RunReplicated( + const absl::string_view hlo_string, bool run_hlo_passes = true, + int64 num_replicas = 1, string backend_config = "") TF_MUST_USE_RESULT; + // If assert_determinism is true, the assertion will fail unless all runs // produce exactly the same output. ::testing::AssertionResult RunMultipleTimes( diff --git a/tensorflow/compiler/xla/tests/prng_test.cc b/tensorflow/compiler/xla/tests/prng_test.cc index 201c0da87f1..1a95f2fb549 100644 --- a/tensorflow/compiler/xla/tests/prng_test.cc +++ b/tensorflow/compiler/xla/tests/prng_test.cc @@ -122,7 +122,7 @@ XLA_TEST_F(PrngTest, DISABLED_ON_INTERPRETER(DISABLED_ON_GPU( bfloat16 interval = static_cast(0.25); std::vector counts(static_cast((high - low) / interval), 0); - constexpr int64 count = 100; + constexpr int64 count = 1000; for (int64 seed = 0; seed < count; ++seed) { auto result = UniformTest(low, high, {}, /*seed=*/seed); result.EachCell([&](absl::Span, bfloat16 value) { diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index b209669715e..7e5b699d5e2 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -365,8 +365,9 @@ XLA_TEST_P(ReduceWindowTest, R4UnitWindow) { Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({0, 3, 2, 1})); XlaOp input; - auto input_data = CreateParameterAndTransferLiteral( - 0, input_literal, "parameter", &builder_, &input); + TF_ASSERT_OK_AND_ASSIGN( + auto input_data, CreateParameterAndTransferLiteral( + 0, input_literal, "parameter", &builder_, &input)); Padding padding = Padding::kSame; ReduceWindowAdd(input, {1, 1, 7, 1}, {1, 4, 1, 1}, padding); @@ -423,8 +424,9 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorStride) { Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp input; - auto input_data = CreateParameterAndTransferLiteral( - 0, input_literal, "parameter", &builder_, &input); + TF_ASSERT_OK_AND_ASSIGN( + auto input_data, CreateParameterAndTransferLiteral( + 0, input_literal, "parameter", &builder_, &input)); int win_len = 1; int stride = 8; @@ -444,8 +446,9 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorUnitStride) { Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp input; - auto input_data = CreateParameterAndTransferLiteral( - 0, input_literal, "parameter", &builder_, &input); + TF_ASSERT_OK_AND_ASSIGN( + auto input_data, CreateParameterAndTransferLiteral( + 0, input_literal, "parameter", &builder_, &input)); int win_len = 3; int stride = 1; @@ -465,8 +468,9 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorWin) { Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp input; - auto input_data = CreateParameterAndTransferLiteral( - 0, input_literal, "parameter", &builder_, &input); + TF_ASSERT_OK_AND_ASSIGN( + auto input_data, CreateParameterAndTransferLiteral( + 0, input_literal, "parameter", &builder_, &input)); int win_len = 8; int stride = 5; @@ -631,8 +635,9 @@ class R4ReduceWindowTest : public ReduceWindowTestBase, Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout(param.layout)); XlaOp parameter; - auto input_arg = CreateParameterAndTransferLiteral(0, input_literal, "p0", - &b, ¶meter); + TF_ASSERT_OK_AND_ASSIGN(auto input_arg, + CreateParameterAndTransferLiteral( + 0, input_literal, "p0", &b, ¶meter)); std::vector> padding(4); for (int i = 0; i < 4; ++i) { @@ -1243,7 +1248,9 @@ class R2ReduceWindowTest : public ReduceWindowTestBase, input, LayoutUtil::MakeLayout(param.layout)); XlaOp parameter; - CreateParameterAndTransferLiteral(0, input_literal, "p0", &b, ¶meter); + TF_ASSERT_OK(CreateParameterAndTransferLiteral(0, input_literal, "p0", &b, + ¶meter) + .status()); std::vector> padding(2); for (int i = 0; i < 2; ++i) { @@ -1443,8 +1450,9 @@ XLA_TEST_P(R1ReduceWindowTest, DoIt) { Literal input_literal = LiteralUtil::CreateR1(absl::Span(input_vector)); XlaOp parameter; - auto input_arg = - CreateParameterAndTransferLiteral(0, input_literal, "p0", &b, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input_arg, CreateParameterAndTransferLiteral(0, input_literal, "p0", + &b, ¶meter)); std::vector> padding(1); padding[0] = {param.pad_low[0], param.pad_high[0]}; diff --git a/tensorflow/compiler/xla/tests/reshape_test.cc b/tensorflow/compiler/xla/tests/reshape_test.cc index 298136002e9..890156cc650 100644 --- a/tensorflow/compiler/xla/tests/reshape_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_test.cc @@ -57,8 +57,9 @@ XLA_TEST_P(ReshapeTest, CollapseTrivial1x1) { input_array.Fill(1.0f); auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "parameter", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral( + 0, input_literal, "parameter", &builder, ¶meter)); Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); auto expected_literal = LiteralUtil::CreateR1({1.0f}); @@ -70,8 +71,9 @@ XLA_TEST_P(ReshapeTest, CollapseTrivialR1EmptyDims) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateR1({1.0f}); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "parameter", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral( + 0, input_literal, "parameter", &builder, ¶meter)); Collapse(/*operand=*/parameter, /*dimensions=*/{}); auto expected_literal = LiteralUtil::CreateR1({1.0f}); @@ -83,8 +85,9 @@ XLA_TEST_P(ReshapeTest, CollapseTrivialR1OnlyDim) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateR1({1.0f}); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "parameter", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral( + 0, input_literal, "parameter", &builder, ¶meter)); Collapse(/*operand=*/parameter, /*dimensions=*/{0}); auto expected_literal = LiteralUtil::CreateR1({1.0f}); @@ -99,8 +102,9 @@ XLA_TEST_P(ReshapeTest, SingleElementArrayToScalar) { input_array.Fill(1.0f); auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "parameter", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral( + 0, input_literal, "parameter", &builder, ¶meter)); auto reshape = Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{}); auto new_shape = builder.GetShape(reshape).ConsumeValueOrDie(); @@ -115,8 +119,9 @@ XLA_TEST_P(ReshapeTest, ScalarToSingleElementArray) { Literal param0_literal = LiteralUtil::CreateR0(1.0f); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, param0_literal, "param0", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, param0_literal, "param0", + &builder, ¶meter)); auto a = Neg(parameter); Reshape(/*operand=*/a, /*dimensions=*/{}, /*new_sizes=*/{1}); @@ -130,8 +135,9 @@ XLA_TEST_P(ReshapeTest, Trivial0x3) { Array2D input_array(0, 3); auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter)); Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); auto expected_literal = LiteralUtil::CreateR1({}); ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, @@ -144,8 +150,9 @@ XLA_TEST_P(ReshapeTest, Trivial0x3WithParameter) { Literal param0_literal = LiteralUtil::CreateR2FromArray2D(Array2D(0, 3)); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, param0_literal, "param0", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, param0_literal, "param0", + &builder, ¶meter)); Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); auto expected_literal = LiteralUtil::CreateR1({}); ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, @@ -157,8 +164,9 @@ XLA_TEST_P(ReshapeTest, Trivial3x0) { Array2D input_array(3, 0); auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter)); Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); auto expected_literal = LiteralUtil::CreateR1({}); ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, @@ -170,8 +178,9 @@ XLA_TEST_P(ReshapeTest, Trivial1x3) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateR2({{1.0f, 2.0f, 3.0f}}); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter)); Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); auto expected_literal = LiteralUtil::CreateR1({1.0f, 2.0f, 3.0f}); ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, @@ -183,8 +192,9 @@ XLA_TEST_P(ReshapeTest, Trivial3x1) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateR2({{1.0f}, {2.0f}, {3.0f}}); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter)); Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); auto expected_literal = LiteralUtil::CreateR1({1.0f, 2.0f, 3.0f}); ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, @@ -196,8 +206,9 @@ XLA_TEST_P(ReshapeTest, R1ToR2_0_To_2x0) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateR1({}); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter)); Reshape(/*operand=*/parameter, /*dimensions=*/{0}, /*new_sizes=*/{2, 0}); auto expected_literal = LiteralUtil::CreateR2({{}, {}}); @@ -211,8 +222,9 @@ XLA_TEST_P(ReshapeTest, R1ToR2_6_To_2x3) { auto input_literal = LiteralUtil::CreateR1({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter)); Reshape(/*operand=*/parameter, /*dimensions=*/{0}, /*new_sizes=*/{2, 3}); auto expected_literal = @@ -226,8 +238,9 @@ XLA_TEST_P(ReshapeTest, Reshape0x2To2x0) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(Array2D(0, 2)); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter)); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{2, 0}); auto expected_literal = LiteralUtil::CreateR2({{}, {}}); @@ -241,8 +254,9 @@ XLA_TEST_P(ReshapeTest, ReshapeRowToCol) { auto simple = MakeLinspaceArray2D(1.0f, 3.0f, 1, 3); auto input_literal = LiteralUtil::CreateFromArray(*simple); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter)); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{3, 1}); @@ -258,8 +272,9 @@ XLA_TEST_P(ReshapeTest, TransposeAsReshape) { auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); auto input_literal = LiteralUtil::CreateFromArray(*a4x3); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter)); Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0}, /*new_sizes=*/{3, 4}); @@ -274,8 +289,9 @@ XLA_TEST_P(ReshapeTest, Transpose0x4) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(Array2D(0, 4)); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter)); Transpose(parameter, {1, 0}); auto expected_literal = LiteralUtil::CreateR2({{}, {}, {}, {}}); ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, @@ -288,8 +304,9 @@ XLA_TEST_P(ReshapeTest, Transpose4x3) { auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); auto input_literal = LiteralUtil::CreateFromArray(*a4x3); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter)); Transpose(parameter, {1, 0}); auto expected = ReferenceUtil::TransposeArray2D(*a4x3); @@ -304,8 +321,9 @@ XLA_TEST_P(ReshapeTest, ReshapeSplitNoShuffleZeroElements) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(Array2D(6, 0)); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter)); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{2, 3, 0, 0}); auto expected_literal = @@ -318,8 +336,9 @@ XLA_TEST_P(ReshapeTest, ReshapeR4ToR2ZeroElements) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(Array4D(2, 3, 4, 0)); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter)); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{24, 0}); auto expected_literal = LiteralUtil::CreateFromArray(Array2D(24, 0)); @@ -334,8 +353,9 @@ XLA_TEST_P(ReshapeTest, ReshapeSplitNoShuffle) { auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); auto input_literal = LiteralUtil::CreateFromArray(*a4x3); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter)); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{2, 6}); @@ -349,8 +369,9 @@ XLA_TEST_P(ReshapeTest, ReshapeSplitAndShuffleZeroElements) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(Array2D(0, 6)); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter)); Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0}, /*new_sizes=*/{3, 0}); auto expected_literal = LiteralUtil::CreateFromArray(Array2D(3, 0)); @@ -365,8 +386,9 @@ XLA_TEST_P(ReshapeTest, ReshapeSplitAndShuffle) { auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); auto input_literal = LiteralUtil::CreateFromArray(*a4x3); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter)); Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0}, /*new_sizes=*/{2, 6}); Array2D expected({{1.0f, 4.0f, 7.0f, 10.0f, 2.0f, 5.0f}, @@ -391,8 +413,9 @@ XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_012) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests()); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter)); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2}, /*new_sizes=*/{24}); auto expected_literal = LiteralUtil::CreateR1( @@ -406,8 +429,9 @@ XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_012_Refine_83) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests()); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter)); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2}, /*new_sizes=*/{8, 3}); auto expected_literal = LiteralUtil::CreateR2({{10, 11, 12}, @@ -426,8 +450,9 @@ XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_120) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests()); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter)); Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, /*new_sizes=*/{24}); auto expected_literal = LiteralUtil::CreateR1( @@ -441,8 +466,9 @@ XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_120_Refine_83) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests()); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter)); Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, /*new_sizes=*/{8, 3}); auto expected_literal = LiteralUtil::CreateR2({{10, 20, 30}, @@ -461,8 +487,9 @@ XLA_TEST_P(ReshapeTest, DocR3_R3_Collapse_120_Refine_262) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests()); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter)); Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, /*new_sizes=*/{2, 6, 2}); auto expected_literal = LiteralUtil::CreateR3( @@ -494,8 +521,9 @@ XLA_TEST_P(ReshapeTest, FullyConnectedCollapse) { t2x2x2x3.FillWithYX(*filler2x3); auto input_literal = LiteralUtil::CreateFromArray(t2x2x2x3); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter)); Collapse(/*operand=*/parameter, /*dimensions=*/{1, 2, 3}); auto expected_literal = LiteralUtil::CreateR2( {{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, @@ -519,8 +547,9 @@ XLA_TEST_P(ReshapeTest, FullyConnectedCollapseDesugared) { t(1, 0, 1, 1) = 7; auto input_literal = LiteralUtil::CreateFromArray(t); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter)); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 4}); @@ -542,8 +571,9 @@ XLA_TEST_P(ReshapeTest, ToScalar) { input_literal.Set(zeros, 83.0f); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", - &b, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", + &b, ¶meter)); Reshape(parameter, dimensions, {}); auto expected_literal = LiteralUtil::CreateR0(83.0f); @@ -556,8 +586,9 @@ XLA_TEST_P(ReshapeTest, BadDimensions) { XlaBuilder b(TestName()); auto input_literal = LiteralUtil::CreateR1({1.0f}); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &b, - ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", + &b, ¶meter)); Reshape(parameter, {}, {}); EXPECT_THAT( ExecuteToString(&b, {}), @@ -568,8 +599,9 @@ XLA_TEST_P(ReshapeTest, BadNewSizes) { XlaBuilder b(TestName()); auto input_literal = LiteralUtil::CreateR1({1.0f, 2.0f}); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &b, - ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", + &b, ¶meter)); Reshape(parameter, {1}, {}); EXPECT_THAT(ExecuteToString(&b, {}), ::testing::HasSubstr("mismatched element counts")); @@ -604,8 +636,9 @@ XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { LayoutUtil::MakeLayout({0, 1, 2, 3})); // clang-format on XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter)); Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 8}); @@ -639,8 +672,9 @@ XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) { {200, 201, 202, 203, 204, 205, 206, 207}, }); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter)); Reshape(parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{3, 2, 1, 4}); // clang-format off @@ -666,8 +700,9 @@ XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4_Dimensions_10) { {200, 201, 202, 203, 204, 205, 206, 207}, }); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter)); Reshape(parameter, /*dimensions=*/{1, 0}, /*new_sizes=*/{3, 2, 1, 4}); // clang-format off @@ -694,8 +729,9 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) { Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp parameter; - auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN(auto input_data, + CreateParameterAndTransferLiteral( + 0, input_literal, "input", &builder, ¶meter)); Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 1}); Literal expected = LiteralUtil::ReshapeSlice({2, 1}, {1, 0}, input_literal); @@ -713,8 +749,9 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) { Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp parameter; - auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN(auto input_data, + CreateParameterAndTransferLiteral( + 0, input_literal, "input", &builder, ¶meter)); Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{4, 2}); Literal expected = LiteralUtil::ReshapeSlice({4, 2}, {1, 0}, input_literal); @@ -733,8 +770,9 @@ XLA_TEST_P(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) { Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp parameter; - auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN(auto input_data, + CreateParameterAndTransferLiteral( + 0, input_literal, "input", &builder, ¶meter)); Reshape(parameter, /*dimensions=*/{0, 2, 1, 3}, /*new_sizes=*/{5, 60}); @@ -759,8 +797,9 @@ XLA_TEST_P(ReshapeTest, NoopReshape) { Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({1, 2, 3, 0})); XlaOp parameter; - auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN(auto input_data, + CreateParameterAndTransferLiteral( + 0, input_literal, "input", &builder, ¶meter)); Reshape(parameter, /*dimensions=*/{3, 0, 1, 2}, /*new_sizes=*/{7, 2, 3, 5}); XlaComputation computation = builder.Build().ConsumeValueOrDie(); @@ -793,8 +832,9 @@ XLA_TEST_P(ReshapeTest, R4ToR4Reshape_Trivial) { {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}}); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, literal_1x2x3x4, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, literal_1x2x3x4, "input", + &builder, ¶meter)); Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{1, 2, 3, 4}); @@ -808,8 +848,9 @@ XLA_TEST_P(ReshapeTest, R4ToR4Reshape) { XlaBuilder builder(TestName()); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, literal_1x2x3x4, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, literal_1x2x3x4, "input", + &builder, ¶meter)); Reshape(parameter, /*dimensions=*/{1, 3, 2, 0}, /*new_sizes=*/{2, 4, 3, 1}); @@ -840,8 +881,9 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeSimple) { input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaBuilder builder(TestName()); XlaOp parameter; - auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN(auto input_data, + CreateParameterAndTransferLiteral( + 0, input_literal, "input", &builder, ¶meter)); Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); @@ -867,8 +909,9 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) { input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaBuilder builder(TestName()); XlaOp parameter; - auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN(auto input_data, + CreateParameterAndTransferLiteral( + 0, input_literal, "input", &builder, ¶meter)); Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); @@ -894,8 +937,9 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) { input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaBuilder builder(TestName()); XlaOp parameter; - auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN(auto input_data, + CreateParameterAndTransferLiteral( + 0, input_literal, "input", &builder, ¶meter)); Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); @@ -922,8 +966,9 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) { input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaBuilder builder(TestName()); XlaOp parameter; - auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN(auto input_data, + CreateParameterAndTransferLiteral( + 0, input_literal, "input", &builder, ¶meter)); Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); @@ -949,8 +994,9 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeTrivialR2) { input, LayoutUtil::MakeLayout({0, 1, 2, 3})); XlaBuilder builder(TestName()); XlaOp parameter; - auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN(auto input_data, + CreateParameterAndTransferLiteral( + 0, input_literal, "input", &builder, ¶meter)); Reshape(parameter, /*dimensions=*/{1, 0, 2, 3}, /*new_sizes=*/new_bounds); diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index fc1ca7d3105..aa02deb7bca 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -1,5 +1,7 @@ # Tools and utilities that aid in XLA development and usage. +load("//tensorflow:tensorflow.bzl", "filegroup") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load( "//tensorflow:tensorflow.bzl", "if_cuda_or_rocm", @@ -264,7 +266,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/core:lib", - "//tensorflow/core:regexp_internal", + "//tensorflow/core/platform:regexp", "@com_google_absl//absl/strings", "@com_google_protobuf//:protobuf_headers", ], diff --git a/tensorflow/compiler/jit/union_find.h b/tensorflow/compiler/xla/union_find.h similarity index 100% rename from tensorflow/compiler/jit/union_find.h rename to tensorflow/compiler/xla/union_find.h diff --git a/tensorflow/compiler/xla/util.cc b/tensorflow/compiler/xla/util.cc index 4034e5fdd27..6e7deda13f0 100644 --- a/tensorflow/compiler/xla/util.cc +++ b/tensorflow/compiler/xla/util.cc @@ -374,7 +374,9 @@ std::pair SplitF64ToF32(double x) { // Only values within the range of F32 are supported, unless it is infinity. // Small values with large negative exponents would be rounded to zero. - CHECK(std::isfinite(x_f32)) << x; + if (!std::isfinite(x_f32)) { + LOG(WARNING) << "Out of range F64 constant detected: " << x; + } // The high float is simply the double rounded to the nearest float. Because // we are rounding to nearest with ties to even, the error introduced in diff --git a/tensorflow/compiler/xla/util_test.cc b/tensorflow/compiler/xla/util_test.cc index 69acc59d8a2..5477dfba18d 100644 --- a/tensorflow/compiler/xla/util_test.cc +++ b/tensorflow/compiler/xla/util_test.cc @@ -15,10 +15,12 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" +#include #include #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/bfloat16.h" namespace xla { namespace { @@ -103,5 +105,26 @@ TEST(UtilTest, SanitizeFileName) { EXPECT_EQ(SanitizeFileName("/A\\B[C]"), "_A_B_C_"); } +TEST(UtilTest, RoundTripFpToString) { + EXPECT_EQ(RoundTripFpToString(std::numeric_limits::quiet_NaN()), + "nan"); + EXPECT_EQ(RoundTripFpToString(-std::numeric_limits::quiet_NaN()), + "-nan"); + EXPECT_EQ(RoundTripFpToString( + std::numeric_limits::quiet_NaN()), + "nan"); + EXPECT_EQ(RoundTripFpToString( + -std::numeric_limits::quiet_NaN()), + "-nan"); + EXPECT_EQ(RoundTripFpToString(std::numeric_limits::quiet_NaN()), + "nan"); + EXPECT_EQ(RoundTripFpToString(-std::numeric_limits::quiet_NaN()), + "-nan"); + EXPECT_EQ(RoundTripFpToString(std::numeric_limits::quiet_NaN()), + "nan"); + EXPECT_EQ(RoundTripFpToString(-std::numeric_limits::quiet_NaN()), + "-nan"); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index d334f879c3e..7da8d2cb84d 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -245,6 +245,13 @@ message ComputationStats { double transcendental_count = 2; } +// The type optimization profiles in use. +enum ProfileType { + INVALID = 0; + WINDOW = 1; + FLAG = 2; +} + // Symbolization metadata for HLO Instructions. // // This metadata is used for debugging XLA code generation, as well as @@ -268,6 +275,8 @@ message OpMetadata { // e.g. it could be the file and line of user code that generated the op. string source_file = 3; int32 source_line = 4; + + repeated ProfileType profile_type = 5; } // Profile data from the execution of a computation. @@ -691,3 +700,11 @@ message WhileLoopBackendConfig { // unknown-trip-count. KnownTripCount known_trip_count = 1; } + +// Specifies a pair of output/operand buffers for kCustomCall that alias each +// other. +message CustomCallOutputOperandAliasing { + repeated int64 output_shape_index = 1; + int64 operand_index = 2; + repeated int64 operand_shape_index = 3; +} diff --git a/tensorflow/compiler/xrt/BUILD b/tensorflow/compiler/xrt/BUILD index 172a970d207..1b699e7d8df 100644 --- a/tensorflow/compiler/xrt/BUILD +++ b/tensorflow/compiler/xrt/BUILD @@ -1,15 +1,15 @@ # Description: Operations defined for XRT +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load( "//tensorflow:tensorflow.bzl", "tf_custom_op_py_library", - "tf_gen_op_libs", "tf_gen_op_wrapper_py", ) +load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs") load( "//tensorflow/core/platform:build_config.bzl", - "tf_proto_library_cc", - "tf_proto_library_py", + "tf_proto_library", ) package( @@ -20,7 +20,7 @@ package( licenses = ["notice"], # Apache 2.0 ) -tf_proto_library_cc( +tf_proto_library( name = "xrt_proto", srcs = ["xrt.proto"], cc_api_version = 2, @@ -33,12 +33,6 @@ tf_proto_library_cc( visibility = ["//visibility:public"], ) -tf_proto_library_py( - name = "xrt_proto", # bzl adds a _py suffix - srcs = ["xrt.proto"], - visibility = ["//visibility:public"], -) - cc_library( name = "xrt_utils", srcs = [ @@ -80,7 +74,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "//tensorflow/core:regexp_internal", + "//tensorflow/core/platform:regexp", "//tensorflow/core/profiler/lib:traceme", "//tensorflow/stream_executor", "//tensorflow/stream_executor:device_memory_allocator", diff --git a/tensorflow/compiler/xrt/cc/BUILD b/tensorflow/compiler/xrt/cc/BUILD index 99ab50c8a8d..c8932150cb5 100644 --- a/tensorflow/compiler/xrt/cc/BUILD +++ b/tensorflow/compiler/xrt/cc/BUILD @@ -1,7 +1,4 @@ -load( - "//tensorflow:tensorflow.bzl", - "tf_gen_op_wrappers_cc", -) +load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrappers_cc") package( default_visibility = ["//visibility:public"], diff --git a/tensorflow/compiler/xrt/kernels/BUILD b/tensorflow/compiler/xrt/kernels/BUILD index 494ba29e981..68c24f88703 100644 --- a/tensorflow/compiler/xrt/kernels/BUILD +++ b/tensorflow/compiler/xrt/kernels/BUILD @@ -1,3 +1,5 @@ +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") + package( default_visibility = [ "//learning/brain:__subpackages__", diff --git a/tensorflow/compiler/xrt/tests/BUILD b/tensorflow/compiler/xrt/tests/BUILD index 2f1faf1cdf1..724cfe38d54 100644 --- a/tensorflow/compiler/xrt/tests/BUILD +++ b/tensorflow/compiler/xrt/tests/BUILD @@ -1,3 +1,4 @@ +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_cuda_cc_test") load( "//tensorflow/core/platform:build_config_root.bzl", diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 935d8840831..adc59c67dce 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -46,12 +46,9 @@ # # Public mobile targets, e.g. for Android: # -# filegroup ":android_proto_srcs" - Protos -# filegroup ":android_srcs" - Core sources # cc_library ":portable_tensorflow_lib" - Native library # cc_library ":portable_tensorflow_lib_lite" - Native library, without ops, # supporting SELECTIVE_REGISTRATION feature. -# portable_proto_library ":portable_proto_lib" (Google-internal) # # Note that :framework and :lib have incomplete transitive dependencies (they # declare but do not define some symbols) if framework_shared_object=True @@ -65,15 +62,13 @@ load( "//tensorflow:tensorflow.bzl", - "cc_header_only_library", "if_android", "if_chromiumos", "if_cuda_or_rocm", "if_ios", + "if_libtpu", "if_mobile", "if_not_windows", - "if_tpu", - "tf_android_core_proto_headers", "tf_cc_test", "tf_cc_test_mkl", "tf_cc_tests", @@ -81,28 +76,28 @@ load( "tf_cuda_library", "tf_defines_nortti_if_lite_protos", "tf_features_nomodules_if_mobile", - "tf_gen_op_libs", - "tf_genrule_cmd_append_to_srcs", "tf_opts_nortti_if_lite_protos", - "tf_portable_full_lite_protos", "transitive_hdrs", ) +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "cc_header_only_library") + # buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "if_nccl") # buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "tensorflow_opensource_extra_deps") -# buildifier: disable=same-origin-load -load("//tensorflow:tensorflow.bzl", "tf_cc_test_gpu") - # buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "tf_cc_tests_gpu") # buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "tf_monitoring_framework_deps") +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "filegroup") + # For platform specific build config load( "//tensorflow/core/platform:build_config.bzl", @@ -117,6 +112,7 @@ load( "tf_protos_all_impl", "tf_protos_grappler_impl", "tf_protos_profiler_impl", + "tf_tpu_dependencies", ) load( "//tensorflow/core/platform:rules_cc.bzl", @@ -135,7 +131,6 @@ load( "if_mkl", "mkl_deps", ) -# Placeholder for Google-internal load statements. package( default_visibility = [ @@ -159,14 +154,8 @@ package_group( # Export the BUILD file so automated tooling can check licenses exports_files([ "BUILD", - "ops/ops.pbtxt", ]) -package_group( - name = "experimental_access", - packages = ["//tensorflow/core/common_runtime/..."], -) - # Authorized users go here. package_group(name = "friends") @@ -177,7 +166,6 @@ package_group(name = "friends") # # Note that some protos are in neither additional_core_proto_srcs nor this # filegroup; e.g. ones with individual proto_library targets. -# LINT.IfChange COMMON_PROTO_SRCS = [ "//tensorflow/core/protobuf:bfc_memory_map.proto", "//tensorflow/core/protobuf:config.proto", @@ -240,7 +228,6 @@ ERROR_CODES_PROTO_SRCS = [ "//tensorflow/core/protobuf:error_codes.proto", "//tensorflow/core/lib/core:error_codes.proto", ] -# LINT.ThenChange(//tensorflow/core/portable_proto_config.asciipb) CORE_PROTO_SRCS = COMMON_PROTO_SRCS + EXAMPLE_PROTO_SRCS + FRAMEWORK_PROTO_SRCS + UTIL_PROTO_SRCS + PROFILER_PROTO_SRCS + ERROR_CODES_PROTO_SRCS @@ -248,6 +235,7 @@ tf_proto_library( name = "protos_all", srcs = [], cc_api_version = 2, + create_go_proto = False, make_default_target_header_only = True, protodeps = [ "//tensorflow/core/example:protos_all", @@ -286,7 +274,6 @@ cc_library( hdrs = ["//tensorflow/core/platform:base_hdrs"], copts = tf_copts(), tags = ["avoid_dep"], - visibility = [":__subpackages__"], deps = [ "//tensorflow/core/platform", "//tensorflow/core/platform:byte_order", @@ -302,17 +289,6 @@ cc_library( ], ) -alias( - name = "framework_bounds_check", - actual = "//tensorflow/core/framework:bounds_check", - visibility = ["//tensorflow/core/kernels:friends"], -) - -alias( - name = "human_readable_json", - actual = "//tensorflow/core/platform:human_readable_json", -) - # Minimal lib so that tools used for mobile compilation # don't have to depend on lib/platformlib. cc_library( @@ -374,22 +350,6 @@ cc_library( ], ) -# APIs defined in lib_experimental are for experimental usage and may be -# subject to change. Its visibility is limited to selected packages. -cc_library( - name = "lib_experimental", - hdrs = [ - "//tensorflow/core/lib/core:legacy_lib_core_threadpool_options_header", - ], - visibility = [ - ":experimental_access", - "//tensorflow/cc:__pkg__", - ], - deps = [ - ":lib", - ], -) - alias( name = "feature_util", actual = "//tensorflow/core/example:feature_util", @@ -458,7 +418,9 @@ tf_cuda_library( "//tensorflow/core/framework:control_flow.h", # TODO(josh11b): Make internal? "//tensorflow/core/framework:dataset.h", "//tensorflow/core/framework:dataset_stateful_op_allowlist.h", + "//tensorflow/core/framework:device.h", "//tensorflow/core/framework:device_base.h", + "//tensorflow/core/framework:device_factory.h", "//tensorflow/core/framework:function.h", "//tensorflow/core/framework:function_handle_cache.h", "//tensorflow/core/framework:graph_def_util.h", @@ -488,6 +450,7 @@ tf_cuda_library( "//tensorflow/core/framework:register_types_traits.h", "//tensorflow/core/framework:resource_mgr.h", "//tensorflow/core/framework:resource_op_kernel.h", + "//tensorflow/core/framework:rng_alg.h", "//tensorflow/core/framework:selective_registration.h", "//tensorflow/core/framework:session_state.h", "//tensorflow/core/framework:shape_inference.h", @@ -522,30 +485,6 @@ tf_cuda_library( ], ) -# TODO(gonnet): Remove this alias once all users have been moved to the actual target. -alias( - name = "allocator", - actual = "//tensorflow/core/framework:allocator", - visibility = ["//visibility:public"], -) - -# TODO(gonnet): Remove this alias once all users have been moved to the actual target. -alias( - name = "allocator_registry_impl", - actual = "//tensorflow/core/framework:allocator_registry_impl", - visibility = ["//visibility:public"], -) - -alias( - name = "overflow", - actual = "//tensorflow/core/util:overflow", -) - -alias( - name = "exec_on_stall", - actual = "//tensorflow/core/util:exec_on_stall", -) - alias( name = "ptr_util", actual = "//tensorflow/core/util:ptr_util", @@ -600,156 +539,7 @@ cc_library( ], ) -# Generates library per group of ops. -tf_gen_op_libs( - is_external = False, - op_lib_names = [ - "batch_ops", - "bitwise_ops", - "boosted_trees_ops", - "tensor_forest_ops", - "candidate_sampling_ops", - "checkpoint_ops", - "clustering_ops", - "collective_ops", - "control_flow_ops", - "count_ops", - "ctc_ops", - "data_flow_ops", - "dataset_ops", - "decode_proto_ops", - "encode_proto_ops", - "experimental_dataset_ops", - "function_ops", - "functional_ops", - "image_ops", - "io_ops", - "linalg_ops", - "list_ops", - "map_ops", - "lookup_ops", - "manip_ops", - "math_ops", - "mkl_nn_ops", - "nccl_ops", - "nn_ops", - "no_op", - "parsing_ops", - "random_grad", - "random_ops", - "special_math_ops", - "stateful_random_ops", - "remote_fused_graph_ops", - "rnn_ops", - "rpc_ops", - "scoped_allocator_ops", - "sdca_ops", - "set_ops", - "script_ops", - "sendrecv_ops", - "sparse_csr_matrix_ops", - "sparse_ops", - "spectral_ops", - "state_ops", - "stateless_random_ops", - "summary_ops", - "training_ops", - ], - deps = [ - ":lib", - ":protos_all_cc", - ], -) - -tf_gen_op_libs( - is_external = False, - op_lib_names = [ - "logging_ops", - ], - deps = [ - ":lib", - ":protos_all_cc", - # TODO(b/162630222): remove this dependency. - "//tensorflow/c/kernels:histogram_summary_op_lib", - "//tensorflow/c/kernels:merge_summary_op_lib", - "//tensorflow/c/kernels:summary_op_lib", - ], -) - -tf_gen_op_libs( - op_lib_names = [ - "string_ops", - ], - deps = [ - ":lib_internal", - ":lib_proto_parsing", - "@com_google_absl//absl/strings", - ], -) - -tf_gen_op_libs( - op_lib_names = [ - "array_ops", - ], - deps = [ - ":lib", - ":protos_all_cc", - ], -) - -tf_gen_op_libs( - op_lib_names = [ - "mkl_array_ops", - ], - deps = [":protos_all_cc"], -) - -tf_gen_op_libs( - op_lib_names = [ - "audio_ops", - ], - deps = [":lib"], -) - -tf_gen_op_libs( - op_lib_names = ["debug_ops"], - deps = [":lib"], -) - -tf_gen_op_libs( - is_external = False, - op_lib_names = [ - "resource_variable_ops", - ], - deps = [":lib"], -) - -tf_gen_op_libs( - op_lib_names = [ - "tpu_configuration_ops", - "tpu_cross_replica_ops", - "tpu_embedding_ops", - "tpu_embedding_load_retrieve_ops", - "tpu_functional_ops", - "tpu_heartbeat_ops", - "tpu_host_compute_ops", - "tpu_infeed_ops", - "tpu_outfeed_ops", - "tpu_ordinal_selector_ops", - "tpu_replication_ops", - ], - deps = [ - ":lib", - ":lib_proto_parsing", - ":protos_all_cc", - "//tensorflow/core/protobuf/tpu:optimization_parameters_proto_cc", - "//tensorflow/core/protobuf/tpu:tpu_embedding_configuration_proto_cc", - "//tensorflow/core/tpu:tpu_embedding_optimization_parameters_utils", - "//tensorflow/core/tpu:tpu_embedding_output_layout_utils", - ], -) - -# And one for all user ops +# One target for all user ops cc_library( name = "user_ops_op_lib", srcs = glob(["user_ops/**/*.cc"]), @@ -760,212 +550,29 @@ cc_library( alwayslink = 1, ) -cc_library( - name = "word2vec_ops", - srcs = ["ops/word2vec_ops.cc"], - linkstatic = 1, - visibility = ["//tensorflow:internal"], - deps = [":framework"], - alwayslink = 1, -) - -cc_library( - name = "cudnn_rnn_ops", - srcs = [ - "ops/cudnn_rnn_ops.cc", - ], - linkstatic = 1, - visibility = ["//tensorflow:internal"], - deps = [ - ":framework", - ":lib", - ":lib_internal", - ":stream_executor", - "//tensorflow/core/kernels:bounds_check_lib", - ], - alwayslink = 1, -) - -tf_gen_op_libs( - op_lib_names = [ - "cudnn_rnn_ops", - ], - deps = [ - ":lib", - ], -) - -cc_library( - name = "ragged_ops", - deps = [ - ":ragged_array_ops_op_lib", - ":ragged_conversion_ops_op_lib", - ":ragged_math_ops_op_lib", - ], -) - -tf_gen_op_libs( - op_lib_names = [ - "ragged_array_ops", - "ragged_conversion_ops", - "ragged_math_ops", - ], - deps = ["//tensorflow/core/util:ragged_to_dense_util"], -) - cc_library( name = "ops", visibility = ["//visibility:public"], deps = [ - ":array_ops_op_lib", - ":audio_ops_op_lib", - ":batch_ops_op_lib", - ":bitwise_ops_op_lib", - ":boosted_trees_ops_op_lib", - ":tensor_forest_ops_op_lib", - ":candidate_sampling_ops_op_lib", - ":checkpoint_ops_op_lib", - ":clustering_ops_op_lib", - ":collective_ops_op_lib", - ":control_flow_ops_op_lib", - ":count_ops_op_lib", - ":ctc_ops_op_lib", - ":cudnn_rnn_ops_op_lib", - ":data_flow_ops_op_lib", - ":dataset_ops_op_lib", - ":debug_ops_op_lib", - ":decode_proto_ops_op_lib", - ":encode_proto_ops_op_lib", - ":experimental_dataset_ops_op_lib", - ":function_ops_op_lib", - ":functional_ops_op_lib", - ":image_ops_op_lib", - ":io_ops_op_lib", - ":linalg_ops_op_lib", - ":list_ops_op_lib", - ":map_ops_op_lib", - ":logging_ops_op_lib", - ":lookup_ops_op_lib", - ":manip_ops_op_lib", - ":math_ops_op_lib", - ":nccl_ops_op_lib", - ":nn_ops_op_lib", - ":no_op_op_lib", - ":parsing_ops_op_lib", - ":ragged_ops", - ":random_ops_op_lib", - ":rnn_ops_op_lib", - ":special_math_ops_op_lib", - ":stateful_random_ops_op_lib", - ":remote_fused_graph_ops_op_lib", - ":resource_variable_ops_op_lib", - ":rpc_ops_op_lib", - ":scoped_allocator_ops_op_lib", - ":script_ops_op_lib", - ":sdca_ops_op_lib", - ":sendrecv_ops_op_lib", - ":set_ops_op_lib", - ":sparse_csr_matrix_ops_op_lib", - ":sparse_ops_op_lib", - ":summary_ops_op_lib", - ":spectral_ops_op_lib", - ":state_ops_op_lib", - ":stateless_random_ops_op_lib", - ":string_ops_op_lib", - ":training_ops_op_lib", ":user_ops_op_lib", - ":word2vec_ops", "//tensorflow/c/kernels:bitcast_op_lib", "//tensorflow/c/kernels:histogram_summary_op_lib", "//tensorflow/c/kernels:merge_summary_op_lib", "//tensorflow/c/kernels:summary_op_lib", - "//tensorflow/compiler/mlir/tensorflow:mlir_passthrough_op", + "//tensorflow/core/ops:ops", ] + if_chromiumos( [], - # Non-tpu platforms don't need tpu dependency. It would be best to guard - # them by if_tpu. But there is no such flag yet. + # Non-tpu platforms don't need tpu dependency. [ - ":tpu_configuration_ops_op_lib", - ":tpu_cross_replica_ops_op_lib", - ":tpu_embedding_ops_op_lib", - ":tpu_embedding_load_retrieve_ops_op_lib", - ":tpu_functional_ops_op_lib", - ":tpu_heartbeat_ops_op_lib", - ":tpu_host_compute_ops_op_lib", - ":tpu_infeed_ops_op_lib", - ":tpu_outfeed_ops_op_lib", - ":tpu_ordinal_selector_ops_op_lib", - ":tpu_replication_ops_op_lib", "//tensorflow/core/tpu/ops", ], - ) + if_mkl([ - ":mkl_array_ops_op_lib", - ":mkl_nn_ops_op_lib", - ]) + if_tensorrt([ + ) + if_tensorrt([ "//tensorflow/compiler/tf2tensorrt:trt_engine_resource_ops_op_lib", "//tensorflow/compiler/tf2tensorrt:trt_op_libs", - ]), - alwayslink = 1, -) - -cc_library( - name = "array_grad", - srcs = ["ops/array_grad.cc"], - linkstatic = 1, # Needed since alwayslink is broken in bazel b/27630669 - visibility = ["//visibility:public"], - deps = [ - ":array_ops_op_lib", - ":framework", - ":lib", - "//tensorflow/c/kernels:bitcast_op_lib", - ], - alwayslink = 1, -) - -cc_library( - name = "functional_grad", - srcs = ["ops/functional_grad.cc"], - linkstatic = 1, # Needed since alwayslink is broken in bazel b/27630669 - visibility = ["//visibility:public"], - deps = [ - ":framework", - ":functional_ops_op_lib", - ":lib", - ], - alwayslink = 1, -) - -cc_library( - name = "math_grad", - srcs = [ - "ops/math_grad.cc", - "ops/random_grad.cc", - "ops/stateless_random_grad.cc", - ], - linkstatic = 1, # Needed since alwayslink is broken in bazel b/27630669 - visibility = ["//visibility:public"], - deps = [ - ":framework", - ":lib", - ":math_ops_op_lib", - ":protos_all_cc", - ], - alwayslink = 1, -) - -cc_library( - name = "nn_grad", - srcs = ["ops/nn_grad.cc"], - linkstatic = 1, # Needed since alwayslink is broken in bazel b/27630669 - visibility = ["//visibility:public"], - deps = [ - ":framework", - ":lib", - ":nn_ops_op_lib", - ] + if_mkl([ - ":mkl_nn_ops_op_lib", - ]), - alwayslink = 1, + ]) + if_libtpu( + if_false = ["//tensorflow/compiler/mlir/tensorflow:mlir_passthrough_op"], + if_true = [], + ), ) alias( @@ -1086,9 +693,7 @@ cc_library( ]) + if_tensorrt([ "//tensorflow/compiler/tf2tensorrt:trt_engine_resource_op_kernels", "//tensorflow/compiler/tf2tensorrt:trt_op_kernels", - ]) + if_tpu([ - "//tensorflow/core/tpu/kernels", - ]), + ]) + tf_tpu_dependencies(), ) cc_library( @@ -1166,7 +771,7 @@ cc_library( ) # Test support library needed for higher-level (TensorFlow-specific) tests -cc_library( +tf_cuda_library( name = "testlib", testonly = 1, srcs = [ @@ -1199,10 +804,10 @@ cc_library( ":ops", ":protos_all_cc", ":test", - ":testlib_ops", # TODO(gunan): resolve dependency issues and load these kernels dynamically. ":testlib_kernels_impl", "//tensorflow/cc:scope", + "//tensorflow/core/common_runtime:testlib_ops", "//tensorflow/core/framework:fake_input", "//tensorflow/core/framework:function_testlib", "//tensorflow/core/framework:shape_inference_testutil", @@ -1212,13 +817,6 @@ cc_library( ], ) -alias( - name = "testlib_ops", - testonly = 1, - actual = - "//tensorflow/core/common_runtime:testlib_ops", -) - # This is a link-only library to provide a DirectSession # implementation of the Session interface. tf_cuda_library( @@ -1232,23 +830,9 @@ tf_cuda_library( alwayslink = 1, ) -# ----------------------------------------------------------------------------- -# MKL targets -alias( - name = "mkl_graph_util", - actual = "//tensorflow/core/graph:mkl_graph_util", -) - # ----------------------------------------------------------------------------- # Public Android targets -# List of protos we want on android -filegroup( - name = "android_proto_srcs", - srcs = CORE_PROTO_SRCS, - visibility = ["//visibility:public"], -) - # Sources required to build the TensorFlow framework without the runtime on # mobile platforms. This is essentially the sources required to build # tensorflow/core/framework:tensor without using granular targets. @@ -1278,7 +862,7 @@ filegroup( "**/*main.cc", ], ), - visibility = ["//visibility:private"], + visibility = ["//visibility:public"], ) # Sources required to build the TensorFlow framework with runtime on @@ -1297,6 +881,7 @@ filegroup( "//tensorflow/core/graph:mobile_srcs_only_runtime", "//tensorflow/core/kernels:mobile_srcs", "//tensorflow/core/lib/io:mobile_srcs_only_runtime", + "//tensorflow/core/nccl:mobile_srcs", "//tensorflow/core/profiler:mobile_srcs", "//tensorflow/core/public:mobile_srcs_only_runtime", "//tensorflow/core/util/sparse:mobile_srcs_only_runtime", @@ -1337,11 +922,98 @@ filegroup( visibility = ["//visibility:public"], ) -alias( - name = "android_srcs", - actual = ":mobile_srcs", - visibility = ["//visibility:public"], -) +# All the aliases for stuff under ops/ +# Once the dependencies move to the real targets, remove the aliases here! + +[ + alias( + name = "%s" % (name,), + actual = "//tensorflow/core/ops:%s" % (name,), + visibility = ["//visibility:public"], + ) + for name in [ + "array_grad", + "array_ops_op_lib", + "audio_ops_op_lib", + "batch_ops_op_lib", + "bitwise_ops_op_lib", + "boosted_trees_ops_op_lib", + "candidate_sampling_ops_op_lib", + "checkpoint_ops_op_lib", + "clustering_ops_op_lib", + "collective_ops_op_lib", + "control_flow_ops_op_lib", + "count_ops_op_lib", + "ctc_ops_op_lib", + "cudnn_rnn_ops_op_lib", + "data_flow_ops_op_lib", + "dataset_ops_op_lib", + "debug_ops_op_lib", + "decode_proto_ops_op_lib", + "encode_proto_ops_op_lib", + "experimental_dataset_ops_op_lib", + "function_ops_op_lib", + "functional_grad", + "functional_ops_op_lib", + "image_ops_op_lib", + "io_ops_op_lib", + "linalg_ops_op_lib", + "list_ops_op_lib", + "logging_ops_op_lib", + "lookup_ops_op_lib", + "manip_ops_op_lib", + "map_ops_op_lib", + "math_grad", + "math_ops_op_lib", + "mkl_array_ops_op_lib", + "mkl_nn_ops_op_lib", + "nccl_ops_op_lib", + "nn_grad", + "nn_ops_op_lib", + "no_op_op_lib", + "parsing_ops_op_lib", + "portable_op_registrations_and_gradients", + "ragged_array_ops_op_lib", + "ragged_conversion_ops_op_lib", + "ragged_math_ops_op_lib", + "ragged_ops", + "random_grad_op_lib", + "random_ops_op_lib", + "remote_fused_graph_ops_op_lib", + "resource_variable_ops_op_lib", + "rnn_ops_op_lib", + "rpc_ops_op_lib", + "scoped_allocator_ops_op_lib", + "script_ops_op_lib", + "sdca_ops_op_lib", + "sendrecv_ops_op_lib", + "set_ops_op_lib", + "sparse_csr_matrix_ops_op_lib", + "sparse_ops_op_lib", + "special_math_ops_op_lib", + "spectral_ops_op_lib", + "state_ops_op_lib", + "stateful_random_ops_op_lib", + "stateless_random_ops_op_lib", + "stateless_random_ops_v2_op_lib", + "string_ops_op_lib", + "summary_ops_op_lib", + "tensor_forest_ops_op_lib", + "tpu_configuration_ops_op_lib", + "tpu_cross_replica_ops_op_lib", + "tpu_embedding_ops_op_lib", + "tpu_embedding_load_retrieve_ops_op_lib", + "tpu_functional_ops_op_lib", + "tpu_heartbeat_ops_op_lib", + "tpu_host_compute_ops_op_lib", + "tpu_infeed_ops_op_lib", + "tpu_outfeed_ops_op_lib", + "tpu_ordinal_selector_ops_op_lib", + "tpu_replication_ops_op_lib", + "training_ops_op_lib", + "word2vec_ops", + ] +] # Native library support for mobile applications. Does not contain # operators, use :portable_tensorflow_lib if you want full operator @@ -1356,7 +1028,7 @@ alias( # Compiles to a trivial library on non-mobile to prevent irrelevant # build errors. If not building this e.g. as part of an android_binary, # a command such as the following must be used: -# bazel build -c opt tensorflow/core:android_tensorflow_lib \ +# bazel build -c opt tensorflow/core:portable_tensorflow_lib \ # --define=TENSORFLOW_PROTOS=lite \ # --crosstool_top=//external:android/crosstool \ # --cpu=armeabi-v7a \ @@ -1379,24 +1051,6 @@ cc_library( alwayslink = 1, ) -alias( - name = "android_tensorflow_lib_lite", - actual = ":portable_tensorflow_lib_lite", - visibility = ["//visibility:public"], -) - -alias( - name = "android_tensorflow_lib_lite_nortti", - actual = ":portable_tensorflow_lib_lite", - visibility = ["//visibility:public"], -) - -alias( - name = "android_tensorflow_lib_lite_nortti_lite_protos", - actual = ":portable_tensorflow_lib_lite", - visibility = ["//visibility:public"], -) - cc_library( name = "mobile_additional_lib_deps", deps = tf_additional_lib_deps() + [ @@ -1407,26 +1061,6 @@ cc_library( ], ) -alias( - name = "ios_tensorflow_lib_lite", - actual = ":portable_tensorflow_lib_lite", - visibility = ["//visibility:public"], -) - -# Full TensorFlow library with operator support. Use this unless reducing -# binary size (by packaging a reduced operator set) is a concern. -alias( - name = "android_tensorflow_lib", - actual = ":portable_tensorflow_lib", - visibility = ["//visibility:public"], -) - -alias( - name = "ios_tensorflow_lib", - actual = ":portable_tensorflow_lib", - visibility = ["//visibility:public"], -) - cc_library( name = "portable_tensorflow_lib", srcs = if_mobile([":portable_op_registrations_and_gradients"]), @@ -1447,52 +1081,6 @@ cc_library( alwayslink = 1, ) -alias( - name = "android_op_registrations_and_gradients", - actual = ":portable_op_registrations_and_gradients", - visibility = ["//visibility:public"], -) - -filegroup( - name = "portable_op_registrations_and_gradients", - srcs = ["//tensorflow/c/kernels:android_all_ops"] + glob( - [ - "ops/**/*.cc", - "ops/**/*.h", - ], - exclude = [ - "**/*test.cc", - "**/*testutil*", - "**/*testlib*", - "**/*main.cc", - "**/tpu_*", - ], - ), - visibility = ["//visibility:public"], -) - -filegroup( - name = "android_test_srcs", - testonly = 1, - # TODO(andrewharp/nhua): - # make more test-related sources portable e.g. "//tensorflow/core/platform:test.cc", - srcs = tf_portable_full_lite_protos( - full = [ - "//tensorflow/core/framework:android_test_hdrs", - "//tensorflow/core/framework:android_test_srcs", - "//tensorflow/core/platform:android_test_srcs", - "//tensorflow/core/util:android_test_srcs", - ], - lite = [ - "//tensorflow/core/framework:android_test_hdrs", - "//tensorflow/core/framework:android_test_srcs_no_core", - "//tensorflow/core/platform:android_test_srcs", - "//tensorflow/core/util:android_test_srcs", - ], - ), - visibility = ["//visibility:public"], -) - # This is like android_test_srcs, minus the things that are already in mobile_srcs. filegroup( name = "android_test_srcs_no_core", @@ -1507,18 +1095,6 @@ filegroup( ) # Portable library providing testing functionality for TensorFlow. -alias( - name = "android_tensorflow_test_lib", - actual = ":portable_tensorflow_test_lib", - visibility = ["//visibility:public"], -) - -alias( - name = "ios_tensorflow_test_lib", - actual = ":portable_tensorflow_test_lib", - visibility = ["//visibility:public"], -) - cc_library( name = "portable_tensorflow_test_lib", testonly = 1, @@ -1538,7 +1114,6 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":portable_tensorflow_lib", - ":protos_all_cc", "//tensorflow/core/kernels:portable_tensorflow_kernels", "//tensorflow/core/platform/default/build_config:gtest", "//third_party/eigen3", @@ -1591,105 +1166,8 @@ alias( ] ] -# The following targets will be moved to core/protobuf. The aliases are only temporary -# since moving existing users will require several CLs over several projects. -[ - [ - alias( - name = "protobuf_%s_pyclif%s" % (proto_name, target_suffix), - actual = "//tensorflow/core/protobuf:%s_pyclif%s" % (proto_name, target_suffix), - visibility = ["//visibility:public"], - ) - for target_suffix in [ - "", - "_pb2", - ] - ] - for proto_name in [ - "config", - "device_properties", - "graph_debug_info", - "meta_graph", - "saved_model", - ] -] - # ----------------------------------------------------------------------------- # Internal targets - -alias( - name = "autotuning_proto", - actual = "//tensorflow/core/protobuf:autotuning_proto", - visibility = [ - "//tensorflow:internal", - ], -) - -alias( - name = "autotuning_proto_cc", - actual = "//tensorflow/core/protobuf:autotuning_proto_cc", - visibility = [ - "//tensorflow:internal", - ], -) - -alias( - name = "conv_autotuning_proto", - actual = "//tensorflow/core/protobuf:conv_autotuning_proto", - visibility = [ - "//tensorflow:internal", - ], -) - -alias( - name = "conv_autotuning_proto_cc", - actual = "//tensorflow/core/protobuf:conv_autotuning_proto_cc", - visibility = [ - "//tensorflow:internal", - ], -) - -alias( - name = "worker_proto_cc", - actual = "//tensorflow/core/protobuf:worker_proto_cc", - visibility = [ - "//tensorflow:internal", - ], -) - -alias( - name = "worker_service_proto_cc", - actual = "//tensorflow/core/protobuf:worker_service_proto_cc", - visibility = [ - "//tensorflow:internal", - ], -) - -alias( - name = "master_proto_cc", - actual = "//tensorflow/core/protobuf:master_proto_cc", - visibility = [ - "//learning/brain/frameworks/uptc:__subpackages__", - "//tensorflow:internal", - ], -) - -alias( - name = "master_service_proto_cc", - actual = "//tensorflow/core/protobuf:master_service_proto_cc", - visibility = [ - "//tensorflow:internal", - ], -) - -alias( - name = "eager_service_proto_cc", - actual = "//tensorflow/core/protobuf:eager_service_proto_cc", - visibility = [ - "//tensorflow:internal", - ], -) - filegroup( name = "lib_internal_private_headers", srcs = [ @@ -1903,6 +1381,7 @@ cc_library( "//tensorflow/core/platform:denormal", "//tensorflow/core/platform:dynamic_annotations", "//tensorflow/core/platform:env", + "//tensorflow/core/platform:env_impl", "//tensorflow/core/platform:error", "//tensorflow/core/platform:errors", "//tensorflow/core/platform:file_statistics", @@ -2007,16 +1486,6 @@ cc_library( ], ) -alias( - name = "png_internal", - actual = "//tensorflow/core/lib/png:png_io", -) - -alias( - name = "android_png_internal", - actual = "//tensorflow/core/lib/png:png_io", -) - cc_library( name = "tflite_portable_logging", hdrs = [ @@ -2040,8 +1509,8 @@ cc_library( ) cc_library( - name = "android_jpeg_internal", - srcs = if_android([ + name = "portable_jpeg_internal", + srcs = if_mobile([ "lib/jpeg/jpeg_handle.cc", "lib/jpeg/jpeg_mem.cc", "//tensorflow/core/platform:jpeg_hdrs", @@ -2055,7 +1524,7 @@ cc_library( "//tensorflow/core/platform/default:logging.h", ], copts = tf_copts(), - linkopts = ["-ldl"], + linkopts = if_android(["-ldl"]), deps = [ ":core_stringpiece", "//tensorflow/core/platform:dynamic_annotations", @@ -2068,8 +1537,8 @@ cc_library( ) cc_library( - name = "android_gif_internal", - srcs = if_android([ + name = "portable_gif_internal", + srcs = if_mobile([ "lib/gif/gif_io.cc", "//tensorflow/core/platform:gif_hdrs", ]), @@ -2082,7 +1551,7 @@ cc_library( "//tensorflow/core/platform/default:logging.h", ], copts = tf_copts(), - linkopts = ["-ldl"], + linkopts = if_android(["-ldl"]), deps = [ "//tensorflow/core/platform:dynamic_annotations", "//tensorflow/core/platform:gif", @@ -2103,11 +1572,6 @@ alias( actual = "//tensorflow/core/protobuf:error_codes_proto_impl_cc", ) -alias( - name = "error_codes_proto_cc", - actual = "//tensorflow/core/lib/core:error_codes_proto_cc", -) - alias( name = "version_lib", actual = "//tensorflow/core/util:version_info", @@ -2119,6 +1583,7 @@ filegroup( "//tensorflow/core/example:feature_util.h", "//tensorflow/core/framework:framework_internal_private_hdrs", "//tensorflow/core/graph:framework_internal_private_headers", + "//tensorflow/core/public:session_options.h", "//tensorflow/core/util:framework_internal_private_hdrs", "//tensorflow/core/util:memmapped_file_system_hdrs", "//tensorflow/core/util/sparse:framework_internal_private_headers_group", @@ -2177,7 +1642,7 @@ cc_header_only_library( ":lib", ":lib_internal", ":version_lib", - "//tensorflow/core/kernels:bounds_check", + "//tensorflow/core/framework:bounds_check", "//tensorflow/core/platform/default/build_config:platformlib", ], ) @@ -2229,6 +1694,7 @@ tf_cuda_library( "//tensorflow/core/framework:attr_value_proto_text", "//tensorflow/core/framework:attr_value_util", "//tensorflow/core/framework:bfloat16", + "//tensorflow/core/framework:bounds_check", "//tensorflow/core/framework:common_shape_fns", "//tensorflow/core/framework:kernel_shape_util", "//tensorflow/core/framework:node_def_util", @@ -2242,7 +1708,7 @@ tf_cuda_library( "//tensorflow/core/framework:shape_inference", "//tensorflow/core/framework:tensor", "//tensorflow/core/framework:tensor_shape", - "//tensorflow/core/kernels:bounds_check", + "//tensorflow/core/platform:env_impl", "//tensorflow/core/platform/default/build_config:platformlib", "//tensorflow/core/profiler/lib:annotated_traceme", "//tensorflow/core/profiler/lib:traceme", @@ -2284,15 +1750,10 @@ cc_header_only_library( ], visibility = ["//visibility:public"], deps = [ - ":stream_executor", + "//tensorflow/core/platform:stream_executor", ], ) -alias( - name = "stream_executor", - actual = "//tensorflow/core/platform:stream_executor", -) - # Like stream_executor library, but compiles without --config=cuda # and does not include any cuda dependencies. alias( @@ -2344,18 +1805,12 @@ tf_cuda_library( ":function_ops_op_lib", ":functional_grad", ":functional_ops_op_lib", - "//tensorflow/core/kernels:bounds_check", + "//tensorflow/core/framework:bounds_check", "//tensorflow/core/kernels:required", ]), alwayslink = 1, ) -alias( - name = "core_cpu_impl", - actual = - "//tensorflow/core/common_runtime:core_cpu_impl", -) - alias( name = "core_cpu_lib", actual = @@ -2368,18 +1823,6 @@ alias( "//tensorflow/core/common_runtime:core_cpu_internal", ) -alias( - name = "regexp_internal", - actual = - "//tensorflow/core/platform:regexp", - visibility = [ - "//tensorflow/compiler:__subpackages__", - "//tensorflow/core/kernels:__subpackages__", - "//tensorflow/core/profiler:__subpackages__", - "//tensorflow/stream_executor:__subpackages__", - ], -) - alias( name = "direct_session_internal", actual = @@ -2392,14 +1835,6 @@ alias( visibility = ["//visibility:public"], ) -alias( - name = "replay_log_proto_cc", - actual = "//tensorflow/core/protobuf:replay_log_proto_cc", - visibility = [ - "//tensorflow:internal", - ], -) - alias( name = "gpu_runtime", actual = @@ -2423,18 +1858,6 @@ cc_library( ], ) -# TODO(gonnet): Remove this alias once all users have been moved to the actual target. -alias( - name = "tensor_testutil", - actual = "//tensorflow/core/framework:tensor_testutil", -) - -# TODO(gonnet): Remove this alias once all users have been moved to the actual target. -alias( - name = "shape_inference_testutil", - actual = "//tensorflow/core/framework:shape_inference_testutil", -) - # Main program for tests alias( name = "test_main", @@ -2442,14 +1865,6 @@ alias( visibility = ["//tensorflow:internal"], ) -test_suite( - name = "low_level_tests", - tests = [ - ":low_level_library_tests", - "//tensorflow/core/platform:low_level_library_tests", - ], -) - tf_cc_tests( name = "low_level_library_tests", size = "small", @@ -2470,7 +1885,6 @@ tf_cc_tests( "//tensorflow/core/lib/random:legacy_lib_random_tests", "//tensorflow/core/lib/strings:legacy_low_level_library_tests", ], - create_named_test_suite = True, deps = [ ":lib", ":lib_internal", @@ -2506,22 +1920,6 @@ tf_cc_test( ], ) -test_suite( - name = "platform_tests", - tests = [ - "//tensorflow/core/platform:abi_test", - "//tensorflow/core/platform:env_test", - "//tensorflow/core/platform:fake_python_env_test", - "//tensorflow/core/platform:file_system_test", - "//tensorflow/core/platform:numa_test", - "//tensorflow/core/platform:platform_strings_test", - "//tensorflow/core/platform:rocm_rocdl_path_test", - "//tensorflow/core/platform:setround_test", - "//tensorflow/core/platform:unbounded_work_queue_test", - "//tensorflow/core/platform:vmodule_test", - ], -) - tf_cc_test( name = "lib_jpeg_jpeg_mem_unittest", srcs = ["lib/jpeg/jpeg_mem_unittest.cc"], @@ -2574,27 +1972,6 @@ tf_cc_test( ], ) -tf_cc_test( - name = "framework_op_gen_lib_test", - size = "small", - srcs = ["//tensorflow/core/framework:op_gen_lib_test.cc"], - deps = [ - ":protos_all_cc", - ":test", - ":test_main", - "//tensorflow/core/framework:op_gen_lib", - ], -) - -test_suite( - name = "higher_level_tests", - tests = [ - ":core_higher_level_tests", - "//tensorflow/core/framework:higher_level_tests", - "//tensorflow/core/util:higher_level_tests", - ], -) - tf_cc_tests( name = "core_higher_level_tests", size = "small", @@ -2613,7 +1990,6 @@ tf_cc_tests( "//tensorflow/core/graph:validate_test.cc", "//tensorflow/core/util/sparse:higher_level_tests_group", ], - create_named_test_suite = True, linkopts = select({ "//tensorflow:macos": ["-headerpad_max_install_names"], "//conditions:default": [], @@ -2650,22 +2026,6 @@ tf_cc_tests( ], ) -tf_cc_test( - name = "cudnn_rnn_ops_test_cc", - size = "small", - srcs = [ - "ops/cudnn_rnn_ops_test.cc", - ], - deps = [ - ":core", - ":framework", - ":lib", - ":test", - ":test_main", - ":testlib", - ], -) - tf_cc_test_mkl( name = "mkl_related_tests", size = "small", @@ -2742,223 +2102,6 @@ tf_cc_tests_gpu( ], ) -tf_cc_test_gpu( - name = "variant_op_copy_test", - size = "small", - srcs = ["//tensorflow/core/framework:variant_op_copy_test.cc"], - linkstatic = tf_kernel_tests_linkstatic(), - tags = tf_cuda_tests_tags(), - deps = [ - ":core", - ":core_cpu", - ":core_cpu_internal", - ":direct_session", - ":framework", - ":framework_internal", - ":gpu_runtime", - ":lib", - ":lib_internal", - ":protos_all_cc", - ":test", - ":test_main", - ":testlib", - "//tensorflow/cc:cc_ops", - "//tensorflow/cc:client_session", - "//tensorflow/cc:ops", - "//tensorflow/cc:scope", - "//tensorflow/core/kernels:array", - "//third_party/eigen3", - ], -) - -tf_cc_test( - name = "framework_run_handler_util_test", - size = "small", - srcs = ["//tensorflow/core/framework:run_handler_util_test.cc"], - linkstatic = tf_kernel_tests_linkstatic(), - deps = [ - ":framework_internal", - ":lib", - ":test", - ":test_main", - ], -) - -tf_cc_test( - name = "framework_run_handler_test", - size = "small", - srcs = ["//tensorflow/core/framework:run_handler_test.cc"], - linkstatic = tf_kernel_tests_linkstatic(), - deps = [ - ":core_cpu", - ":direct_session_internal", - ":framework_internal", - ":lib", - ":lib_internal", - ":protos_all_cc", - ":test", - ":test_main", - ":testlib", - "//tensorflow/core/framework:tensor_testutil", - "//tensorflow/core/kernels:cwise_op", - "//tensorflow/core/kernels:matmul_op", - "//third_party/eigen3", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/synchronization", - ], -) - -tf_cc_test( - name = "framework_op_segment_test", - size = "small", - srcs = ["//tensorflow/core/framework:op_segment_test.cc"], - linkstatic = tf_kernel_tests_linkstatic(), - deps = [ - ":core", - ":core_cpu", - ":core_cpu_internal", - ":direct_session_internal", - ":framework", - ":framework_internal", - ":lib", - ":lib_internal", - ":ops", - ":protos_all_cc", - ":test", - ":test_main", - ":testlib", - "//tensorflow/cc:cc_ops", - "//tensorflow/core/kernels:cwise_op", - "//tensorflow/core/kernels:ops_util", - "//third_party/eigen3", - ], -) - -tf_cc_test( - name = "ops_array_grad_test", - size = "small", - srcs = ["ops/array_grad_test.cc"], - linkstatic = tf_kernel_tests_linkstatic(), - deps = [ - ":core", - ":core_cpu", - ":core_cpu_internal", - ":direct_session_internal", - ":framework", - ":framework_internal", - ":lib", - ":lib_internal", - ":ops", - ":protos_all_cc", - ":test", - ":test_main", - ":testlib", - "//tensorflow/cc:cc_ops", - "//tensorflow/core/kernels:array", - "//tensorflow/core/kernels:cwise_op", - "//tensorflow/core/kernels:function_ops", - "//tensorflow/core/kernels:math", - "//third_party/eigen3", - ], -) - -tf_cc_test( - name = "ops_math_grad_test", - size = "small", - srcs = ["ops/math_grad_test.cc"], - linkstatic = tf_kernel_tests_linkstatic(), - tags = ["no_gpu"], - deps = [ - ":core", - ":core_cpu", - ":core_cpu_internal", - ":direct_session_internal", - ":framework", - ":framework_internal", - ":lib", - ":lib_internal", - ":ops", - ":protos_all_cc", - ":test", - ":test_main", - ":testlib", - "//tensorflow/cc:cc_ops", - "//tensorflow/core/kernels:array", - "//tensorflow/core/kernels:data_flow", - "//tensorflow/core/kernels:function_ops", - "//tensorflow/core/kernels:math", - "//third_party/eigen3", - ], -) - -tf_cc_test( - name = "ops_remote_fused_graph_ops_test", - size = "small", - srcs = ["ops/remote_fused_graph_ops_test.cc"], - linkstatic = tf_kernel_tests_linkstatic(), - deps = [ - ":core", - ":core_cpu", - ":core_cpu_internal", - ":framework", - ":framework_internal", - ":lib", - ":lib_internal", - ":ops", - ":protos_all_cc", - ":test", - ":test_main", - ":testlib", - "//tensorflow/core/kernels:remote_fused_graph_ops", - ], -) - -tf_cc_test( - name = "ops_tests", - size = "small", - srcs = [ - "ops/array_ops_test.cc", - "ops/candidate_sampling_ops_test.cc", - "ops/control_flow_ops_test.cc", - "ops/ctc_ops_test.cc", - "ops/data_flow_ops_test.cc", - "ops/functional_ops_test.cc", - "ops/image_ops_test.cc", - "ops/io_ops_test.cc", - "ops/linalg_ops_test.cc", - "ops/math_ops_test.cc", - "ops/nn_ops_test.cc", - "ops/parsing_ops_test.cc", - "ops/random_ops_test.cc", - "ops/rnn_ops_test.cc", - "ops/set_ops_test.cc", - "ops/shape_function_test.cc", - "ops/sparse_csr_matrix_ops_test.cc", - "ops/sparse_ops_test.cc", - "ops/spectral_ops_test.cc", - "ops/state_ops_test.cc", - "ops/string_ops_test.cc", - "ops/training_ops_test.cc", - ], - linkstatic = tf_kernel_tests_linkstatic(), - deps = [ - ":core", - ":core_cpu", - ":core_cpu_internal", - ":framework", - ":framework_internal", - ":lib", - ":lib_internal", - ":ops", - ":protos_all_cc", - ":test", - ":test_main", - ":testlib", - "//tensorflow/cc:cc_ops", - "//third_party/eigen3", - ], -) - # Test data filegroup( name = "image_testdata", @@ -2988,50 +2131,11 @@ filegroup( # GIF data with optimization "lib/gif/testdata/optimized.gif", # BMP data - "lib/bmp/testdata/lena.bmp", - "lib/bmp/testdata/rgb_small.bmp", - "lib/bmp/testdata/rgb_small_255.bmp", - "lib/bmp/testdata/rgba_small.bmp", - "lib/bmp/testdata/rgba_small_255.bmp", - "lib/bmp/testdata/grayscale_small.bmp", - "lib/bmp/testdata/grayscale_small_3channels.bmp", - "lib/bmp/testdata/grayscale_small_4channels.bmp", + "//tensorflow/core/lib/bmp:bmp_testdata", ], visibility = ["//visibility:public"], ) -filegroup( - name = "lmdb_testdata", - testonly = 1, - srcs = [ - # A simple key-value store: - # 0 : 'b' - # 1 : 'b' - # ... - # 9 : 'b' - # Which is then overwritten with: - # 0 : 'a' - # 1 : 'b' - # ... - # 9 : 'j' - "lib/lmdb/testdata/data.mdb", - # LMDB, being a memory-mapped database, uses a different file format on - # big-endian systems. - "lib/lmdb/testdata/data_bigendian.mdb", - ], - visibility = ["//visibility:public"], -) - -alias( - name = "cuda_libdevice_path", - actual = "//tensorflow/core/platform:cuda_libdevice_path", -) - -# Normalize CORE_PROTO_SRCS to generate valid output file names. -PORTABLE_PROTO_HEADERS_OUT = tf_android_core_proto_headers(CORE_PROTO_SRCS) + [ - "//google/protobuf/any.proto.h", -] - transitive_hdrs( name = "headers", visibility = ["//tensorflow:__subpackages__"], @@ -3040,7 +2144,7 @@ transitive_hdrs( ":framework", ":lib", ":protos_all_cc", - ":stream_executor", "//tensorflow/core/platform:platform_strings", + "//tensorflow/core/platform:stream_executor", ], ) diff --git a/tensorflow/core/api_def/BUILD b/tensorflow/core/api_def/BUILD index dfa0b78cb17..f9e2adaec6b 100644 --- a/tensorflow/core/api_def/BUILD +++ b/tensorflow/core/api_def/BUILD @@ -6,6 +6,8 @@ # :python_api_def # :java_api_def +load("//tensorflow:tensorflow.bzl", "filegroup") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load( "//tensorflow:tensorflow.bzl", "tf_cc_binary", @@ -37,9 +39,9 @@ filegroup( visibility = ["//tensorflow:internal"], ) -filegroup( +alias( name = "java_api_def", - srcs = glob(["java_api/*"]), + actual = "//tensorflow/core/api_def/java_api:java_api_def", visibility = ["//tensorflow:internal"], ) diff --git a/tensorflow/core/api_def/README.md b/tensorflow/core/api_def/README.md new file mode 100644 index 00000000000..76232442e8d --- /dev/null +++ b/tensorflow/core/api_def/README.md @@ -0,0 +1,4 @@ +This folder contains the ApiDef proto definitions of TensorFlow operations. + +The canonical source of documentation for these operations can be found in +the base_api/ directory. diff --git a/tensorflow/core/api_def/base_api/api_def_Acos.pbtxt b/tensorflow/core/api_def/base_api/api_def_Acos.pbtxt index 2184b644b23..dc018aec4aa 100644 --- a/tensorflow/core/api_def/base_api/api_def_Acos.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_Acos.pbtxt @@ -1,4 +1,11 @@ op { graph_op_name: "Acos" summary: "Computes acos of x element-wise." + description: <

See also: config.proto + * * @param value a serialized config proto - * @see */ public Options config(byte[] value) { config = value; diff --git a/tensorflow/java/src/main/java/org/tensorflow/NativeLibrary.java b/tensorflow/java/src/main/java/org/tensorflow/NativeLibrary.java index 3c9a678cf56..97de99cb75e 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/NativeLibrary.java +++ b/tensorflow/java/src/main/java/org/tensorflow/NativeLibrary.java @@ -51,7 +51,7 @@ final class NativeLibrary { // (1) The native library has already been statically loaded, OR // (2) The required native code has been statically linked (through a custom launcher), OR // (3) The native code is part of another library (such as an application-level library) - // that has already been loaded. For example, tensorflow/examples/android and + // that has already been loaded. For example, tensorflow/tools/android/test and // tensorflow/tools/android/inference_interface include the required native code in // differently named libraries. // diff --git a/tensorflow/lite/BUILD b/tensorflow/lite/BUILD index e80e32fe6cf..597f81194cd 100644 --- a/tensorflow/lite/BUILD +++ b/tensorflow/lite/BUILD @@ -1,6 +1,7 @@ load("//tensorflow:tensorflow.bzl", "if_not_windows", "tf_cc_test") -load("//tensorflow/lite:build_def.bzl", "if_tflite_experimental_runtime", "tflite_cc_shared_object", "tflite_copts", "tflite_experimental_runtime_linkopts") +load("//tensorflow/lite:build_def.bzl", "tflite_cc_shared_object", "tflite_copts") load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite") +load("//tensorflow:tensorflow.bzl", "get_compatible_with_portable") package( default_visibility = ["//visibility:public"], @@ -15,13 +16,6 @@ exports_files(glob([ "models/testdata/*", ])) -config_setting( - name = "enable_default_profiler", - values = { - "copt": "-DTFLITE_ENABLE_DEFAULT_PROFILER", - }, -) - config_setting( name = "gemmlowp_profiling", values = { @@ -43,18 +37,6 @@ config_setting( }, ) -config_setting( - name = "tflite_experimental_runtime_eager", - values = {"define": "tflite_experimental_runtime=eager"}, - visibility = ["//visibility:public"], -) - -config_setting( - name = "tflite_experimental_runtime_non_eager", - values = {"define": "tflite_experimental_runtime=non-eager"}, - visibility = ["//visibility:public"], -) - config_setting( name = "tf_lite_static_memory", values = { @@ -90,6 +72,7 @@ FRAMEWORK_LIB_HDRS = [ cc_library( name = "version", hdrs = ["version.h"], + compatible_with = get_compatible_with_portable(), copts = TFLITE_DEFAULT_COPTS, # Note that we only use the header defines from :version_lib. deps = ["//tensorflow/core:version_lib"], @@ -107,6 +90,7 @@ cc_library( name = "arena_planner", srcs = ["arena_planner.cc"], hdrs = ["arena_planner.h"], + compatible_with = get_compatible_with_portable(), copts = TFLITE_DEFAULT_COPTS, deps = [ ":graph_info", @@ -137,6 +121,7 @@ cc_test( cc_library( name = "context", hdrs = ["context.h"], + compatible_with = get_compatible_with_portable(), copts = TFLITE_DEFAULT_COPTS, deps = ["//tensorflow/lite/c:common"], ) @@ -145,6 +130,7 @@ cc_library( name = "external_cpu_backend_context", srcs = ["external_cpu_backend_context.cc"], hdrs = ["external_cpu_backend_context.h"], + compatible_with = get_compatible_with_portable(), copts = TFLITE_DEFAULT_COPTS, deps = [ "//tensorflow/lite/c:common", @@ -154,6 +140,7 @@ cc_library( cc_library( name = "graph_info", hdrs = ["graph_info.h"], + compatible_with = get_compatible_with_portable(), copts = TFLITE_DEFAULT_COPTS, deps = ["//tensorflow/lite/c:common"], ) @@ -161,6 +148,7 @@ cc_library( cc_library( name = "memory_planner", hdrs = ["memory_planner.h"], + compatible_with = get_compatible_with_portable(), copts = TFLITE_DEFAULT_COPTS, deps = ["//tensorflow/lite/c:common"], ) @@ -169,6 +157,7 @@ cc_library( name = "simple_memory_arena", srcs = ["simple_memory_arena.cc"], hdrs = ["simple_memory_arena.h"], + compatible_with = get_compatible_with_portable(), copts = TFLITE_DEFAULT_COPTS, deps = ["//tensorflow/lite/c:common"], ) @@ -188,9 +177,16 @@ cc_library( "builtin_ops.h", "context_util.h", ], + compatible_with = get_compatible_with_portable(), deps = ["//tensorflow/lite/c:common"], ) +cc_library( + name = "builtin_ops", + hdrs = ["builtin_ops.h"], + compatible_with = get_compatible_with_portable(), +) + exports_files(["builtin_ops.h"]) cc_library( @@ -198,6 +194,7 @@ cc_library( hdrs = [ "string_type.h", ], + compatible_with = get_compatible_with_portable(), copts = TFLITE_DEFAULT_COPTS, ) @@ -219,6 +216,7 @@ cc_library( hdrs = [ "allocation.h", ], + compatible_with = get_compatible_with_portable(), copts = TFLITE_DEFAULT_COPTS, deps = [ ":string", @@ -235,11 +233,10 @@ cc_library( "interpreter.cc", "interpreter_builder.cc", "model_builder.cc", - "mutable_op_resolver.cc", "optional_debug_tools.cc", - "stderr_reporter.cc", ], hdrs = FRAMEWORK_LIB_HDRS, + compatible_with = get_compatible_with_portable(), copts = tflite_copts() + TFLITE_DEFAULT_COPTS, visibility = [ "//tensorflow/lite:__subpackages__", @@ -249,28 +246,29 @@ cc_library( ":arena_planner", ":external_cpu_backend_context", ":graph_info", + ":kernel_api", ":memory_planner", ":minimal_logging", + ":mutable_op_resolver", ":shared_library", ":simple_memory_arena", + ":stderr_reporter", ":string", ":type_to_tflitetype", ":util", ":version", "//tensorflow/lite/c:common", "//tensorflow/lite/core/api", + "//tensorflow/lite/core/api:verifier", "//tensorflow/lite/delegates:status", "//tensorflow/lite/delegates/nnapi:nnapi_delegate", "//tensorflow/lite/experimental/resource", "//tensorflow/lite/kernels/internal:compatibility", "//tensorflow/lite/nnapi:nnapi_implementation", + "//tensorflow/lite/profiling:platform_profiler", "//tensorflow/lite/schema:schema_fbs", - ] + select({ - ":enable_default_profiler": [ - "//tensorflow/lite/profiling:platform_profiler", - ], - "//conditions:default": [], - }), + "//tensorflow/lite/schema:schema_utils", + ], alwayslink = 1, ) @@ -280,17 +278,13 @@ cc_library( srcs = [ ], hdrs = FRAMEWORK_LIB_HDRS, + compatible_with = get_compatible_with_portable(), copts = tflite_copts() + TFLITE_DEFAULT_COPTS, - defines = if_tflite_experimental_runtime( - if_eager = ["TFLITE_EXPERIMENTAL_RUNTIME_EAGER"], - if_non_eager = ["TFLITE_EXPERIMENTAL_RUNTIME_NON_EAGER"], - if_none = [], - ), deps = [ - ":framework_lib", ":allocation", ":arena_planner", ":external_cpu_backend_context", + ":framework_lib", ":graph_info", ":memory_planner", ":minimal_logging", @@ -301,17 +295,79 @@ cc_library( ":version", "//tensorflow/lite/c:common", "//tensorflow/lite/core/api", + "//tensorflow/lite/core/api:verifier", "//tensorflow/lite/delegates/nnapi:nnapi_delegate", "//tensorflow/lite/experimental/resource", "//tensorflow/lite/nnapi:nnapi_implementation", "//tensorflow/lite/schema:schema_fbs", - ] + tflite_experimental_runtime_linkopts(), + ], +) + +cc_library( + name = "error_reporter", + hdrs = ["error_reporter.h"], + compatible_with = get_compatible_with_portable(), + copts = tflite_copts() + TFLITE_DEFAULT_COPTS, + visibility = [ + "//visibility:public", + ], + deps = [ + "//tensorflow/lite:stderr_reporter", + "//tensorflow/lite/core/api:error_reporter", + ], +) + +cc_library( + name = "stderr_reporter", + srcs = ["stderr_reporter.cc"], + hdrs = ["stderr_reporter.h"], + compatible_with = get_compatible_with_portable(), + copts = tflite_copts() + TFLITE_DEFAULT_COPTS, + visibility = [ + "//visibility:public", + ], + deps = [ + ":minimal_logging", + "//tensorflow/lite/c:common", + "//tensorflow/lite/core/api:error_reporter", + ], +) + +cc_library( + name = "op_resolver", + hdrs = ["op_resolver.h"], + compatible_with = get_compatible_with_portable(), + copts = tflite_copts() + TFLITE_DEFAULT_COPTS, + visibility = [ + "//visibility:public", + ], + deps = [ + "//tensorflow/lite:mutable_op_resolver", + "//tensorflow/lite/core/api:op_resolver", + ], +) + +cc_library( + name = "mutable_op_resolver", + srcs = ["mutable_op_resolver.cc"], + hdrs = ["mutable_op_resolver.h"], + compatible_with = get_compatible_with_portable(), + copts = tflite_copts() + TFLITE_DEFAULT_COPTS, + visibility = [ + "//visibility:public", + ], + deps = [ + ":util", + "//tensorflow/lite/core/api:op_resolver", + "//tensorflow/lite/schema:schema_fbs", + ], ) cc_library( name = "string_util", srcs = ["string_util.cc"], hdrs = ["string_util.h"], + compatible_with = get_compatible_with_portable(), copts = TFLITE_DEFAULT_COPTS, deps = [ ":string", @@ -356,6 +412,7 @@ cc_library( cc_library( name = "tflite_with_xnnpack_default", + compatible_with = get_compatible_with_portable(), visibility = ["//visibility:private"], # TODO(b/151246885): put ":tflite_with_xnnpack_enabled" to macos/windows # once we have a good testing coverage on these two platforms. @@ -373,6 +430,7 @@ cc_library( "core/macros.h", "tflite_with_xnnpack_optional.h", ], + compatible_with = get_compatible_with_portable(), copts = tflite_copts() + TFLITE_DEFAULT_COPTS, deps = [ "//tensorflow/lite/c:common", @@ -478,8 +536,10 @@ cc_test( data = [ "testdata/0_subgraphs.bin", "testdata/2_subgraphs.bin", + "testdata/add_shared_tensors.bin", "testdata/empty_model.bin", "testdata/multi_add_flex.bin", + "testdata/segment_sum_invalid_buffer.bin", "testdata/sparse_tensor.bin", "testdata/test_min_runtime.bin", "testdata/test_model.bin", @@ -564,10 +624,20 @@ cc_test( ], ) +cc_test( + name = "stderr_reporter_test", + srcs = ["stderr_reporter_test.cc"], + deps = [ + ":stderr_reporter", + "@com_google_googletest//:gtest_main", + ], +) + cc_library( name = "util", srcs = ["util.cc"], hdrs = ["util.h"], + compatible_with = get_compatible_with_portable(), copts = TFLITE_DEFAULT_COPTS + tflite_copts(), deps = [ ":kernel_api", @@ -611,6 +681,7 @@ cc_library( ], }), hdrs = ["minimal_logging.h"], + compatible_with = get_compatible_with_portable(), copts = TFLITE_DEFAULT_COPTS + tflite_copts(), linkopts = select({ "//tensorflow:android": ["-llog"], @@ -631,6 +702,7 @@ cc_library( "type_to_tflitetype.h", ], }), + compatible_with = get_compatible_with_portable(), deps = ["//tensorflow/lite/c:common"], ) @@ -660,6 +732,7 @@ cc_test( cc_library( name = "shared_library", hdrs = ["shared_library.h"], + compatible_with = get_compatible_with_portable(), linkopts = if_not_windows(["-ldl"]), ) @@ -668,6 +741,13 @@ cc_library( hdrs = ["core/macros.h"], ) +cc_library( + name = "stateful_error_reporter", + hdrs = ["stateful_error_reporter.h"], + compatible_with = get_compatible_with_portable(), + deps = ["//tensorflow/lite/core/api"], +) + # Shared lib target for convenience, pulls in the core runtime and builtin ops. # Note: This target is not yet finalized, and the exact set of exported (C/C++) # APIs is subject to change. The output library name is platform dependent: diff --git a/tensorflow/lite/CMakeLists.txt b/tensorflow/lite/CMakeLists.txt index cfd8ebfc141..a75728e8a9d 100644 --- a/tensorflow/lite/CMakeLists.txt +++ b/tensorflow/lite/CMakeLists.txt @@ -20,8 +20,6 @@ # This has only been tested on Windows, Linux and macOS. # # The following are not currently supported: -# - GPU acceleration -# - Android # - iOS # - Micro backend # - Tests @@ -38,25 +36,35 @@ set(TENSORFLOW_SOURCE_DIR "" CACHE PATH "Directory that contains the TensorFlow project" ) if(NOT TENSORFLOW_SOURCE_DIR) - set(TENSORFLOW_SOURCE_DIR "${CMAKE_SOURCE_DIR}/../../") + get_filename_component(TENSORFLOW_SOURCE_DIR + "${CMAKE_CURRENT_LIST_DIR}/../../" + ABSOLUTE + ) endif() set(TF_SOURCE_DIR "${TENSORFLOW_SOURCE_DIR}/tensorflow") -set(TFLITE_SOURCE_DIR "${CMAKE_SOURCE_DIR}") -set(CMAKE_MODULE_PATH "${TFLITE_SOURCE_DIR}/tools/cmake/modules" ${CMAKE_MODULE_PATH}) -set(CMAKE_PREFIX_PATH "${TFLITE_SOURCE_DIR}/tools/cmake/modules" ${CMAKE_PREFIX_PATH}) +set(TFLITE_SOURCE_DIR "${CMAKE_CURRENT_LIST_DIR}") +set(CMAKE_MODULE_PATH + "${TFLITE_SOURCE_DIR}/tools/cmake/modules" + ${CMAKE_MODULE_PATH} +) +set(CMAKE_PREFIX_PATH + "${TFLITE_SOURCE_DIR}/tools/cmake/modules" + ${CMAKE_PREFIX_PATH} +) +# b/168750039: To workaround absl module not found error on Android build. +set(absl_DIR ${CMAKE_MODULE_PATH}) option(TFLITE_ENABLE_RUY "Enable experimental RUY integration" OFF) option(TFLITE_ENABLE_RESOURCE "Enable experimental support for resources" ON) option(TFLITE_ENABLE_NNAPI "Enable NNAPI (Android only)." ON) option(TFLITE_ENABLE_MMAP "Enable MMAP (unsupported on Windows)" ON) -option(TFLITE_ENABLE_GPU "Enable GPU (not supported)" OFF) +option(TFLITE_ENABLE_GPU "Enable GPU" OFF) # This must be enabled when converting from TF models with SELECT_TF_OPS # enabled. # https://www.tensorflow.org/lite/guide/ops_select#converting_the_model # This is currently not supported. option(TFLITE_ENABLE_FLEX "Enable SELECT_TF_OPS" OFF) # TODO: Add support -option(TFLITE_ENABLE_XNNPACK "Enable XNNPACK backend" OFF) # TODO: Add XNNPACK -option(TFLITE_ENABLE_PROFILING "Enable profiling" OFF) +option(TFLITE_ENABLE_XNNPACK "Enable XNNPACK backend" ON) set(CMAKE_CXX_STANDARD 14) # Some components require C++14. set(CMAKE_CXX_STANDARD_REQUIRED ON) set(_TFLITE_ENABLE_NNAPI "${TFLITE_ENABLE_NNAPI}") @@ -120,40 +128,9 @@ find_package(gemmlowp REQUIRED) find_package(neon2sse REQUIRED) find_package(ruy REQUIRED) # Generate TensorFlow Lite FlatBuffer code. -# This is not currently neccessary since the generated code is checked into -# the repository but it would likely be preferable to do this in future. -# NOTE: This will not work for cross compilation (e.g for iOS, Android etc.) -# as flatc needs to be compiled with the host toolchain and this currently -# builds with the target toolchain. Instead this should recursively call -# cmake with the default host toolchain to build flatc. -set(TFLITE_FLATBUFFERS_SCHEMAS "${TFLITE_SOURCE_DIR}/schema/schema.fbs") -set(TFLITE_FLATBUFFERS_GEN_DIR - "${CMAKE_BINARY_DIR}/flatbuffers_generated/" -) -set(TFLITE_FLATBUFFERS_HDRS "") -foreach(INPUT_SCHEMA ${TFLITE_FLATBUFFERS_SCHEMAS}) - file(RELATIVE_PATH FILENAME "${TENSORFLOW_SOURCE_DIR}" "${INPUT_SCHEMA}") - get_filename_component(OUTPUT_DIR - "${TFLITE_FLATBUFFERS_GEN_DIR}/${FILENAME}" DIRECTORY - ) - get_filename_component(OUTPUT_BASENAME - "${FILENAME}" NAME_WE - ) - set(OUTPUT_FILENAME "${OUTPUT_DIR}/${OUTPUT_BASENAME}_generated.h") - list(APPEND TFLITE_FLATBUFFERS_HDRS "${OUTPUT_FILENAME}") - add_custom_command( - OUTPUT "${OUTPUT_FILENAME}" - COMMAND flatc - --cpp - --gen-mutable - --gen-object-api - --reflect-names - -I "${TENSORFLOW_SOURCE_DIR}" - -o "${OUTPUT_DIR}" - "${INPUT_SCHEMA}" - DEPENDS - "${INPUT_SCHEMA}") -endforeach() +# We used to have an actual compilation logic with flatc but decided to use +# schema_generated.h since flatc doesn't work with cross compilation. +set(TFLITE_FLATBUFFERS_SCHEMA_DIR "${TFLITE_SOURCE_DIR}/schema") set(TF_TARGET_PRIVATE_OPTIONS "") if(CMAKE_CXX_COMPILER_ID MATCHES "Clang$") # TensorFlow uses a heap of deprecated proto fields so surpress these @@ -175,6 +152,16 @@ if(CMAKE_SYSTEM_NAME MATCHES "Windows") # use of std::min std::max. # Use NOGDI to ERROR macro which breaks TensorFlow logging. list(APPEND TFLITE_TARGET_PRIVATE_OPTIONS "-DNOMINMAX" "-DNOGDI") + # lite/kernels/conv.cc has more than 64k sections so enable /bigobj to + # support compilation with MSVC2015. + if(MSVC) + list(APPEND TFLITE_TARGET_PRIVATE_OPTIONS "/bigobj") + elseif(CMAKE_COMPILER_IS_GNUCXX) + list(APPEND TFLITE_TARGET_PRIVATE_OPTIONS "-Wa,-mbig-obj") + endif() +endif() +if(CMAKE_SYSTEM_NAME MATCHES "Android") + find_library(ANDROID_LOG_LIB log) endif() # Build a list of source files to compile into the TF Lite library. populate_tflite_source_vars("." TFLITE_SRCS) @@ -203,9 +190,60 @@ if(TFLITE_ENABLE_FLEX) ) endif() if(TFLITE_ENABLE_GPU) - # Implementation is under delegates/gpu. - message(FATAL_ERROR - "GPU acceleration is not currently supported in CMake builds" + find_package(opencl_headers REQUIRED) + find_package(vulkan_headers REQUIRED) + populate_tflite_source_vars( + "delegates/gpu/cl" TFLITE_DELEGATES_GPU_CL_SRCS + FILTER "(_test|gl_interop|egl_sync)\\.(cc|h)$" + ) + populate_tflite_source_vars( + "delegates/gpu/cl/kernels" TFLITE_DELEGATES_GPU_CL_KERNELS_SRCS + FILTER "(_test)\\.(cc|h)$" + ) + populate_tflite_source_vars( + "delegates/gpu/cl/kernels/special" + TFLITE_DELEGATES_GPU_CL_KERNELS_SPECIAL_SRCS + FILTER "(_test)\\.(cc|h)$" + ) + populate_tflite_source_vars( + "delegates/gpu/cl/selectors" TFLITE_DELEGATES_GPU_CL_SELECTORS_SRCS + FILTER "(_test)\\.(cc|h)$" + ) + populate_tflite_source_vars( + "delegates/gpu/common" TFLITE_DELEGATES_GPU_COMMON_SRCS + FILTER "(_test)\\.(cc|h)$" + ) + populate_tflite_source_vars( + "delegates/gpu/common/default" TFLITE_DELEGATES_GPU_COMMON_DEFAULT_SRCS + FILTER "(_test)\\.(cc|h)$" + ) + populate_tflite_source_vars( + "delegates/gpu/common/memory_management" + TFLITE_DELEGATES_GPU_COMMON_MEMORY_MANAGEMENT_SRCS + FILTER "(_test)\\.(cc|h)$" + ) + populate_tflite_source_vars( + "delegates/gpu/common/transformations" + TFLITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_SRCS + FILTER "(_test)\\.(cc|h)$" + ) + list(APPEND TFLITE_DELEGATES_GPU_SRCS + ${TFLITE_SOURCE_DIR}/delegates/gpu/api.cc + ${TFLITE_SOURCE_DIR}/delegates/gpu/delegate.cc + ${TFLITE_DELEGATES_GPU_CL_SRCS} + ${TFLITE_DELEGATES_GPU_CL_KERNELS_SRCS} + ${TFLITE_DELEGATES_GPU_CL_KERNELS_SPECIAL_SRCS} + ${TFLITE_DELEGATES_GPU_CL_SELECTORS_SRCS} + ${TFLITE_SOURCE_DIR}/delegates/gpu/cl/selectors/default/default_selector.cc + ${TFLITE_DELEGATES_GPU_COMMON_SRCS} + ${TFLITE_DELEGATES_GPU_COMMON_DEFAULT_SRCS} + ${TFLITE_DELEGATES_GPU_COMMON_MEMORY_MANAGEMENT_SRCS} + ${TFLITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_SRCS} + ) + list(APPEND TFLITE_TARGET_PUBLIC_OPTIONS "-DCL_DELEGATE_NO_GL" "-DEGL_NO_X11") + list(APPEND TFLITE_TARGET_DEPENDENCIES + absl::any + absl::flat_hash_map ) endif() if(_TFLITE_ENABLE_NNAPI) @@ -225,8 +263,13 @@ else() ) endif() if(TFLITE_ENABLE_XNNPACK) + find_package(xnnpack REQUIRED) populate_tflite_source_vars("delegates/xnnpack" TFLITE_DELEGATES_XNNPACK_SRCS + FILTER ".*(_test|_tester)\\.(cc|h)" + ) + list(APPEND TFLITE_TARGET_DEPENDENCIES + XNNPACK ) endif() if (TFLITE_ENABLE_RESOURCE) @@ -270,31 +313,30 @@ populate_tflite_source_vars("kernels/internal/reference/integer_ops" populate_tflite_source_vars("kernels/internal/reference/sparse_ops" TFLITE_KERNEL_INTERNAL_REF_SPARSE_OPS_SRCS ) -if(TFLITE_ENABLE_PROFILING) - populate_tflite_source_vars("profiling" TFLITE_KERNEL_PROFILING_SRCS) -endif() -populate_tflite_source_vars("tools/optimize" TFLITE_TOOLS_OPTIMIZE_SRCS) -populate_tflite_source_vars("tools/optimize/calibration" - TFLITE_TOOLS_OPTIMIZE_CALIBRATION_SRCS + +# Common include directories +set(TFLITE_INCLUDE_DIRS + "${TENSORFLOW_SOURCE_DIR}" + "${TFLITE_FLATBUFFERS_SCHEMA_DIR}" ) -populate_tflite_source_vars("tools/optimize/calibration/builtin_logging_ops" - TFLITE_TOOLS_OPTIMIZE_CALIBRATION_OPS_SRCS +include_directories( + BEFORE + ${TFLITE_INCLUDE_DIRS} ) -populate_tflite_source_vars("tools/optimize/sparsity" - TFLITE_TOOLS_OPTIMIZE_SPARSITY_SRCS -) -add_library(tensorflowlite + +# TFLite library +add_library(tensorflow-lite ${TFLITE_CORE_API_SRCS} ${TFLITE_CORE_SRCS} ${TFLITE_C_SRCS} ${TFLITE_DELEGATES_FLEX_SRCS} + ${TFLITE_DELEGATES_GPU_SRCS} ${TFLITE_DELEGATES_NNAPI_SRCS} ${TFLITE_DELEGATES_SRCS} ${TFLITE_DELEGATES_XNNPACK_SRCS} ${TFLITE_EXPERIMENTAL_RESOURCE_SRCS} ${TFLITE_EXPERIMENTAL_RUY_PROFILER_SRCS} ${TFLITE_EXPERIMENTAL_RUY_SRCS} - ${TFLITE_FLATBUFFERS_HDRS} ${TFLITE_KERNEL_INTERNAL_OPT_INTEGER_OPS_SRCS} ${TFLITE_KERNEL_INTERNAL_OPT_SPARSE_OPS_SRCS} ${TFLITE_KERNEL_INTERNAL_OPT_SRCS} @@ -302,16 +344,18 @@ add_library(tensorflowlite ${TFLITE_KERNEL_INTERNAL_REF_SPARSE_OPS_SRCS} ${TFLITE_KERNEL_INTERNAL_REF_SRCS} ${TFLITE_KERNEL_INTERNAL_SRCS} - ${TFLITE_KERNEL_PROFILING_SRCS} ${TFLITE_KERNEL_SRCS} ${TFLITE_NNAPI_SRCS} ${TFLITE_SRCS} - ${TFLITE_TOOLS_OPTIMIZE_CALIBRATION_OPS_SRCS} - ${TFLITE_TOOLS_OPTIMIZE_CALIBRATION_SRCS} - ${TFLITE_TOOLS_OPTIMIZE_SPARSITY_SRCS} - ${TFLITE_TOOLS_OPTIMIZE_SRCS} + ${TFLITE_SOURCE_DIR}/profiling/platform_profiler.cc + ${TFLITE_SOURCE_DIR}/schema/schema_utils.cc + ${TFLITE_SOURCE_DIR}/tools/optimize/sparsity/format_converter.cc ) -target_link_libraries(tensorflowlite +target_include_directories(tensorflow-lite + PUBLIC + ${TFLITE_INCLUDE_DIRS} +) +target_link_libraries(tensorflow-lite PUBLIC Eigen3::Eigen NEON_2_SSE @@ -328,14 +372,80 @@ target_link_libraries(tensorflowlite ruy ${TFLITE_TARGET_DEPENDENCIES} ) -target_include_directories(tensorflowlite - PUBLIC - "${TENSORFLOW_SOURCE_DIR}" - PRIVATE - "${TFLITE_FLATBUFFERS_GEN_DIR}" -) -target_compile_options(tensorflowlite +target_compile_options(tensorflow-lite PUBLIC ${TFLITE_TARGET_PUBLIC_OPTIONS} PRIVATE ${TFLITE_TARGET_PRIVATE_OPTIONS} ) -add_library(tensorflow::tensorflowlite ALIAS tensorflowlite) +add_library(tensorflow::tensorflowlite ALIAS tensorflow-lite) + +# Benchmark Tool +populate_source_vars("${TFLITE_SOURCE_DIR}/tools/benchmark" + TFLITE_BENCHMARK_SRCS + FILTER "(_test|_plus_flex_main|_performance_options.*)\\.cc$" +) +list(APPEND TFLITE_BENCHMARK_SRCS + ${TF_SOURCE_DIR}/core/util/stats_calculator.cc + ${TFLITE_SOURCE_DIR}/profiling/memory_info.cc + ${TFLITE_SOURCE_DIR}/profiling/platform_profiler.cc + ${TFLITE_SOURCE_DIR}/profiling/profile_summarizer.cc + ${TFLITE_SOURCE_DIR}/profiling/profile_summary_formatter.cc + ${TFLITE_SOURCE_DIR}/profiling/time.cc + ${TFLITE_SOURCE_DIR}/tools/command_line_flags.cc + ${TFLITE_SOURCE_DIR}/tools/delegates/default_execution_provider.cc + ${TFLITE_SOURCE_DIR}/tools/evaluation/utils.cc + ${TFLITE_SOURCE_DIR}/tools/optimize/sparsity/format_converter.cc + ${TFLITE_SOURCE_DIR}/tools/tool_params.cc +) + +list(APPEND TFLITE_BENCHMARK_LIBS + tensorflow-lite + ${CMAKE_DL_LIBS} +) + +# TODO(b/171007016): Enable performance options on Windows. +if(NOT "${CMAKE_SYSTEM_NAME}" STREQUAL "Windows") + list(APPEND TFLITE_BENCHMARK_SRCS + ${TFLITE_SOURCE_DIR}/tools/benchmark/benchmark_performance_options.cc + ) +endif() + +if(TFLITE_ENABLE_XNNPACK) + list(APPEND TFLITE_BENCHMARK_SRCS + ${TFLITE_SOURCE_DIR}/tools/delegates/xnnpack_delegate_provider.cc + ) +else() + set(TFLITE_BENCHMARK_CC_OPTIONS "-DTFLITE_WITHOUT_XNNPACK") +endif() # TFLITE_ENABLE_XNNPACK + +if(CMAKE_SYSTEM_NAME MATCHES "Android") + list(APPEND TFLITE_BENCHMARK_SRCS + ${TFLITE_SOURCE_DIR}/profiling/atrace_profiler.cc + ) + if(_TFLITE_ENABLE_NNAPI) + list(APPEND TFLITE_BENCHMARK_SRCS + ${TFLITE_SOURCE_DIR}/tools/delegates/nnapi_delegate_provider.cc + ) + endif() # _TFLITE_ENABLE_NNAPI + list(APPEND TFLITE_BENCHMARK_LIBS + ${ANDROID_LOG_LIB} + absl::strings + ) +endif() # Android + +if(TFLITE_ENABLE_GPU) + list(APPEND TFLITE_BENCHMARK_SRCS + ${TFLITE_SOURCE_DIR}/tools/delegates/gpu_delegate_provider.cc + ) +endif() # TFLITE_ENABLE_GPU + +add_executable(benchmark_model + EXCLUDE_FROM_ALL + ${TFLITE_BENCHMARK_SRCS} +) +target_compile_options(benchmark_model + PRIVATE + ${TFLITE_BENCHMARK_CC_OPTIONS} +) +target_link_libraries(benchmark_model + ${TFLITE_BENCHMARK_LIBS} +) diff --git a/tensorflow/lite/arena_planner.cc b/tensorflow/lite/arena_planner.cc index dd5e3777fc1..b134a5de044 100644 --- a/tensorflow/lite/arena_planner.cc +++ b/tensorflow/lite/arena_planner.cc @@ -140,7 +140,7 @@ TfLiteStatus ArenaPlanner::PlanAllocations() { } // Count references to node input tensors. - for (size_t i = 0; i < graph_info_->num_nodes(); ++i) { + for (size_t i = 0; i < graph_info_->num_execution_nodes(); ++i) { const TfLiteNode& node = graph_info_->node(i); TfLiteIntArray* node_inputs = node.inputs; for (int j = 0; j < node_inputs->size; ++j) { @@ -158,7 +158,7 @@ TfLiteStatus ArenaPlanner::PlanAllocations() { } } // Go through the graph in execution order. - for (size_t i = 0; i < graph_info_->num_nodes(); ++i) { + for (size_t i = 0; i < graph_info_->num_execution_nodes(); ++i) { const TfLiteNode& node = graph_info_->node(i); // First queue output tensors for allocation. @@ -197,8 +197,8 @@ TfLiteStatus ArenaPlanner::ExecuteAllocations(int first_node, int last_node) { dealloc_node_.resize(graph_info_->num_tensors(), kNodeNotAssigned); allocs_.resize(graph_info_->num_tensors()); // Set allocation and deallocation for temporary tensors. - for (size_t i = first_node; - i <= static_cast(last_node) && i < graph_info_->num_nodes(); + for (size_t i = first_node; i <= static_cast(last_node) && + i < graph_info_->num_execution_nodes(); ++i) { const TfLiteNode& node = graph_info_->node(i); TfLiteIntArray* node_temporaries = node.temporaries; diff --git a/tensorflow/lite/arena_planner_test.cc b/tensorflow/lite/arena_planner_test.cc index 813e10082a3..47ecc68cf40 100644 --- a/tensorflow/lite/arena_planner_test.cc +++ b/tensorflow/lite/arena_planner_test.cc @@ -134,7 +134,8 @@ class TestGraphInfo : public GraphInfo { TfLiteTensor* tensor(size_t index) override { return &graph_->tensors()->at(index); } - size_t num_nodes() const override { return graph_->nodes().size(); } + size_t num_execution_nodes() const override { return graph_->nodes().size(); } + size_t num_total_nodes() const override { return graph_->nodes().size(); } const TfLiteNode& node(size_t index) const override { return graph_->nodes()[index]; } diff --git a/tensorflow/lite/build_def.bzl b/tensorflow/lite/build_def.bzl index bdddac82d5b..0f731c43577 100644 --- a/tensorflow/lite/build_def.bzl +++ b/tensorflow/lite/build_def.bzl @@ -169,6 +169,33 @@ def tflite_cc_shared_object( def tf_to_tflite(name, src, options, out): """Convert a frozen tensorflow graphdef to TF Lite's flatbuffer. + Args: + name: Name of rule. + src: name of the input graphdef file. + options: options passed to TFLite Converter. + out: name of the output flatbuffer file. + """ + + toco_cmdline = " ".join([ + "$(location //tensorflow/lite/python:tflite_convert)", + "--experimental_new_converter", + ("--graph_def_file=$(location %s)" % src), + ("--output_file=$(location %s)" % out), + ] + options) + native.genrule( + name = name, + srcs = [src], + outs = [out], + cmd = toco_cmdline, + tools = ["//tensorflow/lite/python:tflite_convert"] + tf_binary_additional_srcs(), + ) + +def DEPRECATED_tf_to_tflite(name, src, options, out): + """DEPRECATED Convert a frozen tensorflow graphdef to TF Lite's flatbuffer, using toco. + + Please use tf_to_tflite instead. + TODO(b/138396996): Migrate away from this deprecated rule. + Args: name: Name of rule. src: name of the input graphdef file. @@ -742,27 +769,6 @@ def gen_model_coverage_test(src, model_name, data, failure_type, tags, size = "m ] + flex_dep(target_op_sets), ) -def if_tflite_experimental_runtime(if_eager, if_non_eager, if_none = []): - return select({ - "//tensorflow/lite:tflite_experimental_runtime_eager": if_eager, - "//tensorflow/lite:tflite_experimental_runtime_non_eager": if_non_eager, - "//conditions:default": if_none, - }) - -def tflite_experimental_runtime_linkopts(if_eager = [], if_non_eager = [], if_none = []): - return if_tflite_experimental_runtime( - if_eager = [ - # "//tensorflow/lite/experimental/tf_runtime:eager_interpreter", - # "//tensorflow/lite/experimental/tf_runtime:eager_model", - # "//tensorflow/lite/experimental/tf_runtime:subgraph", - ] + if_eager, - if_non_eager = [ - # "//tensorflow/lite/experimental/tf_runtime:interpreter", - # "//tensorflow/lite/experimental/tf_runtime:model", - ] + if_non_eager, - if_none = [] + if_none, - ) - def tflite_custom_cc_library( name, models = [], diff --git a/tensorflow/lite/builtin_ops.h b/tensorflow/lite/builtin_ops.h index 85140289ac1..a37607f6260 100644 --- a/tensorflow/lite/builtin_ops.h +++ b/tensorflow/lite/builtin_ops.h @@ -24,7 +24,8 @@ extern "C" { #endif // __cplusplus // The enum for builtin operators. -// Note: CUSTOM and DELEGATE are 2 special ops which are not real built-in ops. +// Note: CUSTOM, DELEGATE, and PLACEHOLDER_FOR_GREATER_OP_CODES are 3 special +// ops which are not real built-in ops. typedef enum { kTfLiteBuiltinAdd = 0, kTfLiteBuiltinAveragePool2d = 1, @@ -153,6 +154,7 @@ typedef enum { kTfLiteBuiltinDensify = 124, kTfLiteBuiltinSegmentSum = 125, kTfLiteBuiltinBatchMatmul = 126, + kTfLiteBuiltinPlaceholderForGreaterOpCodes = 127, } TfLiteBuiltinOperator; #ifdef __cplusplus diff --git a/tensorflow/lite/c/BUILD b/tensorflow/lite/c/BUILD index 5ac6d7881ac..e8db0dcf440 100644 --- a/tensorflow/lite/c/BUILD +++ b/tensorflow/lite/c/BUILD @@ -3,6 +3,7 @@ load( "tflite_cc_shared_object", "tflite_copts", ) +load("//tensorflow:tensorflow.bzl", "get_compatible_with_portable") package( default_visibility = ["//visibility:public"], @@ -46,6 +47,7 @@ cc_library( visibility = ["//visibility:private"], deps = [ ":common", + "//tensorflow/lite:builtin_ops", "//tensorflow/lite:framework", "//tensorflow/lite/core/api", ], @@ -59,11 +61,12 @@ cc_library( deps = [ ":c_api_internal", ":common", + "//tensorflow/lite:builtin_ops", "//tensorflow/lite:framework", "//tensorflow/lite:version", - "//tensorflow/lite/core/api", "//tensorflow/lite/delegates/nnapi:nnapi_delegate", "//tensorflow/lite/kernels:builtin_ops", + "//tensorflow/lite/kernels/internal:compatibility", ], alwayslink = 1, ) @@ -123,6 +126,8 @@ cc_library( "builtin_op_data.h", "common.h", ], + compatible_with = get_compatible_with_portable(), + deps = ["//tensorflow/lite:builtin_ops"], alwayslink = 1, ) diff --git a/tensorflow/lite/c/c_api.cc b/tensorflow/lite/c/c_api.cc index 4afd413ba9c..205c665d08b 100644 --- a/tensorflow/lite/c/c_api.cc +++ b/tensorflow/lite/c/c_api.cc @@ -16,10 +16,12 @@ limitations under the License. #include +#include "tensorflow/lite/builtin_ops.h" #include "tensorflow/lite/c/c_api_internal.h" #include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h" #include "tensorflow/lite/error_reporter.h" #include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/internal/compatibility.h" #include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/model.h" #include "tensorflow/lite/version.h" @@ -31,21 +33,55 @@ extern "C" { namespace { class CallbackErrorReporter : public tflite::ErrorReporter { public: - using ErrorCallback = void (*)(void* user_data, const char* format, - va_list args); - - CallbackErrorReporter(ErrorCallback callback, void* user_data) - : callback_(callback), user_data_(user_data) {} + explicit CallbackErrorReporter(TfLiteErrorReporterCallback callback) + : callback_(callback) {} int Report(const char* format, va_list args) override { - callback_(user_data_, format, args); + callback_.error_reporter(callback_.user_data, format, args); return 0; } private: - ErrorCallback callback_; - void* user_data_; + TfLiteErrorReporterCallback callback_; }; + +/// `CallbackOpResolver` is a (C++) `tflite::OpResolver` that forwards the +/// methods to (C ABI) callback functions from a `TfLiteOpResolverCallbacks` +/// struct. +/// +/// The SetCallbacks method must be called before calling any of the FindOp +/// methods. +class CallbackOpResolver : public ::tflite::OpResolver { + public: + CallbackOpResolver() {} + void SetCallbacks( + const struct TfLiteOpResolverCallbacks& op_resolver_callbacks) { + op_resolver_callbacks_ = op_resolver_callbacks; + } + const TfLiteRegistration* FindOp(tflite::BuiltinOperator op, + int version) const override { + if (op_resolver_callbacks_.find_builtin_op == nullptr) { + return nullptr; + } + return op_resolver_callbacks_.find_builtin_op( + op_resolver_callbacks_.user_data, + static_cast(op), version); + } + const TfLiteRegistration* FindOp(const char* op, int version) const override { + if (op_resolver_callbacks_.find_custom_op == nullptr) { + return nullptr; + } + return op_resolver_callbacks_.find_custom_op( + op_resolver_callbacks_.user_data, op, version); + } + + private: + CallbackOpResolver(const CallbackOpResolver&) = delete; + CallbackOpResolver& operator=(const CallbackOpResolver&) = delete; + + struct TfLiteOpResolverCallbacks op_resolver_callbacks_ = {}; +}; + } // namespace // LINT.IfChange @@ -89,62 +125,16 @@ void TfLiteInterpreterOptionsSetErrorReporter( TfLiteInterpreterOptions* options, void (*reporter)(void* user_data, const char* format, va_list args), void* user_data) { - options->error_reporter = reporter; - options->error_reporter_user_data = user_data; + options->error_reporter_callback.error_reporter = reporter; + options->error_reporter_callback.user_data = user_data; } TfLiteInterpreter* TfLiteInterpreterCreate( const TfLiteModel* model, const TfLiteInterpreterOptions* optional_options) { - if (!model || !model->impl) { - return nullptr; - } - - std::unique_ptr optional_error_reporter; - if (optional_options && optional_options->error_reporter != nullptr) { - optional_error_reporter.reset( - new CallbackErrorReporter(optional_options->error_reporter, - optional_options->error_reporter_user_data)); - } - - // TODO(b/111881878): Allow use of C API without pulling in all builtin ops. tflite::ops::builtin::BuiltinOpResolver resolver; - if (optional_options) { - resolver.AddAll(optional_options->op_resolver); - } - tflite::ErrorReporter* error_reporter = optional_error_reporter - ? optional_error_reporter.get() - : tflite::DefaultErrorReporter(); - tflite::InterpreterBuilder builder(model->impl->GetModel(), resolver, - error_reporter); - - std::unique_ptr interpreter; - if (builder(&interpreter) != kTfLiteOk) { - return nullptr; - } - - if (optional_options) { - if (optional_options->num_threads != - TfLiteInterpreterOptions::kDefaultNumThreads) { - interpreter->SetNumThreads(optional_options->num_threads); - } - - if (optional_options->use_nnapi) { - if (interpreter->ModifyGraphWithDelegate(tflite::NnApiDelegate()) != - kTfLiteOk) { - return nullptr; - } - } - - for (auto* delegate : optional_options->delegates) { - if (interpreter->ModifyGraphWithDelegate(delegate) != kTfLiteOk) { - return nullptr; - } - } - } - - return new TfLiteInterpreter{model->impl, std::move(optional_error_reporter), - std::move(interpreter)}; + return tflite::internal::InterpreterCreateWithOpResolver( + model, optional_options, &resolver); } void TfLiteInterpreterDelete(TfLiteInterpreter* interpreter) { @@ -240,3 +230,77 @@ TfLiteStatus TfLiteTensorCopyToBuffer(const TfLiteTensor* tensor, #ifdef __cplusplus } // extern "C" #endif // __cplusplus + +namespace tflite { +namespace internal { + +TfLiteInterpreter* InterpreterCreateWithOpResolver( + const TfLiteModel* model, const TfLiteInterpreterOptions* optional_options, + tflite::MutableOpResolver* mutable_resolver) { + TFLITE_DCHECK_NE(mutable_resolver, nullptr); + if (!model || !model->impl) { + return nullptr; + } + + std::unique_ptr optional_error_reporter; + if (optional_options && + optional_options->error_reporter_callback.error_reporter != nullptr) { + optional_error_reporter.reset( + new CallbackErrorReporter(optional_options->error_reporter_callback)); + } + + // By default, we use the provided mutable_op_resolver, adding any builtin or + // custom ops registered with `TfLiteInterpreterOptionsAddBuiltinOp` and/or + // `TfLiteInterpreterOptionsAddCustomOp`. + tflite::OpResolver* op_resolver = mutable_resolver; + if (optional_options) { + mutable_resolver->AddAll(optional_options->mutable_op_resolver); + } + // However, if `TfLiteInterpreterOptionsSetOpResolver` has been called with + // a non-null callback parameter, then we instead use a + // `CallbackOpResolver` that will forward to the callbacks provided there. + CallbackOpResolver callback_op_resolver; + if (optional_options && + (optional_options->op_resolver_callbacks.find_builtin_op != nullptr || + optional_options->op_resolver_callbacks.find_custom_op != nullptr)) { + callback_op_resolver.SetCallbacks(optional_options->op_resolver_callbacks); + op_resolver = &callback_op_resolver; + } + + tflite::ErrorReporter* error_reporter = optional_error_reporter + ? optional_error_reporter.get() + : tflite::DefaultErrorReporter(); + tflite::InterpreterBuilder builder(model->impl->GetModel(), *op_resolver, + error_reporter); + + std::unique_ptr interpreter; + if (builder(&interpreter) != kTfLiteOk) { + return nullptr; + } + + if (optional_options) { + if (optional_options->num_threads != + TfLiteInterpreterOptions::kDefaultNumThreads) { + interpreter->SetNumThreads(optional_options->num_threads); + } + + if (optional_options->use_nnapi) { + if (interpreter->ModifyGraphWithDelegate(tflite::NnApiDelegate()) != + kTfLiteOk) { + return nullptr; + } + } + + for (auto* delegate : optional_options->delegates) { + if (interpreter->ModifyGraphWithDelegate(delegate) != kTfLiteOk) { + return nullptr; + } + } + } + + return new TfLiteInterpreter{model->impl, std::move(optional_error_reporter), + std::move(interpreter)}; +} + +} // namespace internal +} // namespace tflite diff --git a/tensorflow/lite/c/c_api.h b/tensorflow/lite/c/c_api.h index 880b80e69b4..152bcf986fe 100644 --- a/tensorflow/lite/c/c_api.h +++ b/tensorflow/lite/c/c_api.h @@ -188,7 +188,7 @@ TFL_CAPI_EXPORT extern int32_t TfLiteInterpreterGetOutputTensorCount( const TfLiteInterpreter* interpreter); // Returns the tensor associated with the output index. -// REQUIRES: 0 <= input_index < TfLiteInterpreterGetOutputTensorCount(tensor) +// REQUIRES: 0 <= output_index < TfLiteInterpreterGetOutputTensorCount(tensor) // // NOTE: The shape and underlying data buffer for output tensors may be not // be available until after the output tensor has been both sized and allocated. diff --git a/tensorflow/lite/c/c_api_experimental.cc b/tensorflow/lite/c/c_api_experimental.cc index cff1b3d1530..23a5ca7a275 100644 --- a/tensorflow/lite/c/c_api_experimental.cc +++ b/tensorflow/lite/c/c_api_experimental.cc @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/lite/c/c_api.h" #include "tensorflow/lite/c/c_api_internal.h" #include "tensorflow/lite/interpreter.h" -#include "tensorflow/lite/mutable_op_resolver.h" #ifdef __cplusplus extern "C" { @@ -38,8 +37,17 @@ void TfLiteInterpreterOptionsAddBuiltinOp( TfLiteInterpreterOptions* options, TfLiteBuiltinOperator op, const TfLiteRegistration* registration, int32_t min_version, int32_t max_version) { - options->op_resolver.AddBuiltin(static_cast(op), - registration, min_version, max_version); + options->mutable_op_resolver.AddBuiltin( + static_cast(op), registration, min_version, + max_version); +} + +TfLiteInterpreter* TfLiteInterpreterCreateWithSelectedOps( + const TfLiteModel* model, + const TfLiteInterpreterOptions* optional_options) { + tflite::MutableOpResolver resolver; + return tflite::internal::InterpreterCreateWithOpResolver( + model, optional_options, &resolver); } void TfLiteInterpreterOptionsAddCustomOp(TfLiteInterpreterOptions* options, @@ -47,7 +55,21 @@ void TfLiteInterpreterOptionsAddCustomOp(TfLiteInterpreterOptions* options, const TfLiteRegistration* registration, int32_t min_version, int32_t max_version) { - options->op_resolver.AddCustom(name, registration, min_version, max_version); + options->mutable_op_resolver.AddCustom(name, registration, min_version, + max_version); +} + +void TfLiteInterpreterOptionsSetOpResolver( + TfLiteInterpreterOptions* options, + const TfLiteRegistration* (*find_builtin_op)(void* user_data, + TfLiteBuiltinOperator op, + int version), + const TfLiteRegistration* (*find_custom_op)(void* user_data, const char* op, + int version), + void* op_resolver_user_data) { + options->op_resolver_callbacks.find_builtin_op = find_builtin_op; + options->op_resolver_callbacks.find_custom_op = find_custom_op; + options->op_resolver_callbacks.user_data = op_resolver_user_data; } void TfLiteInterpreterOptionsSetUseNNAPI(TfLiteInterpreterOptions* options, diff --git a/tensorflow/lite/c/c_api_experimental.h b/tensorflow/lite/c/c_api_experimental.h index 0398c385874..bfbdd9c8fdd 100644 --- a/tensorflow/lite/c/c_api_experimental.h +++ b/tensorflow/lite/c/c_api_experimental.h @@ -23,33 +23,99 @@ limitations under the License. extern "C" { #endif // __cplusplus -// Resets all variable tensors to zero. +/// Resets all variable tensors to zero. +/// +/// WARNING: This is an experimental API and subject to change. TFL_CAPI_EXPORT extern TfLiteStatus TfLiteInterpreterResetVariableTensors( TfLiteInterpreter* interpreter); -// Adds an op registration for a builtin operator. -// -// NOTE: The interpreter will make a copy of `registration` internally, so the -// caller should ensure that its contents (function pointers, etc...) remain -// valid for the duration of the interpreter's lifetime. A common practice is -// making the provided TfLiteRegistration instance static. +/// Adds an op registration for a builtin operator. +/// +/// Op registrations are used to map ops referenced in the flatbuffer model +/// to executable function pointers (`TfLiteRegistration`s). +/// +/// NOTE: The interpreter will make a shallow copy of `registration` internally, +/// so the caller should ensure that its contents (function pointers, etc...) +/// remain valid for the duration of the interpreter's lifetime. A common +/// practice is making the provided `TfLiteRegistration` instance static. +/// +/// Code that uses this function should NOT call +/// `TfLiteInterpreterOptionsSetOpResolver' on the same options object. +/// +/// WARNING: This is an experimental API and subject to change. TFL_CAPI_EXPORT void TfLiteInterpreterOptionsAddBuiltinOp( TfLiteInterpreterOptions* options, TfLiteBuiltinOperator op, const TfLiteRegistration* registration, int32_t min_version, int32_t max_version); -// Adds an op registration for a custom operator. -// -// NOTE: The interpreter will make a copy of `registration` internally, so the -// caller should ensure that its contents (function pointers, etc...) remain -// valid for the duration of any created interpreter's lifetime. A common -// practice is making the provided TfLiteRegistration instance static. +/// Adds an op registration for a custom operator. +/// +/// Op registrations are used to map ops referenced in the flatbuffer model +/// to executable function pointers (`TfLiteRegistration`s). +/// +/// NOTE: The interpreter will make a shallow copy of `registration` internally, +/// so the caller should ensure that its contents (function pointers, etc...) +/// remain valid for the duration of any created interpreter's lifetime. A +/// common practice is making the provided `TfLiteRegistration` instance static. +/// +/// Code that uses this function should NOT call +/// `TfLiteInterpreterOptionsSetOpResolver' on the same options object. +/// +/// WARNING: This is an experimental API and subject to change. TFL_CAPI_EXPORT void TfLiteInterpreterOptionsAddCustomOp( TfLiteInterpreterOptions* options, const char* name, const TfLiteRegistration* registration, int32_t min_version, int32_t max_version); -// Enable or disable the NN API for the interpreter (true to enable). +/// Registers callbacks for resolving builtin or custom operators. +/// +/// The `TfLiteInterpreterOptionsSetOpResolver` function provides an alternative +/// method for registering builtin ops and/or custom ops, by providing operator +/// resolver callbacks. Unlike using `TfLiteInterpreterOptionsAddBuiltinOp` +/// and/or `TfLiteInterpreterOptionsAddAddCustomOp`, these let you register all +/// the operators in a single call. +/// +/// Code that uses this function should NOT call +/// `TfLiteInterpreterOptionsAddBuiltin' or +/// `TfLiteInterpreterOptionsAddCustomOp' on the same options object. +/// +/// WARNING: This is an experimental API and subject to change. +void TfLiteInterpreterOptionsSetOpResolver( + TfLiteInterpreterOptions* options, + const TfLiteRegistration* (*find_builtin_op)(void* user_data, + TfLiteBuiltinOperator op, + int version), + const TfLiteRegistration* (*find_custom_op)(void* user_data, + const char* custom_op, + int version), + void* op_resolver_user_data); + +/// Returns a new interpreter using the provided model and options, or null on +/// failure, where the model uses only the operators explicitly added to the +/// options. This is the same as `TFLiteInterpreterCreate` from `c_api.h`, +/// except that the only operators that are supported are the ones registered +/// in `options` via calls to `TfLiteInterpreterOptionsSetOpResolver`, +/// `TfLiteInterpreterOptionsAddBuiltinOp`, and/or +/// `TfLiteInterpreterOptionsAddCustomOp`. +/// +/// * `model` must be a valid model instance. The caller retains ownership of +/// the object, and can destroy it immediately after creating the interpreter; +/// the interpreter will maintain its own reference to the underlying model +/// data. +/// * `options` should not be null. The caller retains ownership of the object, +/// and can safely destroy it immediately after creating the interpreter. +/// +/// NOTE: The client *must* explicitly allocate tensors before attempting to +/// access input tensor data or invoke the interpreter. +/// +/// WARNING: This is an experimental API and subject to change. +TFL_CAPI_EXPORT extern TfLiteInterpreter* +TfLiteInterpreterCreateWithSelectedOps(const TfLiteModel* model, + const TfLiteInterpreterOptions* options); + +/// Enable or disable the NN API for the interpreter (true to enable). +/// +/// WARNING: This is an experimental API and subject to change. TFL_CAPI_EXPORT extern void TfLiteInterpreterOptionsSetUseNNAPI( TfLiteInterpreterOptions* options, bool enable); diff --git a/tensorflow/lite/c/c_api_experimental_test.cc b/tensorflow/lite/c/c_api_experimental_test.cc index 18bc7bb0397..4de137ec0e6 100644 --- a/tensorflow/lite/c/c_api_experimental_test.cc +++ b/tensorflow/lite/c/c_api_experimental_test.cc @@ -23,8 +23,8 @@ limitations under the License. namespace { -TfLiteRegistration* GetDummyRegistration() { - static TfLiteRegistration registration = { +const TfLiteRegistration* GetDummyRegistration() { + static const TfLiteRegistration registration = { /*init=*/nullptr, /*free=*/nullptr, /*prepare=*/nullptr, @@ -53,6 +53,112 @@ TEST(CApiExperimentalTest, Smoke) { TfLiteModelDelete(model); } +// Test using TfLiteInterpreterCreateWithSelectedOps. +TEST(CApiExperimentalTest, SelectedBuiltins) { + TfLiteModel* model = + TfLiteModelCreateFromFile("tensorflow/lite/testdata/add.bin"); + ASSERT_NE(model, nullptr); + + TfLiteInterpreterOptions* options = TfLiteInterpreterOptionsCreate(); + TfLiteInterpreterOptionsAddBuiltinOp(options, kTfLiteBuiltinAdd, + GetDummyRegistration(), 1, 1); + + TfLiteInterpreter* interpreter = + TfLiteInterpreterCreateWithSelectedOps(model, options); + ASSERT_NE(interpreter, nullptr); + ASSERT_EQ(TfLiteInterpreterAllocateTensors(interpreter), kTfLiteOk); + EXPECT_EQ(TfLiteInterpreterResetVariableTensors(interpreter), kTfLiteOk); + EXPECT_EQ(TfLiteInterpreterInvoke(interpreter), kTfLiteOk); + + TfLiteInterpreterDelete(interpreter); + TfLiteInterpreterOptionsDelete(options); + TfLiteModelDelete(model); +} + +// Test that when using TfLiteInterpreterCreateWithSelectedOps, +// we do NOT get the standard builtin operators by default. +TEST(CApiExperimentalTest, MissingBuiltin) { + TfLiteModel* model = + TfLiteModelCreateFromFile("tensorflow/lite/testdata/add.bin"); + ASSERT_NE(model, nullptr); + + // Install a custom error reporter into the interpreter by way of options. + tflite::TestErrorReporter reporter; + TfLiteInterpreterOptions* options = TfLiteInterpreterOptionsCreate(); + TfLiteInterpreterOptionsSetErrorReporter( + options, + [](void* user_data, const char* format, va_list args) { + reinterpret_cast(user_data)->Report(format, + args); + }, + &reporter); + + // Create an interpreter with no builtins at all. + TfLiteInterpreter* interpreter = + TfLiteInterpreterCreateWithSelectedOps(model, options); + + // Check that interpreter creation failed, because the model contain a buitin + // op that wasn't supported, and that we got the expected error messages. + ASSERT_EQ(interpreter, nullptr); + EXPECT_EQ(reporter.error_messages(), + "Didn't find op for builtin opcode 'ADD' version '1'\n" + "Registration failed.\n"); + EXPECT_EQ(reporter.num_calls(), 2); + + TfLiteInterpreterDelete(interpreter); + TfLiteInterpreterOptionsDelete(options); + TfLiteModelDelete(model); +} + +struct OpResolverData { + bool called_for_add = false; +}; + +const TfLiteRegistration* MyFindBuiltinOp(void* user_data, + TfLiteBuiltinOperator op, + int version) { + OpResolverData* my_data = static_cast(user_data); + if (op == kTfLiteBuiltinAdd && version == 1) { + my_data->called_for_add = true; + return GetDummyRegistration(); + } + return nullptr; +} + +const TfLiteRegistration* MyFindCustomOp(void*, const char* custom_op, + int version) { + if (absl::string_view(custom_op) == "foo" && version == 1) { + return GetDummyRegistration(); + } + return nullptr; +} + +// Test using TfLiteInterpreterCreateWithSelectedOps. +TEST(CApiExperimentalTest, SetOpResolver) { + TfLiteModel* model = + TfLiteModelCreateFromFile("tensorflow/lite/testdata/add.bin"); + ASSERT_NE(model, nullptr); + + TfLiteInterpreterOptions* options = TfLiteInterpreterOptionsCreate(); + + OpResolverData my_data; + TfLiteInterpreterOptionsSetOpResolver(options, MyFindBuiltinOp, + MyFindCustomOp, &my_data); + EXPECT_FALSE(my_data.called_for_add); + + TfLiteInterpreter* interpreter = + TfLiteInterpreterCreateWithSelectedOps(model, options); + ASSERT_NE(interpreter, nullptr); + ASSERT_EQ(TfLiteInterpreterAllocateTensors(interpreter), kTfLiteOk); + EXPECT_EQ(TfLiteInterpreterResetVariableTensors(interpreter), kTfLiteOk); + EXPECT_EQ(TfLiteInterpreterInvoke(interpreter), kTfLiteOk); + EXPECT_TRUE(my_data.called_for_add); + + TfLiteInterpreterDelete(interpreter); + TfLiteInterpreterOptionsDelete(options); + TfLiteModelDelete(model); +} + } // namespace int main(int argc, char** argv) { diff --git a/tensorflow/lite/c/c_api_internal.h b/tensorflow/lite/c/c_api_internal.h index f13712362a6..ee07e3e06a5 100644 --- a/tensorflow/lite/c/c_api_internal.h +++ b/tensorflow/lite/c/c_api_internal.h @@ -20,13 +20,15 @@ limitations under the License. #include #include +#include "tensorflow/lite/builtin_ops.h" #include "tensorflow/lite/core/api/error_reporter.h" +#include "tensorflow/lite/core/api/op_resolver.h" #include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/model.h" #include "tensorflow/lite/mutable_op_resolver.h" -// Internal structures used by the C API. These are likely to change and should -// not be depended on directly by any C API clients. +// Internal structures and subroutines used by the C API. These are likely to +// change and should not be depended on directly by any C API clients. // // NOTE: This header does not follow C conventions and does not define a C API. // It is effectively an (internal) implementation detail of the C API. @@ -36,20 +38,54 @@ struct TfLiteModel { std::shared_ptr impl; }; +// The `TfLiteOpResolver` struct is an abstract callback interface that +// contains function pointers for callbacks that return a +// `TfLiteRegistration` given an op code or custom op name. This mechanism is +// used to map ops referenced in the flatbuffer model to executable function +// pointers (`TfLiteRegistration`s). +// This struct mirrors the tflite::OpResolver C++ abstract base class. +struct TfLiteOpResolverCallbacks { + // Opaque data that gets passed down to the callback functions. + void* user_data = nullptr; + + // Callback that finds the op registration for a builtin operator by enum + // code. The `user_data` parameter will be set to the + // `op_resolver_user_data` value that was passed to + // `TfLiteInterpreterOptionsSetOpResolver`. + const TfLiteRegistration* (*find_builtin_op)(void* user_data, + TfLiteBuiltinOperator op, + int version); + // Callback that finds the op registration of a custom operator by op name. + // The `user_data` parameter will be set to the `op_resolver_user_data` value + // that was passed to `TfLiteInterpreterOptionsSetOpResolver`. + const TfLiteRegistration* (*find_custom_op)(void* user_data, const char* op, + int version); +}; + +// This struct mirrors the tflite::ErrorResolver C++ abstract base class. +struct TfLiteErrorReporterCallback { + // Opaque data that gets passed down to the callback function. + void* user_data = nullptr; + + // Callback function that reports an error. + void (*error_reporter)(void* user_data, const char* format, + va_list args) = nullptr; +}; + struct TfLiteInterpreterOptions { enum { kDefaultNumThreads = -1, }; int num_threads = kDefaultNumThreads; - tflite::MutableOpResolver op_resolver; + tflite::MutableOpResolver mutable_op_resolver; - void (*error_reporter)(void* user_data, const char* format, - va_list args) = nullptr; - void* error_reporter_user_data = nullptr; + TfLiteOpResolverCallbacks op_resolver_callbacks = {}; std::vector delegates; + TfLiteErrorReporterCallback error_reporter_callback; + bool use_nnapi = false; }; @@ -60,10 +96,38 @@ struct TfLiteInterpreter { // The interpreter does not take ownership of the provided ErrorReporter // instance, so we ensure its validity here. Note that the interpreter may use - // the reporter in its destructor, so it should be declared first. + // the reporter in its destructor, so the reporter should be declared first. std::unique_ptr optional_error_reporter; std::unique_ptr impl; }; +namespace tflite { +namespace internal { + +// This adds the builtin and/or custom operators specified in options in +// `optional_options` (if any) to `mutable_resolver`, and then returns a newly +// created TfLiteInterpreter using `mutable_op_resolver` as the default +// OpResolver, and using any other options in `optional_options`, and using +// the provided `model`. +// +// * `model` must be a valid model instance. The caller retains ownership of the +// object, and can destroy it immediately after creating the interpreter; the +// interpreter will maintain its own reference to the underlying model data. +// * `optional_options` may be null. The caller retains ownership of the object, +// and can safely destroy it immediately after creating the interpreter. +// * `mutable_resolver` must not be null. The caller retains ownership of the +// MutableOpResolver object, and can safely destroy it immediately after +// creating the interpreter. +// +// NOTE: The client *must* explicitly allocate tensors before attempting to +// access input tensor data or invoke the interpreter. + +TfLiteInterpreter* InterpreterCreateWithOpResolver( + const TfLiteModel* model, const TfLiteInterpreterOptions* optional_options, + tflite::MutableOpResolver* mutable_resolver); + +} // namespace internal +} // namespace tflite + #endif // TENSORFLOW_LITE_C_C_API_INTERNAL_H_ diff --git a/tensorflow/lite/c/common.h b/tensorflow/lite/c/common.h index d320a90d005..8917c254825 100644 --- a/tensorflow/lite/c/common.h +++ b/tensorflow/lite/c/common.h @@ -226,6 +226,17 @@ void TfLiteFloatArrayFree(TfLiteFloatArray* a); } \ } while (0) +#define TF_LITE_ENSURE_NEAR(context, a, b, epsilon) \ + do { \ + auto delta = ((a) > (b)) ? ((a) - (b)) : ((b) - (a)); \ + if (delta > epsilon) { \ + TF_LITE_KERNEL_LOG((context), "%s:%d %s not near %s (%f != %f)", \ + __FILE__, __LINE__, #a, #b, static_cast(a), \ + static_cast(b)); \ + return kTfLiteError; \ + } \ + } while (0) + #define TF_LITE_ENSURE_OK(context, status) \ do { \ const TfLiteStatus s = (status); \ @@ -410,7 +421,7 @@ typedef struct TfLiteCustomAllocation { size_t bytes; } TfLiteCustomAllocation; -// An tensor in the interpreter system which is a wrapper around a buffer of +// A tensor in the interpreter system which is a wrapper around a buffer of // data including a dimensionality (or NULL if not currently defined). #ifndef TF_LITE_STATIC_MEMORY typedef struct TfLiteTensor { diff --git a/tensorflow/lite/core/api/BUILD b/tensorflow/lite/core/api/BUILD index a1e6fc41cd9..38b2e295da2 100644 --- a/tensorflow/lite/core/api/BUILD +++ b/tensorflow/lite/core/api/BUILD @@ -1,17 +1,16 @@ load("//tensorflow/lite:build_def.bzl", "tflite_copts") load("//tensorflow/lite/micro:build_def.bzl", "micro_copts") +load("//tensorflow:tensorflow.bzl", "get_compatible_with_portable") package( - default_visibility = ["//visibility:public"], + default_visibility = ["//visibility:private"], licenses = ["notice"], # Apache 2.0 ) cc_library( name = "api", srcs = [ - "error_reporter.cc", "flatbuffer_conversions.cc", - "op_resolver.cc", "tensor_utils.cc", ], hdrs = [ @@ -21,17 +20,67 @@ cc_library( "profiler.h", "tensor_utils.h", ], + compatible_with = get_compatible_with_portable(), copts = tflite_copts() + micro_copts(), + visibility = ["//visibility:public"], deps = [ + ":error_reporter", + ":op_resolver", "@flatbuffers//:runtime_cc", "//tensorflow/lite/c:common", # TODO(b/158301698): consider moving internal:compatibility to a more # central location. "//tensorflow/lite/kernels/internal:compatibility", "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/schema:schema_utils", ], ) +# We define separate targets for "op_resolver" and "error_reporter", +# even though those headers are also exported by the "api" target, +# so that targets which only want to depend on these small abstract base +# class modules can express more fine-grained dependencies without +# pulling in tensor_utils and flatbuffer_conversions. + +cc_library( + name = "op_resolver", + srcs = ["op_resolver.cc"], + hdrs = ["op_resolver.h"], + compatible_with = get_compatible_with_portable(), + copts = tflite_copts() + micro_copts(), + visibility = [ + "//visibility:public", + ], + deps = [ + ":error_reporter", + "//tensorflow/lite/c:common", + "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/schema:schema_utils", + "@flatbuffers//:runtime_cc", + ], +) + +cc_library( + name = "error_reporter", + srcs = ["error_reporter.cc"], + hdrs = ["error_reporter.h"], + compatible_with = get_compatible_with_portable(), + copts = tflite_copts() + micro_copts(), + visibility = [ + "//visibility:public", + ], + deps = [], +) + +cc_library( + name = "verifier", + hdrs = ["verifier.h"], + compatible_with = get_compatible_with_portable(), + copts = tflite_copts() + micro_copts(), + visibility = ["//visibility:public"], + deps = [":error_reporter"], +) + cc_test( name = "error_reporter_test", size = "small", @@ -48,6 +97,7 @@ cc_test( srcs = ["op_resolver_test.cc"], deps = [ ":api", + "//tensorflow/lite/schema:schema_utils", "@com_google_googletest//:gtest", ], ) diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.cc b/tensorflow/lite/core/api/flatbuffer_conversions.cc index 5d2936f3636..77621c3f2fd 100644 --- a/tensorflow/lite/core/api/flatbuffer_conversions.cc +++ b/tensorflow/lite/core/api/flatbuffer_conversions.cc @@ -333,6 +333,10 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type, return ParseReshape(op, error_reporter, allocator, builtin_data); } + case BuiltinOperator_RESIZE_BILINEAR: { + return ParseResizeBilinear(op, error_reporter, allocator, builtin_data); + } + case BuiltinOperator_RESIZE_NEAREST_NEIGHBOR: { return ParseResizeNearestNeighbor(op, error_reporter, allocator, builtin_data); @@ -346,6 +350,10 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type, return ParseRsqrt(op, error_reporter, allocator, builtin_data); } + case BuiltinOperator_SHAPE: { + return ParseShape(op, error_reporter, allocator, builtin_data); + } + case BuiltinOperator_SIN: { return ParseSin(op, error_reporter, allocator, builtin_data); } @@ -358,6 +366,10 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type, return ParseSplit(op, error_reporter, allocator, builtin_data); } + case BuiltinOperator_SPLIT_V: { + return ParseSplitV(op, error_reporter, allocator, builtin_data); + } + case BuiltinOperator_SQRT: { return ParseSqrt(op, error_reporter, allocator, builtin_data); } @@ -560,22 +572,6 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type, *builtin_data = params.release(); return kTfLiteOk; } - case BuiltinOperator_RESIZE_BILINEAR: { - auto params = safe_allocator.Allocate(); - TF_LITE_ENSURE(error_reporter, params != nullptr); - if (const auto* schema_params = - op->builtin_options_as_ResizeBilinearOptions()) { - params->align_corners = schema_params->align_corners(); - params->half_pixel_centers = schema_params->half_pixel_centers(); - } else { - // Some older models did not populate the ResizeBilinearOptions field in - // the flatbuffer, so ensure it's set to a sensible default. - params->align_corners = false; - params->half_pixel_centers = false; - } - *builtin_data = params.release(); - return kTfLiteOk; - } case BuiltinOperator_SKIP_GRAM: { auto params = safe_allocator.Allocate(); TF_LITE_ENSURE(error_reporter, params != nullptr); @@ -619,15 +615,7 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type, *builtin_data = params.release(); return kTfLiteOk; } - case BuiltinOperator_SPLIT_V: { - auto params = safe_allocator.Allocate(); - TF_LITE_ENSURE(error_reporter, params != nullptr); - if (const auto* schema_params = op->builtin_options_as_SplitVOptions()) { - params->num_splits = schema_params->num_splits(); - } - *builtin_data = params.release(); - return kTfLiteOk; - } + case BuiltinOperator_SQUEEZE: { auto params = safe_allocator.Allocate(); TF_LITE_ENSURE(error_reporter, params != nullptr); @@ -667,16 +655,6 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type, *builtin_data = params.release(); return kTfLiteOk; } - case BuiltinOperator_SHAPE: { - auto params = safe_allocator.Allocate(); - TF_LITE_ENSURE(error_reporter, params != nullptr); - if (const auto* schema_params = op->builtin_options_as_ShapeOptions()) { - TF_LITE_ENSURE_STATUS(ConvertTensorType( - schema_params->out_type(), ¶ms->out_type, error_reporter)); - } - *builtin_data = params.release(); - return kTfLiteOk; - } case BuiltinOperator_DELEGATE: { // TODO(ycling): Revisit when supporting saving delegated models. TF_LITE_REPORT_ERROR(error_reporter, @@ -825,6 +803,8 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_DENSIFY: case BuiltinOperator_SEGMENT_SUM: return kTfLiteOk; + case BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES: + return kTfLiteError; } return kTfLiteError; } // NOLINT[readability/fn_size] @@ -1475,6 +1455,33 @@ TfLiteStatus ParseReshape(const Operator* op, ErrorReporter* error_reporter, return kTfLiteOk; } +TfLiteStatus ParseResizeBilinear(const Operator* op, + ErrorReporter* error_reporter, + BuiltinDataAllocator* allocator, + void** builtin_data) { + CheckParsePointerParams(op, error_reporter, allocator, builtin_data); + + SafeBuiltinDataAllocator safe_allocator(allocator); + std::unique_ptr + params = safe_allocator.Allocate(); + TF_LITE_ENSURE(error_reporter, params != nullptr); + + const ResizeBilinearOptions* schema_params = + op->builtin_options_as_ResizeBilinearOptions(); + + if (schema_params != nullptr) { + params->align_corners = schema_params->align_corners(); + params->half_pixel_centers = schema_params->half_pixel_centers(); + } else { + params->align_corners = false; + params->half_pixel_centers = false; + } + + *builtin_data = params.release(); + return kTfLiteOk; +} + TfLiteStatus ParseResizeNearestNeighbor(const Operator* op, ErrorReporter* error_reporter, BuiltinDataAllocator* allocator, @@ -1518,6 +1525,29 @@ TfLiteStatus ParseRsqrt(const Operator*, ErrorReporter*, BuiltinDataAllocator*, return kTfLiteOk; } +TfLiteStatus ParseShape(const Operator* op, ErrorReporter* error_reporter, + BuiltinDataAllocator* allocator, void** builtin_data) { + SafeBuiltinDataAllocator safe_allocator(allocator); + std::unique_ptr + params = safe_allocator.Allocate(); + TF_LITE_ENSURE(error_reporter, params != nullptr); + + const ShapeOptions* schema_params = op->builtin_options_as_ShapeOptions(); + + if (schema_params != nullptr) { + TF_LITE_ENSURE_STATUS(ConvertTensorType(schema_params->out_type(), + ¶ms->out_type, error_reporter)); + } else { + // TODO(b/157480169): We should either return kTfLiteError or fill in some + // reasonable defaults in the params struct. We are not doing so until we + // better undertand the ramifications of changing the legacy behavior. + } + + *builtin_data = params.release(); + return kTfLiteOk; +} + // We have this parse function instead of directly returning kTfLiteOk from the // switch-case in ParseOpData because this function is used as part of the // selective registration for the OpResolver implementation in micro. @@ -1575,6 +1605,30 @@ TfLiteStatus ParseSplit(const Operator* op, ErrorReporter* error_reporter, return kTfLiteOk; } +TfLiteStatus ParseSplitV(const Operator* op, ErrorReporter* error_reporter, + BuiltinDataAllocator* allocator, void** builtin_data) { + CheckParsePointerParams(op, error_reporter, allocator, builtin_data); + SafeBuiltinDataAllocator safe_allocator(allocator); + + std::unique_ptr + params = safe_allocator.Allocate(); + TF_LITE_ENSURE(error_reporter, params != nullptr); + + const SplitVOptions* schema_params = op->builtin_options_as_SplitVOptions(); + + if (schema_params != nullptr) { + params->num_splits = schema_params->num_splits(); + } else { + // TODO(b/157480169): We should either return kTfLiteError or fill in some + // reasonable defaults in the params struct. We are not doing so until we + // better undertand the ramifications of changing the legacy behavior. + } + + *builtin_data = params.release(); + return kTfLiteOk; +} + // We have this parse function instead of directly returning kTfLiteOk from the // switch-case in ParseOpData because this function is used as part of the // selective registration for the OpResolver implementation in micro. diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.h b/tensorflow/lite/core/api/flatbuffer_conversions.h index aaeb98c0a2e..136809977c9 100644 --- a/tensorflow/lite/core/api/flatbuffer_conversions.h +++ b/tensorflow/lite/core/api/flatbuffer_conversions.h @@ -45,7 +45,7 @@ class BuiltinDataAllocator { // platform targets support that properly. static_assert(std::is_pod::value, "Builtin data structure must be POD."); void* allocated_memory = this->Allocate(sizeof(T), alignof(T)); - return new (allocated_memory) T; + return new (allocated_memory) T(); } virtual ~BuiltinDataAllocator() {} @@ -205,6 +205,11 @@ TfLiteStatus ParseRelu6(const Operator* op, ErrorReporter* error_reporter, TfLiteStatus ParseReshape(const Operator* op, ErrorReporter* error_reporter, BuiltinDataAllocator* allocator, void** builtin_data); +TfLiteStatus ParseResizeBilinear(const Operator* op, + ErrorReporter* error_reporter, + BuiltinDataAllocator* allocator, + void** builtin_data); + TfLiteStatus ParseResizeNearestNeighbor(const Operator* op, ErrorReporter* error_reporter, BuiltinDataAllocator* allocator, @@ -216,6 +221,9 @@ TfLiteStatus ParseRound(const Operator* op, ErrorReporter* error_reporter, TfLiteStatus ParseRsqrt(const Operator* op, ErrorReporter* error_reporter, BuiltinDataAllocator* allocator, void** builtin_data); +TfLiteStatus ParseShape(const Operator* op, ErrorReporter* error_reporter, + BuiltinDataAllocator* allocator, void** builtin_data); + TfLiteStatus ParseSin(const Operator* op, ErrorReporter* error_reporter, BuiltinDataAllocator* allocator, void** builtin_data); @@ -225,6 +233,9 @@ TfLiteStatus ParseSoftmax(const Operator* op, ErrorReporter* error_reporter, TfLiteStatus ParseSplit(const Operator* op, ErrorReporter* error_reporter, BuiltinDataAllocator* allocator, void** builtin_data); +TfLiteStatus ParseSplitV(const Operator* op, ErrorReporter* error_reporter, + BuiltinDataAllocator* allocator, void** builtin_data); + TfLiteStatus ParseSqrt(const Operator* op, ErrorReporter* error_reporter, BuiltinDataAllocator* allocator, void** builtin_data); diff --git a/tensorflow/lite/core/api/op_resolver.cc b/tensorflow/lite/core/api/op_resolver.cc index c239d9ed23e..c5dffb63549 100644 --- a/tensorflow/lite/core/api/op_resolver.cc +++ b/tensorflow/lite/core/api/op_resolver.cc @@ -18,6 +18,7 @@ limitations under the License. #include "flatbuffers/flatbuffers.h" // from @flatbuffers #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/core/api/error_reporter.h" +#include "tensorflow/lite/schema/schema_utils.h" namespace tflite { @@ -26,7 +27,7 @@ TfLiteStatus GetRegistrationFromOpCode( ErrorReporter* error_reporter, const TfLiteRegistration** registration) { TfLiteStatus status = kTfLiteOk; *registration = nullptr; - auto builtin_code = opcode->builtin_code(); + auto builtin_code = GetBuiltinCode(opcode); int version = opcode->version(); if (builtin_code > BuiltinOperator_MAX || diff --git a/tensorflow/lite/core/api/op_resolver_test.cc b/tensorflow/lite/core/api/op_resolver_test.cc index 4dfca5c971a..44acc92ba8c 100644 --- a/tensorflow/lite/core/api/op_resolver_test.cc +++ b/tensorflow/lite/core/api/op_resolver_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "tensorflow/lite/schema/schema_utils.h" namespace tflite { namespace { diff --git a/tensorflow/lite/core/api/verifier.h b/tensorflow/lite/core/api/verifier.h new file mode 100644 index 00000000000..ca1cfb044bd --- /dev/null +++ b/tensorflow/lite/core/api/verifier.h @@ -0,0 +1,38 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +/// \file +/// Abstract interface for verifying a model. +#ifndef TENSORFLOW_LITE_CORE_API_VERIFIER_H_ +#define TENSORFLOW_LITE_CORE_API_VERIFIER_H_ + +#include "tensorflow/lite/core/api/error_reporter.h" + +namespace tflite { + +/// Abstract interface that verifies whether a given model is legit. +/// It facilitates the use-case to verify and build a model without loading it +/// twice. +/// (See also "tensorflow/lite/tools/verifier.h".) +class TfLiteVerifier { + public: + /// Returns true if the model is legit. + virtual bool Verify(const char* data, int length, + ErrorReporter* reporter) = 0; + virtual ~TfLiteVerifier() {} +}; + +} // namespace tflite + +#endif // TENSORFLOW_LITE_CORE_API_VERIFIER_H_ diff --git a/tensorflow/lite/core/subgraph.cc b/tensorflow/lite/core/subgraph.cc index ecdb04c8b3c..2b9246a1100 100644 --- a/tensorflow/lite/core/subgraph.cc +++ b/tensorflow/lite/core/subgraph.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/lite/arena_planner.h" +#include "tensorflow/lite/builtin_ops.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/context_util.h" #include "tensorflow/lite/core/api/tensor_utils.h" @@ -30,8 +31,6 @@ limitations under the License. namespace tflite { -namespace impl { - namespace { struct TfLiteQuantizationDeleter { @@ -87,6 +86,7 @@ template bool HasDynamicTensorImpl(const TfLiteContext& context, const TensorIntArray& int_array) { for (int i : int_array) { + if (i == kTfLiteOptionalTensor) continue; const TfLiteTensor& tensor = context.tensors[i]; if (tensor.allocation_type == kTfLiteDynamic) { return true; @@ -167,9 +167,10 @@ class InterpreterInfo : public GraphInfo { TfLiteTensor* tensor(size_t index) override { return &subgraph_->tensors()[index]; } - size_t num_nodes() const override { + size_t num_execution_nodes() const override { return subgraph_->execution_plan().size(); } + size_t num_total_nodes() const override { return subgraph_->nodes_size(); } const TfLiteNode& node(size_t index) const override { int node_index = subgraph_->execution_plan()[index]; return subgraph_->nodes_and_registration()[node_index].first; @@ -582,6 +583,33 @@ TfLiteStatus Subgraph::CheckTensorIndices(const char* label, const int* indices, return kTfLiteOk; } +// We have two arrays and we need to check that elements from one array don't +// show up in the other. We could sort both arrays and then iterate with two +// pointers from start to finish always increasing the smaller one but since +// these arrays are usually short (<25 elements for inputs, usually <3 for +// outputs), this might be slower than the naive approach (if arrays have size n +// and m, with n >> m ~ O(1), first approach is O(nlogn) whereas the other is +// O(n)). Plus, sorting the input and output arrays might not be something we +// want as it destroys ordering of elements. +// +// If it turns out that this is an issue, we can switch to the other algorithm. +TfLiteStatus Subgraph::CheckInputAndOutputForOverlap(const int* input_indices, + int num_inputs, + const int* output_indices, + int num_outputs) { + for (int i = 0; i < num_inputs; i++) { + for (int j = 0; j < num_outputs; j++) { + if (input_indices[i] == output_indices[j]) { + ReportError("Tensor %d is both input %d and output %d\n", + input_indices[i], i, j); + consistent_ = false; + return kTfLiteError; + } + } + } + return kTfLiteOk; +} + namespace { // Multiply two sizes and return true if overflow occurred; // This is based off tensorflow/overflow.h but is simpler as we already @@ -674,13 +702,17 @@ TfLiteStatus Subgraph::ResetVariableTensors() { continue; } - // Variable tensors have to be `kTfLiteArenaRwPersistent`, and must be - // allocated after the initial `PrepareOpsAndTensors()` is called. - TF_LITE_ENSURE_EQ(&context_, tensor.allocation_type, - kTfLiteArenaRwPersistent); - TF_LITE_ENSURE(&context_, tensor.data.raw != nullptr); - - tflite::ResetVariableTensor(&tensor); + if (tensor.allocation_type == kTfLiteArenaRwPersistent) { + // If variable tensors allocation type is `kTfLiteArenaRwPersistent`, then + // they must be allocated after the initial `PrepareOpsAndTensors()` is + // called. + TF_LITE_ENSURE(&context_, tensor.data.raw != nullptr); + tflite::ResetVariableTensor(&tensor); + } else { + // If variable tensors allocation type is not `kTfLiteArenaRwPersistent`, + // then it can only be `kTfLiteCustom` in which case, we do not reset it. + TF_LITE_ENSURE_EQ(&context_, tensor.allocation_type, kTfLiteCustom); + } } return kTfLiteOk; } @@ -704,6 +736,16 @@ TfLiteStatus Subgraph::AddNodeWithParameters( &context_, CheckTensorIndices("node outputs", outputs.data(), outputs.size())); + // For builtin ops, inputs and outputs must not overlap. Custom ops must do + // this check by themselves if they don't support overlapping tensors. This + // distinction is to allow custom ops to just forward a tensor, reusing it as + // both input and output. + if (builtin_data != nullptr) { + TF_LITE_ENSURE_OK(&context_, CheckInputAndOutputForOverlap( + inputs.data(), inputs.size(), + outputs.data(), outputs.size())); + } + int new_node_index = nodes_and_registration_.size(); if (node_index) *node_index = new_node_index; nodes_and_registration_.resize(nodes_and_registration_.size() + 1); @@ -990,6 +1032,19 @@ TfLiteStatus Subgraph::Invoke() { tensor->data_is_stale) { TF_LITE_ENSURE_STATUS(EnsureTensorDataIsReadable(tensor_index)); } + if (tensor->data.raw == nullptr && tensor->bytes > 0) { + if (registration.builtin_code == kTfLiteBuiltinReshape && i == 1) { + // In general, having a tensor here with no buffer will be an error. + // However, for the reshape operator, the second input tensor is only + // used for the shape, not for the data. Thus, null buffer is ok. + continue; + } else { + // In all other cases, we need to return an error as otherwise we will + // trigger a null pointer dereference (likely). + ReportError("Input tensor %d lacks data", tensor_index); + return kTfLiteError; + } + } } if (check_cancelled_func_ != nullptr && @@ -1341,6 +1396,48 @@ TfLiteStatus Subgraph::UndoAllDelegates() { execution_plan_ = pre_delegation_execution_plan_; pre_delegation_execution_plan_.clear(); + // Handling FP16 delegation (if applies). + // + // First pass through execution plan to remember mapping of FP16 + // dequantizations in the graph. + // This is required because delegates that support FP16 could remap supported + // nodes' inputs to point to their fp16 versions (if delegate supports fp16 + // acceleration). This remapping is performed in FP16GraphPartitionHelper in + // delegates/utils. We need to undo this remapping to ensure CPU kernels work. + std::vector fp16_to_fp32(tensors_size(), -1); + for (int execution_plan_index = 0; + execution_plan_index < execution_plan_.size(); ++execution_plan_index) { + int node_index = execution_plan_[execution_plan_index]; + auto& node_and_reg = nodes_and_registration_[node_index]; + const TfLiteNode& node = node_and_reg.first; + const TfLiteRegistration& reg = node_and_reg.second; + if (reg.builtin_code == kTfLiteBuiltinDequantize && + node.inputs->size == 1 && node.outputs->size == 1) { + const int input_idx = node.inputs->data[0]; + if (tensors_[input_idx].type == kTfLiteFloat16) { + fp16_to_fp32[input_idx] = node.outputs->data[0]; + } + } + } + // Second pass through the execution plan to remap applicable nodes' fp16 + // inputs to their original fp32 versions. Note that if a CPU kernel does + // support fp16, the model will not contain a DEQUANTIZE for its constant + // input. + for (int execution_plan_index = 0; + execution_plan_index < execution_plan_.size(); ++execution_plan_index) { + int node_index = execution_plan_[execution_plan_index]; + auto& node_and_reg = nodes_and_registration_[node_index]; + const TfLiteNode& node = node_and_reg.first; + const TfLiteRegistration& reg = node_and_reg.second; + if (reg.builtin_code == kTfLiteBuiltinDequantize) continue; + for (int i = 0; i < node.inputs->size; ++i) { + const int original_input_idx = node.inputs->data[i]; + if (tensors_[original_input_idx].type == kTfLiteFloat16) { + node.inputs->data[i] = fp16_to_fp32[original_input_idx]; + } + } + } + // Delegate nodes are appended to nodes_and_registration_. Therefore, // cleanup nodes_and_registration_ to only contain nodes from // pre_delegation_execution_plan_. @@ -1486,8 +1583,10 @@ TfLiteStatus Subgraph::ModifyGraphWithDelegate(TfLiteDelegate* delegate) { TfLiteStatus Subgraph::SetCustomAllocationForTensor( int tensor_index, const TfLiteCustomAllocation& allocation) { TfLiteTensor* tensor = &context_.tensors[tensor_index]; - TF_LITE_ENSURE(context(), tensor->allocation_type == kTfLiteArenaRw || - tensor->allocation_type == kTfLiteCustom); + TF_LITE_ENSURE(context(), + (tensor->allocation_type == kTfLiteArenaRw || + tensor->allocation_type == kTfLiteArenaRwPersistent || + tensor->allocation_type == kTfLiteCustom)); TF_LITE_ENSURE_STATUS( ValidateCustomAllocationForTensor(context(), tensor, allocation)); @@ -1510,6 +1609,4 @@ TfLiteStatus Subgraph::SetCustomAllocationForTensor( return kTfLiteOk; } -} // namespace impl - } // namespace tflite diff --git a/tensorflow/lite/core/subgraph.h b/tensorflow/lite/core/subgraph.h index 3a28b4cb99c..b94d1a0b2bc 100644 --- a/tensorflow/lite/core/subgraph.h +++ b/tensorflow/lite/core/subgraph.h @@ -30,14 +30,8 @@ limitations under the License. #include "tensorflow/lite/memory_planner.h" #include "tensorflow/lite/util.h" -#if TFLITE_EXPERIMENTAL_RUNTIME_EAGER -#include "tensorflow/lite/experimental/tf_runtime/public/subgraph.h" -#endif - namespace tflite { -namespace impl { - // Forward declare since NNAPIDelegate uses Interpreter. class NNAPIDelegate; @@ -342,8 +336,8 @@ class Subgraph { // for the tensor, it can no longer be reset to the TFLite arena memory. // // Parameters should satisfy the following conditions: - // 1. tensor->allocation_type == kTfLiteArenaRw - // In general, this is true for all non-constants such as I/O tensors. + // 1. tensor->allocation_type == kTfLiteArenaRw or kTfLiteArenaRwPersistent + // In general, this is true for I/O tensors & variable tensors. // 2. allocation->data has the appropriate permissions for runtime access // (Read-only for inputs, Read-Write for others), and outlives Interpreter. // 3. allocation->bytes >= tensor->bytes. @@ -457,6 +451,15 @@ class Subgraph { TfLiteStatus CheckTensorIndices(const char* label, const int* indices, int length); + // Check that the input indices and the output indices don't overlap. + // This is needed because same tensor must not be used both as input and + // output for an operator. + // NOTE: this changes consistent_ to be false if indices are out of bounds. + TfLiteStatus CheckInputAndOutputForOverlap(const int* input_indices, + int num_inputs, + const int* output_indices, + int num_outputs); + // Compute the number of bytes required to represent a tensor with dimensions // specified by the array dims (of length dims_size). Returns the status code // and bytes. @@ -739,13 +742,5 @@ class Subgraph { resource::ResourceMap* resources_ = nullptr; }; -} // namespace impl - -#if TFLITE_EXPERIMENTAL_RUNTIME_EAGER -using Subgraph = tflrt::Subgraph; -#else -using Subgraph = impl::Subgraph; -#endif - } // namespace tflite #endif // TENSORFLOW_LITE_CORE_SUBGRAPH_H_ diff --git a/tensorflow/lite/delegates/BUILD b/tensorflow/lite/delegates/BUILD index e1f91f32c34..d106ae4a738 100644 --- a/tensorflow/lite/delegates/BUILD +++ b/tensorflow/lite/delegates/BUILD @@ -14,6 +14,7 @@ # ============================================================================== load("//tensorflow/lite:build_def.bzl", "tflite_copts", "tflite_linkopts") +load("//tensorflow:tensorflow.bzl", "get_compatible_with_portable") package( default_visibility = ["//visibility:public"], @@ -23,6 +24,7 @@ package( cc_library( name = "status", hdrs = ["status.h"], + compatible_with = get_compatible_with_portable(), copts = tflite_copts(), deps = [ "//tensorflow/lite/c:common", @@ -33,6 +35,7 @@ cc_library( name = "utils", srcs = ["utils.cc"], hdrs = ["utils.h"], + compatible_with = get_compatible_with_portable(), copts = tflite_copts(), deps = [ "//tensorflow/lite:kernel_api", @@ -73,14 +76,19 @@ cc_test( ], deps = [ ":interpreter_utils", + ":utils", "//tensorflow/lite:framework", + "//tensorflow/lite:kernel_api", + "//tensorflow/lite:util", "//tensorflow/lite:version", "//tensorflow/lite/core/api", "//tensorflow/lite/kernels:builtin_ops", "//tensorflow/lite/kernels:kernel_util", "//tensorflow/lite/kernels/internal:compatibility", "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/schema:schema_utils", "//tensorflow/lite/testing:util", + "//third_party/eigen3", "@com_google_googletest//:gtest", ], ) diff --git a/tensorflow/lite/delegates/delegate_test.cc b/tensorflow/lite/delegates/delegate_test.cc index aed4400ed99..b70ebdcc3aa 100644 --- a/tensorflow/lite/delegates/delegate_test.cc +++ b/tensorflow/lite/delegates/delegate_test.cc @@ -19,13 +19,21 @@ limitations under the License. #include #include +#include "third_party/eigen3/Eigen/Core" +#include "tensorflow/lite/builtin_op_data.h" +#include "tensorflow/lite/builtin_ops.h" #include "tensorflow/lite/delegates/interpreter_utils.h" +#include "tensorflow/lite/delegates/utils.h" #include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/interpreter_builder.h" +#include "tensorflow/lite/kernels/builtin_op_kernels.h" #include "tensorflow/lite/kernels/internal/compatibility.h" #include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/schema/schema_utils.h" #include "tensorflow/lite/testing/util.h" +#include "tensorflow/lite/util.h" #include "tensorflow/lite/version.h" namespace tflite { @@ -41,9 +49,12 @@ TfLiteRegistration AddOpRegistration() { reg.prepare = [](TfLiteContext* context, TfLiteNode* node) { // Set output size to input size - const TfLiteTensor* input1 = GetInput(context, node, 0); - const TfLiteTensor* input2 = GetInput(context, node, 1); - TfLiteTensor* output = GetOutput(context, node, 0); + const TfLiteTensor* input1; + TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input1)); + const TfLiteTensor* input2; + TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &input2)); + TfLiteTensor* output; + TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output)); TF_LITE_ENSURE_EQ(context, input1->dims->size, input2->dims->size); for (int i = 0; i < input1->dims->size; ++i) { @@ -57,13 +68,16 @@ TfLiteRegistration AddOpRegistration() { reg.invoke = [](TfLiteContext* context, TfLiteNode* node) { // Copy input data to output data. - const TfLiteTensor* a0 = GetInput(context, node, 0); + const TfLiteTensor* a0; + TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &a0)); TF_LITE_ENSURE(context, a0); TF_LITE_ENSURE(context, a0->data.f); - const TfLiteTensor* a1 = GetInput(context, node, 1); + const TfLiteTensor* a1; + TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &a1)); TF_LITE_ENSURE(context, a1); TF_LITE_ENSURE(context, a1->data.f); - TfLiteTensor* out = GetOutput(context, node, 0); + TfLiteTensor* out; + TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &out)); TF_LITE_ENSURE(context, out); TF_LITE_ENSURE(context, out->data.f); int num = a0->dims->data[0]; @@ -266,7 +280,8 @@ class TestDelegate : public ::testing::Test { a0 = GetInput(context, node, 0); a1 = a0; } - TfLiteTensor* out = GetOutput(context, node, 0); + TfLiteTensor* out; + TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &out)); int num = 1; for (int i = 0; i < a0->dims->size; ++i) { num *= a0->dims->data[i]; @@ -288,8 +303,10 @@ class TestDelegate : public ::testing::Test { reg.prepare = [](TfLiteContext* context, TfLiteNode* node) { // Shapes should already by propagated by the runtime, just need to // check. - const TfLiteTensor* input1 = GetInput(context, node, 0); - TfLiteTensor* output = GetOutput(context, node, 0); + const TfLiteTensor* input1; + TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input1)); + TfLiteTensor* output; + TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output)); const int input_dims_size = input1->dims->size; TF_LITE_ENSURE(context, output->dims->size == input_dims_size); for (int i = 0; i < input_dims_size; ++i) { @@ -314,7 +331,8 @@ class TestDelegate : public ::testing::Test { input1 = GetInput(context, node, 0); input2 = input1; } - TfLiteTensor* output = GetOutput(context, node, 0); + TfLiteTensor* output; + TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output)); TF_LITE_ENSURE_STATUS(context->ResizeTensor( context, output, TfLiteIntArrayCopy(input1->dims))); @@ -526,6 +544,35 @@ TEST_F(TestDelegate, SecondDelegationInvokeFailure) { } } +// This test ensures that node indices in multi-delegate application are handled +// correctly by the TFLite partitioning algorithm. +TEST_F(TestDelegate, TwoDelegates_ExecutionPlanIndicesDifferent) { + // First delegate supports nodes 0, 1. + // After this delegation, the execution plan size is 2. + delegate_ = std::unique_ptr( + new SimpleDelegate({0, 1}, kTfLiteDelegateFlagsAllowDynamicTensors)); + // Second delegate supports (original) node index 2. + // The execution plan has 2 nodes, so this verifies that the partitioning + // algorithm correctly refers to (original) node indices instead of execution + // plan indices. + delegate2_ = std::unique_ptr( + new SimpleDelegate({2}, kTfLiteDelegateFlagsNone)); + + ASSERT_EQ(interpreter_->execution_plan().size(), 3); + ASSERT_EQ( + interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()), + kTfLiteOk); + ASSERT_EQ(interpreter_->execution_plan().size(), 2); + + ASSERT_EQ( + interpreter_->ModifyGraphWithDelegate(delegate2_->get_tf_lite_delegate()), + kTfLiteOk); + ASSERT_EQ(interpreter_->execution_plan().size(), 2); + + // Verify Invoke works. + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); +} + TEST_F(TestDelegate, StaticDelegateMakesGraphImmutable) { delegate_ = std::unique_ptr(new SimpleDelegate({0, 1, 2})); ASSERT_EQ( @@ -1139,11 +1186,14 @@ class TestDelegateWithDynamicTensors : public ::testing::Test { reg.prepare = [](TfLiteContext* context, TfLiteNode* node) { // Output 0 is dynamic - TfLiteTensor* output0 = GetOutput(context, node, 0); + TfLiteTensor* output0; + TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output0)); SetTensorToDynamic(output0); // Output 1 has the same shape as input. - const TfLiteTensor* input = GetInput(context, node, 0); - TfLiteTensor* output1 = GetOutput(context, node, 1); + const TfLiteTensor* input; + TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input)); + TfLiteTensor* output1; + TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 1, &output1)); TF_LITE_ENSURE_STATUS(context->ResizeTensor( context, output1, TfLiteIntArrayCopy(input->dims))); return kTfLiteOk; @@ -1163,11 +1213,14 @@ class TestDelegateWithDynamicTensors : public ::testing::Test { // If tensors are resized, the runtime should propagate shapes // automatically if correct flag is set. Ensure values are correct. // Output 0 should be dynamic. - TfLiteTensor* output0 = GetOutput(context, node, 0); + TfLiteTensor* output0; + TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output0)); TF_LITE_ENSURE(context, IsDynamicTensor(output0)); // Output 1 has the same shape as input. - const TfLiteTensor* input = GetInput(context, node, 0); - TfLiteTensor* output1 = GetOutput(context, node, 1); + const TfLiteTensor* input; + TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input)); + TfLiteTensor* output1; + TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 1, &output1)); TF_LITE_ENSURE(context, input->dims->size == output1->dims->size); TF_LITE_ENSURE(context, input->dims->data[0] == output1->dims->data[0]); return kTfLiteOk; @@ -1240,6 +1293,294 @@ TEST_F(TestDelegateWithDynamicTensors, ShapePropagation_FlagNotSet) { ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteError); } +// Tests for FP16 graphs +// ===================== + +// Tests delegate functionality related to FP16 graphs. +// Model architecture: +// 1->DEQ->2 4->DEQ->5 7->DEQ->8 10->DEQ->11 +// | | | | +// 0----->ADD->3----->ADD->6----->MUL->9------>ADD-->12 +// Input: 0, Output:12. +// All constants are 2, so the function is: (x + 2 + 2) * 2 + 2 = 2x + 10 +// +// Delegate only supports ADD, so can have upto two delegated partitions. +// TODO(b/156707497): Add more cases here once we have landed CPU kernels +// supporting FP16. +class TestFP16Delegation : public ::testing::TestWithParam { + protected: + void SetUp() override { + interpreter_.reset(new Interpreter); + interpreter_->AddTensors(13); + interpreter_->SetInputs({0}); + interpreter_->SetOutputs({12}); + + float16_const_ = Eigen::half_impl::float_to_half_rtne(2.f); + + // TENSORS. + TfLiteQuantizationParams quant; + // Input. + interpreter_->SetTensorParametersReadWrite(0, kTfLiteFloat32, "", {1}, + quant); + // fp16 constant, dequantize output, Add0 output. + interpreter_->SetTensorParametersReadOnly( + 1, kTfLiteFloat16, "", {1}, quant, + reinterpret_cast(&float16_const_), sizeof(TfLiteFloat16)); + interpreter_->SetTensorParametersReadWrite(2, kTfLiteFloat32, "", {1}, + quant); + interpreter_->SetTensorParametersReadWrite(3, kTfLiteFloat32, "", {1}, + quant); + // fp16 constant, dequantize output, Add1 output. + interpreter_->SetTensorParametersReadOnly( + 4, kTfLiteFloat16, "", {1}, quant, + reinterpret_cast(&float16_const_), sizeof(TfLiteFloat16)); + interpreter_->SetTensorParametersReadWrite(5, kTfLiteFloat32, "", {1}, + quant); + interpreter_->SetTensorParametersReadWrite(6, kTfLiteFloat32, "", {1}, + quant); + // fp16 constant, dequantize output, Mul0 output. + interpreter_->SetTensorParametersReadOnly( + 7, kTfLiteFloat16, "", {1}, quant, + reinterpret_cast(&float16_const_), sizeof(TfLiteFloat16)); + interpreter_->SetTensorParametersReadWrite(8, kTfLiteFloat32, "", {1}, + quant); + interpreter_->SetTensorParametersReadWrite(9, kTfLiteFloat32, "", {1}, + quant); + // fp16 constant, dequantize output, Add2 output. + interpreter_->SetTensorParametersReadOnly( + 10, kTfLiteFloat16, "", {1}, quant, + reinterpret_cast(&float16_const_), sizeof(TfLiteFloat16)); + interpreter_->SetTensorParametersReadWrite(11, kTfLiteFloat32, "", {1}, + quant); + interpreter_->SetTensorParametersReadWrite(12, kTfLiteFloat32, "", {1}, + quant); + + // NODES. + auto* add_reg = ops::builtin::Register_ADD(); + auto* mul_reg = ops::builtin::Register_MUL(); + auto* deq_reg = ops::builtin::Register_DEQUANTIZE(); + add_reg->builtin_code = kTfLiteBuiltinAdd; + deq_reg->builtin_code = kTfLiteBuiltinDequantize; + mul_reg->builtin_code = kTfLiteBuiltinMul; + TfLiteAddParams* builtin_data0 = + reinterpret_cast(malloc(sizeof(TfLiteAddParams))); + TfLiteAddParams* builtin_data1 = + reinterpret_cast(malloc(sizeof(TfLiteAddParams))); + TfLiteMulParams* builtin_data2 = + reinterpret_cast(malloc(sizeof(TfLiteMulParams))); + TfLiteAddParams* builtin_data3 = + reinterpret_cast(malloc(sizeof(TfLiteAddParams))); + builtin_data0->activation = kTfLiteActNone; + builtin_data1->activation = kTfLiteActNone; + builtin_data2->activation = kTfLiteActNone; + builtin_data3->activation = kTfLiteActNone; + interpreter_->AddNodeWithParameters({1}, {2}, nullptr, 0, nullptr, deq_reg); + interpreter_->AddNodeWithParameters({0, 2}, {3}, nullptr, 0, builtin_data0, + add_reg); + interpreter_->AddNodeWithParameters({4}, {5}, nullptr, 0, nullptr, deq_reg); + interpreter_->AddNodeWithParameters({3, 5}, {6}, nullptr, 0, builtin_data1, + add_reg); + interpreter_->AddNodeWithParameters({7}, {8}, nullptr, 0, nullptr, deq_reg); + interpreter_->AddNodeWithParameters({6, 8}, {9}, nullptr, 0, builtin_data2, + mul_reg); + interpreter_->AddNodeWithParameters({10}, {11}, nullptr, 0, nullptr, + deq_reg); + interpreter_->AddNodeWithParameters({9, 11}, {12}, nullptr, 0, + builtin_data3, add_reg); + } + + void VerifyInvoke() { + std::vector input = {3.0f}; + std::vector expected_output = {16.0f}; + + const int input_tensor_idx = interpreter_->inputs()[0]; + const int output_tensor_idx = interpreter_->outputs()[0]; + + memcpy(interpreter_->typed_tensor(input_tensor_idx), input.data(), + sizeof(float)); + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + TfLiteTensor* output_tensor = interpreter_->tensor(output_tensor_idx); + for (int i = 0; i < 1; ++i) { + EXPECT_EQ(output_tensor->data.f[i], expected_output[i]) << i; + } + } + + void TearDown() override { interpreter_.reset(); } + + protected: + class FP16Delegate { + public: + // Uses FP16GraphPartitionHelper to accept ADD nodes with fp16 input. + explicit FP16Delegate(int num_delegated_subsets, + bool fail_node_prepare = false, + bool fail_node_invoke = false) + : num_delegated_subsets_(num_delegated_subsets), + fail_delegate_node_prepare_(fail_node_prepare), + fail_delegate_node_invoke_(fail_node_invoke) { + delegate_.Prepare = [](TfLiteContext* context, + TfLiteDelegate* delegate) -> TfLiteStatus { + auto* fp16_delegate = static_cast(delegate->data_); + // FP16 graph partitioning. + delegates::IsNodeSupportedFn node_supported_fn = + [=](TfLiteContext* context, TfLiteNode* node, + TfLiteRegistration* registration, + std::string* unsupported_details) -> bool { + return registration->builtin_code == kTfLiteBuiltinAdd; + }; + delegates::FP16GraphPartitionHelper partition_helper(context, + node_supported_fn); + TfLiteIntArray* nodes_to_separate = nullptr; + if (partition_helper.Partition(nullptr) != kTfLiteOk) { + nodes_to_separate = TfLiteIntArrayCreate(0); + } else { + std::vector ops_to_replace = + partition_helper.GetNodesOfFirstNLargestPartitions( + fp16_delegate->num_delegated_subsets()); + nodes_to_separate = ConvertVectorToTfLiteIntArray(ops_to_replace); + } + + context->ReplaceNodeSubsetsWithDelegateKernels( + context, fp16_delegate->FakeFusedRegistration(), nodes_to_separate, + delegate); + TfLiteIntArrayFree(nodes_to_separate); + return kTfLiteOk; + }; + delegate_.CopyFromBufferHandle = + [](TfLiteContext* context, TfLiteDelegate* delegate, + TfLiteBufferHandle buffer_handle, + TfLiteTensor* output) -> TfLiteStatus { return kTfLiteOk; }; + delegate_.FreeBufferHandle = nullptr; + delegate_.CopyToBufferHandle = nullptr; + // Store type-punned data SimpleDelegate structure. + delegate_.data_ = static_cast(this); + delegate_.flags = kTfLiteDelegateFlagsNone; + } + + TfLiteRegistration FakeFusedRegistration() { + TfLiteRegistration reg = {nullptr}; + reg.custom_name = "fake_fp16_add_op"; + + // Different flavors of the delegate kernel's Invoke(), dependent on + // testing parameters. + if (fail_delegate_node_invoke_) { + reg.invoke = [](TfLiteContext* context, + TfLiteNode* node) -> TfLiteStatus { + return kTfLiteError; + }; + } else { + reg.invoke = [](TfLiteContext* context, + TfLiteNode* node) -> TfLiteStatus { + float output = 0; + for (int i = 0; i < node->inputs->size; ++i) { + const TfLiteTensor* input_tensor = GetInput(context, node, i); + if (input_tensor->type == kTfLiteFloat32) { + output += input_tensor->data.f[0]; + } else { + // All constants are 2. + output += 2; + } + } + TfLiteTensor* out = GetOutput(context, node, 0); + out->data.f[0] = output; + return kTfLiteOk; + }; + } + + // Different flavors of the delegate kernel's Prepare(), dependent on + // testing parameters. + if (fail_delegate_node_prepare_) { + reg.prepare = [](TfLiteContext* context, TfLiteNode* node) { + return kTfLiteError; + }; + } else { + reg.prepare = [](TfLiteContext* context, TfLiteNode* node) { + // Set output size to input size + const TfLiteTensor* input = GetInput(context, node, 0); + TfLiteTensor* output = GetOutput(context, node, 0); + TF_LITE_ENSURE_STATUS(context->ResizeTensor( + context, output, TfLiteIntArrayCopy(input->dims))); + return kTfLiteOk; + }; + } + + return reg; + } + + TfLiteDelegate* get_tf_lite_delegate() { return &delegate_; } + + int num_delegated_subsets() { return num_delegated_subsets_; } + + private: + TfLiteDelegate delegate_; + int num_delegated_subsets_; + bool fail_delegate_node_prepare_ = false; + bool fail_delegate_node_invoke_ = false; + }; + + std::unique_ptr interpreter_; + std::unique_ptr delegate_; + Eigen::half float16_const_; +}; + +TEST_P(TestFP16Delegation, NonDelegatedInterpreterWorks) { + ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); + VerifyInvoke(); +} + +TEST_P(TestFP16Delegation, DelegationWorks) { + delegate_ = std::unique_ptr( + new FP16Delegate(/**num_delegated_subsets**/ GetParam())); + ASSERT_EQ( + interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()), + kTfLiteOk); + // Should have 5 nodes: delegate, mul, add2 & 2 dequantize (one for mul & + // add2). + ASSERT_EQ(interpreter_->execution_plan().size(), 5); + VerifyInvoke(); +} + +TEST_P(TestFP16Delegation, DelegatePrepareFails) { + delegate_ = std::unique_ptr(new FP16Delegate( + /**num_delegated_subsets**/ GetParam(), /**fail_node_prepare**/ true)); + ASSERT_EQ( + interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()), + kTfLiteDelegateError); + // Delegation failed, but runtime should go back to correct previous state. + ASSERT_EQ(interpreter_->execution_plan().size(), 8); + VerifyInvoke(); +} + +TEST_P(TestFP16Delegation, DelegateInvokeWithCPUFallback) { + delegate_ = std::unique_ptr(new FP16Delegate( + /**num_delegated_subsets**/ GetParam(), /**fail_node_prepare**/ false, + /**fail_node_invoke**/ true)); + ASSERT_EQ( + interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()), + kTfLiteOk); + + std::vector input = {3.0f}; + std::vector expected_output = {16.0f}; + + const int input_tensor_idx = interpreter_->inputs()[0]; + const int output_tensor_idx = interpreter_->outputs()[0]; + + memcpy(interpreter_->typed_tensor(input_tensor_idx), input.data(), + sizeof(float)); + EXPECT_EQ( + delegates::InterpreterUtils::InvokeWithCPUFallback(interpreter_.get()), + kTfLiteDelegateError); + TfLiteTensor* output_tensor = interpreter_->tensor(output_tensor_idx); + for (int i = 0; i < 1; ++i) { + EXPECT_EQ(output_tensor->data.f[i], expected_output[i]) << i; + } + + ASSERT_EQ(interpreter_->execution_plan().size(), 8); + VerifyInvoke(); +} + +INSTANTIATE_TEST_SUITE_P(TestFP16Delegation, TestFP16Delegation, + ::testing::Values(1, 2)); + } // namespace } // namespace tflite diff --git a/tensorflow/lite/delegates/flex/BUILD b/tensorflow/lite/delegates/flex/BUILD index 6210007361a..098159d9d26 100644 --- a/tensorflow/lite/delegates/flex/BUILD +++ b/tensorflow/lite/delegates/flex/BUILD @@ -1,5 +1,7 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_opts_nortti_if_lite_protos") +load("//tensorflow/lite:build_def.bzl", "tflite_copts") load("//tensorflow/lite/delegates/flex:build_def.bzl", "tflite_flex_cc_library") +load("//tensorflow:tensorflow.bzl", "get_compatible_with_cloud") # # This is a TF Lite delegate that is powered by TensorFlow's Eager. @@ -84,6 +86,7 @@ cc_library( hdrs = [ "delegate.h", ], + copts = tflite_copts(), visibility = ["//visibility:public"], deps = [ ":buffer_map", @@ -124,10 +127,13 @@ tf_cc_test( name = "delegate_test", size = "small", srcs = ["delegate_test.cc"], - tags = ["no_gpu"], # GPU + flex is not officially supported. + tags = [ + "no_gpu", # GPU + flex is not officially supported. + ], deps = [ ":delegate", ":test_util", + "//tensorflow/lite:shared_library", "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", ], @@ -177,8 +183,8 @@ tf_cc_test( srcs = ["kernel_test.cc"], tags = ["no_gpu"], # GPU + flex is not officially supported. deps = [ + ":delegate", ":delegate_data", - ":delegate_only_runtime", ":test_util", "@com_google_googletest//:gtest", ], @@ -241,6 +247,7 @@ cc_library( "allowlisted_flex_ops.h", "allowlisted_flex_ops_internal.h", ], + compatible_with = get_compatible_with_cloud(), deps = select({ "//tensorflow:android": [ "//tensorflow/core:portable_tensorflow_lib_lite", @@ -276,3 +283,15 @@ tf_cc_test( ], }), ) + +# Alias to support selective build of image ops. +# TODO(b/163285312): Remove after tensorflow/core refactoring completed. +cc_library( + name = "portable_images_lib", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:portable_gif_internal", + "//tensorflow/core:portable_jpeg_internal", + "//tensorflow/core/lib/png:png_io", + ], +) diff --git a/tensorflow/lite/delegates/flex/allowlisted_flex_ops.cc b/tensorflow/lite/delegates/flex/allowlisted_flex_ops.cc index eefbeb72b15..eee1c99ed58 100644 --- a/tensorflow/lite/delegates/flex/allowlisted_flex_ops.cc +++ b/tensorflow/lite/delegates/flex/allowlisted_flex_ops.cc @@ -68,6 +68,10 @@ const std::set& GetFlexAllowlist() { "AvgPoolGrad", "BatchMatMul", "BatchMatMulV2", + "BatchMatrixDiag", + "BatchMatrixDiagPart", + "BatchMatrixInverse", + "BatchMatrixSetDiag", "BatchNormWithGlobalNormalization", "BatchNormWithGlobalNormalizationGrad", "BatchToSpace", @@ -75,7 +79,20 @@ const std::set& GetFlexAllowlist() { "BiasAdd", "BiasAddGrad", "BiasAddV1", + "Bincount", + "Bitcast", + "BitwiseAnd", + "BitwiseOr", + "BitwiseXor", "BoostedTreesBucketize", + "BoostedTreesCreateQuantileStreamResource", + "BoostedTreesFlushQuantileSummaries", + "BoostedTreesMakeQuantileSummaries", + "BoostedTreesQuantileStreamResourceAddSummaries", + "BoostedTreesQuantileStreamResourceDeserialize", + "BoostedTreesQuantileStreamResourceFlush", + "BoostedTreesQuantileStreamResourceGetBucketBoundaries", + "BoostedTreesQuantileStreamResourceHandleOp", "BroadcastArgs", "BroadcastGradientArgs", "BroadcastTo", @@ -85,6 +102,7 @@ const std::set& GetFlexAllowlist() { "Cast", "Ceil", "CheckNumerics", + "CheckNumericsV2", "CombinedNonMaxSuppression", "Complex", "ComplexAbs", @@ -99,6 +117,9 @@ const std::set& GetFlexAllowlist() { "Conv2DBackpropFilter", "Conv2DBackpropInput", "Conv3D", + "Conv3DBackpropFilter", + "Conv3DBackpropFilterV2", + "Conv3DBackpropInput", "Conv3DBackpropInputV2", "Cos", "Cosh", @@ -107,21 +128,32 @@ const std::set& GetFlexAllowlist() { "CropAndResizeGradImage", "Cumprod", "Cumsum", + "CumulativeLogsumexp", "DataFormatDimMap", "DataFormatVecPermute", "DebugGradientIdentity", "DebugGradientRefIdentity", + "DecodeAndCropJpeg", "DecodeBase64", + "DecodeBmp", + "DecodeGif", + "DecodeImage", + "DecodeJpeg", + "DecodePng", "DecodeRaw", "DecodeWav", "DeepCopy", "DeleteSessionTensor", + "DenseBincount", "DepthToSpace", "DepthwiseConv2dNative", "Dequantize", "DestroyTemporaryVariable", "Diag", + "DiagPart", "Dilation2D", + "Dilation2DBackpropFilter", + "Dilation2DBackpropInput", "Div", "DivNoNan", "DynamicPartition", @@ -130,7 +162,11 @@ const std::set& GetFlexAllowlist() { "Elu", "EluGrad", "Empty", + "EmptyTensorList", "EncodeBase64", + "EncodeJpeg", + "EncodeJpegVariableQuality", + "EncodePng", "EncodeWav", "EnsureShape", "Enter", @@ -172,6 +208,7 @@ const std::set& GetFlexAllowlist() { "GetSessionTensor", "Greater", "GreaterEqual", + "HistogramSummary", "IFFT", "IFFT2D", "IFFT3D", @@ -182,6 +219,7 @@ const std::set& GetFlexAllowlist() { "IdentityN", "Imag", "ImageProjectiveTransformV2", + "ImageProjectiveTransformV3", "ImmutableConst", "InTopK", "InTopKV2", @@ -190,13 +228,16 @@ const std::set& GetFlexAllowlist() { "InplaceUpdate", "Inv", "InvGrad", + "Invert", "InvertPermutation", + "IsBoostedTreesQuantileStreamResourceInitialized", "IsFinite", "IsNan", "IsVariableInitialized", "LRN", "LeakyRelu", "LeakyReluGrad", + "LeftShift", "Less", "LessEqual", "LinSpace", @@ -209,6 +250,9 @@ const std::set& GetFlexAllowlist() { "LoopCond", "MatMul", "MatrixDiag", + "MatrixDiagPart", + "MatrixDiagPartV2", + "MatrixDiagPartV3", "MatrixDiagV2", "MatrixDiagV3", "MatrixInverse", @@ -218,6 +262,8 @@ const std::set& GetFlexAllowlist() { "Max", "MaxPool", "MaxPool3D", + "MaxPool3DGrad", + "MaxPool3DGradGrad", "MaxPoolGrad", "MaxPoolGradGrad", "MaxPoolGradGradV2", @@ -228,6 +274,7 @@ const std::set& GetFlexAllowlist() { "Maximum", "Mean", "Merge", + "MergeSummary", "MergeV2Checkpoints", "Mfcc", "Min", @@ -244,6 +291,7 @@ const std::set& GetFlexAllowlist() { "NonMaxSuppressionV2", "NonMaxSuppressionV3", "NonMaxSuppressionV4", + "NonMaxSuppressionV5", "NonMaxSuppressionWithOverlaps", "NotEqual", "OneHot", @@ -253,15 +301,18 @@ const std::set& GetFlexAllowlist() { "PadV2", "PaddingFIFOQueue", "PaddingFIFOQueueV2", + "ParallelConcat", "ParallelDynamicStitch", "ParseExample", "ParseExampleV2", "ParseSequenceExample", + "ParseSequenceExampleV2", "ParseSingleExample", "ParseSingleSequenceExample", "Placeholder", "PlaceholderV2", "PlaceholderWithDefault", + "PopulationCount", "Pow", "PreventGradient", "Print", @@ -302,10 +353,14 @@ const std::set& GetFlexAllowlist() { "RFFT", "RFFT2D", "RFFT3D", + "RaggedBincount", + "RaggedGather", "RaggedRange", "RaggedTensorToSparse", "RaggedTensorToTensor", "RandomGamma", + "RandomPoisson", + "RandomPoissonV2", "RandomStandardNormal", "RandomUniform", "RandomUniformInt", @@ -315,6 +370,7 @@ const std::set& GetFlexAllowlist() { "RealDiv", "Reciprocal", "ReciprocalGrad", + "Recv", "ReduceJoin", "RefEnter", "RefExit", @@ -342,22 +398,31 @@ const std::set& GetFlexAllowlist() { "ResourceApplyAdagradDA", "ResourceApplyAdagradV2", "ResourceApplyAdam", + "ResourceApplyAdamWithAmsgrad", "ResourceApplyAddSign", "ResourceApplyCenteredRMSProp", "ResourceApplyFtrl", "ResourceApplyFtrlV2", "ResourceApplyGradientDescent", + "ResourceApplyKerasMomentum", "ResourceApplyMomentum", "ResourceApplyPowerSign", "ResourceApplyProximalAdagrad", "ResourceApplyProximalGradientDescent", "ResourceApplyRMSProp", + "ResourceScatterNdAdd", + "ResourceScatterNdMax", + "ResourceScatterNdMin", + "ResourceScatterNdSub", + "ResourceScatterNdUpdate", "ResourceSparseApplyAdadelta", "ResourceSparseApplyAdagrad", "ResourceSparseApplyAdagradDA", + "ResourceSparseApplyAdagradV2", "ResourceSparseApplyCenteredRMSProp", "ResourceSparseApplyFtrl", "ResourceSparseApplyFtrlV2", + "ResourceSparseApplyKerasMomentum", "ResourceSparseApplyMomentum", "ResourceSparseApplyProximalAdagrad", "ResourceSparseApplyProximalGradientDescent", @@ -369,14 +434,23 @@ const std::set& GetFlexAllowlist() { "Reverse", "ReverseSequence", "ReverseV2", + "RightShift", "Round", "Rsqrt", "RsqrtGrad", + "SampleDistortedBoundingBox", "SampleDistortedBoundingBoxV2", "Save", "SaveSlices", "SaveV2", + "ScalarSummary", "ScatterNd", + "ScatterNdAdd", + "ScatterNdMax", + "ScatterNdMin", + "ScatterNdNonAliasingAdd", + "ScatterNdSub", + "ScatterNdUpdate", "SegmentMax", "SegmentMean", "SegmentMin", @@ -386,6 +460,7 @@ const std::set& GetFlexAllowlist() { "SelectV2", "Selu", "SeluGrad", + "Send", "Shape", "ShapeN", "ShardedFilename", @@ -409,6 +484,7 @@ const std::set& GetFlexAllowlist() { "SparseApplyAdadelta", "SparseApplyAdagrad", "SparseApplyAdagradDA", + "SparseApplyAdagradV2", "SparseApplyCenteredRMSProp", "SparseApplyFtrl", "SparseApplyFtrlV2", @@ -416,6 +492,7 @@ const std::set& GetFlexAllowlist() { "SparseApplyProximalAdagrad", "SparseApplyProximalGradientDescent", "SparseApplyRMSProp", + "SparseBincount", "SparseCross", "SparseCrossHashed", "SparseCrossV2", @@ -446,12 +523,14 @@ const std::set& GetFlexAllowlist() { "StackPush", "StackPushV2", "StackV2", + "StatelessMultinomial", "StatelessRandomGammaV2", "StatelessRandomNormal", "StatelessRandomPoisson", "StatelessRandomUniform", "StatelessRandomUniformFullInt", "StatelessRandomUniformInt", + "StatelessSampleDistortedBoundingBox", "StatelessTruncatedNormal", "StaticRegexReplace", "StopGradient", @@ -459,8 +538,10 @@ const std::set& GetFlexAllowlist() { "StridedSliceAssign", "StridedSliceGrad", "StringJoin", + "StringLower", "StringSplit", "StringSplitV2", + "StringStrip", "StringToHashBucket", "StringToHashBucketFast", "StringToHashBucketStrong", @@ -506,6 +587,31 @@ const std::set& GetFlexAllowlist() { "TensorArrayWrite", "TensorArrayWriteV2", "TensorArrayWriteV3", + "TensorListConcat", + "TensorListConcatLists", + "TensorListConcatV2", + "TensorListElementShape", + "TensorListFromTensor", + "TensorListGather", + "TensorListGetItem", + "TensorListLength", + "TensorListPopBack", + "TensorListPushBack", + "TensorListPushBackBatch", + "TensorListReserve", + "TensorListResize", + "TensorListScatter", + "TensorListScatterIntoExistingList", + "TensorListScatterV2", + "TensorListSetItem", + "TensorListSplit", + "TensorListStack", + "TensorScatterAdd", + "TensorScatterMax", + "TensorScatterMin", + "TensorScatterSub", + "TensorScatterUpdate", + "TensorStridedSliceUpdate", "Tile", "TileGrad", "Timestamp", @@ -527,21 +633,30 @@ const std::set& GetFlexAllowlist() { "UnsortedSegmentMin", "UnsortedSegmentProd", "UnsortedSegmentSum", + "UnwrapDatasetVariant", "Variable", "VariableV2", "Where", + "WrapDatasetVariant", "Xdivy", + "Xlog1py", "Xlogy", "ZerosLike", "_Arg", "_ArrayToList", + "_DeviceArg", + "_DeviceRetval", + "_FusedConv2D", "_HostCast", "_HostRecv", "_HostSend", "_ListToArray", + "_ParallelConcatStart", + "_ParallelConcatUpdate", "_Recv", "_Retval", "_Send", + "_SwitchN", // go/keep-sorted end }); return *allowlisted_flex_ops; diff --git a/tensorflow/lite/delegates/flex/buffer_map.cc b/tensorflow/lite/delegates/flex/buffer_map.cc index c2611290c1b..86ea4b849ea 100644 --- a/tensorflow/lite/delegates/flex/buffer_map.cc +++ b/tensorflow/lite/delegates/flex/buffer_map.cc @@ -149,6 +149,11 @@ tensorflow::Tensor BufferMap::GetTensor(int tensor_index) const { return id_to_tensor_.at(tensor_index); } +const tensorflow::Tensor* BufferMap::GetTensorPtr(int tensor_index) const { + auto& tensor = id_to_tensor_.at(tensor_index); + return &tensor; +} + void BufferMap::SetFromTfLite(int tensor_index, const TfLiteTensor* tensor) { tensorflow::TensorShape shape; int num_dims = tensor->dims->size; diff --git a/tensorflow/lite/delegates/flex/buffer_map.h b/tensorflow/lite/delegates/flex/buffer_map.h index 6c35895c249..6a29c7f80dc 100644 --- a/tensorflow/lite/delegates/flex/buffer_map.h +++ b/tensorflow/lite/delegates/flex/buffer_map.h @@ -47,6 +47,11 @@ class BufferMap { // Precondition: HasTensor() is true. tensorflow::Tensor GetTensor(int tensor_index) const; + // Returns the const pointer to tensorflow::Tensor associated with the given + // 'tensor_index'. + // Precondition: HasTensor() is true. + const tensorflow::Tensor* GetTensorPtr(int tensor_index) const; + // Associates the given tensorflow::Tensor with the given 'tensor_index'. // Note that TensorFlow Tensors share data buffers, so this method is only a // shallow copy. diff --git a/tensorflow/lite/delegates/flex/build_def.bzl b/tensorflow/lite/delegates/flex/build_def.bzl index 9b9f1b2c4cb..5826e1f83cd 100644 --- a/tensorflow/lite/delegates/flex/build_def.bzl +++ b/tensorflow/lite/delegates/flex/build_def.bzl @@ -2,6 +2,7 @@ load( "//tensorflow:tensorflow.bzl", + "clean_dep", "if_android", "if_ios", "if_mobile", @@ -46,12 +47,12 @@ def generate_flex_kernel_header( ["$(location %s)" % f for f in models], ) list_ops_output = include_path + "/list_flex_ops" - list_ops_tool = "//tensorflow/lite/tools:list_flex_ops_main" + list_ops_tool = clean_dep("//tensorflow/lite/tools:list_flex_ops_main") if additional_deps: tf_cc_binary( name = "%s_list_flex_ops_main" % name, deps = [ - "//tensorflow/lite/tools:list_flex_ops_main_lib", + clean_dep("//tensorflow/lite/tools:list_flex_ops_main_lib"), ] + additional_deps, ) list_ops_tool = ":%s_list_flex_ops_main" % name @@ -66,12 +67,12 @@ def generate_flex_kernel_header( ) # Generate the kernel registration header file from list of flex ops. - tool = "//tensorflow/python/tools:print_selective_registration_header" + tool = clean_dep("//tensorflow/python/tools:print_selective_registration_header") native.genrule( name = "%s_kernel_registration" % name, srcs = [list_ops_output], outs = [header], - tools = [tool], + exec_tools = [tool], message = "Processing %s..." % list_ops_output, cmd = ("$(location " + tool + ")" + " --default_ops=\"\"" + @@ -95,7 +96,7 @@ def tflite_flex_cc_library( additional_deps: Dependencies for additional TF ops. visibility: visibility of the generated rules. """ - portable_tensorflow_lib = "//tensorflow/core:portable_tensorflow_lib" + portable_tensorflow_lib = clean_dep("//tensorflow/core:portable_tensorflow_lib") if models: CUSTOM_KERNEL_HEADER = generate_flex_kernel_header( name = "%s_tf_op_headers" % name, @@ -108,9 +109,9 @@ def tflite_flex_cc_library( native.cc_library( name = "%s_tensorflow_lib" % name, srcs = if_mobile([ - "//tensorflow/core:portable_op_registrations_and_gradients", - "//tensorflow/core/kernels:android_core_ops", - "//tensorflow/core/kernels:android_extended_ops", + clean_dep("//tensorflow/core:portable_op_registrations_and_gradients"), + clean_dep("//tensorflow/core/kernels:android_core_ops"), + clean_dep("//tensorflow/core/kernels:android_extended_ops"), ]) + [CUSTOM_KERNEL_HEADER.header], copts = tf_copts(android_optimization_level_override = None) + tf_opts_nortti_if_lite_protos() + if_ios(["-Os"]), defines = [ @@ -126,7 +127,7 @@ def tflite_flex_cc_library( CUSTOM_KERNEL_HEADER.include_path, ], textual_hdrs = [ - "//tensorflow/core/kernels:android_all_ops_textual_hdrs", + clean_dep("//tensorflow/core/kernels:android_all_ops_textual_hdrs"), ], visibility = visibility, deps = [ @@ -135,10 +136,11 @@ def tflite_flex_cc_library( "//third_party/eigen3", "@com_google_absl//absl/types:optional", "@gemmlowp", - "//tensorflow/core:protos_all_cc", "@icu//:common", - "//tensorflow/core:portable_tensorflow_lib_lite", - "//tensorflow/core/platform:strong_hash", + clean_dep("//tensorflow/core:protos_all_cc"), + clean_dep("//tensorflow/core:portable_tensorflow_lib_lite"), + clean_dep("//tensorflow/core/platform:strong_hash"), + clean_dep("//tensorflow/lite/delegates/flex:portable_images_lib"), ], alwayslink = 1, ) @@ -148,23 +150,23 @@ def tflite_flex_cc_library( native.cc_library( name = name, hdrs = [ - "//tensorflow/lite/delegates/flex:delegate.h", + clean_dep("//tensorflow/lite/delegates/flex:delegate.h"), ], visibility = visibility, deps = [ - "//tensorflow/lite/delegates/flex:delegate_data", - "//tensorflow/lite/delegates/flex:delegate_only_runtime", - "//tensorflow/lite/delegates/utils:simple_delegate", + clean_dep("//tensorflow/lite/delegates/flex:delegate_data"), + clean_dep("//tensorflow/lite/delegates/flex:delegate_only_runtime"), + clean_dep("//tensorflow/lite/delegates/utils:simple_delegate"), ] + select({ - "//tensorflow:android": [ + clean_dep("//tensorflow:android"): [ portable_tensorflow_lib, ], - "//tensorflow:ios": [ + clean_dep("//tensorflow:ios"): [ portable_tensorflow_lib, ], "//conditions:default": [ - "//tensorflow/core:tensorflow", - "//tensorflow/lite/c:common", + clean_dep("//tensorflow/core:tensorflow"), + clean_dep("//tensorflow/lite/c:common"), ], }) + additional_deps, alwayslink = 1, @@ -202,21 +204,21 @@ def tflite_flex_jni_library( native.cc_library( name = "%s_flex_native" % name, srcs = [ - "//tensorflow/lite/testing:init_tensorflow.h", - "//tensorflow/lite/testing:init_tensorflow.cc", - "//tensorflow/lite/delegates/flex/java/src/main/native:flex_delegate_jni.cc", + clean_dep("//tensorflow/lite/testing:init_tensorflow.h"), + clean_dep("//tensorflow/lite/testing:init_tensorflow.cc"), + clean_dep("//tensorflow/lite/delegates/flex/java/src/main/native:flex_delegate_jni.cc"), ], copts = tflite_copts(), visibility = visibility, deps = [ ":%s_flex_delegate" % name, - "//tensorflow/lite/java/jni", - "//tensorflow/lite/delegates/utils:simple_delegate", + clean_dep("//tensorflow/lite/java/jni"), + clean_dep("//tensorflow/lite/delegates/utils:simple_delegate"), ] + select({ - "//tensorflow:android": [], - "//tensorflow:ios": [], + clean_dep("//tensorflow:android"): [], + clean_dep("//tensorflow:ios"): [], "//conditions:default": [ - "//tensorflow/core:lib", + clean_dep("//tensorflow/core:lib"), ], }), alwayslink = 1, @@ -264,14 +266,14 @@ def tflite_flex_android_library( android_library( name = name, - srcs = ["//tensorflow/lite/delegates/flex/java/src/main/java/org/tensorflow/lite/flex:flex_delegate"], - manifest = "//tensorflow/lite/java:AndroidManifest.xml", - proguard_specs = ["//tensorflow/lite/java:proguard.flags"], + srcs = [clean_dep("//tensorflow/lite/delegates/flex/java/src/main/java/org/tensorflow/lite/flex:flex_delegate")], + manifest = clean_dep("//tensorflow/lite/java:AndroidManifest.xml"), + proguard_specs = [clean_dep("//tensorflow/lite/java:proguard.flags")], custom_package = custom_package, deps = [ ":%s_native" % name, - "//tensorflow/lite/java:tensorflowlite_java", - "@org_checkerframework_qual", + clean_dep("//tensorflow/lite/java:tensorflowlite_java"), + clean_dep("@org_checkerframework_qual"), ], visibility = visibility, ) diff --git a/tensorflow/lite/delegates/flex/delegate.cc b/tensorflow/lite/delegates/flex/delegate.cc index 4664ab34700..f7d07af6595 100644 --- a/tensorflow/lite/delegates/flex/delegate.cc +++ b/tensorflow/lite/delegates/flex/delegate.cc @@ -19,6 +19,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/lite/c/common.h" #include "tensorflow/lite/context_util.h" #include "tensorflow/lite/core/macros.h" #include "tensorflow/lite/delegates/flex/buffer_map.h" @@ -142,14 +143,16 @@ TfLiteStatus FlexDelegate::CopyFromBufferHandle( } // namespace tflite +// LINT.IfChange // Exported C interface function which is used by AcquireFlexDelegate() at -// interpreter_build.cc. To export the function name globally, the function name -// must be matched with patterns in tf_version_script.lds +// interpreter_builder.cc. To export the function name globally, the function +// name must be matched with patterns in tf_version_script.lds. In Android, we +// don't use this feature so skip building. +#if !defined(__ANDROID__) extern "C" { -#if defined(_WIN32) -__declspec(dllexport) -#endif - tflite::TfLiteDelegateUniquePtr TF_AcquireFlexDelegate() { +TFL_CAPI_EXPORT tflite::TfLiteDelegateUniquePtr TF_AcquireFlexDelegate() { return tflite::FlexDelegate::Create(); } } // extern "C" +#endif // !defined(__ANDROID__) +// LINT.ThenChange(//tensorflow/lite/interpreter_builder.cc) diff --git a/tensorflow/lite/delegates/flex/delegate_data.cc b/tensorflow/lite/delegates/flex/delegate_data.cc index 2be928073ff..8e3ed964e01 100644 --- a/tensorflow/lite/delegates/flex/delegate_data.cc +++ b/tensorflow/lite/delegates/flex/delegate_data.cc @@ -46,7 +46,6 @@ tensorflow::Status DelegateData::Prepare( eager_context_ = new tensorflow::EagerContext( session_options, tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, - tensorflow::ContextMirroringPolicy::MIRRORING_NONE, /*async=*/false, /*lazy_copy_function_remote_inputs=*/false, device_mgr.release(), /*device_mgr_owned*/ true, rendezvous, nullptr); return tensorflow::Status(); diff --git a/tensorflow/lite/delegates/flex/delegate_test.cc b/tensorflow/lite/delegates/flex/delegate_test.cc index d574d8fabbb..6450848bf0e 100644 --- a/tensorflow/lite/delegates/flex/delegate_test.cc +++ b/tensorflow/lite/delegates/flex/delegate_test.cc @@ -14,9 +14,13 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/delegates/flex/delegate.h" +#include +#include + #include #include #include "tensorflow/lite/delegates/flex/test_util.h" +#include "tensorflow/lite/shared_library.h" namespace tflite { namespace flex { @@ -301,6 +305,100 @@ TEST_F(DelegateTest, MultiThreaded) { ASSERT_EQ(GetType(8), kTfLiteFloat32); } +#if !defined(__ANDROID__) +TEST_F(DelegateTest, TF_AcquireFlexDelegate) { + auto TF_AcquireFlexDelegate = + reinterpret_cast( + SharedLibrary::GetSymbol("TF_AcquireFlexDelegate")); + ASSERT_TRUE(TF_AcquireFlexDelegate); + auto delegate_ptr = TF_AcquireFlexDelegate(); + ASSERT_TRUE(delegate_ptr != nullptr); +} +#endif // !defined(__ANDROID__) + +TEST_F(DelegateTest, StaticOutput) { + // Define the graph with input, output shapes of [2]. + AddTensors(7, {0, 1, 2, 3}, {6}, kTfLiteFloat32, {2}); + + AddTfOp(testing::kAdd, {0, 2}, {4}); + AddTfOp(testing::kAdd, {1, 3}, {5}); + AddTfOp(testing::kMul, {4, 5}, {6}); + + // Apply the delegate. + ConfigureDelegate(); + + // Define inputs which matech with the original shapes. + SetShape(0, {2}); + SetShape(1, {2}); + SetShape(2, {2}); + SetShape(3, {2}); + SetValues(0, {1.1f, 2.2f}); + SetValues(1, {3.3f, 4.4f}); + SetValues(2, {1.1f, 2.2f}); + SetValues(3, {3.3f, 4.4f}); + + ASSERT_TRUE(Invoke()); + + ASSERT_THAT(GetShape(6), ElementsAre(2)); + ASSERT_THAT(GetValues(6), ElementsAre(14.52f, 38.72f)); + ASSERT_EQ(GetType(6), kTfLiteFloat32); + // Since shapes are consistent, static output tensor is used. + ASSERT_FALSE(IsDynamicTensor(6)); +} + +TEST_F(DelegateTest, StaticOutputRFFT) { + // Define the graph with input, output shapes of [3, 257]. + AddTensors(4, {0, 1}, {3}, kTfLiteFloat32, {3, 257}); + int32_t rfft_length[] = {512}; + SetConstTensor(1, {1}, kTfLiteInt32, + reinterpret_cast(&rfft_length), + sizeof(rfft_length)); + + AddTfOp(testing::kRfft, {0, 1}, {2}); + AddTfOp(testing::kImag, {2}, {3}); + + // Apply the delegate. + ConfigureDelegate(); + + // Define inputs. + SetShape(0, {3, 512}); + SetValues(0, std::vector(3 * 512, 1.0f)); + + ASSERT_TRUE(Invoke()); + + ASSERT_EQ(GetType(3), kTfLiteFloat32); + // Since shapes are consistent, static output tensor is used. + ASSERT_FALSE(IsDynamicTensor(3)); +} + +TEST_F(DelegateTest, DynamicOutputAfterReshape) { + // Define the graph. + AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3}); + + AddTfOp(testing::kUnpack, {0}, {1, 2}); + AddTfOp(testing::kUnpack, {3}, {4, 5}); + AddTfOp(testing::kAdd, {1, 4}, {6}); + AddTfOp(testing::kAdd, {2, 5}, {7}); + AddTfOp(testing::kMul, {6, 7}, {8}); + + // Apply the delegate. + ConfigureDelegate(); + + // Define inputs with reshape. + SetShape(0, {2, 2, 1}); + SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f}); + SetShape(3, {2, 2, 1}); + SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f}); + + ASSERT_TRUE(Invoke()); + + ASSERT_THAT(GetShape(8), ElementsAre(2, 1)); + ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f)); + ASSERT_EQ(GetType(8), kTfLiteFloat32); + // Since shapes are inconsistent, dynamic output tensor is used. + ASSERT_TRUE(IsDynamicTensor(8)); +} + } // namespace } // namespace flex } // namespace tflite diff --git a/tensorflow/lite/delegates/flex/kernel.cc b/tensorflow/lite/delegates/flex/kernel.cc index b3e978908bd..f21c984fe3e 100644 --- a/tensorflow/lite/delegates/flex/kernel.cc +++ b/tensorflow/lite/delegates/flex/kernel.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/lite/delegates/flex/delegate_data.h" #include "tensorflow/lite/delegates/flex/util.h" #include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/minimal_logging.h" #include "tensorflow/lite/string_type.h" // Note: this is part of TF Lite's Flex delegation code which is to be @@ -48,6 +49,16 @@ limitations under the License. // retrieve the associated NodeDef, which is then used to configure the // corresponding TensorFlow/Eager Op. +using tensorflow::shape_inference::DimensionHandle; +using tensorflow::shape_inference::InferenceContext; +using tensorflow::shape_inference::ShapeAndType; +using tensorflow::shape_inference::ShapeHandle; + +const std::string GetDimsDebugString(const TfLiteIntArray* dims) { + return absl::StrCat("[", absl::StrJoin(tflite::TfLiteIntArrayView(dims), ","), + "]"); +} + namespace tflite { namespace flex { @@ -188,6 +199,9 @@ class OpNode { void set_index(int index) { index_ = index; } const tensorflow::NodeDef& nodedef() const { return nodedef_; } + const tensorflow::OpRegistrationData* op_reg_data() const { + return op_reg_data_; + } const OpInputs& inputs() const { return inputs_; } OpInputs* mutable_inputs() { return &inputs_; } @@ -222,10 +236,9 @@ class OpNode { } // Fill NodeDef with defaults if it's a valid op. - const tensorflow::OpRegistrationData* op_reg_data; TF_RETURN_IF_ERROR( - tensorflow::OpRegistry::Global()->LookUp(nodedef_.op(), &op_reg_data)); - AddDefaultsToNodeDef(op_reg_data->op_def, &nodedef_); + tensorflow::OpRegistry::Global()->LookUp(nodedef_.op(), &op_reg_data_)); + AddDefaultsToNodeDef(op_reg_data_->op_def, &nodedef_); return tensorflow::Status::OK(); } @@ -312,6 +325,8 @@ class OpNode { int index_; // The corresponding NodeDef, containing the attributes for the op. tensorflow::NodeDef nodedef_; + // The corresponding OpRegistrationData pointer. + const tensorflow::OpRegistrationData* op_reg_data_; // List of inputs, as TF Lite tensor indices. OpInputs inputs_; // List of outputs, as TF Lite tensor indices. @@ -455,10 +470,22 @@ TfLiteStatus DelegateKernel::Prepare(TfLiteContext* context, TfLiteNode* node) { tensor_ref_count[tensor_index] += 2; } + const bool shapes_are_valid = + (ValidateOutputTensorShapeConsistency(context) == kTfLiteOk); + if (shapes_are_valid) { + TFLITE_LOG(tflite::TFLITE_LOG_INFO, + "FlexDelegate: All tensor shapes are consistent."); + } else { + TFLITE_LOG(tflite::TFLITE_LOG_WARNING, + "FlexDelegate: Some tensor shapes are inconsistent."); + } + // All output tensors are allocated by TensorFlow/Eager, so we // mark them as kTfLiteDynamic. for (auto tensor_index : op_data_->subgraph_outputs) { - SetTensorToDynamic(&context->tensors[tensor_index]); + if (!shapes_are_valid) { + SetTensorToDynamic(&context->tensors[tensor_index]); + } ++tensor_ref_count[tensor_index]; } @@ -488,6 +515,85 @@ TfLiteStatus DelegateKernel::Prepare(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } +TfLiteStatus DelegateKernel::ValidateOutputTensorShapeConsistency( + TfLiteContext* context) const { + for (const auto& node_data : op_data_->nodes) { + auto op_name = node_data->name().c_str(); + // Create an InferenceContext object. + auto num_inputs = node_data->inputs().Size(); + std::vector input_tensors_vector(num_inputs, + nullptr); + InferenceContext c( + TF_GRAPH_DEF_VERSION, node_data->nodedef(), + node_data->op_reg_data()->op_def, std::vector(num_inputs), + input_tensors_vector, {}, + std::vector>>()); + + // Set input_shapes for ShapeInferenceFn. + for (int i = 0; i < num_inputs; ++i) { + const auto input_tensor_index = node_data->inputs().TfLiteIndex(i); + TfLiteTensor* tfl_tensor = &context->tensors[input_tensor_index]; + // Provide constant input tensors since some op ("RFFT") needs it to + // calculate the output shape. + if (IsConstantTensor(tfl_tensor)) { + input_tensors_vector[i] = + op_data_->buffer_map->GetTensorPtr(input_tensor_index); + } + const auto dims_array = tfl_tensor->dims; + std::vector dims(dims_array->size); + for (int j = 0; j < dims_array->size; ++j) { + dims[j] = c.MakeDim(dims_array->data[j]); + } + c.SetInput(i, c.MakeShape(dims)); + } + c.set_input_tensors(input_tensors_vector); + + tensorflow::Status status = c.construction_status(); + if (!status.ok()) { + TFLITE_LOG(tflite::TFLITE_LOG_WARNING, + "Shape construction failed for op '%s'", op_name); + return kTfLiteError; + } + + // Run ShapeInferenceFn to calculate output shapes. + if (node_data->op_reg_data()->shape_inference_fn == nullptr) { + TFLITE_LOG(tflite::TFLITE_LOG_WARNING, + "No shape inference function exists for op '%s'", op_name); + return kTfLiteError; + } + status = c.Run(node_data->op_reg_data()->shape_inference_fn); + + // Compare calculated output shapes with node_data->outputs + auto num_outputs = node_data->outputs().Size(); + if (num_outputs != c.num_outputs()) { + TFLITE_LOG(tflite::TFLITE_LOG_WARNING, + "Number of output tensors are mismatched for op '%s' %d != %d", + op_name, num_outputs, c.num_outputs()); + return kTfLiteError; + } + for (int i = 0; i < num_outputs; ++i) { + const auto output_tensor_index = node_data->outputs().TfLiteIndex(i); + TfLiteTensor* tfl_tensor = &context->tensors[output_tensor_index]; + // tfl_tensor->dims only has valid information if the given model is + // converted by the MLIR converter. Also when ResizeInputTensor() is + // called the dims information becomes invalid. + const std::string tfl_shape_string = GetDimsDebugString(tfl_tensor->dims); + const std::string calculated_shape_string = c.DebugString(c.output(i)); + // Getting a shape string via c.DebugString() is the easiest way to get + // the shape information of the given ShapeHandle for now. + // TODO(b/169017408): Find a better approach without using debug string. + if (tfl_shape_string != calculated_shape_string) { + TFLITE_LOG(tflite::TFLITE_LOG_WARNING, + "op '%s' output%d tensor#%d shape mismatch for %s != %s", + op_name, i, output_tensor_index, tfl_shape_string.c_str(), + calculated_shape_string.c_str()); + return kTfLiteError; + } + } + } + return kTfLiteOk; +} + TfLiteStatus DelegateKernel::Eval(TfLiteContext* context, TfLiteNode* node) { BufferMap* buffer_map = op_data_->buffer_map; @@ -522,12 +628,30 @@ TfLiteStatus DelegateKernel::Eval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteError; } + // Copy TF tensor data to TFL allocated buffer for non dynamic tensors. + // For dynamic tensors, copy shape and put buffer_handle for the later + // CopyFromBufferHandle() call. TfLiteTensor* tensor = &context->tensors[tensor_index]; - TF_LITE_ENSURE_OK( - context, - CopyShapeAndType(context, buffer_map->GetTensor(tensor_index), tensor)); - tensor->buffer_handle = tensor_index; - tensor->data_is_stale = true; + const tensorflow::Tensor& tf_tensor = buffer_map->GetTensor(tensor_index); + if (tensor->allocation_type == kTfLiteDynamic) { + TF_LITE_ENSURE_OK(context, CopyShapeAndType(context, tf_tensor, tensor)); + tensor->buffer_handle = tensor_index; + tensor->data_is_stale = true; + continue; + } + // If the tensor isn't dynamic, we can copy data directly to the buffer of + // the tensor. Before copying the data, check if the target buffer has + // expected size. + if (tf_tensor.NumElements() != NumElements(tensor) || + tf_tensor.TotalBytes() != tensor->bytes) { + TF_LITE_KERNEL_LOG( + context, "Tensor: %s(%d) buffer size mismatch %zu(%lld) != %ld(%ld)", + tensor->name, tensor_index, tf_tensor.TotalBytes(), + tf_tensor.NumElements(), tensor->bytes, NumElements(tensor)); + return kTfLiteError; + } + tensorflow::StringPiece t_data = tf_tensor.tensor_data(); + memcpy(tensor->data.raw, t_data.data(), t_data.size()); } return kTfLiteOk; diff --git a/tensorflow/lite/delegates/flex/kernel.h b/tensorflow/lite/delegates/flex/kernel.h index 9a7b93e31f2..b2ab485bdaa 100644 --- a/tensorflow/lite/delegates/flex/kernel.h +++ b/tensorflow/lite/delegates/flex/kernel.h @@ -35,6 +35,11 @@ class DelegateKernel : public SimpleDelegateKernelInterface { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) override; private: + // Validate that the computed output tensor shape for the Flex node matches + // the existing output shape assigned to the output tensor. + TfLiteStatus ValidateOutputTensorShapeConsistency( + TfLiteContext* context) const; + std::unique_ptr op_data_; }; diff --git a/tensorflow/lite/delegates/flex/kernel_test.cc b/tensorflow/lite/delegates/flex/kernel_test.cc index f7234075c95..adc65c3ced9 100644 --- a/tensorflow/lite/delegates/flex/kernel_test.cc +++ b/tensorflow/lite/delegates/flex/kernel_test.cc @@ -18,6 +18,8 @@ limitations under the License. #include "tensorflow/lite/delegates/flex/delegate_data.h" #include "tensorflow/lite/delegates/flex/test_util.h" +extern const std::string GetDimsDebugString(const TfLiteIntArray* dims); + namespace tflite { namespace flex { namespace testing { @@ -351,6 +353,62 @@ TEST_F(MultipleSubgraphsTest, DoNotForwardInputTensors) { }))); } +tensorflow::OpDef MakeOpDef(int num_inputs, int num_outputs) { + tensorflow::OpRegistrationData op_reg_data; + tensorflow::OpDefBuilder b("dummy"); + for (int i = 0; i < num_inputs; ++i) { + b.Input(tensorflow::strings::StrCat("i", i, ": float")); + } + for (int i = 0; i < num_outputs; ++i) { + b.Output(tensorflow::strings::StrCat("o", i, ": float")); + } + CHECK(b.Attr("foo:string").Finalize(&op_reg_data).ok()); + return op_reg_data.op_def; +} + +tensorflow::PartialTensorShape S( + std::initializer_list dims) { + return tensorflow::PartialTensorShape(dims); +} + +TEST(ValidateOutputTensorShapeConsistencyTest, ShapeHandleDebugString) { + // Setup test to contain an input tensor list of size 3. + tensorflow::OpDef op_def = MakeOpDef(4, 1); + tensorflow::NodeDef def; + tensorflow::shape_inference::InferenceContext c( + 0, def, op_def, {S({1}), S({2, 3}), S({4, 5, 6}), {}}, {}, {}, {}); + c.SetInput(3, c.UnknownShape()); + + std::vector shapes; + EXPECT_EQ("[1]", c.DebugString(c.input(0))); + EXPECT_EQ("[2,3]", c.DebugString(c.input(1))); + EXPECT_EQ("[4,5,6]", c.DebugString(c.input(2))); + // c.DebugString() returns "?" for the unknown shape which is different with + // "-1" of TFLite. But this is intended behavior since we should use dynamic + // tensor for unknown shape so the shape comparison must fail. + EXPECT_EQ("?", c.DebugString(c.input(3))); +} + +TEST(ValidateOutputTensorShapeConsistencyTest, GetDimsDebugString) { + TfLiteIntArray* dims1 = TfLiteIntArrayCreate(1); + dims1->data[0] = 1; + EXPECT_EQ("[1]", GetDimsDebugString(dims1)); + free(dims1); + + TfLiteIntArray* dims2 = TfLiteIntArrayCreate(2); + dims2->data[0] = 2; + dims2->data[1] = 3; + EXPECT_EQ("[2,3]", GetDimsDebugString(dims2)); + free(dims2); + + TfLiteIntArray* dims3 = TfLiteIntArrayCreate(3); + dims3->data[0] = 4; + dims3->data[1] = 5; + dims3->data[2] = 6; + EXPECT_EQ("[4,5,6]", GetDimsDebugString(dims3)); + free(dims3); +} + } // namespace testing } // namespace flex } // namespace tflite diff --git a/tensorflow/lite/delegates/flex/test_util.cc b/tensorflow/lite/delegates/flex/test_util.cc index 8c0e40b58dd..02685aa0502 100644 --- a/tensorflow/lite/delegates/flex/test_util.cc +++ b/tensorflow/lite/delegates/flex/test_util.cc @@ -67,6 +67,10 @@ TfLiteType FlexModelTest::GetType(int tensor_index) { return interpreter_->tensor(tensor_index)->type; } +bool FlexModelTest::IsDynamicTensor(int tensor_index) { + return interpreter_->tensor(tensor_index)->allocation_type == kTfLiteDynamic; +} + void FlexModelTest::AddTensors(int num_tensors, const std::vector& inputs, const std::vector& outputs, TfLiteType type, const std::vector& dims) { @@ -88,6 +92,18 @@ void FlexModelTest::AddTensors(int num_tensors, const std::vector& inputs, CHECK_EQ(interpreter_->SetOutputs(outputs), kTfLiteOk); } +void FlexModelTest::SetConstTensor(int tensor_index, + const std::vector& values, + TfLiteType type, const char* buffer, + size_t bytes) { + TfLiteQuantizationParams quant; + CHECK_EQ(interpreter_->SetTensorParametersReadOnly(tensor_index, type, + /*name=*/"", + /*dims=*/values, quant, + buffer, bytes), + kTfLiteOk); +} + void FlexModelTest::AddTfLiteMulOp(const std::vector& inputs, const std::vector& outputs) { ++next_op_index_; @@ -154,6 +170,10 @@ void FlexModelTest::AddTfOp(TfOpType op, const std::vector& inputs, } else if (op == kMul) { string attributes = type_attribute; AddTfOp("FlexMul", "Mul", attributes, inputs, outputs); + } else if (op == kRfft) { + AddTfOp("FlexRFFT", "RFFT", "", inputs, outputs); + } else if (op == kImag) { + AddTfOp("FlexImag", "Imag", "", inputs, outputs); } else if (op == kNonExistent) { AddTfOp("NonExistentOp", "NonExistentOp", "", inputs, outputs); } else if (op == kIncompatibleNodeDef) { diff --git a/tensorflow/lite/delegates/flex/test_util.h b/tensorflow/lite/delegates/flex/test_util.h index 1913a406e83..c00adbfe9b3 100644 --- a/tensorflow/lite/delegates/flex/test_util.h +++ b/tensorflow/lite/delegates/flex/test_util.h @@ -28,6 +28,8 @@ enum TfOpType { kIdentity, kAdd, kMul, + kRfft, + kImag, // Represents an op that does not exist in TensorFlow. kNonExistent, // Represents an valid TensorFlow op where the NodeDef is incompatible. @@ -80,6 +82,9 @@ class FlexModelTest : public ::testing::Test { // Returns the tensor's type at the given index. TfLiteType GetType(int tensor_index); + // Returns if the tensor at the given index is dynamic. + bool IsDynamicTensor(int tensor_index); + const TestErrorReporter& error_reporter() const { return error_reporter_; } // Adds `num_tensor` tensors to the model. `inputs` contains the indices of @@ -89,6 +94,11 @@ class FlexModelTest : public ::testing::Test { const std::vector& outputs, TfLiteType type, const std::vector& dims); + // Set a constant tensor of the given shape, type and buffer at the given + // index. + void SetConstTensor(int tensor_index, const std::vector& values, + TfLiteType type, const char* buffer, size_t bytes); + // Adds a TFLite Mul op. `inputs` contains the indices of the input tensors // and `outputs` contains the indices of the output tensors. void AddTfLiteMulOp(const std::vector& inputs, diff --git a/tensorflow/lite/delegates/gpu/BUILD b/tensorflow/lite/delegates/gpu/BUILD index d69bed4c03a..8778653b586 100644 --- a/tensorflow/lite/delegates/gpu/BUILD +++ b/tensorflow/lite/delegates/gpu/BUILD @@ -54,7 +54,7 @@ cc_library( "//tensorflow/lite/delegates/gpu/common:shape", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common:tensor", - "//tensorflow/lite/delegates/gpu/common/transformations:general_transformations", + "//tensorflow/lite/delegates/gpu/common/transformations:model_transformations", "//tensorflow/lite/delegates/gpu/gl:api", "//tensorflow/lite/delegates/gpu/gl:command_queue", "//tensorflow/lite/delegates/gpu/gl:compiler", @@ -96,7 +96,6 @@ objc_library( "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common:tensor", "//tensorflow/lite/delegates/gpu/common:types", - "//tensorflow/lite/delegates/gpu/common/transformations:general_transformations", "//tensorflow/lite/delegates/gpu/metal:api", "//tensorflow/lite/delegates/gpu/metal:buffer_convert", "//tensorflow/lite/delegates/gpu/metal:compiled_model", diff --git a/tensorflow/lite/delegates/gpu/cl/BUILD b/tensorflow/lite/delegates/gpu/cl/BUILD index 9ae3836d6c4..63171348b74 100644 --- a/tensorflow/lite/delegates/gpu/cl/BUILD +++ b/tensorflow/lite/delegates/gpu/cl/BUILD @@ -55,6 +55,7 @@ cc_library( ":cl_device", ":gpu_object", ":opencl_wrapper", + ":serialization_cc_fbs", ":tensor_type", ":util", "//tensorflow/lite/delegates/gpu/common:access_type", @@ -76,7 +77,10 @@ cc_test( ], deps = [ ":arguments", + ":buffer", + ":device_info", ":gpu_object", + ":tensor", ":tensor_type", "//tensorflow/lite/delegates/gpu/common:data_type", "@com_google_absl//absl/strings", @@ -283,7 +287,7 @@ cc_library( ":cl_command_queue", ":cl_context", ":cl_device", - ":cl_kernel", + ":device_info", ":precision", ":program_cache", ":tensor", @@ -343,7 +347,7 @@ cc_library( "//tensorflow/lite/delegates/gpu/common:model_builder", "//tensorflow/lite/delegates/gpu/common:model_transformer", "//tensorflow/lite/delegates/gpu/common:status", - "//tensorflow/lite/delegates/gpu/common/transformations:general_transformations", + "//tensorflow/lite/delegates/gpu/common/transformations:model_transformations", "@com_google_absl//absl/types:span", ], ) @@ -355,6 +359,7 @@ cc_library( deps = [ ":cl_context", ":opencl_wrapper", + ":serialization_cc_fbs", "//tensorflow/lite/delegates/gpu/common:access_type", "//tensorflow/lite/delegates/gpu/common:data_type", "//tensorflow/lite/delegates/gpu/common:status", @@ -363,18 +368,30 @@ cc_library( cc_library( name = "inference_context", - srcs = ["inference_context.cc"], - hdrs = ["inference_context.h"], + srcs = [ + "inference_context.cc", + "serialization.cc", + ], + hdrs = [ + "inference_context.h", + "serialization.h", + ], deps = [ + ":arguments", ":buffer", ":cl_command_queue", + ":cl_context", ":cl_device", ":environment", + ":gpu_object", + ":linear_storage", ":model_hints", ":opencl_wrapper", ":precision", + ":serialization_cc_fbs", ":storage_type_util", ":tensor_type", + ":texture2d", "//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation", "//tensorflow/lite/delegates/gpu/cl/selectors:operation_selector", "//tensorflow/lite/delegates/gpu/cl/selectors:special_selector", @@ -392,6 +409,7 @@ cc_library( "//tensorflow/lite/delegates/gpu/common/transformations:merge_padding_with", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/types:span", ], ) @@ -463,6 +481,14 @@ cc_library( ], ) +flatbuffer_cc_library( + name = "serialization_cc_fbs", + srcs = ["serialization.fbs"], + flatc_args = [ + "--scoped-enums", + ], +) + cc_library( name = "storage_type_util", srcs = ["storage_type_util.cc"], diff --git a/tensorflow/lite/delegates/gpu/cl/api.cc b/tensorflow/lite/delegates/gpu/cl/api.cc index 01d32aa9206..e2135d05b53 100644 --- a/tensorflow/lite/delegates/gpu/cl/api.cc +++ b/tensorflow/lite/delegates/gpu/cl/api.cc @@ -570,6 +570,56 @@ TensorObjectDef TensorToDef(const Tensor& tensor) { return def; } +CalculationsPrecision GetPrecision(const Environment& env, + const InferenceOptions& options) { + CalculationsPrecision precision; + switch (GetPosition(options, InferencePriority::MAX_PRECISION)) { + case 1: + precision = CalculationsPrecision::F32; + break; + case 2: + precision = CalculationsPrecision::F32_F16; + break; + case 3: + precision = CalculationsPrecision::F16; + break; + default: + precision = CalculationsPrecision::F16; + break; + } + // Increase precision if lower precision is not supported. + if (!env.IsSupported(precision)) { + precision = CalculationsPrecision::F32_F16; + if (!env.IsSupported(precision)) { + precision = CalculationsPrecision::F32; + } + } + return precision; +} + +TensorStorageType GetStorageTypeFromOptions(const Environment& env, + const InferenceOptions& options) { + // Fallback to BUFFER that should be supported by default. + std::vector preferred_storage_types; + if (GetRelativeImportance(options, InferencePriority::MIN_LATENCY, + InferencePriority::MIN_MEMORY_USAGE) == + PriorityImportance::HIGHER) { + preferred_storage_types = {GetFastestStorageType(env.device().GetInfo()), + TensorStorageType::BUFFER}; + } else { + preferred_storage_types = { + GetStorageTypeWithMinimalMemoryConsumption(env.device().GetInfo()), + TensorStorageType::BUFFER}; + } + + for (TensorStorageType storage_type : preferred_storage_types) { + if (env.IsSupported(storage_type)) { + return storage_type; + } + } + return TensorStorageType::UNKNOWN; +} + class InferenceBuilderImpl : public InferenceBuilder { public: explicit InferenceBuilderImpl(Environment* environment) @@ -580,11 +630,14 @@ class InferenceBuilderImpl : public InferenceBuilder { const GraphFloat32& graph) { context_ = absl::make_unique(); InferenceContext::CreateInferenceInfo create_info; - create_info.precision = GetPrecision(options); - create_info.storage_type = GetStorageType(options); + create_info.precision = GetPrecision(*environment_, options); + create_info.storage_type = + GetStorageTypeFromOptions(*environment_, options); if (options.usage == InferenceUsage::FAST_SINGLE_ANSWER) { create_info.hints.Add(ModelHints::kReduceKernelsCount); create_info.hints.Add(ModelHints::kFastTuning); + } else if (options.usage == InferenceUsage::SUSTAINED_SPEED) { + create_info.hints.Add(ModelHints::kAllowSpecialKernels); } RETURN_IF_ERROR(context_->InitFromGraph(create_info, graph, environment_)); @@ -601,8 +654,32 @@ class InferenceBuilderImpl : public InferenceBuilder { absl::make_unique(environment_, context_.get()); #endif - inputs_ = LinkTensors(graph, graph.inputs()); - outputs_ = LinkTensors(graph, graph.outputs()); + inputs_ = LinkTensors(context_->GetInputIds(), AccessType::READ); + outputs_ = LinkTensors(context_->GetOutputIds(), AccessType::WRITE); + return absl::OkStatus(); + } + + absl::Status Initialize(const InferenceEnvironmentOptions& env_options, + const std::vector& serialized_model) { + context_ = absl::make_unique(); + RETURN_IF_ERROR( + context_->RestoreDeserialized(serialized_model, environment_)); + +#ifdef CL_DELEGATE_ALLOW_GL + if (env_options.IsGlAware() && + IsGlSharingSupported(environment_->device())) { + gl_interop_fabric_ = absl::make_unique( + env_options.egl_display, environment_); + } + tie_factory_ = absl::make_unique( + environment_, context_.get(), gl_interop_fabric_.get()); +#else + tie_factory_ = + absl::make_unique(environment_, context_.get()); +#endif + + inputs_ = LinkTensors(context_->GetInputIds(), AccessType::READ); + outputs_ = LinkTensors(context_->GetOutputIds(), AccessType::WRITE); return absl::OkStatus(); } @@ -669,64 +746,14 @@ class InferenceBuilderImpl : public InferenceBuilder { } private: - TensorStorageType GetStorageType(const InferenceOptions& options) const { - // Fallback to BUFFER that should be supported by default. - std::vector preferred_storage_types; - if (GetRelativeImportance(options, InferencePriority::MIN_LATENCY, - InferencePriority::MIN_MEMORY_USAGE) == - PriorityImportance::HIGHER) { - preferred_storage_types = {GetFastestStorageType(environment_->device()), - TensorStorageType::BUFFER}; - } else { - preferred_storage_types = { - GetStorageTypeWithMinimalMemoryConsumption(environment_->device()), - TensorStorageType::BUFFER}; - } - - for (TensorStorageType storage_type : preferred_storage_types) { - if (environment_->IsSupported(storage_type)) { - return storage_type; - } - } - return TensorStorageType::UNKNOWN; - } - - CalculationsPrecision GetPrecision(const InferenceOptions& options) const { - CalculationsPrecision precision; - switch (GetPosition(options, InferencePriority::MAX_PRECISION)) { - case 1: - precision = CalculationsPrecision::F32; - break; - case 2: - precision = CalculationsPrecision::F32_F16; - break; - case 3: - precision = CalculationsPrecision::F16; - break; - default: - precision = CalculationsPrecision::F16; - break; - } - // Increase precision if lower precision is not supported. - if (!environment_->IsSupported(precision)) { - precision = CalculationsPrecision::F32_F16; - if (!environment_->IsSupported(precision)) { - precision = CalculationsPrecision::F32; - } - } - return precision; - } - // Links internal tensors with external user-facing objects. - std::vector LinkTensors(const GraphFloat32& graph, - const std::vector& values) { + std::vector LinkTensors(const std::vector& ids, + AccessType access) { std::vector links; - links.reserve(values.size()); - for (const auto& value : values) { - TensorObjectDef def = TensorToDef(*context_->GetTensor(value->id)); - AccessType access = - graph.IsGraphInput(value->id) ? AccessType::READ : AccessType::WRITE; - links.push_back({value->id, access, def, def}); + links.reserve(ids.size()); + for (const auto& id : ids) { + TensorObjectDef def = TensorToDef(*context_->GetTensor(id)); + links.push_back({id, access, def, def}); } return links; } @@ -839,6 +866,39 @@ class InferenceEnvironmentImpl : public InferenceEnvironment { return environment_.Init(); } + absl::Status BuildSerializedModel( + const InferenceOptions& options, GraphFloat32 model, + std::vector* serialized_model) final { + if (!IsValid(options)) { + return absl::InvalidArgumentError("InferenceOptions are invalid."); + } + InferenceOptions resolved_options = options; + ResolveAutoPriority(&resolved_options); + if (environment_.program_cache() && + !options_.serialized_binary_cache.empty()) { + // Ignore returned error. Cache is discarded. + environment_.program_cache() + ->AddSerializedCache(environment_.context(), environment_.device(), + options_.serialized_binary_cache) + .IgnoreError(); + } + + RETURN_IF_ERROR(RunGraphTransforms(&model)); + InferenceContext context; + InferenceContext::CreateInferenceInfo create_info; + create_info.precision = GetPrecision(environment_, options); + create_info.storage_type = GetStorageTypeFromOptions(environment_, options); + if (options.usage == InferenceUsage::FAST_SINGLE_ANSWER) { + create_info.hints.Add(ModelHints::kReduceKernelsCount); + create_info.hints.Add(ModelHints::kFastTuning); + } else if (options.usage == InferenceUsage::SUSTAINED_SPEED) { + create_info.hints.Add(ModelHints::kAllowSpecialKernels); + } + RETURN_IF_ERROR(context.InitFromGraph(create_info, model, &environment_, + serialized_model)); + return absl::OkStatus(); + } + absl::Status NewInferenceBuilder( const InferenceOptions& options, GraphFloat32 model, std::unique_ptr* builder) final { @@ -864,6 +924,24 @@ class InferenceEnvironmentImpl : public InferenceEnvironment { return absl::OkStatus(); } + absl::Status NewInferenceBuilder( + const std::vector& serialized_model, + std::unique_ptr* builder) final { + if (environment_.program_cache() && + !options_.serialized_binary_cache.empty()) { + // Ignore returned error. Cache is discarded. + environment_.program_cache() + ->AddSerializedCache(environment_.context(), environment_.device(), + options_.serialized_binary_cache) + .IgnoreError(); + } + + auto builder_impl = absl::make_unique(&environment_); + RETURN_IF_ERROR(builder_impl->Initialize(options_, serialized_model)); + *builder = std::move(builder_impl); + return absl::OkStatus(); + } + std::vector GetSerializedBinaryCache() const final { std::vector data; // Is there was a problem, data would be empty. diff --git a/tensorflow/lite/delegates/gpu/cl/api.h b/tensorflow/lite/delegates/gpu/cl/api.h index 826d4f2bc78..65671117522 100644 --- a/tensorflow/lite/delegates/gpu/cl/api.h +++ b/tensorflow/lite/delegates/gpu/cl/api.h @@ -75,6 +75,20 @@ class InferenceEnvironment { public: virtual ~InferenceEnvironment() {} + // Converts GraphFloat32 into intermediate, device-specific representation. + // This serialized_model specific for device and InferenceOptions. + // serialized_model cannot be used with another device or InferenceOptions. + // Loading serialized_model is much faster than loading GraphFloat32. + // serialized_model must be used with appropriate NewInferenceBuilder + // method (see below). + virtual absl::Status BuildSerializedModel( + const InferenceOptions& options, GraphFloat32 model, + std::vector* serialized_model) = 0; + + virtual absl::Status NewInferenceBuilder( + const std::vector& serialized_model, + std::unique_ptr* builder) = 0; + virtual absl::Status NewInferenceBuilder( const InferenceOptions& options, GraphFloat32 model, std::unique_ptr* builder) = 0; diff --git a/tensorflow/lite/delegates/gpu/cl/arguments.cc b/tensorflow/lite/delegates/gpu/cl/arguments.cc index 5623de2419c..7c5e635816e 100644 --- a/tensorflow/lite/delegates/gpu/cl/arguments.cc +++ b/tensorflow/lite/delegates/gpu/cl/arguments.cc @@ -256,13 +256,6 @@ void Arguments::AddObjectRef(const std::string& name, AccessType access_type, object_refs_[name] = {std::move(descriptor_ptr)}; } -void Arguments::AddObject(const std::string& name, AccessType access_type, - GPUObjectPtr&& object, - GPUObjectDescriptorPtr&& descriptor_ptr) { - descriptor_ptr->SetAccess(access_type); - objects_[name] = {std::move(object), std::move(descriptor_ptr)}; -} - void Arguments::AddObject(const std::string& name, GPUObjectDescriptorPtr&& descriptor_ptr) { descriptor_ptr->SetAccess(AccessType::READ); @@ -666,56 +659,64 @@ absl::Status Arguments::Bind(cl_kernel kernel, int offset) { std::string Arguments::AddActiveArgument(const std::string& arg_name, bool use_f32_for_halfs) { - if (auto it = int_values_.find(arg_name); it != int_values_.end()) { - int int_index; - if (it->second.active) { - int_index = it->second.offset; - } else { - it->second.active = true; - it->second.offset = shared_int4s_data_.size(); - int_index = it->second.offset; - shared_int4s_data_.push_back(it->second.value); - } - std::string index = std::to_string(int_index / 4); - std::string postfixes[4] = {"x", "y", "z", "w"}; - return "shared_int4_" + index + "." + postfixes[int_index % 4]; - } - if (auto it = float_values_.find(arg_name); it != float_values_.end()) { - int float_index; - if (it->second.active) { - float_index = it->second.offset; - } else { - it->second.active = true; - it->second.offset = shared_float4s_data_.size(); - float_index = it->second.offset; - shared_float4s_data_.push_back(it->second.value); - } - std::string index = std::to_string(float_index / 4); - std::string postfixes[4] = {"x", "y", "z", "w"}; - return "shared_float4_" + index + "." + postfixes[float_index % 4]; - } - if (auto it = half_values_.find(arg_name); it != half_values_.end()) { - int half_index; - if (it->second.active) { - half_index = it->second.offset; - } else { - it->second.active = true; - if (use_f32_for_halfs) { - it->second.store_as_f32 = true; - it->second.offset = shared_float4s_data_.size(); - shared_float4s_data_.push_back(it->second.value); + { + auto it = int_values_.find(arg_name); + if (it != int_values_.end()) { + int int_index; + if (it->second.active) { + int_index = it->second.offset; } else { - it->second.offset = shared_half4s_data_.size(); - shared_half4s_data_.push_back(it->second.value); + it->second.active = true; + it->second.offset = shared_int4s_data_.size(); + int_index = it->second.offset; + shared_int4s_data_.push_back(it->second.value); } - half_index = it->second.offset; + std::string index = std::to_string(int_index / 4); + std::string postfixes[4] = {"x", "y", "z", "w"}; + return "shared_int4_" + index + "." + postfixes[int_index % 4]; } - std::string index = std::to_string(half_index / 4); - std::string postfixes[4] = {"x", "y", "z", "w"}; - if (it->second.store_as_f32) { - return "(half)(shared_float4_" + index + "." + postfixes[half_index % 4] + - ")"; - } else { + } + { + auto it = float_values_.find(arg_name); + if (it != float_values_.end()) { + int float_index; + if (it->second.active) { + float_index = it->second.offset; + } else { + it->second.active = true; + it->second.offset = shared_float4s_data_.size(); + float_index = it->second.offset; + shared_float4s_data_.push_back(it->second.value); + } + std::string index = std::to_string(float_index / 4); + std::string postfixes[4] = {"x", "y", "z", "w"}; + return "shared_float4_" + index + "." + postfixes[float_index % 4]; + } + } + { + auto it = half_values_.find(arg_name); + if (it != half_values_.end()) { + int half_index; + if (it->second.active) { + half_index = it->second.offset; + } else { + it->second.active = true; + if (use_f32_for_halfs) { + it->second.store_as_f32 = true; + it->second.offset = shared_float4s_data_.size(); + shared_float4s_data_.push_back(it->second.value); + } else { + it->second.offset = shared_half4s_data_.size(); + shared_half4s_data_.push_back(it->second.value); + } + half_index = it->second.offset; + } + std::string index = std::to_string(half_index / 4); + std::string postfixes[4] = {"x", "y", "z", "w"}; + if (it->second.store_as_f32) { + return "(half)(shared_float4_" + index + "." + + postfixes[half_index % 4] + ")"; + } return "shared_half4_" + index + "." + postfixes[half_index % 4]; } } @@ -755,24 +756,38 @@ void Arguments::ResolveObjectNames(const std::string& object_name, } } +GPUObjectDescriptor* Arguments::GetObjectDescriptor( + const std::string& object_name) const { + { + auto it = object_refs_.find(object_name); + if (it != object_refs_.end()) { + return it->second.descriptor.get(); + } + } + { + auto it = objects_.find(object_name); + if (it != objects_.end()) { + return it->second.descriptor.get(); + } + } + return nullptr; +} + absl::Status Arguments::ResolveSelector( const std::map& linkables, const std::string& object_name, const std::string& selector, const std::vector& args, const std::vector& template_args, std::string* result) { - const GPUObjectDescriptor* desc_ptr; - if (auto it = object_refs_.find(object_name); it != object_refs_.end()) { - desc_ptr = it->second.descriptor.get(); - } else if (auto it = objects_.find(object_name); it != objects_.end()) { - desc_ptr = it->second.descriptor.get(); - } else { + const GPUObjectDescriptor* desc_ptr = GetObjectDescriptor(object_name); + if (!desc_ptr) { return absl::NotFoundError( absl::StrCat("No object with name - ", object_name)); } auto names = desc_ptr->GetGPUResources().GetNames(); const auto* tensor_desc = dynamic_cast(desc_ptr); if (tensor_desc && selector == "Write") { - if (auto it = linkables.find(object_name); it != linkables.end()) { + auto it = linkables.find(object_name); + if (it != linkables.end()) { if (desc_ptr->GetAccess() != AccessType::WRITE && desc_ptr->GetAccess() != AccessType::READ_WRITE) { return absl::FailedPreconditionError(absl::StrCat( @@ -850,11 +865,16 @@ absl::Status Arguments::AllocateObjects(CLContext* context) { for (auto& t : objects_) { RETURN_IF_ERROR( t.second.descriptor->CreateGPUObject(context, &t.second.obj_ptr)); - t.second.descriptor->Release(); } return absl::OkStatus(); } +void Arguments::ReleaseCPURepresentation() { + for (auto& t : objects_) { + t.second.descriptor->Release(); + } +} + absl::Status Arguments::AddObjectArgs() { for (auto& t : objects_) { AddGPUResources(t.first, t.second.descriptor->GetGPUResources()); diff --git a/tensorflow/lite/delegates/gpu/cl/arguments.h b/tensorflow/lite/delegates/gpu/cl/arguments.h index 643e1b7655d..a5435c4fc2f 100644 --- a/tensorflow/lite/delegates/gpu/cl/arguments.h +++ b/tensorflow/lite/delegates/gpu/cl/arguments.h @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/cl/cl_device.h" #include "tensorflow/lite/delegates/gpu/cl/gpu_object.h" #include "tensorflow/lite/delegates/gpu/cl/opencl_wrapper.h" +#include "tensorflow/lite/delegates/gpu/cl/serialization_generated.h" #include "tensorflow/lite/delegates/gpu/cl/util.h" #include "tensorflow/lite/delegates/gpu/common/access_type.h" #include "tensorflow/lite/delegates/gpu/common/status.h" @@ -33,49 +34,37 @@ namespace tflite { namespace gpu { namespace cl { -class Arguments { +class ArgumentsBinder { + public: + virtual absl::Status SetInt(const std::string& name, int value) = 0; + virtual absl::Status SetFloat(const std::string& name, float value) = 0; + virtual absl::Status SetHalf(const std::string& name, half value) = 0; + virtual ~ArgumentsBinder() = default; +}; + +class Arguments : public ArgumentsBinder { public: Arguments() = default; void AddFloat(const std::string& name, float value = 0.0f); void AddHalf(const std::string& name, half value = half(0.0f)); void AddInt(const std::string& name, int value = 0); - void AddBuffer(const std::string& name, const GPUBufferDescriptor& desc); - void AddImage2D(const std::string& name, const GPUImage2DDescriptor& desc); - void AddImage2DArray(const std::string& name, - const GPUImage2DArrayDescriptor& desc); - void AddImage3D(const std::string& name, const GPUImage3DDescriptor& desc); - void AddImageBuffer(const std::string& name, - const GPUImageBufferDescriptor& desc); - void AddCustomMemory(const std::string& name, - const GPUCustomMemoryDescriptor& desc); - void AddObjectRef(const std::string& name, AccessType access_type, GPUObjectDescriptorPtr&& descriptor_ptr); - void AddObject(const std::string& name, AccessType access_type, - GPUObjectPtr&& object, - GPUObjectDescriptorPtr&& descriptor_ptr); void AddObject(const std::string& name, GPUObjectDescriptorPtr&& descriptor_ptr); - absl::Status SetInt(const std::string& name, int value); - absl::Status SetFloat(const std::string& name, float value); - absl::Status SetHalf(const std::string& name, half value); - absl::Status SetImage2D(const std::string& name, cl_mem memory); - absl::Status SetBuffer(const std::string& name, cl_mem memory); - absl::Status SetImage2DArray(const std::string& name, cl_mem memory); - absl::Status SetImage3D(const std::string& name, cl_mem memory); - absl::Status SetImageBuffer(const std::string& name, cl_mem memory); - absl::Status SetCustomMemory(const std::string& name, cl_mem memory); + absl::Status SetInt(const std::string& name, int value) override; + absl::Status SetFloat(const std::string& name, float value) override; + absl::Status SetHalf(const std::string& name, half value) override; absl::Status SetObjectRef(const std::string& name, const GPUObject* object); - std::string GetListOfArgs(); - absl::Status Bind(cl_kernel kernel, int offset = 0); void RenameArgs(const std::string& postfix, std::string* code) const; absl::Status Merge(Arguments&& args, const std::string& postfix); absl::Status AllocateObjects(CLContext* context); + void ReleaseCPURepresentation(); absl::Status TransformToCLCode( const DeviceInfo& device_info, const std::map& linkables, std::string* code); @@ -86,7 +75,33 @@ class Arguments { Arguments(const Arguments&) = delete; Arguments& operator=(const Arguments&) = delete; + ~Arguments() override = default; + private: + friend flatbuffers::Offset Encode( + const Arguments& args, flatbuffers::FlatBufferBuilder* builder); + friend absl::Status Decode(CLContext* context, const data::Arguments* fb_args, + Arguments* args); + + void AddBuffer(const std::string& name, const GPUBufferDescriptor& desc); + void AddImage2D(const std::string& name, const GPUImage2DDescriptor& desc); + void AddImage2DArray(const std::string& name, + const GPUImage2DArrayDescriptor& desc); + void AddImage3D(const std::string& name, const GPUImage3DDescriptor& desc); + void AddImageBuffer(const std::string& name, + const GPUImageBufferDescriptor& desc); + void AddCustomMemory(const std::string& name, + const GPUCustomMemoryDescriptor& desc); + + absl::Status SetImage2D(const std::string& name, cl_mem memory); + absl::Status SetBuffer(const std::string& name, cl_mem memory); + absl::Status SetImage2DArray(const std::string& name, cl_mem memory); + absl::Status SetImage3D(const std::string& name, cl_mem memory); + absl::Status SetImageBuffer(const std::string& name, cl_mem memory); + absl::Status SetCustomMemory(const std::string& name, cl_mem memory); + + std::string GetListOfArgs(); + std::string AddActiveArgument(const std::string& arg_name, bool use_f32_for_halfs); void AddGPUResources(const std::string& name, const GPUResources& resources); @@ -110,6 +125,9 @@ class Arguments { const std::vector& member_names, std::string* code); + GPUObjectDescriptor* GetObjectDescriptor( + const std::string& object_name) const; + static constexpr char kArgsPrefix[] = "args."; struct IntValue { diff --git a/tensorflow/lite/delegates/gpu/cl/arguments_test.cc b/tensorflow/lite/delegates/gpu/cl/arguments_test.cc index 29a15e16a57..722ca5b1827 100644 --- a/tensorflow/lite/delegates/gpu/cl/arguments_test.cc +++ b/tensorflow/lite/delegates/gpu/cl/arguments_test.cc @@ -14,85 +14,58 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/delegates/gpu/cl/arguments.h" +#include #include #include #include +#include "absl/strings/match.h" +#include "tensorflow/lite/delegates/gpu/cl/buffer.h" +#include "tensorflow/lite/delegates/gpu/cl/device_info.h" #include "tensorflow/lite/delegates/gpu/cl/gpu_object.h" -#include "tensorflow/lite/delegates/gpu/cl/tensor_type.h" namespace tflite { namespace gpu { namespace cl { -namespace { -struct TestDescriptor : public GPUObjectDescriptor { - absl::Status PerformSelector(const std::string& selector, - const std::vector& args, - const std::vector& template_args, - std::string* result) const override { - if (selector == "Length") { - *result = "length"; - return absl::OkStatus(); - } else if (selector == "Read") { - if (args.size() != 1) { - return absl::NotFoundError( - absl::StrCat("TestDescriptor Read require one argument, but ", - args.size(), " was passed")); - } - *result = absl::StrCat("buffer[", args[0], "]"); - return absl::OkStatus(); - } else { - return absl::NotFoundError(absl::StrCat( - "TestDescriptor don't have selector with name - ", selector)); - } - } - - GPUResources GetGPUResources(AccessType access_type) const override { - GPUResources resources; - resources.ints.push_back("length"); - GPUBufferDescriptor desc; - desc.data_type = DataType::FLOAT32; - desc.element_size = 4; - resources.buffers.push_back({"buffer", desc}); - return resources; - } -}; -} // namespace - TEST(ArgumentsTest, TestSelectorResolve) { - TestDescriptor descriptor; - Arguments args; - args.AddObjectRef("object", AccessType::WRITE, - absl::make_unique(descriptor)); - std::string sample_code = R"( - if (a < 3) { - value = args.object.Read(id); - } -)"; - const std::string expected_result = R"( - if (a < 3) { - value = object_buffer[id]; - } -)"; - ASSERT_OK(args.TransformToCLCode({}, &sample_code)); - EXPECT_EQ(sample_code, expected_result); + BufferDescriptor desc; + desc.element_type = DataType::FLOAT32; + desc.element_size = 4; + desc.memory_type = MemoryType::GLOBAL; - std::string cl_arguments = args.GetListOfArgs(); - EXPECT_TRUE(cl_arguments.find("__global float4* object_buffer") != - std::string::npos); + Arguments args; + args.AddObjectRef("weights", AccessType::READ, + absl::make_unique(std::move(desc))); + std::string sample_code = R"( +__kernel void main_function($0) { + if (a < 3) { + value = args.weights.Read(id); + } +})"; + + DeviceInfo device_info; + ASSERT_OK(args.TransformToCLCode(device_info, {}, &sample_code)); + EXPECT_TRUE(absl::StrContains(sample_code, "value = weights_buffer[id];")); + EXPECT_TRUE( + absl::StrContains(sample_code, "__global float4* weights_buffer")); } TEST(ArgumentsTest, TestNoSelector) { - TestDescriptor descriptor; + BufferDescriptor desc; + desc.element_type = DataType::FLOAT32; + desc.element_size = 4; + desc.memory_type = MemoryType::GLOBAL; + Arguments args; - args.AddObjectRef("object", AccessType::WRITE, - absl::make_unique(descriptor)); + args.AddObjectRef("weights", AccessType::READ, + absl::make_unique(std::move(desc))); std::string sample_code = R"( if (a < 3) { - value = args.object.Write(id); + value = args.weights.UnknownSelector(id); } )"; - EXPECT_FALSE(args.TransformToCLCode({}, &sample_code).ok()); + DeviceInfo device_info; + EXPECT_FALSE(args.TransformToCLCode(device_info, {}, &sample_code).ok()); } TEST(ArgumentsTest, TestRenameArgs) { diff --git a/tensorflow/lite/delegates/gpu/cl/cl_command_queue.cc b/tensorflow/lite/delegates/gpu/cl/cl_command_queue.cc index a1795b18b27..10937cfc56b 100644 --- a/tensorflow/lite/delegates/gpu/cl/cl_command_queue.cc +++ b/tensorflow/lite/delegates/gpu/cl/cl_command_queue.cc @@ -56,14 +56,15 @@ void CLCommandQueue::Release() { } } -absl::Status CLCommandQueue::DispatchImplicit(const CLKernel& kernel, int3 grid, - int3 work_group_size, - CLEvent* event) { +absl::Status CLCommandQueue::Dispatch(const CLKernel& kernel, + const int3& work_groups_count, + const int3& work_group_size, + CLEvent* event) { std::vector local(3); std::vector global(3); for (int i = 0; i < 3; ++i) { local[i] = work_group_size[i]; - global[i] = AlignByN(grid[i], work_group_size[i]); + global[i] = work_groups_count[i] * work_group_size[i]; } cl_event resulting_event; const int error_code = clEnqueueNDRangeKernel( @@ -80,9 +81,10 @@ absl::Status CLCommandQueue::DispatchImplicit(const CLKernel& kernel, int3 grid, return absl::OkStatus(); } -absl::Status CLCommandQueue::DispatchImplicit(const CLKernel& kernel, int3 grid, - int3 work_group_size) { - return DispatchImplicit(kernel, grid, work_group_size, nullptr); +absl::Status CLCommandQueue::Dispatch(const CLKernel& kernel, + const int3& work_groups_count, + const int3& work_group_size) { + return Dispatch(kernel, work_groups_count, work_group_size, nullptr); } absl::Status CLCommandQueue::EnqueueEvent(CLEvent* event) { @@ -191,12 +193,13 @@ void ProfilingCommandQueue::SetEventsLabel(const std::string& name) { void ProfilingCommandQueue::ResetMeasurements() { events_.clear(); } -absl::Status ProfilingCommandQueue::DispatchImplicit(const CLKernel& kernel, - int3 grid, - int3 work_group_size) { +absl::Status ProfilingCommandQueue::Dispatch(const CLKernel& kernel, + const int3& work_groups_count, + const int3& work_group_size) { events_.push_back(CLEvent()); - RETURN_IF_ERROR(CLCommandQueue::DispatchImplicit( - kernel, grid, work_group_size, &events_[events_.size() - 1])); + RETURN_IF_ERROR(CLCommandQueue::Dispatch(kernel, work_groups_count, + work_group_size, + &events_[events_.size() - 1])); events_.back().SetName(current_label_); return absl::OkStatus(); } @@ -213,14 +216,15 @@ ProfilingInfo ProfilingCommandQueue::GetProfilingInfo() const { } absl::Status ProfilingCommandQueue::GetBestWorkGroupIndex( - const CLKernel& kernel, const DeviceInfo& device_info, const int3& grid, + const CLKernel& kernel, const DeviceInfo& device_info, + const std::vector& work_groups_count, const std::vector& work_group_sizes, int* index) { // Some Adreno 3xx can have wrong numbers for some events const bool possible_bug_with_events = device_info.IsAdreno3xx(); events_.resize(work_group_sizes.size()); for (int i = 0; i < work_group_sizes.size(); ++i) { - RETURN_IF_ERROR(CLCommandQueue::DispatchImplicit( - kernel, grid, work_group_sizes[i], &events_[i])); + RETURN_IF_ERROR(CLCommandQueue::Dispatch(kernel, work_groups_count[i], + work_group_sizes[i], &events_[i])); // reducing the speed of memory leak on Mali for some kernels if (device_info.IsMali() && i % 8 == 7) { @@ -330,24 +334,34 @@ absl::Duration ProfilingInfo::GetTotalTime() const { std::string ProfilingInfo::GetDetailedReport() const { std::string result; - std::map timing; + struct OpStatistic { + int count; + double total_time; + }; + std::map statistics; result += "Per kernel timing(" + std::to_string(dispatches.size()) + " kernels):\n"; for (const auto& dispatch : dispatches) { result += " " + dispatch.label + " - " + std::to_string(absl::ToDoubleMilliseconds(dispatch.duration)) + - "ms\n"; + " ms\n"; auto name = dispatch.label.substr(0, dispatch.label.find(" ")); - if (timing.find(name) != timing.end()) { - timing[name] += absl::ToDoubleMilliseconds(dispatch.duration); + if (statistics.find(name) != statistics.end()) { + statistics[name].count++; + statistics[name].total_time += + absl::ToDoubleMilliseconds(dispatch.duration); } else { - timing[name] = absl::ToDoubleMilliseconds(dispatch.duration); + statistics[name].count = 1; + statistics[name].total_time = + absl::ToDoubleMilliseconds(dispatch.duration); } } result += "--------------------\n"; result += "Accumulated time per operation type:\n"; - for (auto& t : timing) { - result += " " + t.first + " - " + std::to_string(t.second) + "ms\n"; + for (auto& t : statistics) { + auto stat = t.second; + result += " " + t.first + "(x" + std::to_string(stat.count) + ") - " + + std::to_string(stat.total_time) + " ms\n"; } result += "--------------------\n"; result += "Ideal total time: " + diff --git a/tensorflow/lite/delegates/gpu/cl/cl_command_queue.h b/tensorflow/lite/delegates/gpu/cl/cl_command_queue.h index 178e3b21a1e..519b87640e7 100644 --- a/tensorflow/lite/delegates/gpu/cl/cl_command_queue.h +++ b/tensorflow/lite/delegates/gpu/cl/cl_command_queue.h @@ -74,14 +74,15 @@ class CLCommandQueue { cl_command_queue queue() const { return queue_; } - virtual absl::Status DispatchImplicit(const CLKernel& kernel, int3 grid, - int3 work_group_size); + virtual absl::Status Dispatch(const CLKernel& kernel, + const int3& work_groups_count, + const int3& work_group_size); + + absl::Status Dispatch(const CLKernel& kernel, const int3& work_groups_count, + const int3& work_group_size, CLEvent* event); absl::Status EnqueueEvent(CLEvent* event); - absl::Status DispatchImplicit(const CLKernel& kernel, int3 grid, - int3 work_group_size, CLEvent* event); - absl::Status EnqueueWriteImage(cl_mem memory, int3 region, const void* data); absl::Status EnqueueReadImage(cl_mem memory, int3 region, void* data); @@ -110,13 +111,13 @@ class ProfilingCommandQueue : public CLCommandQueue { ProfilingCommandQueue(const ProfilingCommandQueue&) = delete; ProfilingCommandQueue& operator=(const ProfilingCommandQueue&) = delete; - absl::Status DispatchImplicit(const CLKernel& kernel, int3 grid, - int3 work_group_size) override; + absl::Status Dispatch(const CLKernel& kernel, const int3& work_groups_count, + const int3& work_group_size) override; // will write index for fastest work_group among work_group_sizes absl::Status GetBestWorkGroupIndex(const CLKernel& kernel, const DeviceInfo& device_info, - const int3& grid, + const std::vector& work_groups_count, const std::vector& work_group_sizes, int* index); diff --git a/tensorflow/lite/delegates/gpu/cl/cl_device.h b/tensorflow/lite/delegates/gpu/cl/cl_device.h index e7cd274661d..79335a61aff 100644 --- a/tensorflow/lite/delegates/gpu/cl/cl_device.h +++ b/tensorflow/lite/delegates/gpu/cl/cl_device.h @@ -73,6 +73,7 @@ class CLDevice { bool SupportsOneLayerTextureArray() const; void DisableOneLayerTextureArray(); + const DeviceInfo& GetInfo() const { return info_; } // We update device info during context creation, so as supported texture // formats can be requested from context only. mutable DeviceInfo info_; diff --git a/tensorflow/lite/delegates/gpu/cl/compiled_program_cache_generated.h b/tensorflow/lite/delegates/gpu/cl/compiled_program_cache_generated.h new file mode 100644 index 00000000000..8a12bf2a9db --- /dev/null +++ b/tensorflow/lite/delegates/gpu/cl/compiled_program_cache_generated.h @@ -0,0 +1,207 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// automatically generated by the FlatBuffers compiler, do not modify + + +#ifndef FLATBUFFERS_GENERATED_COMPILEDPROGRAMCACHE_TFLITE_GPU_CL_DATA_H_ +#define FLATBUFFERS_GENERATED_COMPILEDPROGRAMCACHE_TFLITE_GPU_CL_DATA_H_ + +#include "flatbuffers/flatbuffers.h" + +namespace tflite { +namespace gpu { +namespace cl { +namespace data { + +struct Program; + +struct CompiledCache; + +struct Program FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_FINGERPRINT = 4, + VT_BINARY = 6 + }; + uint64_t fingerprint() const { + return GetField(VT_FINGERPRINT, 0); + } + const flatbuffers::Vector *binary() const { + return GetPointer *>(VT_BINARY); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_FINGERPRINT) && + VerifyOffset(verifier, VT_BINARY) && + verifier.VerifyVector(binary()) && + verifier.EndTable(); + } +}; + +struct ProgramBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_fingerprint(uint64_t fingerprint) { + fbb_.AddElement(Program::VT_FINGERPRINT, fingerprint, 0); + } + void add_binary(flatbuffers::Offset> binary) { + fbb_.AddOffset(Program::VT_BINARY, binary); + } + explicit ProgramBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ProgramBuilder &operator=(const ProgramBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateProgram( + flatbuffers::FlatBufferBuilder &_fbb, + uint64_t fingerprint = 0, + flatbuffers::Offset> binary = 0) { + ProgramBuilder builder_(_fbb); + builder_.add_fingerprint(fingerprint); + builder_.add_binary(binary); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateProgramDirect( + flatbuffers::FlatBufferBuilder &_fbb, + uint64_t fingerprint = 0, + const std::vector *binary = nullptr) { + auto binary__ = binary ? _fbb.CreateVector(*binary) : 0; + return tflite::gpu::cl::data::CreateProgram( + _fbb, + fingerprint, + binary__); +} + +struct CompiledCache FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_DRIVER_VERSION = 4, + VT_PROGRAMS = 6 + }; + const flatbuffers::String *driver_version() const { + return GetPointer(VT_DRIVER_VERSION); + } + const flatbuffers::Vector> *programs() const { + return GetPointer> *>(VT_PROGRAMS); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_DRIVER_VERSION) && + verifier.VerifyString(driver_version()) && + VerifyOffset(verifier, VT_PROGRAMS) && + verifier.VerifyVector(programs()) && + verifier.VerifyVectorOfTables(programs()) && + verifier.EndTable(); + } +}; + +struct CompiledCacheBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_driver_version(flatbuffers::Offset driver_version) { + fbb_.AddOffset(CompiledCache::VT_DRIVER_VERSION, driver_version); + } + void add_programs(flatbuffers::Offset>> programs) { + fbb_.AddOffset(CompiledCache::VT_PROGRAMS, programs); + } + explicit CompiledCacheBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + CompiledCacheBuilder &operator=(const CompiledCacheBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateCompiledCache( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset driver_version = 0, + flatbuffers::Offset>> programs = 0) { + CompiledCacheBuilder builder_(_fbb); + builder_.add_programs(programs); + builder_.add_driver_version(driver_version); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateCompiledCacheDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const char *driver_version = nullptr, + const std::vector> *programs = nullptr) { + auto driver_version__ = driver_version ? _fbb.CreateString(driver_version) : 0; + auto programs__ = programs ? _fbb.CreateVector>(*programs) : 0; + return tflite::gpu::cl::data::CreateCompiledCache( + _fbb, + driver_version__, + programs__); +} + +inline const tflite::gpu::cl::data::CompiledCache *GetCompiledCache(const void *buf) { + return flatbuffers::GetRoot(buf); +} + +inline const tflite::gpu::cl::data::CompiledCache *GetSizePrefixedCompiledCache(const void *buf) { + return flatbuffers::GetSizePrefixedRoot(buf); +} + +inline const char *CompiledCacheIdentifier() { + return "AFCM"; +} + +inline bool CompiledCacheBufferHasIdentifier(const void *buf) { + return flatbuffers::BufferHasIdentifier( + buf, CompiledCacheIdentifier()); +} + +inline bool VerifyCompiledCacheBuffer( + flatbuffers::Verifier &verifier) { + return verifier.VerifyBuffer(CompiledCacheIdentifier()); +} + +inline bool VerifySizePrefixedCompiledCacheBuffer( + flatbuffers::Verifier &verifier) { + return verifier.VerifySizePrefixedBuffer(CompiledCacheIdentifier()); +} + +inline const char *CompiledCacheExtension() { + return "jetbin"; +} + +inline void FinishCompiledCacheBuffer( + flatbuffers::FlatBufferBuilder &fbb, + flatbuffers::Offset root) { + fbb.Finish(root, CompiledCacheIdentifier()); +} + +inline void FinishSizePrefixedCompiledCacheBuffer( + flatbuffers::FlatBufferBuilder &fbb, + flatbuffers::Offset root) { + fbb.FinishSizePrefixed(root, CompiledCacheIdentifier()); +} + +} // namespace data +} // namespace cl +} // namespace gpu +} // namespace tflite + +#endif // FLATBUFFERS_GENERATED_COMPILEDPROGRAMCACHE_TFLITE_GPU_CL_DATA_H_ diff --git a/tensorflow/lite/delegates/gpu/cl/device_info.cc b/tensorflow/lite/delegates/gpu/cl/device_info.cc index 5d035e34617..43d050e8371 100644 --- a/tensorflow/lite/delegates/gpu/cl/device_info.cc +++ b/tensorflow/lite/delegates/gpu/cl/device_info.cc @@ -40,7 +40,8 @@ MaliGPU GetMaliGPUVersion(const std::string& device_name) { {"T830", MaliGPU::T830}, {"T860", MaliGPU::T860}, {"T880", MaliGPU::T880}, {"G31", MaliGPU::G31}, {"G51", MaliGPU::G51}, {"G71", MaliGPU::G71}, {"G52", MaliGPU::G52}, {"G72", MaliGPU::G72}, {"G76", MaliGPU::G76}, - {"G57", MaliGPU::G57}, {"G77", MaliGPU::G77}, + {"G57", MaliGPU::G57}, {"G77", MaliGPU::G77}, {"G68", MaliGPU::G68}, + {"G78", MaliGPU::G78}, }; for (const auto& v : kMapping) { if (device_name.find(v.first) != std::string::npos) { @@ -212,7 +213,8 @@ bool MaliInfo::IsBifrost() const { } bool MaliInfo::IsValhall() const { - return gpu_version == MaliGPU::G57 || gpu_version == MaliGPU::G77; + return gpu_version == MaliGPU::G57 || gpu_version == MaliGPU::G77 || + gpu_version == MaliGPU::G68 || gpu_version == MaliGPU::G78; } bool DeviceInfo::SupportsTextureArray() const { diff --git a/tensorflow/lite/delegates/gpu/cl/device_info.h b/tensorflow/lite/delegates/gpu/cl/device_info.h index abb3feb07b1..f28f4719232 100644 --- a/tensorflow/lite/delegates/gpu/cl/device_info.h +++ b/tensorflow/lite/delegates/gpu/cl/device_info.h @@ -95,6 +95,8 @@ enum class MaliGPU { G76, G57, G77, + G68, + G78, UNKNOWN }; diff --git a/tensorflow/lite/delegates/gpu/cl/environment.cc b/tensorflow/lite/delegates/gpu/cl/environment.cc index 785e88299a7..5b06b307133 100644 --- a/tensorflow/lite/delegates/gpu/cl/environment.cc +++ b/tensorflow/lite/delegates/gpu/cl/environment.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include -#include "tensorflow/lite/delegates/gpu/cl/cl_kernel.h" #include "tensorflow/lite/delegates/gpu/cl/util.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" @@ -26,59 +25,6 @@ namespace tflite { namespace gpu { namespace cl { namespace { - -std::string GetKernelOneLayerTextureArray() { - return R"( - -__kernel void main_function(__write_only image2d_array_t dst) { - int X = (int)(get_global_id(0)); - int Y = (int)(get_global_id(1)); - - write_imagef(dst, (int4)(X, Y, 0, 0), (float4)(2.0, 2.0, 2.0, 2.0)); -} -)"; -} - -// Some Adreno < 600 have bug with one layer texture array. b/131099086 -// If we have one layer texture array and will write smt from kernel to this -// texture, we will get zeroes instead of actual values. -// The same kernel will work, if we use texture array with more than one layer. -// With help of this code we can detect this bug. -absl::Status CheckKernelSupportOfOneLayerTextureArray(Environment* env, - bool* result) { - // No bug on Adreno 6xx - if (env->device().info_.adreno_info.gpu_version >= 600) { - *result = true; - return absl::OkStatus(); - } - CLKernel kernel; - RETURN_IF_ERROR(env->program_cache()->GetOrCreateCLKernel( - GetKernelOneLayerTextureArray(), "main_function", env->context(), - env->device(), &kernel)); - - Tensor tensor; - const BHWC shape(1, 4, 4, 4); - RETURN_IF_ERROR(CreateTensor( - env->context(), shape, - {DataType::FLOAT32, TensorStorageType::TEXTURE_ARRAY, Layout::HWC}, - &tensor)); - RETURN_IF_ERROR(kernel.SetMemory(0, tensor.GetMemoryPtr())); - RETURN_IF_ERROR(env->queue()->DispatchImplicit(kernel, {4, 4, 1}, {4, 4, 1})); - TensorFloat32 tensor_gpu; - tensor_gpu.shape = shape; - tensor_gpu.data.resize(shape.DimensionsProduct()); - RETURN_IF_ERROR(tensor.ReadData(env->queue(), &tensor_gpu)); - - *result = true; - for (int i = 0; i < 64; ++i) { - if (tensor_gpu.data[i] != 2.0) { - *result = false; - break; - } - } - return absl::OkStatus(); -} - absl::Status CreateEnvironment(Environment* result, bool shared, cl_context_properties egl_context, cl_context_properties egl_display) { @@ -99,16 +45,7 @@ absl::Status CreateEnvironment(Environment* result, bool shared, *result = Environment(std::move(gpu), std::move(context), std::move(queue), std::move(profiling_queue)); - if (result->device().IsAdreno() && result->device().SupportsTextureArray()) { - bool supports_one_layer; - RETURN_IF_ERROR( - CheckKernelSupportOfOneLayerTextureArray(result, &supports_one_layer)); - if (!supports_one_layer) { - result->GetDevicePtr()->DisableOneLayerTextureArray(); - } - } - - return absl::OkStatus(); + return result->Init(); } } // namespace @@ -141,10 +78,12 @@ Environment& Environment::operator=(Environment&& environment) { absl::Status Environment::Init() { if (device().IsAdreno() && device().SupportsTextureArray()) { - bool supports_one_layer; - RETURN_IF_ERROR( - CheckKernelSupportOfOneLayerTextureArray(this, &supports_one_layer)); - if (!supports_one_layer) { + // Some Adreno < 600 have bug with one layer texture array. b/131099086 + // If we have one layer texture array and will write smt from kernel to this + // texture, we will get zeroes instead of actual values. + // The same kernel will work, if we use texture array with more than one + // layer. + if (device().info_.adreno_info.gpu_version < 600) { GetDevicePtr()->DisableOneLayerTextureArray(); } } @@ -232,54 +171,54 @@ bool Environment::IsSupported(TensorStorageType storage_type) const { return false; } -TensorStorageType GetFastestStorageType(const CLDevice& gpu) { - if (gpu.IsAdreno()) { - if (gpu.IsAdreno6xxOrHigher()) { +TensorStorageType GetFastestStorageType(const DeviceInfo& gpu_info) { + if (gpu_info.IsAdreno()) { + if (gpu_info.IsAdreno6xxOrHigher()) { return TensorStorageType::TEXTURE_ARRAY; } else { return TensorStorageType::TEXTURE_2D; } - } else if (gpu.IsPowerVR()) { + } else if (gpu_info.IsPowerVR()) { return TensorStorageType::TEXTURE_2D; - } else if (gpu.IsMali()) { - const MaliInfo mali_info = gpu.info_.mali_info; + } else if (gpu_info.IsMali()) { + const MaliInfo mali_info = gpu_info.mali_info; if (mali_info.IsMaliT8xx() || mali_info.IsBifrostGen3() || mali_info.IsValhall()) { return TensorStorageType::TEXTURE_2D; } else { return TensorStorageType::BUFFER; } - } else if (gpu.IsNvidia()) { - return gpu.SupportsImageBuffer() ? TensorStorageType::IMAGE_BUFFER - : TensorStorageType::BUFFER; - } else if (gpu.IsAMD()) { - return gpu.SupportsImageBuffer() ? TensorStorageType::IMAGE_BUFFER - : TensorStorageType::BUFFER; - } else if (gpu.IsIntel()) { + } else if (gpu_info.IsNvidia()) { + return gpu_info.SupportsImageBuffer() ? TensorStorageType::IMAGE_BUFFER + : TensorStorageType::BUFFER; + } else if (gpu_info.IsAMD()) { + return gpu_info.SupportsImageBuffer() ? TensorStorageType::IMAGE_BUFFER + : TensorStorageType::BUFFER; + } else if (gpu_info.IsIntel()) { return TensorStorageType::BUFFER; } return TensorStorageType::BUFFER; } TensorStorageType GetStorageTypeWithMinimalMemoryConsumption( - const CLDevice& gpu) { - if (gpu.IsAdreno()) { - if (gpu.IsAdreno3xx() || gpu.IsAdreno4xx()) { + const DeviceInfo& gpu_info) { + if (gpu_info.IsAdreno()) { + if (gpu_info.IsAdreno3xx() || gpu_info.IsAdreno4xx()) { return TensorStorageType::BUFFER; } else { return TensorStorageType::IMAGE_BUFFER; } - } else if (gpu.IsPowerVR()) { + } else if (gpu_info.IsPowerVR()) { return TensorStorageType::BUFFER; - } else if (gpu.IsMali()) { + } else if (gpu_info.IsMali()) { return TensorStorageType::BUFFER; - } else if (gpu.IsNvidia()) { - return gpu.SupportsImageBuffer() ? TensorStorageType::IMAGE_BUFFER - : TensorStorageType::BUFFER; - } else if (gpu.IsAMD()) { - return gpu.SupportsImageBuffer() ? TensorStorageType::IMAGE_BUFFER - : TensorStorageType::BUFFER; - } else if (gpu.IsIntel()) { + } else if (gpu_info.IsNvidia()) { + return gpu_info.SupportsImageBuffer() ? TensorStorageType::IMAGE_BUFFER + : TensorStorageType::BUFFER; + } else if (gpu_info.IsAMD()) { + return gpu_info.SupportsImageBuffer() ? TensorStorageType::IMAGE_BUFFER + : TensorStorageType::BUFFER; + } else if (gpu_info.IsIntel()) { return TensorStorageType::BUFFER; } return TensorStorageType::BUFFER; diff --git a/tensorflow/lite/delegates/gpu/cl/environment.h b/tensorflow/lite/delegates/gpu/cl/environment.h index 640f2d8cac3..1f5b4befdce 100644 --- a/tensorflow/lite/delegates/gpu/cl/environment.h +++ b/tensorflow/lite/delegates/gpu/cl/environment.h @@ -19,9 +19,9 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/cl/cl_command_queue.h" #include "tensorflow/lite/delegates/gpu/cl/cl_context.h" #include "tensorflow/lite/delegates/gpu/cl/cl_device.h" +#include "tensorflow/lite/delegates/gpu/cl/device_info.h" #include "tensorflow/lite/delegates/gpu/cl/precision.h" #include "tensorflow/lite/delegates/gpu/cl/program_cache.h" -#include "tensorflow/lite/delegates/gpu/cl/tensor.h" #include "tensorflow/lite/delegates/gpu/cl/tensor_type.h" #include "tensorflow/lite/delegates/gpu/common/data_type.h" #include "tensorflow/lite/delegates/gpu/common/status.h" @@ -75,9 +75,9 @@ class Environment { ProgramCache program_cache_; }; -TensorStorageType GetFastestStorageType(const CLDevice& gpu); +TensorStorageType GetFastestStorageType(const DeviceInfo& gpu_info); TensorStorageType GetStorageTypeWithMinimalMemoryConsumption( - const CLDevice& gpu); + const DeviceInfo& gpu_info); absl::Status CreateEnvironment(Environment* result); diff --git a/tensorflow/lite/delegates/gpu/cl/gpu_api_delegate.cc b/tensorflow/lite/delegates/gpu/cl/gpu_api_delegate.cc index fc8fcde439b..e0933ed56e1 100644 --- a/tensorflow/lite/delegates/gpu/cl/gpu_api_delegate.cc +++ b/tensorflow/lite/delegates/gpu/cl/gpu_api_delegate.cc @@ -27,7 +27,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/model_builder.h" #include "tensorflow/lite/delegates/gpu/common/model_transformer.h" #include "tensorflow/lite/delegates/gpu/common/status.h" -#include "tensorflow/lite/delegates/gpu/common/transformations/general_transformations.h" +#include "tensorflow/lite/delegates/gpu/common/transformations/model_transformations.h" namespace tflite { namespace gpu { @@ -97,8 +97,8 @@ class Delegate { // Apply general transformations on the graph. NullTransformationReporter reporter; ModelTransformer transformer(&graph, &reporter); - if (!ApplyGeneralTransformations(&transformer)) { - return absl::InternalError("Graph general transformations failed"); + if (!ApplyModelTransformations(&transformer)) { + return absl::InternalError("Graph transformations failed"); } InferenceEnvironmentOptions env_options; diff --git a/tensorflow/lite/delegates/gpu/cl/gpu_object.h b/tensorflow/lite/delegates/gpu/cl/gpu_object.h index 297a5f70858..abd77a4489b 100644 --- a/tensorflow/lite/delegates/gpu/cl/gpu_object.h +++ b/tensorflow/lite/delegates/gpu/cl/gpu_object.h @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/cl/cl_context.h" #include "tensorflow/lite/delegates/gpu/cl/opencl_wrapper.h" +#include "tensorflow/lite/delegates/gpu/cl/serialization_generated.h" #include "tensorflow/lite/delegates/gpu/common/access_type.h" #include "tensorflow/lite/delegates/gpu/common/data_type.h" #include "tensorflow/lite/delegates/gpu/common/status.h" @@ -164,6 +165,10 @@ class GPUObjectDescriptor { AccessType GetAccess() const { return access_type_; } protected: + friend flatbuffers::Offset Encode( + const GPUObjectDescriptor& desc, flatbuffers::FlatBufferBuilder* builder); + friend void Decode(const data::GPUObjectDescriptor* fb_obj, + GPUObjectDescriptor* obj); mutable std::map state_vars_; AccessType access_type_; }; diff --git a/tensorflow/lite/delegates/gpu/cl/inference_context.cc b/tensorflow/lite/delegates/gpu/cl/inference_context.cc index 9cb8ddee818..ca0c0319f54 100644 --- a/tensorflow/lite/delegates/gpu/cl/inference_context.cc +++ b/tensorflow/lite/delegates/gpu/cl/inference_context.cc @@ -153,7 +153,7 @@ CLNode& CLNode::operator=(CLNode&& node) { absl::Status InferenceContext::InitFromGraph( const CreateInferenceInfo& create_info, const GraphFloat32& graph, - Environment* env) { + Environment* env, std::vector* serialized_model) { CreationContext creation_context; creation_context.device = env->GetDevicePtr(); creation_context.context = &env->context(); @@ -188,15 +188,63 @@ absl::Status InferenceContext::InitFromGraph( if (create_info.hints.Check(ModelHints::kFastTuning)) { tuning_parameters.tuning_type = TuningType::FAST; } + if (tuning_parameters.info->IsMali()) { + const MaliInfo& info = tuning_parameters.info->mali_info; + if (info.IsMaliT6xx()) { + // Mali T628 hangs forever in clFinish when used profiling queue + // TuningType::FAST does not use profiling queue. + tuning_parameters.tuning_type = TuningType::FAST; + } + } RETURN_IF_ERROR(Tune(tuning_parameters)); + + if (serialized_model) { + flatbuffers::FlatBufferBuilder builder; + auto encoded_fb = Encode(*this, &builder); + data::FinishInferenceContextBuffer(builder, encoded_fb); + serialized_model->resize(builder.GetSize()); + std::memcpy(serialized_model->data(), builder.GetBufferPointer(), + builder.GetSize()); + } + for (auto& node : nodes_) { + node.operation->args_.ReleaseCPURepresentation(); + } + return absl::OkStatus(); +} + +absl::Status InferenceContext::RestoreDeserialized( + const std::vector& serialized_model, Environment* env) { + flatbuffers::Verifier verifier(serialized_model.data(), + serialized_model.size()); + if (!data::VerifyInferenceContextBuffer(verifier)) { + return absl::DataLossError("Deserialization failed."); + } + auto decoded_fb = data::GetInferenceContext(serialized_model.data()); + RETURN_IF_ERROR(Decode(&env->context(), decoded_fb, this)); + + CreationContext creation_context; + creation_context.device = env->GetDevicePtr(); + creation_context.context = &env->context(); + creation_context.queue = env->queue(); + creation_context.cache = env->program_cache(); + + RETURN_IF_ERROR(AllocateMemory(creation_context.context)); + BindMemoryToOperations(); + for (auto& node : nodes_) { + RETURN_IF_ERROR(node.operation->CompileDeserialized(creation_context)); + } + RETURN_IF_ERROR(UpdateParams()); + for (auto& node : nodes_) { + node.operation->args_.ReleaseCPURepresentation(); + } return absl::OkStatus(); } absl::Status InferenceContext::InitFromGraphWithTransforms( const CreateInferenceInfo& create_info, GraphFloat32* graph, - Environment* env) { + Environment* env, std::vector* serialized_model) { RETURN_IF_ERROR(RunGraphTransforms(graph)); - RETURN_IF_ERROR(InitFromGraph(create_info, *graph, env)); + RETURN_IF_ERROR(InitFromGraph(create_info, *graph, env, serialized_model)); return absl::OkStatus(); } @@ -206,6 +254,11 @@ void InferenceContext::CopyInAndOutIds(const GraphFloat32& graph) { input_ids_.push_back(input->id); } + const auto variable_inputs = graph.variable_inputs(); + for (const auto& variable_input : variable_inputs) { + variable_ids_and_refs_[variable_input->id] = variable_input->tensor.ref; + } + const auto outputs = graph.outputs(); for (const auto& output : outputs) { output_ids_.push_back(output->id); @@ -261,10 +314,12 @@ absl::Status InferenceContext::ConvertOperations(const DeviceInfo& device_info, if (consumed_nodes.find(node.id) != consumed_nodes.end()) { continue; } + std::string op_name = node.operation.type + " " + std::to_string(node.id); GPUOperationsSubgraph gpu_subgraph; if (hints.Check(ModelHints::kAllowSpecialKernels) && GPUSubgraphFromGraph(device_info, precision_, graph, node.id, - tensor_descriptors, &consumed_nodes, &gpu_subgraph) + tensor_descriptors, &consumed_nodes, &gpu_subgraph, + &op_name) .ok()) { // Mapping of subgraph (set of nodes) to GPU operations. Should happen // before straigtforward mapping. @@ -333,7 +388,7 @@ absl::Status InferenceContext::ConvertOperations(const DeviceInfo& device_info, cl_node.outputs[j] = mapping_to_global_ids[-(id + 1)]; } } - cl_node.name = node.operation.type + " " + std::to_string(node.id); + cl_node.name = op_name; nodes_.push_back(std::move(cl_node)); } } @@ -387,41 +442,71 @@ absl::Status InferenceContext::Merge() { return absl::OkStatus(); } -void InferenceContext::GetUsages( - const std::function& functor, - std::map* usages) { +void InferenceContext::GetUsages(const std::function& functor, + std::map* usages) { for (ValueId in_id : input_ids_) { - const auto& desc = tensor_reserver_.Get(in_id).descriptor; - if (functor(desc)) { + if (functor(in_id)) { AddUsage(in_id, 0, usages); } } for (int op_index = 0; op_index < nodes_.size(); ++op_index) { auto tensors = GetCLNodeTensors(nodes_[op_index]); for (auto& tensor : tensors) { - if (functor(tensor.second)) { + if (functor(tensor.first)) { AddUsage(tensor.first, op_index, usages); } } } for (ValueId out_id : output_ids_) { - const auto& desc = tensor_reserver_.Get(out_id).descriptor; - if (functor(desc)) { + if (functor(out_id)) { AddUsage(out_id, nodes_.size(), usages); } } } +InferenceContext::TensorMemoryType InferenceContext::GetTensorMemoryType( + ValueId id) { + if (variable_ids_and_refs_.find(id) != variable_ids_and_refs_.end()) { + return TensorMemoryType::VARIABLE; + } else if (IsBufferBased(tensor_reserver_.Get(id).descriptor.storage_type)) { + return TensorMemoryType::BUFFER; + } else { + return TensorMemoryType::STRONG_SHAPE; + } +} + absl::Status InferenceContext::AllocateMemory(CLContext* context) { + RETURN_IF_ERROR(AllocateMemoryForVariableTensors(context)); RETURN_IF_ERROR(AllocateMemoryForBuffers(context)); RETURN_IF_ERROR(AllocateMemoryForStrongShapes(context)); return absl::OkStatus(); } +absl::Status InferenceContext::AllocateMemoryForVariableTensors( + CLContext* context) { + std::map ref_value_to_tensor_index; + + for (auto value_and_ref_value : variable_ids_and_refs_) { + if (ref_value_to_tensor_index.find(value_and_ref_value.second) == + ref_value_to_tensor_index.end()) { + const auto& t = tensor_reserver_.Get(value_and_ref_value.first); + const auto& shape = t.shape; + const auto& descriptor = t.descriptor; + + RETURN_IF_ERROR( + CreateTensor(*context, shape, descriptor, + &variable_tensors_[value_and_ref_value.second])); + } + } + return absl::OkStatus(); +} + absl::Status InferenceContext::AllocateMemoryForBuffers(CLContext* context) { std::map buffer_usages; GetUsages( - [](const TensorDescriptor& t) { return IsBufferBased(t.storage_type); }, + [this](ValueId id) { + return GetTensorMemoryType(id) == TensorMemoryType::BUFFER; + }, &buffer_usages); std::vector> buffer_usage_records; @@ -455,7 +540,7 @@ absl::Status InferenceContext::AllocateMemoryForBuffers(CLContext* context) { for (auto& node : nodes_) { auto tensors = GetCLNodeTensors(node); for (auto& t : tensors) { - if (!IsBufferBased(t.second.storage_type)) continue; + if (GetTensorMemoryType(t.first) != TensorMemoryType::BUFFER) continue; const int tensor_index = graph_ids_to_shared_buffer_tensors_[t.first]; if (created_tensors[tensor_index]) continue; const auto& shape = tensor_reserver_.Get(t.first).shape; @@ -473,7 +558,9 @@ absl::Status InferenceContext::AllocateMemoryForStrongShapes( CLContext* context) { std::map usages; GetUsages( - [](const TensorDescriptor& t) { return !IsBufferBased(t.storage_type); }, + [this](ValueId id) { + return GetTensorMemoryType(id) == TensorMemoryType::STRONG_SHAPE; + }, &usages); std::vector> usage_records; @@ -492,7 +579,9 @@ absl::Status InferenceContext::AllocateMemoryForStrongShapes( for (auto& node : nodes_) { auto tensors = GetCLNodeTensors(node); for (auto& t : tensors) { - if (IsBufferBased(t.second.storage_type)) continue; + if (GetTensorMemoryType(t.first) != TensorMemoryType::STRONG_SHAPE) { + continue; + } const auto& shape = tensor_reserver_.Get(t.first).shape; const auto id = assignment.object_ids[remap_from_graph_ids[t.first]]; graph_ids_to_strong_shape_tensors_[t.first] = id; @@ -581,13 +670,18 @@ uint64_t InferenceContext::GetSizeOfMemoryAllocatedForIntermediateTensors() for (const auto& b : shared_buffers_) { total_memory += b.GetMemorySizeInBytes(); } + for (const auto& t : variable_tensors_) { + total_memory += t.second.GetMemorySizeInBytes(); + } return total_memory; } Tensor* InferenceContext::GetTensor(ValueId id) { - if (graph_ids_to_shared_buffer_tensors_.find(id) != - graph_ids_to_shared_buffer_tensors_.end()) { + if (variable_ids_and_refs_.find(id) != variable_ids_and_refs_.end()) { + return &variable_tensors_[variable_ids_and_refs_[id]]; + } else if (graph_ids_to_shared_buffer_tensors_.find(id) != + graph_ids_to_shared_buffer_tensors_.end()) { return &shared_buffer_tensors_[graph_ids_to_shared_buffer_tensors_[id]]; } else { return &strong_shape_tensors_[graph_ids_to_strong_shape_tensors_[id]]; diff --git a/tensorflow/lite/delegates/gpu/cl/inference_context.h b/tensorflow/lite/delegates/gpu/cl/inference_context.h index 8486f2ddcd3..ec8055ebcde 100644 --- a/tensorflow/lite/delegates/gpu/cl/inference_context.h +++ b/tensorflow/lite/delegates/gpu/cl/inference_context.h @@ -26,10 +26,12 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/cl/buffer.h" #include "tensorflow/lite/delegates/gpu/cl/cl_command_queue.h" #include "tensorflow/lite/delegates/gpu/cl/environment.h" +#include "tensorflow/lite/delegates/gpu/cl/gpu_object.h" #include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/cl/model_hints.h" #include "tensorflow/lite/delegates/gpu/cl/opencl_wrapper.h" #include "tensorflow/lite/delegates/gpu/cl/precision.h" +#include "tensorflow/lite/delegates/gpu/cl/serialization_generated.h" #include "tensorflow/lite/delegates/gpu/cl/tensor_type.h" #include "tensorflow/lite/delegates/gpu/common/model.h" #include "tensorflow/lite/delegates/gpu/common/status.h" @@ -62,15 +64,17 @@ class InferenceContext { TensorStorageType storage_type; ModelHints hints; }; + absl::Status InitFromGraph(const CreateInferenceInfo& create_info, - const GraphFloat32& graph, Environment* env); + const GraphFloat32& graph, Environment* env, + std::vector* serialized_model = nullptr); // Applies OpenCL-specific transformations to the graph before the // initialization. These transformations are either impossible or useless in // other backends. absl::Status InitFromGraphWithTransforms( const CreateInferenceInfo& create_info, GraphFloat32* graph, - Environment* env); + Environment* env, std::vector* serialized_model = nullptr); absl::Status AddToQueue(CLCommandQueue* queue); absl::Status Profile(ProfilingCommandQueue* queue, ProfilingInfo* result); @@ -87,7 +91,22 @@ class InferenceContext { absl::Status GetOutputTensor(ValueId id, CLCommandQueue* queue, TensorFloat32* result); + const std::vector& GetInputIds() const { return input_ids_; } + const std::vector& GetOutputIds() const { return output_ids_; } + + absl::Status RestoreDeserialized(const std::vector& serialized_model, + Environment* env); + private: + enum TensorMemoryType { STRONG_SHAPE = 0, BUFFER = 1, VARIABLE = 2 }; + + friend flatbuffers::Offset Encode( + const InferenceContext& inference, + flatbuffers::FlatBufferBuilder* builder); + friend absl::Status Decode(CLContext* context, + const data::InferenceContext* fb_inference, + InferenceContext* inference); + void CopyInAndOutIds(const GraphFloat32& graph); absl::Status ConvertOperations(const DeviceInfo& device_info, const GraphFloat32& graph, ModelHints hints); @@ -98,14 +117,18 @@ class InferenceContext { absl::Status Merge(); absl::Status AllocateMemory(CLContext* context); + absl::Status AllocateMemoryForVariableTensors(CLContext* context); + absl::Status AllocateMemoryForBuffers(CLContext* context); absl::Status AllocateMemoryForStrongShapes(CLContext* context); // utility function - void GetUsages(const std::function& functor, + void GetUsages(const std::function& functor, std::map* usages); + TensorMemoryType GetTensorMemoryType(ValueId id); + void BindMemoryToOperations(); absl::Status Compile(const CreationContext& creation_context); absl::Status Tune(const TuningParameters& tuning_parameters); @@ -154,12 +177,39 @@ class InferenceContext { void SetNext(ValueId id) { next_ = id; } DummyTensor Get(ValueId id) { return reservations_[id]; } + std::vector> GetTensorDescs() const { + std::vector> result; + for (auto& v : reservations_) { + TensorDescriptor desc = v.second.descriptor; + desc.shape.b = v.second.shape.b; + desc.shape.h = v.second.shape.h; + desc.shape.w = v.second.shape.w; + desc.shape.d = 1; + desc.shape.c = v.second.shape.c; + result.push_back({v.first, desc}); + } + return result; + } + + void Add(const std::vector>& tensors) { + for (auto& v : tensors) { + DummyTensor dummy; + dummy.descriptor = v.second; + dummy.shape.b = v.second.shape.b; + dummy.shape.h = v.second.shape.h; + dummy.shape.w = v.second.shape.w; + dummy.shape.c = v.second.shape.c; + Add(v.first, dummy); + } + } + private: absl::flat_hash_map reservations_; ValueId next_; }; TensorReserver tensor_reserver_; + std::map variable_tensors_; std::vector shared_buffers_; std::vector shared_buffer_tensors_; // use references to memory from shared_buffers_ @@ -169,6 +219,7 @@ class InferenceContext { std::map graph_ids_to_strong_shape_tensors_; std::vector input_ids_; + std::map variable_ids_and_refs_; std::vector output_ids_; }; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/BUILD b/tensorflow/lite/delegates/gpu/cl/kernels/BUILD index 02f5f9c4a4a..d7e7c7dd498 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/BUILD +++ b/tensorflow/lite/delegates/gpu/cl/kernels/BUILD @@ -104,32 +104,6 @@ cc_library( ], ) -cc_library( - name = "conv_3d", - srcs = ["conv_3d.cc"], - hdrs = ["conv_3d.h"], - deps = [ - ":gpu_operation", - ":util", - ":work_group_picking", - "//tensorflow/lite/delegates/gpu/cl:buffer", - "//tensorflow/lite/delegates/gpu/cl:cl_device", - "//tensorflow/lite/delegates/gpu/cl:linear_storage", - "//tensorflow/lite/delegates/gpu/cl:precision", - "//tensorflow/lite/delegates/gpu/cl:tensor", - "//tensorflow/lite/delegates/gpu/cl:tensor_type", - "//tensorflow/lite/delegates/gpu/cl:texture2d", - "//tensorflow/lite/delegates/gpu/cl:util", - "//tensorflow/lite/delegates/gpu/common:data_type", - "//tensorflow/lite/delegates/gpu/common:operations", - "//tensorflow/lite/delegates/gpu/common:shape", - "//tensorflow/lite/delegates/gpu/common:status", - "//tensorflow/lite/delegates/gpu/common:tensor", - "//tensorflow/lite/delegates/gpu/common:types", - "@com_google_absl//absl/strings", - ], -) - cc_library( name = "conv_buffer_1x1", srcs = ["conv_buffer_1x1.cc"], @@ -233,6 +207,7 @@ cc_library( "//tensorflow/lite/delegates/gpu/cl:precision", "//tensorflow/lite/delegates/gpu/cl:tensor", "//tensorflow/lite/delegates/gpu/cl:tensor_type", + "//tensorflow/lite/delegates/gpu/cl:texture2d", "//tensorflow/lite/delegates/gpu/cl:util", "//tensorflow/lite/delegates/gpu/common:data_type", "//tensorflow/lite/delegates/gpu/common:operations", @@ -241,6 +216,7 @@ cc_library( "//tensorflow/lite/delegates/gpu/common:tensor", "//tensorflow/lite/delegates/gpu/common:types", "//tensorflow/lite/delegates/gpu/common:winograd_util", + "@com_google_absl//absl/strings", ], ) @@ -263,50 +239,6 @@ cc_test( ], ) -cc_library( - name = "conv_texture", - srcs = ["conv_texture.cc"], - hdrs = ["conv_texture.h"], - deps = [ - ":gpu_operation", - ":util", - ":work_group_picking", - "//tensorflow/lite/delegates/gpu/cl:cl_command_queue", - "//tensorflow/lite/delegates/gpu/cl:cl_context", - "//tensorflow/lite/delegates/gpu/cl:linear_storage", - "//tensorflow/lite/delegates/gpu/cl:precision", - "//tensorflow/lite/delegates/gpu/cl:tensor", - "//tensorflow/lite/delegates/gpu/cl:tensor_type", - "//tensorflow/lite/delegates/gpu/cl:texture2d", - "//tensorflow/lite/delegates/gpu/cl:util", - "//tensorflow/lite/delegates/gpu/common:data_type", - "//tensorflow/lite/delegates/gpu/common:operations", - "//tensorflow/lite/delegates/gpu/common:shape", - "//tensorflow/lite/delegates/gpu/common:status", - "//tensorflow/lite/delegates/gpu/common:tensor", - "//tensorflow/lite/delegates/gpu/common:types", - "//tensorflow/lite/delegates/gpu/common:winograd_util", - "@com_google_absl//absl/strings", - ], -) - -cc_test( - name = "conv_texture_test", - srcs = ["conv_texture_test.cc"], - linkstatic = True, - tags = tf_gpu_tests_tags() + [ - "linux", - "local", - ], - deps = [ - ":cl_test", - ":conv_texture", - "//tensorflow/lite/delegates/gpu/common:operations", - "//tensorflow/lite/delegates/gpu/common:status", - "@com_google_googletest//:gtest_main", - ], -) - cc_library( name = "conv_weights_converter", srcs = ["conv_weights_converter.cc"], @@ -384,30 +316,6 @@ cc_test( ], ) -cc_library( - name = "convolution_transposed_3d", - srcs = ["convolution_transposed_3d.cc"], - hdrs = ["convolution_transposed_3d.h"], - deps = [ - ":gpu_operation", - ":util", - ":work_group_picking", - "//tensorflow/lite/delegates/gpu/cl:buffer", - "//tensorflow/lite/delegates/gpu/cl:linear_storage", - "//tensorflow/lite/delegates/gpu/cl:tensor", - "//tensorflow/lite/delegates/gpu/cl:tensor_type", - "//tensorflow/lite/delegates/gpu/cl:texture2d", - "//tensorflow/lite/delegates/gpu/cl:util", - "//tensorflow/lite/delegates/gpu/common:data_type", - "//tensorflow/lite/delegates/gpu/common:operations", - "//tensorflow/lite/delegates/gpu/common:shape", - "//tensorflow/lite/delegates/gpu/common:status", - "//tensorflow/lite/delegates/gpu/common:tensor", - "//tensorflow/lite/delegates/gpu/common:types", - "@com_google_absl//absl/strings", - ], -) - cc_library( name = "convolution_transposed_3x3", srcs = ["convolution_transposed_3x3.cc"], @@ -681,17 +589,24 @@ cc_library( hdrs = ["fully_connected.h"], deps = [ ":gpu_operation", + ":tuning_parameters", ":util", + "//tensorflow/lite/delegates/gpu/cl:arguments", "//tensorflow/lite/delegates/gpu/cl:buffer", + "//tensorflow/lite/delegates/gpu/cl:cl_kernel", + "//tensorflow/lite/delegates/gpu/cl:device_info", "//tensorflow/lite/delegates/gpu/cl:linear_storage", + "//tensorflow/lite/delegates/gpu/cl:precision", "//tensorflow/lite/delegates/gpu/cl:tensor", - "//tensorflow/lite/delegates/gpu/cl:util", + "//tensorflow/lite/delegates/gpu/cl:tensor_type", + "//tensorflow/lite/delegates/gpu/cl:texture2d", "//tensorflow/lite/delegates/gpu/common:data_type", "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:shape", - "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common:tensor", "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/common:util", + "@com_google_absl//absl/memory", ], ) @@ -706,8 +621,14 @@ cc_test( deps = [ ":cl_test", ":fully_connected", + ":gpu_operation", + "//tensorflow/lite/delegates/gpu/cl:environment", + "//tensorflow/lite/delegates/gpu/cl:precision", + "//tensorflow/lite/delegates/gpu/cl:tensor_type", + "//tensorflow/lite/delegates/gpu/common:data_type", "//tensorflow/lite/delegates/gpu/common:operations", - "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common:shape", + "//tensorflow/lite/delegates/gpu/common:tensor", "@com_google_googletest//:gtest_main", ], ) @@ -722,13 +643,19 @@ cc_library( ":work_group_picking", "//tensorflow/lite/delegates/gpu/cl:arguments", "//tensorflow/lite/delegates/gpu/cl:buffer", + "//tensorflow/lite/delegates/gpu/cl:cl_command_queue", "//tensorflow/lite/delegates/gpu/cl:cl_context", "//tensorflow/lite/delegates/gpu/cl:cl_device", + "//tensorflow/lite/delegates/gpu/cl:cl_kernel", + "//tensorflow/lite/delegates/gpu/cl:cl_program", + "//tensorflow/lite/delegates/gpu/cl:device_info", "//tensorflow/lite/delegates/gpu/cl:precision", "//tensorflow/lite/delegates/gpu/cl:program_cache", + "//tensorflow/lite/delegates/gpu/cl:serialization_cc_fbs", "//tensorflow/lite/delegates/gpu/cl:tensor", "//tensorflow/lite/delegates/gpu/cl:tensor_type", "//tensorflow/lite/delegates/gpu/common:access_type", + "//tensorflow/lite/delegates/gpu/common:data_type", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common:types", "@com_google_absl//absl/strings", @@ -767,6 +694,25 @@ cc_test( ], ) +cc_test( + name = "lstm_full_test", + srcs = ["lstm_full_test.cc"], + linkstatic = True, + tags = tf_gpu_tests_tags() + [ + "linux", + "local", + ], + deps = [ + "//tensorflow/lite:framework", + "//tensorflow/lite/delegates/gpu:delegate", + "//tensorflow/lite/kernels:test_main", + "//tensorflow/lite/kernels:test_util", + "//tensorflow/lite/schema:schema_fbs", + "@com_google_googletest//:gtest", + "@flatbuffers", + ], +) + cc_library( name = "mean_stddev_normalization", srcs = ["mean_stddev_normalization.cc"], @@ -1005,6 +951,37 @@ cc_test( ], ) +cc_library( + name = "reduce", + srcs = ["reduce.cc"], + hdrs = ["reduce.h"], + deps = [ + ":gpu_operation", + ":util", + "//tensorflow/lite/delegates/gpu/cl:precision", + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common:types", + ], +) + +cc_test( + name = "reduce_test", + srcs = ["reduce_test.cc"], + linkstatic = True, + tags = tf_gpu_tests_tags() + [ + "linux", + "local", + ], + deps = [ + ":cl_test", + ":reduce", + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:status", + "@com_google_googletest//:gtest_main", + ], +) + cc_library( name = "relu", srcs = ["relu.cc"], @@ -1259,7 +1236,7 @@ cc_library( hdrs = ["tuning_parameters.h"], deps = [ "//tensorflow/lite/delegates/gpu/cl:cl_command_queue", - "//tensorflow/lite/delegates/gpu/cl:cl_device", + "//tensorflow/lite/delegates/gpu/cl:device_info", ], ) @@ -1302,11 +1279,8 @@ cc_library( deps = [ "//tensorflow/lite/delegates/gpu/cl:device_info", "//tensorflow/lite/delegates/gpu/cl:precision", - "//tensorflow/lite/delegates/gpu/cl:tensor_type", - "//tensorflow/lite/delegates/gpu/common:access_type", "//tensorflow/lite/delegates/gpu/common:data_type", "//tensorflow/lite/delegates/gpu/common:shape", - "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common:tensor", "//tensorflow/lite/delegates/gpu/common:types", "//tensorflow/lite/delegates/gpu/common:util", @@ -1364,9 +1338,8 @@ cc_library( hdrs = ["work_group_picking.h"], deps = [ ":tuning_parameters", - "//tensorflow/lite/delegates/gpu/cl:cl_command_queue", "//tensorflow/lite/delegates/gpu/cl:cl_kernel", - "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/cl:device_info", "//tensorflow/lite/delegates/gpu/common:types", "//tensorflow/lite/delegates/gpu/common:util", "//tensorflow/lite/delegates/gpu/common:workgroup_selection", @@ -1381,7 +1354,6 @@ test_suite( "conv_buffer_1x1_test", "conv_constants_test", "conv_powervr_test", - "conv_texture_test", "convolution_transposed_3x3_thin_test", "convolution_transposed_4x4_test", "convolution_transposed_test", @@ -1397,6 +1369,7 @@ test_suite( "padding_test", "pooling_test", "prelu_test", + "reduce_test", "relu_test", "reshape_test", "reshapex4_test", diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/cl_test.cc b/tensorflow/lite/delegates/gpu/cl/kernels/cl_test.cc index 0112241117e..efe97f9931b 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/cl_test.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/cl_test.cc @@ -55,6 +55,7 @@ absl::Status ExecuteGPUOperation(const std::vector& src_cpu, RETURN_IF_ERROR(operation->Compile(creation_context)); RETURN_IF_ERROR(operation->UpdateParams()); + operation->args_.ReleaseCPURepresentation(); RETURN_IF_ERROR(operation->AddToQueue(creation_context.queue)); RETURN_IF_ERROR(creation_context.queue->WaitForCompletion()); diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/conv_3d.cc b/tensorflow/lite/delegates/gpu/cl/kernels/conv_3d.cc deleted file mode 100644 index 06664f67768..00000000000 --- a/tensorflow/lite/delegates/gpu/cl/kernels/conv_3d.cc +++ /dev/null @@ -1,863 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/delegates/gpu/cl/kernels/conv_3d.h" - -#include -#include -#include - -#include "absl/strings/substitute.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h" -#include "tensorflow/lite/delegates/gpu/cl/precision.h" -#include "tensorflow/lite/delegates/gpu/cl/tensor_type.h" -#include "tensorflow/lite/delegates/gpu/common/data_type.h" -#include "tensorflow/lite/delegates/gpu/common/status.h" - -namespace tflite { -namespace gpu { -namespace cl { -namespace { -std::string GenerateUploadByThreads(const std::string& local_ptr_name, - const std::string& global_ptr_name, - const std::string& global_offset_name, - const std::string& lid_name, - int total_work_items, - int elements_to_upload) { - std::string c; - std::string offset = - global_offset_name.empty() ? "" : global_offset_name + " + "; - const int groups = elements_to_upload / total_work_items; - const int reminder = elements_to_upload % total_work_items; - for (int i = 0; i < groups; ++i) { - c += " " + local_ptr_name + "[" + lid_name + " + " + - std::to_string(total_work_items * i) + "] = " + global_ptr_name + "[" + - offset + lid_name + " + " + std::to_string(total_work_items * i) + - "];\n"; - } - if (reminder != 0) { - c += " if (" + lid_name + " < " + std::to_string(reminder) + ") {\n"; - c += " " + local_ptr_name + "[" + lid_name + " + " + - std::to_string(total_work_items * groups) + "] = " + global_ptr_name + - "[" + offset + lid_name + " + " + - std::to_string(total_work_items * groups) + "];\n"; - c += " }\n"; - } - return c; -} - -std::string GenerateAsyncUpload(const std::string& local_ptr_name, - const std::string& global_ptr_name, - const std::string& global_offset_name, - int elements_to_upload) { - std::string c; - std::string offset = - global_offset_name.empty() ? "" : " + " + global_offset_name; - c += " async_work_group_copy(" + local_ptr_name + ", " + global_ptr_name + - offset + ", " + std::to_string(elements_to_upload) + ", 0);\n"; - return c; -} - -std::string GenerateGlobalCoordinates(const int4& block_size, - const int3& work_group_launch_order) { - std::string c; - int3 launch_remap; - launch_remap[work_group_launch_order.x] = 0; - launch_remap[work_group_launch_order.y] = 1; - launch_remap[work_group_launch_order.z] = 2; - if (work_group_launch_order[0] == 0) { - c += " int DST_X = get_global_id(0) * " + std::to_string(block_size.x) + - ";\n"; - } else { - c += " int DST_X = (get_group_id(" + std::to_string(launch_remap[0]) + - ") * get_local_size(0) + get_local_id(0)) * " + - std::to_string(block_size.x) + ";\n"; - } - if (work_group_launch_order[1] == 1) { - c += " int DST_Y = get_global_id(1) * " + std::to_string(block_size.y) + - ";\n"; - } else { - c += " int DST_Y = (get_group_id(" + std::to_string(launch_remap[1]) + - ") * get_local_size(1) + get_local_id(1)) * " + - std::to_string(block_size.y) + ";\n"; - } - if (work_group_launch_order[2] == 2) { - c += " int linear_id_z = get_global_id(2);\n"; - } else { - c += " int linear_id_z = get_group_id(" + std::to_string(launch_remap[2]) + - ") * get_local_size(2) + get_local_id(2);\n"; - } - c += " int DST_S = (linear_id_z % args.grid_size_s) * " + - std::to_string(block_size.w) + ";\n"; - c += " int DST_Z = (linear_id_z / args.grid_size_s) * " + - std::to_string(block_size.z) + ";\n"; - return c; -} - -std::string GenerateConv(CalculationsPrecision precision, - const int4& block_size, int offset, - bool weights_are_buffer) { - std::string c; - const std::string channels[] = {"x", "y", "z", "w"}; - for (int s = 0; s < block_size.w; ++s) { - switch (precision) { - case CalculationsPrecision::F32: - case CalculationsPrecision::F16: - for (int ch = 0; ch < 4; ++ch) { - const std::string weight_id = std::to_string(s * 4 + ch + offset); - std::string weight_name; - if (weights_are_buffer) { - weight_name = "weights_cache[" + weight_id + "]"; - } else { - weight_name = "f" + weight_id; - } - for (int z = 0; z < block_size.z; ++z) { - for (int y = 0; y < block_size.y; ++y) { - for (int x = 0; x < block_size.x; ++x) { - std::string id = - std::to_string(z) + std::to_string(y) + std::to_string(x); - c += " r" + std::to_string(s) + id + " += " + weight_name + - " * src" + id + "." + channels[ch] + ";\n"; - } - } - } - } - break; - case CalculationsPrecision::F32_F16: - for (int z = 0; z < block_size.z; ++z) { - for (int y = 0; y < block_size.y; ++y) { - for (int x = 0; x < block_size.x; ++x) { - std::string id = - std::to_string(z) + std::to_string(y) + std::to_string(x); - std::vector weight_names(4); - for (int i = 0; i < 4; ++i) { - std::string weight_id = std::to_string(s * 4 + i + offset); - if (weights_are_buffer) { - weight_names[i] = "weights_cache[" + weight_id + "]"; - } else { - weight_names[i] = "f" + weight_id; - } - } - c += absl::Substitute( - " $0 += convert_float4($1.x * $2 + $1.y * $3 + $1.z * " - "$4 + $1.w * $5);\n", - "r" + std::to_string(s) + id, "src" + id, weight_names[0], - weight_names[1], weight_names[2], weight_names[3]); - } - } - } - break; - } - } - return c; -} -} // namespace - -Conv3D::Conv3D(const OperationDef& definition, - const Convolution3DAttributes& attr, - const DeviceInfo& device_info) - : GPUOperation(definition), - stride_(attr.strides.w, attr.strides.h, attr.strides.d), - padding_(-attr.padding.prepended.w, -attr.padding.prepended.h, - -attr.padding.prepended.d), - kernel_size_(attr.weights.shape.w, attr.weights.shape.h, - attr.weights.shape.d), - dilation_(attr.dilations.w, attr.dilations.h, attr.dilations.d), - conv_params_(GuessBestParams(device_info, definition, attr)) { - const bool stride_correction = - definition_.IsBatchSupported() && stride_.x != 1; - code_ = GenerateConv3D(definition_, stride_correction, conv_params_); - if (definition_.precision == CalculationsPrecision::F16 && - device_info.IsPowerVR()) { - compiler_options_.push_back(CompilerOptions::POWERVR_FP16); - } -} - -Conv3D::Conv3D(Conv3D&& operation) - : GPUOperation(std::move(operation)), - stride_(operation.stride_), - padding_(operation.padding_), - kernel_size_(operation.kernel_size_), - dilation_(operation.dilation_), - conv_params_(operation.conv_params_) {} - -Conv3D& Conv3D::operator=(Conv3D&& operation) { - if (this != &operation) { - std::swap(stride_, operation.stride_); - std::swap(padding_, operation.padding_); - std::swap(kernel_size_, operation.kernel_size_); - std::swap(dilation_, operation.dilation_); - std::swap(conv_params_, operation.conv_params_); - GPUOperation::operator=(std::move(operation)); - } - return *this; -} - -absl::Status Conv3D::BindArguments() { - if (!conv_params_.x_kernel_is_1) { - RETURN_IF_ERROR(args_.SetInt("stride_x", stride_.x)); - RETURN_IF_ERROR(args_.SetInt("padding_x", padding_.x * src_[0]->Batch())); - RETURN_IF_ERROR(args_.SetInt("kernel_size_x", kernel_size_.x)); - RETURN_IF_ERROR(args_.SetInt("dilation_x", dilation_.x * src_[0]->Batch())); - } - if (!conv_params_.y_kernel_is_1) { - RETURN_IF_ERROR(args_.SetInt("stride_y", stride_.y)); - RETURN_IF_ERROR(args_.SetInt("padding_y", padding_.y)); - RETURN_IF_ERROR(args_.SetInt("kernel_size_y", kernel_size_.y)); - RETURN_IF_ERROR(args_.SetInt("dilation_y", dilation_.y)); - } - if (!conv_params_.z_kernel_is_1) { - RETURN_IF_ERROR(args_.SetInt("stride_z", stride_.z)); - RETURN_IF_ERROR(args_.SetInt("padding_z", padding_.z)); - RETURN_IF_ERROR(args_.SetInt("kernel_size_z", kernel_size_.z)); - RETURN_IF_ERROR(args_.SetInt("dilation_z", dilation_.z)); - } - return args_.SetInt("grid_size_s", DivideRoundUp(dst_[0]->Slices(), - conv_params_.block_size.w)); -} - -int3 Conv3D::GetGridSize() const { - const int grid_x = DivideRoundUp(dst_[0]->Width() * dst_[0]->Batch(), - conv_params_.block_size.x); - const int grid_y = - DivideRoundUp(dst_[0]->Height(), conv_params_.block_size.y); - const int grid_z = - DivideRoundUp(dst_[0]->Slices(), conv_params_.block_size.w) * - DivideRoundUp(dst_[0]->Depth(), conv_params_.block_size.z); - int3 wg; - wg.x = DivideRoundUp(grid_x, work_group_size_.x); - wg.y = DivideRoundUp(grid_y, work_group_size_.y); - wg.z = DivideRoundUp(grid_z, work_group_size_.z); - return int3(wg[conv_params_.work_group_launch_order[0]] * work_group_size_.x, - wg[conv_params_.work_group_launch_order[1]] * work_group_size_.y, - wg[conv_params_.work_group_launch_order[2]] * work_group_size_.z); -} - -void Conv3D::GetPossibleKernelWorkGroups(TuningType tuning_type, - const DeviceInfo& device_info, - const KernelInfo& kernel_info, - std::vector* work_groups) const { - if (conv_params_.weights_upload_type == - WeightsUploadType::LOCAL_MEM_ASYNC_SUBGROUP || - conv_params_.weights_upload_type == - WeightsUploadType::LOCAL_MEM_BY_THREADS) { - work_groups->push_back(work_group_size_); - return; - } - if (conv_params_.work_group_launch_order[0] == 0 && - conv_params_.work_group_launch_order[1] == 1 && - conv_params_.work_group_launch_order[2] == 2) { - GetPossibleWorkGroupsConv(tuning_type, device_info, kernel_info, grid_size_, - work_groups); - } else { - work_groups->push_back(work_group_size_); - } -} - -std::string Conv3D::GenerateConv3D(const OperationDef& op_def, - bool stride_correction, - const Conv3D::ConvParams& conv_params) { - auto src_desc = op_def.src_tensors[0]; - src_desc.SetTextureAddressMode(TextureAddressMode::ZERO); - if (op_def.IsBatchSupported()) { - src_desc.SetStateVar("BatchedWidth", "true"); - } - AddSrcTensor("src_tensor", src_desc); - - auto dst_desc = op_def.dst_tensors[0]; - if (op_def.IsBatchSupported()) { - dst_desc.SetStateVar("BatchedWidth", "true"); - } - AddDstTensor("dst_tensor", dst_desc); - - if (!conv_params_.x_kernel_is_1) { - args_.AddInt("stride_x"); - args_.AddInt("padding_x"); - args_.AddInt("kernel_size_x"); - args_.AddInt("dilation_x"); - } - if (!conv_params_.y_kernel_is_1) { - args_.AddInt("stride_y"); - args_.AddInt("padding_y"); - args_.AddInt("kernel_size_y"); - args_.AddInt("dilation_y"); - } - if (!conv_params_.z_kernel_is_1) { - args_.AddInt("stride_z"); - args_.AddInt("padding_z"); - args_.AddInt("kernel_size_z"); - args_.AddInt("dilation_z"); - } - args_.AddInt("grid_size_s"); - - const auto src_tensor_type = op_def.src_tensors[0].storage_type; - const bool buffer_type = src_tensor_type == TensorStorageType::BUFFER || - src_tensor_type == TensorStorageType::IMAGE_BUFFER; - - const bool manual_clamp_x = buffer_type && !conv_params.x_kernel_is_1; - const bool manual_clamp_y = buffer_type && !conv_params.y_kernel_is_1; - const bool manual_clamp_z = - src_tensor_type != TensorStorageType::TEXTURE_3D && - !conv_params.z_kernel_is_1; - - const bool can_read_out_of_x = !buffer_type; - const bool can_read_out_of_y = !buffer_type; - const bool can_read_out_of_z = - src_tensor_type == TensorStorageType::TEXTURE_3D || - src_tensor_type == TensorStorageType::TEXTURE_2D || - src_tensor_type == TensorStorageType::SINGLE_TEXTURE_2D; - - const bool is1x1x1 = conv_params.x_kernel_is_1 && conv_params.y_kernel_is_1 && - conv_params.z_kernel_is_1; - - const bool need_local_mem = - conv_params.weights_upload_type == - Conv3D::WeightsUploadType::LOCAL_MEM_BY_THREADS || - conv_params.weights_upload_type == - Conv3D::WeightsUploadType::LOCAL_MEM_ASYNC_SUBGROUP; - - const int4 block_size = conv_params.block_size; - std::string c = GetCommonDefines(op_def.precision); - if (need_local_mem) { // we use fixed workgroup size when use local mem - c += "__attribute__((reqd_work_group_size(" + - std::to_string(work_group_size_.x) + ", " + - std::to_string(work_group_size_.y) + ", " + - std::to_string(work_group_size_.z) + ")))\n"; - } - c += "__kernel void main_function(\n"; - c += "$0) {\n"; - c += GenerateGlobalCoordinates(block_size, - conv_params.work_group_launch_order); - if (!need_local_mem) { - c += " if (DST_X >= args.dst_tensor.Width() || DST_Y >= " - "args.dst_tensor.Height() || DST_Z >= args.dst_tensor.Depth()) " - "return;\n"; - } - if (conv_params.weights_upload_type == - Conv3D::WeightsUploadType::LOCAL_MEM_BY_THREADS) { - c += " int lid = get_local_id(1) * " + std::to_string(work_group_size_.x) + - " + get_local_id(0);\n"; - } - for (int s = 0; s < block_size.w; ++s) { - for (int z = 0; z < block_size.z; ++z) { - for (int y = 0; y < block_size.y; ++y) { - for (int x = 0; x < block_size.x; ++x) { - c += " ACCUM_FLT4 r" + std::to_string(s) + std::to_string(z) + - std::to_string(y) + std::to_string(x) + - " = (ACCUM_FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n"; - } - } - } - } - if (!conv_params.x_kernel_is_1) { - for (int x = 0; x < block_size.x; ++x) { - const std::string xc = "(DST_X + " + std::to_string(x) + ")"; - if (stride_correction) { - c += " int xc" + std::to_string(x) + " = " + - GetXStrideCorrected(xc, "args.src_tensor.Batch()", "args.stride_x", - "args.padding_x") + - ";\n"; - } else { - c += " int xc" + std::to_string(x) + " = " + xc + - " * args.stride_x + args.padding_x;\n"; - } - } - } else if (!can_read_out_of_x) { - for (int x = 0; x < block_size.x; ++x) { - const std::string xc = "(DST_X + " + std::to_string(x) + ")"; - c += " int xc" + std::to_string(x) + " = clamp(" + xc + - ", 0, args.src_tensor.Width() - 1);\n"; - } - } - if (!conv_params.y_kernel_is_1) { - for (int y = 0; y < block_size.y; ++y) { - const std::string yc = "(DST_Y + " + std::to_string(y) + ")"; - c += " int yc" + std::to_string(y) + " = " + yc + - " * args.stride_y + args.padding_y;\n"; - } - } else if (!can_read_out_of_y) { - for (int y = 0; y < block_size.y; ++y) { - const std::string yc = "(DST_Y + " + std::to_string(y) + ")"; - c += " int yc" + std::to_string(y) + " = clamp(" + yc + - ", 0, args.src_tensor.Height() - 1);\n"; - } - } - if (!conv_params.z_kernel_is_1) { - for (int z = 0; z < block_size.z; ++z) { - const std::string zc = "(DST_Z + " + std::to_string(z) + ")"; - c += " int zc" + std::to_string(z) + " = " + zc + - " * args.stride_z + args.padding_z;\n"; - } - } else if (!can_read_out_of_z) { - for (int z = 0; z < block_size.z; ++z) { - const std::string zc = "(DST_Z + " + std::to_string(z) + ")"; - c += " int zc" + std::to_string(z) + " = clamp(" + zc + - ", 0, args.src_tensor.Depth() - 1);\n"; - } - } - if (need_local_mem) { - c += " __local FLT4 weights_cache[" + - std::to_string(block_size.w * 4 * conv_params.src_depth_loop_size) + - "];\n"; - } - if (conv_params.weights_upload_type == - Conv3D::WeightsUploadType::GLOBAL_MEM) { - c += " __global FLT4* weights_cache;\n"; - } - std::string kernel_size; - kernel_size += conv_params.x_kernel_is_1 ? "" : " * args.kernel_size_x"; - kernel_size += conv_params.y_kernel_is_1 ? "" : " * args.kernel_size_y"; - kernel_size += conv_params.z_kernel_is_1 ? "" : " * args.kernel_size_z"; - if (conv_params.AreWeightsBuffer()) { - c += " __global FLT4* filters_loc = args.weights.GetPtr() + DST_S * 4 * " - "args.src_tensor.Slices()" + - kernel_size + ";\n"; - } - if (buffer_type) { - c += " const int src_layer_offset = args.src_tensor.SliceStride();\n"; - } - if (!is1x1x1) { - c += " int filter_offset = 0;\n"; - } - if (!conv_params.z_kernel_is_1) { - c += " for (int kz = 0; kz < args.kernel_size_z; ++kz) {\n"; - for (int z = 0; z < block_size.z; ++z) { - const std::string zck = "zck" + std::to_string(z); - c += " int zck" + std::to_string(z) + " = kz * args.dilation_z + zc" + - std::to_string(z) + ";\n"; - if (manual_clamp_z) { - c += " bool mz" + std::to_string(z) + " = " + zck + " >= 0 && " + zck + - " < args.src_tensor.Depth();\n"; - c += " " + zck + " = clamp(" + zck + - ", 0, args.src_tensor.Depth() - 1);\n"; - } - } - } - if (!conv_params.y_kernel_is_1) { - c += " for (int ky = 0; ky < args.kernel_size_y; ++ky) {\n"; - for (int y = 0; y < block_size.y; ++y) { - const std::string yck = "yck" + std::to_string(y); - c += " int " + yck + " = ky * args.dilation_y + yc" + std::to_string(y) + - ";\n"; - if (manual_clamp_y) { - c += " bool my" + std::to_string(y) + " = " + yck + " >= 0 && " + yck + - " < args.src_tensor.Height();\n"; - c += " " + yck + " = clamp(" + yck + - ", 0, args.src_tensor.Height() - 1);\n"; - } - } - } - if (!conv_params.x_kernel_is_1) { - c += " for (int kx = 0; kx < args.kernel_size_x; ++kx) {\n"; - for (int x = 0; x < block_size.x; ++x) { - const std::string xck = "xck" + std::to_string(x); - c += " int xck" + std::to_string(x) + " = kx * args.dilation_x + xc" + - std::to_string(x) + ";\n"; - if (manual_clamp_x) { - c += " bool mx" + std::to_string(x) + " = " + xck + " >= 0 && " + xck + - " < args.src_tensor.Width();\n"; - c += " " + xck + " = clamp(" + xck + - ", 0, args.src_tensor.Width() - 1);\n"; - } - } - } - - auto get_src_x_coord = [&](int id) { - std::string xs = std::to_string(id); - std::string xc = "xck" + xs; - if (conv_params.x_kernel_is_1) { - if (can_read_out_of_x) { - xc = "DST_X + " + xs; - } else { - xc = "xc" + xs; - } - } - return xc; - }; - auto get_src_y_coord = [&](int id) { - std::string ys = std::to_string(id); - std::string yc = "yck" + ys; - if (conv_params.y_kernel_is_1) { - if (can_read_out_of_y) { - yc = "DST_Y + " + ys; - } else { - yc = "yc" + ys; - } - } - return yc; - }; - auto get_src_z_coord = [&](int id) { - std::string zs = std::to_string(id); - std::string zc = "zck" + zs; - if (conv_params.z_kernel_is_1) { - if (can_read_out_of_z) { - zc = "DST_Z + " + zs; - } else { - zc = "zc" + zs; - } - } - return zc; - }; - - if (buffer_type) { - for (int z = 0; z < block_size.z; ++z) { - const std::string zs = std::to_string(z); - const std::string zc = get_src_z_coord(z); - for (int y = 0; y < block_size.y; ++y) { - const std::string ys = std::to_string(y); - const std::string yc = get_src_y_coord(y); - for (int x = 0; x < block_size.x; ++x) { - const std::string xs = std::to_string(x); - const std::string xc = get_src_x_coord(x); - const std::string id = zs + ys + xs; - c += " args.src_tensor.GetAddress(src_a_" + id + ", " + xc + ", " + - yc + ", " + zc + ", 0);\n"; - if (!is1x1x1 && src_tensor_type == TensorStorageType::IMAGE_BUFFER) { - std::string condition; - if (manual_clamp_x) { - if (!condition.empty()) { - condition += " && "; - } - condition += "mx" + xs; - } - if (manual_clamp_y) { - if (!condition.empty()) { - condition += " && "; - } - condition += "my" + ys; - } - if (manual_clamp_z) { - if (!condition.empty()) { - condition += " && "; - } - condition += "mz" + zs; - } - c += " src_a_" + id + " = select(-1, src_a_" + id + ", " + - condition + ");\n"; - c += " int dz_" + id + " = select(0, src_layer_offset, " + - condition + ");\n"; - } - } - } - } - } - - auto declare_src = [&]() { - for (int z = 0; z < block_size.z; ++z) { - const std::string zs = std::to_string(z); - for (int y = 0; y < block_size.y; ++y) { - const std::string ys = std::to_string(y); - for (int x = 0; x < block_size.x; ++x) { - const std::string xs = std::to_string(x); - const std::string id = zs + ys + xs; - c += " FLT4 src" + id + ";\n"; - } - } - } - }; - - auto read_src = [&]() { - for (int z = 0; z < block_size.z; ++z) { - const std::string zs = std::to_string(z); - const std::string zc = get_src_z_coord(z); - for (int y = 0; y < block_size.y; ++y) { - const std::string ys = std::to_string(y); - const std::string yc = get_src_y_coord(y); - for (int x = 0; x < block_size.x; ++x) { - const std::string xs = std::to_string(x); - const std::string xc = get_src_x_coord(x); - std::string multiplier; - multiplier += manual_clamp_x ? " * (FLT)(mx" + xs + ")" : ""; - multiplier += manual_clamp_y ? " * (FLT)(my" + ys + ")" : ""; - multiplier += manual_clamp_z ? " * (FLT)(mz" + zs + ")" : ""; - const std::string id = zs + ys + xs; - if (buffer_type) { - if (src_tensor_type == TensorStorageType::IMAGE_BUFFER) { - multiplier = ""; - } - c += " src" + id + " = args.src_tensor.Read(src_a_" + id + ")" + - multiplier + ";\n"; - if (!is1x1x1 && - src_tensor_type == TensorStorageType::IMAGE_BUFFER) { - c += " src_a_" + id + " += dz_" + id + ";\n"; - } else { - c += " src_a_" + id + " += src_layer_offset;\n"; - } - } else { - c += " src" + id + " = args.src_tensor.Read(" + xc + ", " + yc + - ", " + zc + ", s)" + multiplier + ";\n"; - } - } - } - } - }; - c += " int s = 0;\n"; - declare_src(); - c += " do {\n"; - const int total_work_items = - work_group_size_.x * work_group_size_.y * work_group_size_.z; - if (conv_params.weights_upload_type == - Conv3D::WeightsUploadType::LOCAL_MEM_ASYNC_SUBGROUP) { - c += - GenerateAsyncUpload("weights_cache", "filters_loc", - /*global_offset_name*/ "", - block_size.w * 4 * conv_params.src_depth_loop_size); - } else if (conv_params.weights_upload_type == - Conv3D::WeightsUploadType::LOCAL_MEM_BY_THREADS) { - c += " barrier(CLK_LOCAL_MEM_FENCE);\n"; - c += GenerateUploadByThreads( - "weights_cache", "filters_loc", - /*global_offset_name*/ "", "lid", total_work_items, - block_size.w * 4 * conv_params.src_depth_loop_size); - } else if (conv_params.weights_upload_type == - Conv3D::WeightsUploadType::GLOBAL_MEM) { - c += " weights_cache = filters_loc;\n"; - } else { // TEXTURES_MEM - for (int dst_s = 0; dst_s < block_size.w; ++dst_s) { - const std::string f_y = is1x1x1 ? "s" : "filter_offset"; - c += absl::Substitute( - R"( FLT4 f$2 = args.weights0.Read(DST_S + $0, $1); - FLT4 f$3 = args.weights1.Read(DST_S + $0, $1); - FLT4 f$4 = args.weights2.Read(DST_S + $0, $1); - FLT4 f$5 = args.weights3.Read(DST_S + $0, $1); -)", - dst_s, f_y, dst_s * 4 + 0, dst_s * 4 + 1, dst_s * 4 + 2, - dst_s * 4 + 3); - } - if (!is1x1x1) { - c += " filter_offset++;\n"; - } - } - read_src(); - c += " s += 1;\n"; - if (conv_params.weights_upload_type == - Conv3D::WeightsUploadType::LOCAL_MEM_BY_THREADS) { - c += " barrier(CLK_LOCAL_MEM_FENCE);\n"; - } - c += GenerateConv(op_def.precision, block_size, 0, - conv_params.AreWeightsBuffer()); - for (int i = 1; i < conv_params.src_depth_loop_size; ++i) { - read_src(); - c += GenerateConv(op_def.precision, block_size, i * block_size.w * 4, - conv_params.AreWeightsBuffer()); - c += " s += 1;\n"; - } - if (conv_params.AreWeightsBuffer()) { - c += " filters_loc += " + - std::to_string(block_size.w * 4 * conv_params.src_depth_loop_size) + - ";\n"; - } - c += " } while (s < args.src_tensor.Slices());\n"; - if (!conv_params.z_kernel_is_1) { - c += " }\n"; - } - if (!conv_params.y_kernel_is_1) { - c += " }\n"; - } - if (!conv_params.x_kernel_is_1) { - c += " }\n"; - } - if (conv_params.weights_upload_type == - Conv3D::WeightsUploadType::LOCAL_MEM_ASYNC_SUBGROUP) { - c += GenerateAsyncUpload("weights_cache", "args.biases.GetPtr()", "DST_S", - block_size.w); - } else if (conv_params.weights_upload_type == - Conv3D::WeightsUploadType::LOCAL_MEM_BY_THREADS) { - c += " barrier(CLK_LOCAL_MEM_FENCE);\n"; - c += - GenerateUploadByThreads("weights_cache", "args.biases.GetPtr()", - "DST_S", "lid", total_work_items, block_size.w); - c += " barrier(CLK_LOCAL_MEM_FENCE);\n"; - } else if (conv_params.weights_upload_type == - Conv3D::WeightsUploadType::GLOBAL_MEM) { - c += " weights_cache = args.biases.GetPtr() + DST_S;\n"; - } - if (need_local_mem) { - c += " if (DST_X >= args.dst_tensor.Width() || DST_Y >= " - "args.dst_tensor.Height() || DST_Z >= args.dst_tensor.Depth()) " - "return;\n"; - } - for (int s = 0; s < block_size.w; ++s) { - const std::string dsts = - "DST_S" + (s == 0 ? "" : " + " + std::to_string(s)); - c += " if (" + dsts + " >= args.dst_tensor.Slices()) return;\n"; - for (int z = 0; z < block_size.z; ++z) { - const std::string dstz = - "DST_Z" + (z == 0 ? "" : " + " + std::to_string(z)); - for (int y = 0; y < block_size.y; ++y) { - const std::string dsty = - "DST_Y" + (y == 0 ? "" : " + " + std::to_string(y)); - for (int x = 0; x < block_size.x; ++x) { - const std::string dstx = - "DST_X" + (x == 0 ? "" : " + " + std::to_string(x)); - const std::string r_id = std::to_string(s) + std::to_string(z) + - std::to_string(y) + std::to_string(x); - c += " if (" + dstx + " < args.dst_tensor.Width() && " + dsty + - " < args.dst_tensor.Height() && " + dstz + - " < args.dst_tensor.Depth()) {\n"; - if (conv_params.AreWeightsBuffer()) { - c += " FLT4 res = TO_FLT4(r" + r_id + ") + weights_cache[" + - std::to_string(s) + "];\n"; - } else { - c += " FLT4 res = TO_FLT4(r" + r_id + ") + args.biases.Read(" + - dsts + ");\n"; - } - c += " args.dst_tensor.Write(res, " + dstx + ", " + dsty + ", " + - dstz + ", " + dsts + ");\n"; - c += " }\n"; - } - } - } - } - c += "}\n"; - return c; -} - -Conv3D::ConvParams Conv3D::GuessBestParams(const DeviceInfo& device_info, - const OperationDef& definition, - int src_slices, int dst_slices, - bool x_kernel_is_1, - bool y_kernel_is_1, - bool z_kernel_is_1) { - ConvParams conv_params; - conv_params.x_kernel_is_1 = x_kernel_is_1; - conv_params.y_kernel_is_1 = y_kernel_is_1; - conv_params.z_kernel_is_1 = z_kernel_is_1; - if (device_info.IsNvidia()) { - conv_params.block_size = int4(1, 1, 1, 4); - work_group_size_ = int3(8, 4, 1); - conv_params.work_group_launch_order = int3(2, 0, 1); - conv_params.src_depth_loop_size = 1; - conv_params.weights_upload_type = WeightsUploadType::LOCAL_MEM_BY_THREADS; - if (dst_slices % 4 == 0 || dst_slices >= 8) { - conv_params.block_size.w = 4; - } else if (dst_slices % 2 == 0 || dst_slices >= 4) { - conv_params.block_size.w = 2; - } else { - conv_params.block_size.w = dst_slices; - } - if (src_slices % 2 == 0) { - conv_params.src_depth_loop_size = 2; - } - if (src_slices % 4 == 0 && conv_params.block_size.w <= 2) { - conv_params.src_depth_loop_size = 4; - } - } else if (device_info.IsPowerVR()) { - conv_params.block_size = int4(1, 1, 1, 4); - work_group_size_ = int3(8, 4, 1); - conv_params.work_group_launch_order = int3(2, 0, 1); - conv_params.src_depth_loop_size = 1; - conv_params.weights_upload_type = - WeightsUploadType::LOCAL_MEM_ASYNC_SUBGROUP; - if (dst_slices % 8 == 0 || dst_slices >= 32) { - conv_params.block_size.w = 8; - } else if (dst_slices % 4 == 0 || dst_slices >= 8) { - conv_params.block_size.w = 4; - } else if (dst_slices % 2 == 0 || dst_slices >= 4) { - conv_params.block_size.w = 2; - } else { - conv_params.block_size.w = dst_slices; - } - if (definition.precision == CalculationsPrecision::F16) { - conv_params.block_size.w = std::min(4, conv_params.block_size.w); - if (src_slices % 2 == 0) { - conv_params.src_depth_loop_size = 2; - } - if (src_slices % 4 == 0 && conv_params.block_size.w <= 2) { - conv_params.src_depth_loop_size = 4; - } - if (conv_params.block_size.w == 1) { - if (src_slices % 2 == 0) { - conv_params.src_depth_loop_size = 2; - } - if (src_slices % 4 == 0) { - conv_params.src_depth_loop_size = 4; - } - if (src_slices <= 8) { - conv_params.src_depth_loop_size = src_slices; - } - } - conv_params.block_size.x = 2; - work_group_size_ = int3(4, 8, 1); - } - } else if (device_info.IsAdreno()) { - conv_params.block_size = int4(2, 2, 1, 2); - work_group_size_ = int3(8, 4, 1); - conv_params.work_group_launch_order = int3(0, 1, 2); - conv_params.src_depth_loop_size = 1; - conv_params.weights_upload_type = WeightsUploadType::TEXTURES_MEM; - } else if (device_info.IsMali()) { - conv_params.block_size = int4(1, 1, 1, 4); - work_group_size_ = int3(8, 4, 1); - conv_params.work_group_launch_order = int3(0, 1, 2); - conv_params.src_depth_loop_size = 1; - conv_params.weights_upload_type = WeightsUploadType::GLOBAL_MEM; - if (dst_slices % 4 == 0 || dst_slices >= 8) { - conv_params.block_size.w = 4; - } else if (dst_slices % 2 == 0 || dst_slices >= 4) { - conv_params.block_size.w = 2; - } else { - conv_params.block_size.w = dst_slices; - } - if (src_slices % 2 == 0) { - conv_params.src_depth_loop_size = 2; - } - if (src_slices % 4 == 0 && conv_params.block_size.w <= 2) { - conv_params.src_depth_loop_size = 4; - } - } else { - conv_params.block_size = int4(2, 2, 1, 2); - work_group_size_ = int3(8, 4, 1); - conv_params.work_group_launch_order = int3(0, 1, 2); - conv_params.src_depth_loop_size = 1; - conv_params.weights_upload_type = WeightsUploadType::TEXTURES_MEM; - } - - return conv_params; -} - -Conv3D::ConvParams Conv3D::GuessBestParams( - const DeviceInfo& device_info, const OperationDef& definition, - const Convolution3DAttributes& attr) { - const int dst_slices = DivideRoundUp(attr.weights.shape.o, 4); - const int src_slices = DivideRoundUp(attr.weights.shape.i, 4); - const bool x_kernel_is_1 = attr.weights.shape.w == 1 && attr.strides.w == 1 && - attr.dilations.w == 1 && - attr.padding.prepended.w == 0 && - attr.padding.appended.w == 0; - const bool y_kernel_is_1 = attr.weights.shape.h == 1 && attr.strides.h == 1 && - attr.dilations.h == 1 && - attr.padding.prepended.h == 0 && - attr.padding.appended.h == 0; - const bool z_kernel_is_1 = attr.weights.shape.d == 1 && attr.strides.d == 1 && - attr.dilations.d == 1 && - attr.padding.prepended.d == 0 && - attr.padding.appended.d == 0; - return GuessBestParams(device_info, definition, src_slices, dst_slices, - x_kernel_is_1, y_kernel_is_1, z_kernel_is_1); -} - -Conv3D CreateConv3D(const DeviceInfo& device_info, - const OperationDef& definition, - const Convolution3DAttributes& attr) { - Conv3D result(definition, attr, device_info); - result.UploadData(attr.weights, attr.bias); - return result; -} - -} // namespace cl -} // namespace gpu -} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/conv_3d.h b/tensorflow/lite/delegates/gpu/cl/kernels/conv_3d.h deleted file mode 100644 index d4a86b0ca5e..00000000000 --- a/tensorflow/lite/delegates/gpu/cl/kernels/conv_3d.h +++ /dev/null @@ -1,269 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_CONV_3D_H_ -#define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_CONV_3D_H_ - -#include - -#include "tensorflow/lite/delegates/gpu/cl/buffer.h" -#include "tensorflow/lite/delegates/gpu/cl/cl_device.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h" -#include "tensorflow/lite/delegates/gpu/cl/linear_storage.h" -#include "tensorflow/lite/delegates/gpu/cl/tensor.h" -#include "tensorflow/lite/delegates/gpu/cl/texture2d.h" -#include "tensorflow/lite/delegates/gpu/cl/util.h" -#include "tensorflow/lite/delegates/gpu/common/data_type.h" -#include "tensorflow/lite/delegates/gpu/common/operations.h" -#include "tensorflow/lite/delegates/gpu/common/shape.h" -#include "tensorflow/lite/delegates/gpu/common/status.h" -#include "tensorflow/lite/delegates/gpu/common/tensor.h" -#include "tensorflow/lite/delegates/gpu/common/types.h" - -namespace tflite { -namespace gpu { -namespace cl { - -class Conv3D : public GPUOperation { - public: - Conv3D() = default; - void GetPossibleKernelWorkGroups( - TuningType tuning_type, const DeviceInfo& device_info, - const KernelInfo& kernel_info, - std::vector* work_groups) const override; - absl::Status BindArguments() override; - int3 GetGridSize() const override; - - // Move only - Conv3D(Conv3D&& operation); - Conv3D& operator=(Conv3D&& operation); - Conv3D(const Conv3D&) = delete; - Conv3D& operator=(const Conv3D&) = delete; - - private: - enum class WeightsUploadType { - LOCAL_MEM_ASYNC_SUBGROUP, // we use it for PowerVR with workgroup size = 32 - LOCAL_MEM_BY_THREADS, - GLOBAL_MEM, - TEXTURES_MEM, - }; - - struct ConvParams { - int4 block_size; // WHDS - int3 work_group_launch_order; - int src_depth_loop_size; - WeightsUploadType weights_upload_type; - bool AreWeightsBuffer() const { - return weights_upload_type != WeightsUploadType::TEXTURES_MEM; - } - bool x_kernel_is_1; - bool y_kernel_is_1; - bool z_kernel_is_1; - }; - - Conv3D(const OperationDef& definition, const Convolution3DAttributes& attr, - const DeviceInfo& device_info); - - template - void UploadData(const tflite::gpu::Tensor& weights, - const tflite::gpu::Tensor& biases); - template - void UploadWeights(const tflite::gpu::Tensor& weights); - - template - void RearrangeWeightsData(const tflite::gpu::Tensor& weights, - absl::Span dst); - - friend Conv3D CreateConv3D(const DeviceInfo& device_info, - const OperationDef& definition, - const Convolution3DAttributes& attr); - - friend std::string GenerateConv3D(const OperationDef& op_def, - bool stride_correction, - const ConvParams& conv_params, - Arguments* args); - - ConvParams GuessBestParams(const DeviceInfo& device_info, - const OperationDef& definition, - const Convolution3DAttributes& attr); - - ConvParams GuessBestParams(const DeviceInfo& device_info, - const OperationDef& definition, int src_slices, - int dst_slices, bool x_kernel_is_1, - bool y_kernel_is_1, bool z_kernel_is_1); - - std::string GenerateConv3D(const OperationDef& op_def, bool stride_correction, - const Conv3D::ConvParams& conv_params); - - int3 stride_; - int3 padding_; - int3 kernel_size_; - int3 dilation_; - ConvParams conv_params_; -}; - -template -void Conv3D::UploadData(const tflite::gpu::Tensor& weights, - const tflite::gpu::Tensor& biases) { - UploadWeights(weights); - TensorLinearDescriptor desc; - desc.storage_type = conv_params_.AreWeightsBuffer() - ? LinearStorageType::BUFFER - : LinearStorageType::TEXTURE_2D; - desc.element_type = definition_.GetDataType(); - desc.UploadLinearData(biases); - args_.AddObject("biases", - absl::make_unique(std::move(desc))); -} - -template -void Conv3D::UploadWeights(const tflite::gpu::Tensor& weights) { - const int block_size = conv_params_.block_size.w; - const int dst_slices = - AlignByN(DivideRoundUp(weights.shape.o, 4), block_size); - const int src_slices = DivideRoundUp(weights.shape.i, 4); - const int kernel_x = kernel_size_.x; - const int kernel_y = kernel_size_.y; - const int kernel_z = kernel_size_.z; - const int texture_width = dst_slices; - const int texture_height = src_slices * kernel_x * kernel_y * kernel_z; - - const int elements_count = - kernel_x * kernel_y * kernel_z * src_slices * dst_slices * 4; - const bool f32_weights = definition_.precision == CalculationsPrecision::F32; - - const int float4_size = f32_weights ? 16 : 8; - - std::vector data(float4_size * elements_count); - - if (f32_weights) { - float4* ptr = reinterpret_cast(data.data()); - RearrangeWeightsData(weights, absl::MakeSpan(ptr, elements_count)); - } else { - half4* ptr = reinterpret_cast(data.data()); - RearrangeWeightsData(weights, absl::MakeSpan(ptr, elements_count)); - } - - if (conv_params_.AreWeightsBuffer()) { - BufferDescriptor desc; - desc.element_type = f32_weights ? DataType::FLOAT32 : DataType::FLOAT16; - desc.element_size = 4; - desc.size = float4_size * elements_count; - desc.data = std::move(data); - args_.AddObject("weights", - absl::make_unique(std::move(desc))); - } else { - int sub_size = float4_size * elements_count / 4; - Texture2DDescriptor desc0; - desc0.element_type = f32_weights ? DataType::FLOAT32 : DataType::FLOAT16; - desc0.size = int2(texture_width, texture_height); - desc0.data.resize(sub_size); - memcpy(desc0.data.data(), data.data(), sub_size); - args_.AddObject("weights0", - absl::make_unique(std::move(desc0))); - - Texture2DDescriptor desc1; - desc1.element_type = f32_weights ? DataType::FLOAT32 : DataType::FLOAT16; - desc1.size = int2(texture_width, texture_height); - desc1.data.resize(sub_size); - memcpy(desc1.data.data(), data.data() + sub_size, sub_size); - args_.AddObject("weights1", - absl::make_unique(std::move(desc1))); - - Texture2DDescriptor desc2; - desc2.element_type = f32_weights ? DataType::FLOAT32 : DataType::FLOAT16; - desc2.size = int2(texture_width, texture_height); - desc2.data.resize(sub_size); - memcpy(desc2.data.data(), data.data() + sub_size * 2, sub_size); - args_.AddObject("weights2", - absl::make_unique(std::move(desc2))); - - Texture2DDescriptor desc3; - desc3.element_type = f32_weights ? DataType::FLOAT32 : DataType::FLOAT16; - desc3.size = int2(texture_width, texture_height); - desc3.data.resize(sub_size); - memcpy(desc3.data.data(), data.data() + sub_size * 3, sub_size); - args_.AddObject("weights3", - absl::make_unique(std::move(desc3))); - } -} - -template -void Conv3D::RearrangeWeightsData(const tflite::gpu::Tensor& weights, - absl::Span dst) { - const int block_size = conv_params_.block_size.w; - const int dst_slices = - AlignByN(DivideRoundUp(weights.shape.o, 4), block_size); - const int src_slices = DivideRoundUp(weights.shape.i, 4); - const int kernel_x = kernel_size_.x; - const int kernel_y = kernel_size_.y; - const int kernel_z = kernel_size_.z; - const int texture_width = dst_slices; - const int texture_height = src_slices * kernel_x * kernel_y * kernel_z; - - int counter = 0; - for (int d = 0; d < dst_slices / block_size; ++d) { - for (int z = 0; z < kernel_z; ++z) { - for (int y = 0; y < kernel_y; ++y) { - for (int x = 0; x < kernel_x; ++x) { - for (int s = 0; s < src_slices; ++s) { - for (int sub_d = 0; sub_d < block_size; ++sub_d) { - T filters[4]; - for (int i = 0; i < 4; ++i) { - for (int j = 0; j < 4; ++j) { - const int s_ch = s * 4 + j; - const int d_ch = (d * block_size + sub_d) * 4 + i; - if (s_ch < weights.shape.i && d_ch < weights.shape.o) { - const int f_index = - weights.shape.LinearIndex({d_ch, y, x, z, s_ch}); - filters[j][i] = weights.data[f_index]; - } else { - filters[j][i] = 0.0f; - } - } - } - if (conv_params_.AreWeightsBuffer()) { - dst[counter++] = filters[0]; - dst[counter++] = filters[1]; - dst[counter++] = filters[2]; - dst[counter++] = filters[3]; - } else { - int x_coord = d * block_size + sub_d; - int y_coord = - ((z * kernel_y + y) * kernel_x + x) * src_slices + s; - int offset = y_coord * dst_slices + x_coord; - dst[offset + texture_width * texture_height * 0] = filters[0]; - dst[offset + texture_width * texture_height * 1] = filters[1]; - dst[offset + texture_width * texture_height * 2] = filters[2]; - dst[offset + texture_width * texture_height * 3] = filters[3]; - } - } - } - } - } - } - } -} - -Conv3D CreateConv3D(const DeviceInfo& device_info, - const OperationDef& definition, - const Convolution3DAttributes& attr); - -} // namespace cl -} // namespace gpu -} // namespace tflite - -#endif // TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_CONV_3D_H_ diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/conv_constants.cc b/tensorflow/lite/delegates/gpu/cl/kernels/conv_constants.cc index dc54286c0fc..c3663634177 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/conv_constants.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/conv_constants.cc @@ -45,84 +45,29 @@ int GetOptimalMaxConstantSize(const DeviceInfo& info) { return GetAdrenoOptimalMaxConstantSize(info.adreno_info.gpu_version); } } -} // namespace -ConvConstants::ConvConstants(const OperationDef& definition, - const Convolution2DAttributes& attr, - const DeviceInfo& device_info) - : GPUOperation(definition), - kernel_size_(attr.weights.shape.w, attr.weights.shape.h), - stride_(attr.strides.w, attr.strides.h), - padding_(-attr.padding.prepended.w, -attr.padding.prepended.h), - dilation_(attr.dilations.w, attr.dilations.h), - src_channels_(attr.weights.shape.i), - dst_channels_(attr.weights.shape.o) { - const bool stride_correction = - definition_.IsBatchSupported() && stride_.x != 1; - code_ = - GenerateConvolutionConstantCode(definition_, kernel_size_, src_channels_, - dst_channels_, stride_correction); - if (definition_.precision == CalculationsPrecision::F16 && - device_info.IsAdreno3xx()) { - compiler_options_.push_back(CompilerOptions::ADRENO_FULL_SIMD_LINE); - } - if (definition_.precision != CalculationsPrecision::F32 && - device_info.IsPowerVR()) { - // BUG, some PowerVRs (GE8320) produce incorrect result without it - compiler_options_.push_back(CompilerOptions::CL_OPT_DISABLE); - } -} - -ConvConstants::ConvConstants(ConvConstants&& kernel) - : GPUOperation(std::move(kernel)), - kernel_size_(kernel.kernel_size_), - stride_(kernel.stride_), - padding_(kernel.padding_), - dilation_(kernel.dilation_), - src_channels_(kernel.src_channels_), - dst_channels_(kernel.dst_channels_) {} - -ConvConstants& ConvConstants::operator=(ConvConstants&& kernel) { - if (this != &kernel) { - std::swap(kernel_size_, kernel.kernel_size_); - std::swap(stride_, kernel.stride_); - std::swap(padding_, kernel.padding_); - std::swap(dilation_, kernel.dilation_); - std::swap(src_channels_, kernel.src_channels_); - std::swap(dst_channels_, kernel.dst_channels_); - GPUOperation::operator=(std::move(kernel)); - } - return *this; -} - -std::string ConvConstants::GenerateConvolutionConstantCode( - const OperationDef& op_def, const int2& kernel_size, int src_channels, - int dst_channels, bool stride_correction) { +std::string GenerateConvolutionConstantCode(const OperationDef& op_def, + const OHWI& weights_shape, + bool stride_correction, + GPUOperation* op) { auto src_desc = op_def.src_tensors[0]; src_desc.SetTextureAddressMode(TextureAddressMode::ZERO); if (op_def.IsBatchSupported()) { src_desc.SetStateVar("BatchedWidth", "true"); } - AddSrcTensor("src_tensor", src_desc); + op->AddSrcTensor("src_tensor", src_desc); auto dst_desc = op_def.dst_tensors[0]; if (op_def.IsBatchSupported()) { dst_desc.SetStateVar("BatchedWidth", "true"); } - AddDstTensor("dst_tensor", dst_desc); - - args_.AddInt("stride_x"); - args_.AddInt("stride_y"); - args_.AddInt("padding_x"); - args_.AddInt("padding_y"); - args_.AddInt("dilation_x"); - args_.AddInt("dilation_y"); + op->AddDstTensor("dst_tensor", dst_desc); std::string c = GetCommonDefines(op_def.precision); - const int out_z = DivideRoundUp(dst_channels, 4); + const int out_z = DivideRoundUp(weights_shape.o, 4); const std::string kOutZ = std::to_string(out_z); - const int src_depth = DivideRoundUp(src_channels, 4); + const int src_depth = DivideRoundUp(weights_shape.i, 4); const auto src_tensor_type = op_def.src_tensors[0].storage_type; const bool manual_clamp = src_tensor_type == TensorStorageType::BUFFER || @@ -176,11 +121,16 @@ std::string ConvConstants::GenerateConvolutionConstantCode( "return;\n"; if (stride_correction) { c += " int start_x = " + - GetXStrideCorrected("X", "args.src_tensor.Batch()", "args.stride_x", - "args.padding_x") + + GetXStrideCorrectedV2("X", "args.src_tensor.Batch()", "args.stride_x", + "args.padding_x") + ";\n"; } else { - c += " int start_x = X * args.stride_x + args.padding_x;\n"; + if (op_def.IsBatchSupported()) { + c += " int start_x = X * args.stride_x + args.padding_x * " + "args.src_tensor.Batch();\n"; + } else { + c += " int start_x = X * args.stride_x + args.padding_x;\n"; + } } c += " int start_y = Y * args.stride_y + args.padding_y;\n"; c += " ACCUM_FLT4 r[" + kOutZ + "];\n"; @@ -189,22 +139,25 @@ std::string ConvConstants::GenerateConvolutionConstantCode( c += " }\n"; int filters_counter = 0; for (int s = 0; s < src_depth; ++s) { - const int ch_count = std::min(4, src_channels - s * 4); + const int ch_count = std::min(4, weights_shape.i - s * 4); const std::string s_conv = "CONV" + std::to_string(ch_count); const std::string s_count = ch_count == 1 ? "" : std::to_string(ch_count); const std::string s_type = absl::StrCat("FLT", s_count); const std::string s_postfix = postfixes[ch_count - 1]; - for (int ky = 0; ky < kernel_size.y; ++ky) { + const std::string dilation_x = + op_def.IsBatchSupported() ? "args.dilation_x * args.src_tensor.Batch()" + : "args.dilation_x"; + for (int ky = 0; ky < weights_shape.h; ++ky) { std::string s_y = absl::StrCat("(start_y + ", ky, " * args.dilation_y)"); if (manual_clamp) { c += " {\n"; c += " bool y_out = " + s_y + " < 0 || " + s_y + " >= args.src_tensor.Height();\n"; } - for (int kx = 0; kx < kernel_size.x; ++kx) { + for (int kx = 0; kx < weights_shape.w; ++kx) { c += " {\n"; std::string s_x = - absl::StrCat("(start_x + ", kx, " * args.dilation_x)"); + absl::StrCat("(start_x + ", kx, " * " + dilation_x + ")"); if (manual_clamp) { c += " bool x_out = " + s_x + "< 0 || " + s_x + ">= args.src_tensor.Width();\n"; @@ -240,20 +193,7 @@ std::string ConvConstants::GenerateConvolutionConstantCode( return c; } -absl::Status ConvConstants::BindArguments() { - RETURN_IF_ERROR(args_.SetInt("stride_x", stride_.x)); - RETURN_IF_ERROR(args_.SetInt("stride_y", stride_.y)); - RETURN_IF_ERROR(args_.SetInt("padding_x", padding_.x * src_[0]->Batch())); - RETURN_IF_ERROR(args_.SetInt("padding_y", padding_.y)); - RETURN_IF_ERROR(args_.SetInt("dilation_x", dilation_.x * src_[0]->Batch())); - return args_.SetInt("dilation_y", dilation_.y); -} - -int3 ConvConstants::GetGridSize() const { - const int grid_x = dst_[0]->Width() * dst_[0]->Batch(); - const int grid_y = dst_[0]->Height(); - return int3(grid_x, grid_y, 1); -} +} // namespace bool IsConvConstantsSupported(const DeviceInfo& device_info, const OperationDef& definition, @@ -277,20 +217,41 @@ bool IsConvConstantsSupported(const DeviceInfo& device_info, return filters_buffer_size <= kConstantMaxSize && flt4_registers <= 8; } -ConvConstants CreateConvConstants(const DeviceInfo& device_info, - const OperationDef& definition, - const Convolution2DAttributes& attr) { - ConvConstants result(definition, attr, device_info); - result.UploadWeights(attr.weights); +GPUOperation CreateConvConstants(const DeviceInfo& device_info, + const OperationDef& definition, + const Convolution2DAttributes& attr) { + GPUOperation op(definition); + UploadWeightsForConvConstants(attr.weights, definition.precision, &op); + op.args_.AddInt("stride_x", attr.strides.w); + op.args_.AddInt("stride_y", attr.strides.h); + op.args_.AddInt("padding_x", -attr.padding.prepended.w); + op.args_.AddInt("padding_y", -attr.padding.prepended.h); + op.args_.AddInt("dilation_x", attr.dilations.w); + op.args_.AddInt("dilation_y", attr.dilations.h); + op.tensor_to_grid_ = TensorToGrid::kWBToX_HDToY_ZIs1; + + const bool stride_correction = + definition.IsBatchSupported() && attr.strides.w != 1; + op.code_ = GenerateConvolutionConstantCode(definition, attr.weights.shape, + stride_correction, &op); + if (definition.precision == CalculationsPrecision::F16 && + device_info.IsAdreno3xx()) { + op.compiler_options_.push_back(CompilerOptions::ADRENO_FULL_SIMD_LINE); + } + if (definition.precision != CalculationsPrecision::F32 && + device_info.IsPowerVR()) { + // BUG, some PowerVRs (GE8320) produce incorrect result without it + op.compiler_options_.push_back(CompilerOptions::CL_OPT_DISABLE); + } TensorLinearDescriptor desc; desc.storage_type = LinearStorageType::BUFFER; desc.element_type = definition.GetDataType(); desc.memory_type = MemoryType::CONSTANT; desc.UploadLinearData(attr.bias); - result.args_.AddObject( + op.args_.AddObject( "biases", absl::make_unique(std::move(desc))); - return result; + return op; } } // namespace cl diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/conv_constants.h b/tensorflow/lite/delegates/gpu/cl/kernels/conv_constants.h index 5be433588ce..c341ecb5753 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/conv_constants.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/conv_constants.h @@ -32,78 +32,8 @@ namespace tflite { namespace gpu { namespace cl { -class ConvConstants : public GPUOperation { - public: - ConvConstants() = default; - absl::Status BindArguments() override; - int3 GetGridSize() const override; - - // Move only - ConvConstants(ConvConstants&& kernel); - ConvConstants& operator=(ConvConstants&& kernel); - ConvConstants(const ConvConstants&) = delete; - ConvConstants& operator=(const ConvConstants&) = delete; - - private: - friend ConvConstants CreateConvConstants(const DeviceInfo& device_info, - const OperationDef& definition, - const Convolution2DAttributes& attr); - ConvConstants(const OperationDef& definition, - const Convolution2DAttributes& attr, - const DeviceInfo& device_info); - - template - void UploadWeights(const tflite::gpu::Tensor& weights); - - template - void RearrangeWeightsData(const tflite::gpu::Tensor& weights, - absl::Span dst); - - std::string GenerateConvolutionConstantCode(const OperationDef& op_def, - const int2& kernel_size, - int src_channels, - int dst_channels, - bool stride_correction); - - int2 kernel_size_; - int2 stride_; - int2 padding_; - int2 dilation_; - int src_channels_; - int dst_channels_; -}; - -template -void ConvConstants::UploadWeights(const tflite::gpu::Tensor& weights) { - const int dst_depth = DivideRoundUp(weights.shape.o, 4); - const int kernel_x = weights.shape.w; - const int kernel_y = weights.shape.h; - - const bool f32_weights = definition_.precision == CalculationsPrecision::F32; - const int float_size = f32_weights ? 4 : 2; - const int float_count = src_channels_ * dst_depth * 4 * kernel_x * kernel_y; - - BufferDescriptor desc; - desc.element_type = f32_weights ? DataType::FLOAT32 : DataType::FLOAT16; - desc.element_size = 4; - desc.memory_type = MemoryType::CONSTANT; - desc.size = float_size * float_count; - desc.data.resize(desc.size); - - if (f32_weights) { - float4* ptr = reinterpret_cast(desc.data.data()); - RearrangeWeightsData(weights, absl::MakeSpan(ptr, float_count / 4)); - } else { - half4* ptr = reinterpret_cast(desc.data.data()); - RearrangeWeightsData(weights, absl::MakeSpan(ptr, float_count / 4)); - } - - args_.AddObject("weigths", - absl::make_unique(std::move(desc))); -} - template -void ConvConstants::RearrangeWeightsData( +void RearrangeWeightsForConvConstants( const tflite::gpu::Tensor& weights, absl::Span dst) { const int dst_depth = DivideRoundUp(weights.shape.o, 4); const int src_depth = DivideRoundUp(weights.shape.i, 4); @@ -115,7 +45,7 @@ void ConvConstants::RearrangeWeightsData( for (int y = 0; y < kernel_y; ++y) { for (int x = 0; x < kernel_x; ++x) { for (int d = 0; d < dst_depth; ++d) { - const int channels_count = std::min(4, src_channels_ - s * 4); + const int channels_count = std::min(4, weights.shape.i - s * 4); T filters[4]; for (int i = 0; i < 4; ++i) { for (int j = 0; j < channels_count; ++j) { @@ -145,13 +75,46 @@ void ConvConstants::RearrangeWeightsData( } } +template +void UploadWeightsForConvConstants(const tflite::gpu::Tensor& weights, + CalculationsPrecision precision, + GPUOperation* op) { + const int dst_depth = DivideRoundUp(weights.shape.o, 4); + const int kernel_x = weights.shape.w; + const int kernel_y = weights.shape.h; + + const bool f32_weights = precision == CalculationsPrecision::F32; + const int float_size = f32_weights ? 4 : 2; + const int float_count = weights.shape.i * dst_depth * 4 * kernel_x * kernel_y; + + BufferDescriptor desc; + desc.element_type = f32_weights ? DataType::FLOAT32 : DataType::FLOAT16; + desc.element_size = 4; + desc.memory_type = MemoryType::CONSTANT; + desc.size = float_size * float_count; + desc.data.resize(desc.size); + + if (f32_weights) { + float4* ptr = reinterpret_cast(desc.data.data()); + RearrangeWeightsForConvConstants(weights, + absl::MakeSpan(ptr, float_count / 4)); + } else { + half4* ptr = reinterpret_cast(desc.data.data()); + RearrangeWeightsForConvConstants(weights, + absl::MakeSpan(ptr, float_count / 4)); + } + + op->args_.AddObject("weigths", + absl::make_unique(std::move(desc))); +} + bool IsConvConstantsSupported(const DeviceInfo& device_info, const OperationDef& definition, const Convolution2DAttributes& attr); -ConvConstants CreateConvConstants(const DeviceInfo& device_info, - const OperationDef& definition, - const Convolution2DAttributes& attr); +GPUOperation CreateConvConstants(const DeviceInfo& device_info, + const OperationDef& definition, + const Convolution2DAttributes& attr); } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/conv_constants_test.cc b/tensorflow/lite/delegates/gpu/cl/kernels/conv_constants_test.cc index 4aa60b8d334..17821e14e0a 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/conv_constants_test.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/conv_constants_test.cc @@ -55,7 +55,7 @@ TEST_F(OpenCLOperationTest, ConvConstantsSimpleWeights) { op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); op_def.dst_tensors.push_back({data_type, storage, Layout::HWC}); TensorFloat32 dst_tensor; - ConvConstants operation = + GPUOperation operation = CreateConvConstants(creation_context_.GetDeviceInfo(), op_def, attr); ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation, BHWC(1, 2, 2, 1), &dst_tensor)); @@ -90,7 +90,7 @@ TEST_F(OpenCLOperationTest, ConvConstants) { op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); op_def.dst_tensors.push_back({data_type, storage, Layout::HWC}); TensorFloat32 dst_tensor; - ConvConstants operation = + GPUOperation operation = CreateConvConstants(creation_context_.GetDeviceInfo(), op_def, attr); ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation, BHWC(1, 2, 2, 2), &dst_tensor)); diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.cc b/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.cc index bd4f6d70994..8952504bda0 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.cc @@ -19,11 +19,13 @@ limitations under the License. #include #include +#include "absl/strings/substitute.h" #include "tensorflow/lite/delegates/gpu/cl/kernels/util.h" #include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h" #include "tensorflow/lite/delegates/gpu/cl/precision.h" #include "tensorflow/lite/delegates/gpu/cl/tensor_type.h" #include "tensorflow/lite/delegates/gpu/common/data_type.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/status.h" namespace tflite { @@ -70,57 +72,77 @@ std::string GenerateAsyncUpload(const std::string& local_ptr_name, return c; } -std::string GenerateBlockCoords(const int3& block_size, +std::string GenerateBlockCoords(const int4& block_size, const int3& work_group_launch_order, - bool linear_hw) { + bool linear_spatial, bool need_depth) { std::string c; int3 launch_remap; launch_remap[work_group_launch_order.x] = 0; launch_remap[work_group_launch_order.y] = 1; launch_remap[work_group_launch_order.z] = 2; - if (linear_hw) { + if (linear_spatial) { if (work_group_launch_order[0] == 0) { - c += " int linear_hw = get_global_id(0);\n"; + c += " int linear_spatial = get_global_id(0);\n"; } else { - c += " int linear_hw = get_group_id(" + std::to_string(launch_remap[0]) + + c += " int linear_spatial = get_group_id(" + + std::to_string(launch_remap[0]) + ") * get_local_size(0) + get_local_id(0);\n"; } - c += " int Y = (linear_hw / args.task_size_x) * " + - std::to_string(block_size.y) + ";\n"; - c += " int X = (linear_hw % args.task_size_x) * " + - std::to_string(block_size.x) + ";\n"; - if (work_group_launch_order[1] == 1) { - c += " int Z = get_global_id(1) * " + std::to_string(block_size.z) + - ";\n"; - } else { - c += " int Z = (get_group_id(" + std::to_string(launch_remap[1]) + - ") * get_local_size(1) + get_local_id(1)) * " + + if (need_depth) { + c += " int DST_X = (linear_spatial % args.task_size_x) * " + + std::to_string(block_size.x) + ";\n"; + c += " linear_spatial = linear_spatial / args.task_size_x;\n"; + c += " int DST_Y = (linear_spatial % args.task_size_y) * " + + std::to_string(block_size.y) + ";\n"; + c += " int DST_Z = (linear_spatial / args.task_size_y) * " + std::to_string(block_size.z) + ";\n"; - } - } else { - if (work_group_launch_order[0] == 0) { - c += " int X = get_global_id(0) * " + std::to_string(block_size.x) + - ";\n"; } else { - c += " int X = (get_group_id(" + std::to_string(launch_remap[0]) + - ") * get_local_size(0) + get_local_id(0)) * " + + c += " int DST_Y = (linear_spatial / args.task_size_x) * " + + std::to_string(block_size.y) + ";\n"; + c += " int DST_X = (linear_spatial % args.task_size_x) * " + std::to_string(block_size.x) + ";\n"; } if (work_group_launch_order[1] == 1) { - c += " int Y = get_global_id(1) * " + std::to_string(block_size.y) + + c += " int DST_S = get_global_id(1) * " + std::to_string(block_size.w) + ";\n"; } else { - c += " int Y = (get_group_id(" + std::to_string(launch_remap[1]) + + c += " int DST_S = (get_group_id(" + std::to_string(launch_remap[1]) + ") * get_local_size(1) + get_local_id(1)) * " + + std::to_string(block_size.w) + ";\n"; + } + } else { + if (work_group_launch_order[0] == 0) { + c += " int DST_X = get_global_id(0) * " + std::to_string(block_size.x) + + ";\n"; + } else { + c += " int DST_X = (get_group_id(" + std::to_string(launch_remap[0]) + + ") * get_local_size(0) + get_local_id(0)) * " + + std::to_string(block_size.x) + ";\n"; + } + std::string global_id_1; + if (work_group_launch_order[1] == 1) { + global_id_1 = "get_global_id(1)"; + } else { + global_id_1 = "(get_group_id(" + std::to_string(launch_remap[1]) + + ") * get_local_size(1) + get_local_id(1))"; + } + if (need_depth) { + c += " int linear_id_1 = " + global_id_1 + ";\n"; + c += " int DST_Z = (linear_id_1 / args.task_size_y) * " + + std::to_string(block_size.z) + ";\n"; + c += " int DST_Y = (linear_id_1 % args.task_size_y) * " + + std::to_string(block_size.y) + ";\n"; + } else { + c += " int DST_Y = " + global_id_1 + " * " + std::to_string(block_size.y) + ";\n"; } if (work_group_launch_order[2] == 2) { - c += " int Z = get_global_id(2) * " + std::to_string(block_size.z) + + c += " int DST_S = get_global_id(2) * " + std::to_string(block_size.w) + ";\n"; } else { - c += " int Z = (get_group_id(" + std::to_string(launch_remap[2]) + + c += " int DST_S = (get_group_id(" + std::to_string(launch_remap[2]) + ") * get_local_size(2) + get_local_id(2)) * " + - std::to_string(block_size.z) + ";\n"; + std::to_string(block_size.w) + ";\n"; } } @@ -132,10 +154,10 @@ ConvPowerVR::ConvPowerVR(const OperationDef& definition, const Convolution2DAttributes& attr, const DeviceInfo& device_info, const BHWC* dst_shape) : GPUOperation(definition), - stride_padding_(attr.strides.w, attr.strides.h, -attr.padding.prepended.w, - -attr.padding.prepended.h), - kernel_dilation_(attr.weights.shape.w, attr.weights.shape.h, - attr.dilations.w, attr.dilations.h), + stride_(attr.strides.w, attr.strides.h, 1, 1), + padding_(-attr.padding.prepended.w, -attr.padding.prepended.h, 0, 0), + kernel_size_(attr.weights.shape.w, attr.weights.shape.h, 1, 1), + dilation_(attr.dilations.w, attr.dilations.h, 1, 1), conv_params_(GuessBestParams(device_info, definition, attr, dst_shape)) {} ConvPowerVR::ConvPowerVR(const OperationDef& definition, @@ -143,10 +165,10 @@ ConvPowerVR::ConvPowerVR(const OperationDef& definition, const BHWC& weights_shape, const DeviceInfo& device_info, const BHWC* dst_shape) : GPUOperation(definition), - stride_padding_(attr.strides.w, attr.strides.h, -attr.padding.prepended.w, - -attr.padding.prepended.h), - kernel_dilation_(weights_shape.w, weights_shape.h, attr.dilations.w, - attr.dilations.h), + stride_(attr.strides.w, attr.strides.h, 1, 1), + padding_(-attr.padding.prepended.w, -attr.padding.prepended.h, 0, 0), + kernel_size_(weights_shape.w, weights_shape.h, 1, 1), + dilation_(attr.dilations.w, attr.dilations.h, 1, 1), conv_params_(GuessBestParams(device_info, definition, attr, weights_shape, dst_shape)) {} @@ -154,25 +176,45 @@ ConvPowerVR::ConvPowerVR(const OperationDef& definition, const FullyConnectedAttributes& attr, const DeviceInfo& device_info, const BHWC* dst_shape) : GPUOperation(definition), - stride_padding_(1, 1, 0, 0), - kernel_dilation_(1, 1, 1, 1), + stride_(1, 1, 1, 1), + padding_(0, 0, 0, 0), + kernel_size_(1, 1, 1, 1), + dilation_(1, 1, 1, 1), conv_params_(GuessBestParams(device_info, definition, attr, dst_shape)) {} ConvPowerVR::ConvPowerVR(const OperationDef& definition) : GPUOperation(definition), - stride_padding_(1, 1, 0, 0), - kernel_dilation_(1, 1, 1, 1) {} + stride_(1, 1, 1, 1), + padding_(0, 0, 0, 0), + kernel_size_(1, 1, 1, 1), + dilation_(1, 1, 1, 1) {} ConvPowerVR::ConvPowerVR(ConvPowerVR&& operation) : GPUOperation(std::move(operation)), - stride_padding_(operation.stride_padding_), - kernel_dilation_(operation.kernel_dilation_), + stride_(operation.stride_), + padding_(operation.padding_), + kernel_size_(operation.kernel_size_), + dilation_(operation.dilation_), conv_params_(operation.conv_params_) {} +ConvPowerVR::ConvPowerVR(const OperationDef& definition, + const Convolution3DAttributes& attr, + const DeviceInfo& device_info, const BHWDC* dst_shape) + : GPUOperation(definition), + stride_(attr.strides.w, attr.strides.h, attr.strides.d, 1), + padding_(-attr.padding.prepended.w, -attr.padding.prepended.h, + -attr.padding.prepended.d, 0), + kernel_size_(attr.weights.shape.w, attr.weights.shape.h, + attr.weights.shape.d, 1), + dilation_(attr.dilations.w, attr.dilations.h, attr.dilations.d, 1), + conv_params_(GuessBestParams(device_info, definition, attr, dst_shape)) {} + ConvPowerVR& ConvPowerVR::operator=(ConvPowerVR&& operation) { if (this != &operation) { - std::swap(stride_padding_, operation.stride_padding_); - std::swap(kernel_dilation_, operation.kernel_dilation_); + std::swap(stride_, operation.stride_); + std::swap(padding_, operation.padding_); + std::swap(kernel_size_, operation.kernel_size_); + std::swap(dilation_, operation.dilation_); std::swap(conv_params_, operation.conv_params_); GPUOperation::operator=(std::move(operation)); } @@ -180,63 +222,88 @@ ConvPowerVR& ConvPowerVR::operator=(ConvPowerVR&& operation) { } void ConvPowerVR::GenerateCode(const DeviceInfo& device_info) { + if (conv_params_.linear_spatial) { + grid_dimension_ = 2; + } const bool stride_correction = - definition_.IsBatchSupported() && stride_padding_.x != 1; + definition_.IsBatchSupported() && stride_.x != 1; code_ = GenerateConv(device_info, definition_, stride_correction, conv_params_); if (definition_.precision == CalculationsPrecision::F16 && device_info.IsPowerVR()) { compiler_options_.push_back(CompilerOptions::POWERVR_FP16); } - if (conv_params_.IsPrivateMemBroadcast()) { + if (conv_params_.IsPrivateMemBroadcast() && device_info.IsCL20OrHigher()) { compiler_options_.push_back(CompilerOptions::CL_2_0); } + bool kernel_is_trivial = + conv_params_.x_kernel_is_1 && conv_params_.y_kernel_is_1; + if (definition_.src_tensors[0].HasAxis(Axis::DEPTH)) { + kernel_is_trivial = kernel_is_trivial & conv_params_.z_kernel_is_1; + } + if (device_info.IsAdreno3xx() && + definition_.precision == CalculationsPrecision::F16 && + kernel_is_trivial) { + compiler_options_.push_back(CompilerOptions::ADRENO_FULL_SIMD_LINE); + } } -absl::Status ConvPowerVR::BindArguments() { - if (!conv_params_.x_kernel_is_1 || !conv_params_.y_kernel_is_1) { - RETURN_IF_ERROR(args_.SetInt("stride_x", stride_padding_.x)); - RETURN_IF_ERROR(args_.SetInt("stride_y", stride_padding_.y)); - RETURN_IF_ERROR( - args_.SetInt("padding_x", stride_padding_.z * src_[0]->Batch())); - RETURN_IF_ERROR(args_.SetInt("padding_y", stride_padding_.w)); - RETURN_IF_ERROR(args_.SetInt("kernel_size_x", kernel_dilation_.x)); - RETURN_IF_ERROR(args_.SetInt("kernel_size_y", kernel_dilation_.y)); - RETURN_IF_ERROR( - args_.SetInt("dilation_x", kernel_dilation_.z * src_[0]->Batch())); - RETURN_IF_ERROR(args_.SetInt("dilation_y", kernel_dilation_.w)); +absl::Status ConvPowerVR::BindArguments(ArgumentsBinder* args) { + if (!conv_params_.x_kernel_is_1) { + RETURN_IF_ERROR(args->SetInt("stride_x", stride_.x)); + RETURN_IF_ERROR(args->SetInt("padding_x", padding_.x * src_[0]->Batch())); + RETURN_IF_ERROR(args->SetInt("kernel_size_x", kernel_size_.x)); + RETURN_IF_ERROR(args->SetInt("dilation_x", dilation_.x * src_[0]->Batch())); } - if (conv_params_.linear_hw) { + if (!conv_params_.y_kernel_is_1) { + RETURN_IF_ERROR(args->SetInt("stride_y", stride_.y)); + RETURN_IF_ERROR(args->SetInt("padding_y", padding_.y)); + RETURN_IF_ERROR(args->SetInt("kernel_size_y", kernel_size_.y)); + RETURN_IF_ERROR(args->SetInt("dilation_y", dilation_.y)); + } + if (definition_.src_tensors[0].HasAxis(Axis::DEPTH) && + !conv_params_.z_kernel_is_1) { + RETURN_IF_ERROR(args->SetInt("stride_z", stride_.z)); + RETURN_IF_ERROR(args->SetInt("padding_z", padding_.z)); + RETURN_IF_ERROR(args->SetInt("kernel_size_z", kernel_size_.z)); + RETURN_IF_ERROR(args->SetInt("dilation_z", dilation_.z)); + } + if (conv_params_.linear_spatial) { const int grid_x = DivideRoundUp(dst_[0]->Width() * dst_[0]->Batch(), conv_params_.block_size.x); - RETURN_IF_ERROR(args_.SetInt("task_size_x", grid_x)); + RETURN_IF_ERROR(args->SetInt("task_size_x", grid_x)); + } + if (definition_.src_tensors[0].HasAxis(Axis::DEPTH)) { + const int task_size_y = + DivideRoundUp(dst_[0]->Height(), conv_params_.block_size.y); + RETURN_IF_ERROR(args->SetInt("task_size_y", task_size_y)); } return absl::OkStatus(); } int3 ConvPowerVR::GetGridSize() const { - const int grid_x = DivideRoundUp(dst_[0]->Width() * dst_[0]->Batch(), - conv_params_.block_size.x); - const int grid_y = + const int task_size_x = DivideRoundUp(dst_[0]->Width() * dst_[0]->Batch(), + conv_params_.block_size.x); + const int task_size_y = DivideRoundUp(dst_[0]->Height(), conv_params_.block_size.y); - const int grid_z = - DivideRoundUp(dst_[0]->Slices(), conv_params_.block_size.z); + const int task_size_z = + DivideRoundUp(dst_[0]->Depth(), conv_params_.block_size.z); + const int task_size_s = + DivideRoundUp(dst_[0]->Slices(), conv_params_.block_size.w); int3 wg; - if (conv_params_.linear_hw) { - wg.x = DivideRoundUp(grid_x * grid_y, work_group_size_.x); - wg.y = DivideRoundUp(grid_z, work_group_size_.y); - return int3( - wg[conv_params_.work_group_launch_order[0]] * work_group_size_.x, - wg[conv_params_.work_group_launch_order[1]] * work_group_size_.y, 1); + if (conv_params_.linear_spatial) { + int grid_x = task_size_x * task_size_y; + if (definition_.src_tensors[0].HasAxis(Axis::DEPTH)) { + grid_x *= task_size_z; + } + return int3(grid_x, task_size_s, 1); } else { - wg.x = DivideRoundUp(grid_x, work_group_size_.x); - wg.y = DivideRoundUp(grid_y, work_group_size_.y); - wg.z = DivideRoundUp(grid_z, work_group_size_.z); - return int3( - wg[conv_params_.work_group_launch_order[0]] * work_group_size_.x, - wg[conv_params_.work_group_launch_order[1]] * work_group_size_.y, - wg[conv_params_.work_group_launch_order[2]] * work_group_size_.z); + int grid_y = task_size_y; + if (definition_.src_tensors[0].HasAxis(Axis::DEPTH)) { + grid_y *= task_size_z; + } + return int3(task_size_x, grid_y, task_size_s); } } @@ -251,14 +318,8 @@ void ConvPowerVR::GetPossibleKernelWorkGroups( work_groups->push_back(work_group_size_); return; } - if (conv_params_.work_group_launch_order[0] == 0 && - conv_params_.work_group_launch_order[1] == 1 && - conv_params_.work_group_launch_order[2] == 2) { - GetPossibleWorkGroupsConv(tuning_type, device_info, kernel_info, grid_size_, - work_groups); - } else { - work_groups->push_back(work_group_size_); - } + GetPossibleWorkGroupsConv(tuning_type, device_info, kernel_info, grid_size_, + work_groups); } std::string ConvPowerVR::GenerateConv(const DeviceInfo& device_info, @@ -284,31 +345,80 @@ std::string ConvPowerVR::GenerateConv(const DeviceInfo& device_info, AddSrcBuffer("weights", desc); } + const auto& src_def = op_def.src_tensors[0]; + + auto generate_id = [&](const std::string& x, const std::string& y, + const std::string& z) { + std::string id; + if (src_def.HasAxis(Axis::WIDTH)) { + id += "_w" + x; + } + if (src_def.HasAxis(Axis::HEIGHT)) { + id += "_h" + y; + } + if (src_def.HasAxis(Axis::DEPTH)) { + id += "_d" + z; + } + return id; + }; + + auto generate_id_full = [&](const std::string& x, const std::string& y, + const std::string& z, const std::string& s) { + return generate_id(x, y, z) + "_s" + s; + }; + + auto generate_check = [&](const std::string& x, const std::string& y, + const std::string& z) { + std::string check; + const std::vector axes{Axis::WIDTH, Axis::HEIGHT, Axis::DEPTH}; + const std::vector names{"in_x", "in_y", "in_z"}; + const std::vector is_1{conv_params_.x_kernel_is_1, + conv_params_.y_kernel_is_1, + conv_params_.z_kernel_is_1}; + const std::vector coords{x, y, z}; + for (int i = 0; i < axes.size(); ++i) { + const auto& axis = axes[i]; + if (src_def.HasAxis(axis) && !src_def.SupportsZeroClamp(axis) && + !is_1[i]) { + if (!check.empty()) { + check += " && "; + } + check += names[i] + coords[i]; + } + } + return check; + }; + auto dst_desc = op_def.dst_tensors[0]; if (op_def.IsBatchSupported()) { dst_desc.SetStateVar("BatchedWidth", "true"); } AddDstTensor("dst_tensor", dst_desc); - const bool is1x1 = conv_params_.x_kernel_is_1 && conv_params_.y_kernel_is_1; - if (!is1x1) { + if (!conv_params_.x_kernel_is_1) { args_.AddInt("stride_x"); - args_.AddInt("stride_y"); args_.AddInt("padding_x"); - args_.AddInt("padding_y"); args_.AddInt("kernel_size_x"); - args_.AddInt("kernel_size_y"); args_.AddInt("dilation_x"); + } + if (!conv_params_.y_kernel_is_1) { + args_.AddInt("stride_y"); + args_.AddInt("padding_y"); + args_.AddInt("kernel_size_y"); args_.AddInt("dilation_y"); } - if (conv_params_.linear_hw) { + if (src_def.HasAxis(Axis::DEPTH) && !conv_params_.z_kernel_is_1) { + args_.AddInt("stride_z"); + args_.AddInt("padding_z"); + args_.AddInt("kernel_size_z"); + args_.AddInt("dilation_z"); + } + if (conv_params_.linear_spatial) { args_.AddInt("task_size_x"); } - - const auto src_tensor_type = op_def.src_tensors[0].storage_type; - const bool buffer_type = src_tensor_type == TensorStorageType::BUFFER || - src_tensor_type == TensorStorageType::IMAGE_BUFFER; - const bool manual_clamp = buffer_type && !is1x1; + if (src_def.HasAxis(Axis::DEPTH)) { + args_.AddInt("task_size_y"); + } const bool need_local_mem = conv_params.weights_upload_type == @@ -317,10 +427,10 @@ std::string ConvPowerVR::GenerateConv(const DeviceInfo& device_info, ConvPowerVR::WeightsUploadType::LOCAL_MEM_ASYNC_SUBGROUP; const int local_mem_size = - conv_params.block_size.z * 4 * conv_params.src_depth_loop_size; + conv_params.block_size.w * 4 * conv_params.src_depth_loop_size; const bool use_simd_broadcast = conv_params.IsPrivateMemBroadcast(); - const int simd_size = conv_params.GetSimdSize(); + const int simd_size = conv_params.simd_size; const bool late_oob_check = need_local_mem || use_simd_broadcast; @@ -340,9 +450,11 @@ std::string ConvPowerVR::GenerateConv(const DeviceInfo& device_info, if (use_simd_broadcast) { if (device_info.cl_version == OpenCLVersion::CL_2_0) { c += "#pragma OPENCL EXTENSION cl_khr_subgroups : enable\n"; + } else if (device_info.SupportsExtension("cl_intel_subgroups")) { + c += "#pragma OPENCL EXTENSION cl_intel_subgroups : enable\n"; } } - const int3 block_size = conv_params.block_size; + const int4 block_size = conv_params.block_size; if (conv_params.fixed_work_group_size) { c += "__attribute__((reqd_work_group_size(" + std::to_string(work_group_size_.x) + ", " + @@ -353,28 +465,41 @@ std::string ConvPowerVR::GenerateConv(const DeviceInfo& device_info, c += "__attribute__((intel_reqd_sub_group_size(" + std::to_string(simd_size) + ")))\n"; } + std::string dst_oob_check; + if (src_def.HasAxis(Axis::DEPTH)) { + if (conv_params.linear_spatial) { + dst_oob_check = + "DST_Z >= args.dst_tensor.Depth() || DST_S >= " + "args.dst_tensor.Slices()"; + } else { + dst_oob_check = + "DST_X >= args.dst_tensor.Width() || DST_Z >= " + "args.dst_tensor.Depth() || DST_S >= args.dst_tensor.Slices()"; + } + } else { + if (conv_params.linear_spatial) { + dst_oob_check = + "DST_Y >= args.dst_tensor.Height() || DST_S >= " + "args.dst_tensor.Slices()"; + } else { + dst_oob_check = + "DST_X >= args.dst_tensor.Width() || DST_Y >= " + "args.dst_tensor.Height() || DST_S >= args.dst_tensor.Slices()"; + } + } c += "__kernel void main_function(\n"; c += "$0) {\n"; - c += GenerateBlockCoords(conv_params.block_size, - conv_params.work_group_launch_order, - conv_params.linear_hw); - std::vector dst_x(conv_params.block_size.x); - for (int x = 0; x < conv_params.block_size.x; ++x) { - dst_x[x] = "(X + " + std::to_string(x) + ")"; - } - std::vector dst_y(conv_params.block_size.y); - for (int y = 0; y < conv_params.block_size.y; ++y) { - dst_y[y] = "(Y + " + std::to_string(y) + ")"; - } + c += GenerateBlockCoords(conv_params.block_size, work_group_launch_order_, + conv_params.linear_spatial, + src_def.HasAxis(Axis::DEPTH)); if (!late_oob_check) { - c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() " - "|| Z >= args.dst_tensor.Slices()) {\n"; + c += " if (" + dst_oob_check + ") {\n"; c += " return;\n"; c += " }\n"; } if (conv_params.weights_upload_type == ConvPowerVR::WeightsUploadType::LOCAL_MEM_BY_THREADS) { - if (conv_params.linear_hw) { + if (conv_params.linear_spatial) { c += " int lid = get_local_id(0);\n"; } else { c += " int lid = get_local_id(1) * " + @@ -384,135 +509,263 @@ std::string ConvPowerVR::GenerateConv(const DeviceInfo& device_info, if (use_simd_broadcast) { c += " int simd_id = get_sub_group_local_id();\n"; } - for (int z = 0; z < block_size.z; ++z) { - for (int y = 0; y < block_size.y; ++y) { - for (int x = 0; x < block_size.x; ++x) { - c += " ACCUM_FLT4 r" + std::to_string(z) + std::to_string(y) + - std::to_string(x) + " = (ACCUM_FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n"; + for (int s = 0; s < block_size.w; ++s) { + const std::string sind = std::to_string(s); + for (int z = 0; z < block_size.z; ++z) { + const std::string zind = std::to_string(z); + for (int y = 0; y < block_size.y; ++y) { + const std::string yind = std::to_string(y); + for (int x = 0; x < block_size.x; ++x) { + const std::string xind = std::to_string(x); + c += " ACCUM_FLT4 r" + generate_id_full(xind, yind, zind, sind) + + " = (ACCUM_FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n"; + } } } } - if (!is1x1) { + if (!conv_params_.x_kernel_is_1) { for (int x = 0; x < block_size.x; ++x) { + const std::string xind = std::to_string(x); + const std::string xc = "(DST_X + " + xind + ")"; if (stride_correction) { - c += " int xc" + std::to_string(x) + " = " + - GetXStrideCorrected(dst_x[x], "args.src_tensor.Batch()", - "args.stride_x", "args.padding_x") + + c += " int xc" + xind + " = " + + GetXStrideCorrected(xc, "args.src_tensor.Batch()", "args.stride_x", + "args.padding_x") + ";\n"; } else { - c += " int xc" + std::to_string(x) + " = " + dst_x[x] + + c += " int xc" + xind + " = " + xc + " * args.stride_x + args.padding_x;\n"; } } + } else { + for (int x = 0; x < block_size.x; ++x) { + const std::string xind = std::to_string(x); + c += " int xc" + xind + " = DST_X + " + xind + ";\n"; + if (!src_def.CanReadOutOfBorder(Axis::WIDTH)) { + c += " xc" + xind + " = clamp(xc" + xind + + ", 0, args.src_tensor.Width() - 1);\n"; + } + } + } + if (!conv_params_.y_kernel_is_1) { for (int y = 0; y < block_size.y; ++y) { - c += " int yc" + std::to_string(y) + " = " + dst_y[y] + + const std::string yind = std::to_string(y); + const std::string yc = "(DST_Y + " + yind + ")"; + c += " int yc" + yind + " = " + yc + " * args.stride_y + args.padding_y;\n"; } + } else { + for (int y = 0; y < block_size.y; ++y) { + const std::string yind = std::to_string(y); + c += " int yc" + yind + " = DST_Y + " + yind + ";\n"; + if (!src_def.CanReadOutOfBorder(Axis::HEIGHT)) { + c += " yc" + yind + " = clamp(yc" + yind + + ", 0, args.src_tensor.Height() - 1);\n"; + } + } + } + if (src_def.HasAxis(Axis::DEPTH)) { + if (!conv_params_.z_kernel_is_1) { + for (int z = 0; z < block_size.z; ++z) { + const std::string zind = std::to_string(z); + const std::string zc = "(DST_Z + " + zind + ")"; + c += " int zc" + zind + " = " + zc + + " * args.stride_z + args.padding_z;\n"; + } + } else { + for (int z = 0; z < block_size.z; ++z) { + const std::string zind = std::to_string(z); + c += " int zc" + zind + " = DST_Z + " + zind + ";\n"; + if (!src_def.CanReadOutOfBorder(Axis::DEPTH)) { + c += " zc" + zind + " = clamp(zc" + zind + + ", 0, args.src_tensor.Depth() - 1);\n"; + } + } + } + } + bool trivial_kernel_size = + conv_params_.x_kernel_is_1 && conv_params_.y_kernel_is_1; + if (src_def.HasAxis(Axis::DEPTH)) { + trivial_kernel_size = trivial_kernel_size && conv_params_.z_kernel_is_1; } if (need_local_mem) { c += " __local " + weights_data_type + " weights_cache[" + std::to_string(local_mem_size) + "];\n"; - } else { + } else if (conv_params.AreWeightsBuffer()) { c += " " + weights_global_ptr + " weights_cache;\n"; + } else if (!trivial_kernel_size) { + c += " int filter_offset = 0;\n"; } - if (is1x1) { + if (conv_params.AreWeightsBuffer()) { if (conv_params.different_weights_for_height) { c += " " + weights_global_ptr + - " filters_loc = args.weights.GetPtr() + (Z * " - "args.src_tensor.Height() + Y * " + - std::to_string(block_size.z) + ") * 4 * args.src_tensor.Slices();\n"; + " filters_loc = args.weights.GetPtr() + (DST_S * " + "args.src_tensor.Height() + DST_Y * " + + std::to_string(block_size.w) + ") * 4 * args.src_tensor.Slices();\n"; } else { + std::string kernel_spatial_offset = ""; + if (!conv_params_.x_kernel_is_1) { + kernel_spatial_offset += " * args.kernel_size_x"; + } + if (!conv_params_.y_kernel_is_1) { + kernel_spatial_offset += " * args.kernel_size_y"; + } + if (src_def.HasAxis(Axis::DEPTH) && !conv_params_.z_kernel_is_1) { + kernel_spatial_offset += " * args.kernel_size_z"; + } c += " " + weights_global_ptr + - " filters_loc = args.weights.GetPtr() + Z * 4 * " - "args.src_tensor.Slices();\n"; + " filters_loc = args.weights.GetPtr() + DST_S * 4 * " + "args.src_tensor.Slices()" + + kernel_spatial_offset + ";\n"; } - } else { - c += " " + weights_global_ptr + - " filters_loc = args.weights.GetPtr() + Z * 4 * " - "args.src_tensor.Slices() *args.kernel_size_x * args.kernel_size_y;\n"; } - if (buffer_type) { - c += " const int src_layer_offset = args.src_tensor.SliceStride();\n"; + if (src_def.HasAxis(Axis::DEPTH) && !conv_params_.z_kernel_is_1) { + c += " for (int kz = 0; kz < args.kernel_size_z; ++kz) {\n"; + for (int z = 0; z < block_size.z; ++z) { + const std::string zck = "zck" + std::to_string(z); + c += " int zck" + std::to_string(z) + " = kz * args.dilation_z + zc" + + std::to_string(z) + ";\n"; + if (!src_def.SupportsZeroClamp(Axis::DEPTH)) { + c += " bool in_z" + std::to_string(z) + " = " + zck + " >= 0 && " + + zck + " < args.src_tensor.Depth();\n"; + if (!src_def.CanReadOutOfBorder(Axis::DEPTH)) { + c += " " + zck + " = clamp(" + zck + + ", 0, args.src_tensor.Depth() - 1);\n"; + } + } + } } - if (!is1x1) { + if (!conv_params_.y_kernel_is_1) { c += " for (int ky = 0; ky < args.kernel_size_y; ++ky) {\n"; for (int y = 0; y < block_size.y; ++y) { const std::string yck = "yck" + std::to_string(y); c += " int " + yck + " = ky * args.dilation_y + yc" + std::to_string(y) + ";\n"; - if (manual_clamp) { - c += " bool my" + std::to_string(y) + " = " + yck + " >= 0 && " + yck + - " < args.src_tensor.Height();\n"; - c += " " + yck + " = clamp(" + yck + - ", 0, args.src_tensor.Height() - 1);\n"; + if (!src_def.SupportsZeroClamp(Axis::HEIGHT)) { + c += " bool in_y" + std::to_string(y) + " = " + yck + " >= 0 && " + + yck + " < args.src_tensor.Height();\n"; + if (!src_def.CanReadOutOfBorder(Axis::HEIGHT)) { + c += " " + yck + " = clamp(" + yck + + ", 0, args.src_tensor.Height() - 1);\n"; + } } } + } + if (!conv_params_.x_kernel_is_1) { c += " for (int kx = 0; kx < args.kernel_size_x; ++kx) {\n"; for (int x = 0; x < block_size.x; ++x) { const std::string xck = "xck" + std::to_string(x); c += " int xck" + std::to_string(x) + " = kx * args.dilation_x + xc" + std::to_string(x) + ";\n"; - if (manual_clamp) { - c += " bool mx" + std::to_string(x) + " = " + xck + " >= 0 && " + xck + - " < args.src_tensor.Width();\n"; - c += " " + xck + " = clamp(" + xck + - ", 0, args.src_tensor.Width() - 1);\n"; + if (!src_def.SupportsZeroClamp(Axis::WIDTH)) { + c += " bool in_x" + std::to_string(x) + " = " + xck + " >= 0 && " + + xck + " < args.src_tensor.Width();\n"; + if (!src_def.CanReadOutOfBorder(Axis::WIDTH)) { + c += " " + xck + " = clamp(" + xck + + ", 0, args.src_tensor.Width() - 1);\n"; + } } } } - if (buffer_type) { + const bool need_multiple_slice_strides = + src_def.ReturnsZeroForNegOneRead() && !trivial_kernel_size; + for (int z = 0; z < block_size.z; ++z) { + const std::string zind = std::to_string(z); for (int y = 0; y < block_size.y; ++y) { - const std::string yck = "yck" + std::to_string(y); + const std::string yind = std::to_string(y); for (int x = 0; x < block_size.x; ++x) { - const std::string xck = "xck" + std::to_string(x); - std::string xc = - is1x1 ? "min(" + dst_x[x] + ", args.src_tensor.Width() - 1)" : xck; - std::string yc = - is1x1 ? "min(" + dst_y[y] + ", args.src_tensor.Height() - 1)" : yck; - std::string id = std::to_string(y) + std::to_string(x); - c += " int src_a_" + id + " = " + yc + - " * args.src_tensor.Width() + " + xc + ";\n"; + const std::string xind = std::to_string(x); + std::string xc = conv_params.x_kernel_is_1 ? "xc" + xind : "xck" + xind; + std::string yc = conv_params.y_kernel_is_1 ? "yc" + yind : "yck" + yind; + const std::string id = generate_id(xind, yind, zind); + std::string coords = "" + xc + ", " + yc; + if (src_def.HasAxis(Axis::DEPTH)) { + std::string zc = + conv_params.z_kernel_is_1 ? "zc" + zind : "zck" + zind; + coords += ", " + zc; + } + if (src_def.IsLinear()) { + c += " args.src_tensor.GetAddress(addr" + id + ", " + coords + + ", 0);\n"; + if (need_multiple_slice_strides) { + const std::string check = generate_check(xind, yind, zind); + c += " addr" + id + " = select(-1, addr" + id + ", (" + check + + "));\n"; + c += " int ds" + id + + " = select(0, args.src_tensor.SliceStride(), (" + check + + "));\n"; + } + } } } } + if (src_def.IsLinear() && !need_multiple_slice_strides) { + c += " int ds = args.src_tensor.SliceStride();\n"; + } auto declare_src = [&]() { - for (int y = 0; y < block_size.y; ++y) { - for (int x = 0; x < block_size.x; ++x) { - const std::string id = std::to_string(y) + std::to_string(x); - c += " " + weights_data_type + " src" + id + ";\n"; + for (int z = 0; z < block_size.z; ++z) { + const std::string zind = std::to_string(z); + for (int y = 0; y < block_size.y; ++y) { + const std::string yind = std::to_string(y); + for (int x = 0; x < block_size.x; ++x) { + const std::string xind = std::to_string(x); + const std::string id = generate_id(xind, yind, zind); + c += " " + weights_data_type + " src" + id + ";\n"; + } } } }; const bool conditional_read = device_info.IsMali(); auto read_src = [&]() { const std::string cl_type = ToCLDataType(conv_params.weights_data_type); - for (int y = 0; y < block_size.y; ++y) { - for (int x = 0; x < block_size.x; ++x) { - if (buffer_type) { - std::string id = std::to_string(y) + std::to_string(x); - if (is1x1) { - c += " src" + id + " = args.src_tensor.Read<" + cl_type + - ">(src_a_" + id + ");\n"; + for (int z = 0; z < block_size.z; ++z) { + const std::string zind = std::to_string(z); + for (int y = 0; y < block_size.y; ++y) { + const std::string yind = std::to_string(y); + for (int x = 0; x < block_size.x; ++x) { + const std::string xind = std::to_string(x); + std::string id = generate_id(xind, yind, zind); + const std::string check = generate_check(xind, yind, zind); + std::string address; + if (src_def.IsLinear()) { + address = "addr" + id; } else { - std::string condition = - "mx" + std::to_string(x) + " && my" + std::to_string(y); - if (conditional_read) { - c += " src" + id + " = " + condition + - " ? args.src_tensor.Read<" + cl_type + ">(src_a_" + id + - ") : (FLT4)(0.0f);\n"; + std::string xc = + conv_params.x_kernel_is_1 ? "xc" + xind : "xck" + xind; + std::string yc = + conv_params.y_kernel_is_1 ? "yc" + yind : "yck" + yind; + address = "" + xc + ", " + yc; + if (src_def.HasAxis(Axis::DEPTH)) { + std::string zc = + conv_params.z_kernel_is_1 ? "zc" + zind : "zck" + zind; + address += ", " + zc; + } + address += ", s"; + } + if (src_def.ReturnsZeroForNegOneRead()) { + c += " src" + id + " = args.src_tensor.Read<" + cl_type + ">(" + + address + ");\n"; + const std::string ds = trivial_kernel_size ? "ds" : "ds" + id; + c += " " + address + " += " + ds + ";\n"; + } else { + if (!check.empty()) { + if (conditional_read) { + c += " src" + id + " = " + check + + " ? args.src_tensor.Read<" + cl_type + ">(" + address + + ") : (FLT4)(0.0f);\n"; + } else { + c += " src" + id + " = args.src_tensor.Read<" + cl_type + + ">(" + address + ") * (FLT)(" + check + ");\n"; + } } else { c += " src" + id + " = args.src_tensor.Read<" + cl_type + - ">(src_a_" + id + ") * (FLT)(" + condition + ");\n"; + ">(" + address + ");\n"; + } + if (src_def.IsLinear()) { + c += " " + address + " += ds;\n"; } } - c += " src_a_" + id + " += src_layer_offset;\n"; - } else { - std::string id = std::to_string(y) + std::to_string(x); - const std::string xc = is1x1 ? dst_x[x] : "xck" + std::to_string(x); - const std::string yc = is1x1 ? dst_y[y] : "yck" + std::to_string(y); - c += " src" + id + " = args.src_tensor.Read<" + cl_type + ">(" + - xc + ", " + yc + ", s);\n"; } } } @@ -522,59 +775,80 @@ std::string ConvPowerVR::GenerateConv(const DeviceInfo& device_info, conv_params.weights_data_type == DataType::FLOAT16); auto conv_core = [&](int shared_offset) { const std::string channels[] = {"x", "y", "z", "w"}; - for (int z = 0; z < block_size.z; ++z) { + for (int s = 0; s < block_size.w; ++s) { + const std::string sind = std::to_string(s); if (weights_type_as_accum_type) { for (int ch = 0; ch < 4; ++ch) { - for (int y = 0; y < block_size.y; ++y) { - for (int x = 0; x < block_size.x; ++x) { - std::string id = std::to_string(y) + std::to_string(x); - if (use_simd_broadcast) { - int simd_id = (z * 4 + ch + shared_offset) / simd_size; - int thread_id = (z * 4 + ch + shared_offset) % simd_size; - std::string w_val_x = "sub_group_broadcast(simd_w" + - std::to_string(simd_id) + ".x, " + - std::to_string(thread_id) + "u)"; - std::string w_val_y = "sub_group_broadcast(simd_w" + - std::to_string(simd_id) + ".y, " + - std::to_string(thread_id) + "u)"; - std::string w_val_z = "sub_group_broadcast(simd_w" + - std::to_string(simd_id) + ".z, " + - std::to_string(thread_id) + "u)"; - std::string w_val_w = "sub_group_broadcast(simd_w" + - std::to_string(simd_id) + ".w, " + - std::to_string(thread_id) + "u)"; - c += " r" + std::to_string(z) + id + ".x += " + w_val_x + - " * src" + id + "." + channels[ch] + ";\n"; - c += " r" + std::to_string(z) + id + ".y += " + w_val_y + - " * src" + id + "." + channels[ch] + ";\n"; - c += " r" + std::to_string(z) + id + ".z += " + w_val_z + - " * src" + id + "." + channels[ch] + ";\n"; - c += " r" + std::to_string(z) + id + ".w += " + w_val_w + - " * src" + id + "." + channels[ch] + ";\n"; - } else { - std::string w_val = "weights_cache[" + - std::to_string(z * 4 + ch + shared_offset) + - "]"; - c += " r" + std::to_string(z) + id + " += " + w_val + - " * src" + id + "." + channels[ch] + ";\n"; + for (int z = 0; z < block_size.z; ++z) { + const std::string zind = std::to_string(z); + for (int y = 0; y < block_size.y; ++y) { + const std::string yind = std::to_string(y); + for (int x = 0; x < block_size.x; ++x) { + const std::string xind = std::to_string(x); + std::string R = "r" + generate_id_full(xind, yind, zind, sind); + std::string S = "src" + generate_id(xind, yind, zind); + if (use_simd_broadcast) { + int simd_id = (s * 4 + ch + shared_offset) / simd_size; + int thread_id = (s * 4 + ch + shared_offset) % simd_size; + std::string w_val_x = "sub_group_broadcast(simd_w" + + std::to_string(simd_id) + ".x, " + + std::to_string(thread_id) + "u)"; + std::string w_val_y = "sub_group_broadcast(simd_w" + + std::to_string(simd_id) + ".y, " + + std::to_string(thread_id) + "u)"; + std::string w_val_z = "sub_group_broadcast(simd_w" + + std::to_string(simd_id) + ".z, " + + std::to_string(thread_id) + "u)"; + std::string w_val_w = "sub_group_broadcast(simd_w" + + std::to_string(simd_id) + ".w, " + + std::to_string(thread_id) + "u)"; + c += " " + R + ".x += " + w_val_x + " * " + S + "." + + channels[ch] + ";\n"; + c += " " + R + ".y += " + w_val_y + " * " + S + "." + + channels[ch] + ";\n"; + c += " " + R + ".z += " + w_val_z + " * " + S + "." + + channels[ch] + ";\n"; + c += " " + R + ".w += " + w_val_w + " * " + S + "." + + channels[ch] + ";\n"; + } else { + const std::string weight_id = + std::to_string(s * 4 + ch + shared_offset); + std::string w_val; + if (conv_params.AreWeightsBuffer()) { + w_val = "weights_cache[" + weight_id + "]"; + } else { + w_val = "f" + weight_id; + } + c += " " + R + " += " + w_val + " * " + S + "." + + channels[ch] + ";\n"; + } } } } } } else { // F32_F16 precision and weights type is float16 - for (int y = 0; y < block_size.y; ++y) { - for (int x = 0; x < block_size.x; ++x) { - std::string id = std::to_string(y) + std::to_string(x); - std::string R = "r" + std::to_string(z) + id; - std::string S = "src" + id; - const int dz = z * 4 + shared_offset; - std::string f0 = "weights_cache[" + std::to_string(dz + 0) + "]"; - std::string f1 = "weights_cache[" + std::to_string(dz + 1) + "]"; - std::string f2 = "weights_cache[" + std::to_string(dz + 2) + "]"; - std::string f3 = "weights_cache[" + std::to_string(dz + 3) + "]"; - c += " " + R + " += convert_float4(" + S + ".x * " + f0 + " + " + - S + ".y * " + f1 + " + " + S + ".z * " + f2 + " + " + S + - ".w * " + f3 + ");\n"; + for (int z = 0; z < block_size.z; ++z) { + const std::string zind = std::to_string(z); + for (int y = 0; y < block_size.y; ++y) { + const std::string yind = std::to_string(y); + for (int x = 0; x < block_size.x; ++x) { + const std::string xind = std::to_string(x); + std::string R = "r" + generate_id_full(xind, yind, zind, sind); + std::string S = "src" + generate_id(xind, yind, zind); + std::vector F(4); + for (int i = 0; i < 4; ++i) { + std::string weight_id = + std::to_string(s * 4 + i + shared_offset); + if (conv_params.AreWeightsBuffer()) { + F[i] = "weights_cache[" + weight_id + "]"; + } else { + F[i] = "f" + weight_id; + } + } + c += " " + R + " += convert_float4(" + S + ".x * " + F[0] + + " + " + S + ".y * " + F[1] + " + " + S + ".z * " + F[2] + + " + " + S + ".w * " + F[3] + ");\n"; + } } } } @@ -611,8 +885,26 @@ std::string ConvPowerVR::GenerateConv(const DeviceInfo& device_info, "];\n"; c += " }\n"; } - } else { // GLOBAL_MEM/CONSTANT_MEM + } else if (conv_params.AreWeightsBuffer()) { // GLOBAL_MEM/CONSTANT_MEM c += " weights_cache = filters_loc;\n"; + } else { // TEXTURES_MEM + for (int dst_s = 0; dst_s < block_size.w; ++dst_s) { + std::string f_y = trivial_kernel_size ? "s" : "filter_offset"; + if (conv_params.different_weights_for_height) { + f_y = "DST_Y * args.src_tensor.Slices() + s"; + } + c += absl::Substitute( + R"( FLT4 f$2 = args.weights0.Read(DST_S + $0, $1); + FLT4 f$3 = args.weights1.Read(DST_S + $0, $1); + FLT4 f$4 = args.weights2.Read(DST_S + $0, $1); + FLT4 f$5 = args.weights3.Read(DST_S + $0, $1); +)", + dst_s, f_y, dst_s * 4 + 0, dst_s * 4 + 1, dst_s * 4 + 2, + dst_s * 4 + 3); + } + if (!trivial_kernel_size) { + c += " filter_offset++;\n"; + } } read_src(); c += " s += 1;\n"; @@ -623,61 +915,96 @@ std::string ConvPowerVR::GenerateConv(const DeviceInfo& device_info, conv_core(0); for (int i = 1; i < conv_params.src_depth_loop_size; ++i) { read_src(); - conv_core(i * block_size.z * 4); + conv_core(i * block_size.w * 4); c += " s += 1;\n"; } - c += " filters_loc += " + std::to_string(local_mem_size) + ";\n"; + if (conv_params.AreWeightsBuffer()) { + c += " filters_loc += " + std::to_string(local_mem_size) + ";\n"; + } c += " } while (s < args.src_tensor.Slices());\n"; - if (!is1x1) { - c += " };\n"; + if (!conv_params.x_kernel_is_1) { c += " };\n"; } - if (conv_params.weights_upload_type == - ConvPowerVR::WeightsUploadType::LOCAL_MEM_ASYNC_SUBGROUP) { - c += GenerateAsyncUpload("weights_cache", "args.biases.GetPtr()", "Z", - block_size.z); - } else if (conv_params.weights_upload_type == - ConvPowerVR::WeightsUploadType::LOCAL_MEM_BY_THREADS) { - c += " barrier(CLK_LOCAL_MEM_FENCE);\n"; - c += GenerateUploadByThreads("weights_cache", "args.biases.GetPtr()", "Z", - "lid", total_work_items, block_size.z); - c += " barrier(CLK_LOCAL_MEM_FENCE);\n"; - } else { - c += " weights_cache = args.biases.GetPtr() + Z;\n"; + if (!conv_params.y_kernel_is_1) { + c += " };\n"; + } + if (src_def.HasAxis(Axis::DEPTH) && !conv_params_.z_kernel_is_1) { + c += " };\n"; + } + if (conv_params.AreWeightsBuffer()) { + if (conv_params.weights_upload_type == + ConvPowerVR::WeightsUploadType::LOCAL_MEM_ASYNC_SUBGROUP) { + c += GenerateAsyncUpload("weights_cache", "args.biases.GetPtr()", "DST_S", + block_size.w); + } else if (conv_params.weights_upload_type == + ConvPowerVR::WeightsUploadType::LOCAL_MEM_BY_THREADS) { + c += " barrier(CLK_LOCAL_MEM_FENCE);\n"; + c += GenerateUploadByThreads("weights_cache", "args.biases.GetPtr()", + "DST_S", "lid", total_work_items, + block_size.w); + c += " barrier(CLK_LOCAL_MEM_FENCE);\n"; + } else { + c += " weights_cache = args.biases.GetPtr() + DST_S;\n"; + } } if (late_oob_check) { - c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() " - "|| Z >= args.dst_tensor.Slices()) {\n"; + c += " if (" + dst_oob_check + ") {\n"; c += " return;\n"; c += " }\n"; } - for (int z = 0; z < block_size.z; ++z) { - const std::string sz = std::to_string(z); - c += " if (Z + " + sz + " >= args.dst_tensor.Slices()) return;\n"; - c += " {\n"; - c += " FLT4 bias_val = TO_FLT4(weights_cache[" + sz + "]);\n"; - for (int y = 0; y < block_size.y; ++y) { - for (int x = 0; x < block_size.x; ++x) { - const std::string xs = dst_x[x]; - const std::string ys = dst_y[y]; - const std::string zs = "Z + " + sz; - const std::string r_id = sz + std::to_string(y) + std::to_string(x); - bool need_x_check = x != 0; - bool need_y_check = y != 0; - if (need_x_check && need_y_check) { - c += " if (" + xs + " < args.dst_tensor.Width() && " + ys + - " < args.dst_tensor.Height()) {\n"; - } else if (need_x_check && !need_y_check) { - c += " if (" + xs + " < args.dst_tensor.Width()) {\n"; - } else if (!need_x_check && need_y_check) { - c += " if (" + ys + " < args.dst_tensor.Height()) {\n"; - } else { - c += " {\n"; + + auto generate_dst_check = [&](int x, int y, int z) { + std::string check; + const std::vector axes{Axis::WIDTH, Axis::HEIGHT, Axis::DEPTH}; + const std::vector names{"Width()", "Height()", "Depth()"}; + std::vector coords(3); + coords[0] = "DST_X + " + std::to_string(x); + coords[1] = "DST_Y + " + std::to_string(y); + coords[2] = "DST_Z + " + std::to_string(z); + const std::vector ids{x, y, z}; + for (int i = 0; i < axes.size(); ++i) { + const auto& axis = axes[i]; + if (src_def.HasAxis(axis) && ids[i] != 0) { + if (!check.empty()) { + check += " && "; + } + check += coords[i] + " < args.dst_tensor." + names[i]; + } + } + return check; + }; + + for (int s = 0; s < block_size.w; ++s) { + const std::string sind = std::to_string(s); + c += " if (DST_S + " + sind + " >= args.dst_tensor.Slices()) return;\n"; + c += " {\n"; + if (conv_params.AreWeightsBuffer()) { + c += " FLT4 bias_val = TO_FLT4(weights_cache[" + sind + "]);\n"; + } else { + c += " FLT4 bias_val = args.biases.Read(DST_S + " + sind + ");\n"; + } + for (int z = 0; z < block_size.z; ++z) { + const std::string zind = std::to_string(z); + for (int y = 0; y < block_size.y; ++y) { + const std::string yind = std::to_string(y); + for (int x = 0; x < block_size.x; ++x) { + const std::string xind = std::to_string(x); + const std::string id = generate_id_full(xind, yind, zind, sind); + const std::string check = generate_dst_check(x, y, z); + std::string coords = "DST_X + " + xind + ", DST_Y + " + yind; + if (src_def.HasAxis(Axis::DEPTH)) { + coords += ", DST_Z + " + zind; + } + coords += ", DST_S + " + sind; + if (!check.empty()) { + c += " if (" + check + ") {\n"; + } else { + c += " {\n"; + } + c += " FLT4 res = TO_FLT4(r" + id + ") + bias_val;\n"; + c += " args.dst_tensor.Write(res, " + coords + ");\n"; + c += " }\n"; } - c += " FLT4 res = TO_FLT4(r" + r_id + ") + bias_val;\n"; - c += " args.dst_tensor.Write(res, " + xs + ", " + ys + ", " + zs + - ");\n"; - c += " }\n"; } } c += " }\n"; @@ -691,7 +1018,7 @@ ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams( int src_depth, int dst_depth, bool x_kernel_is_1, bool y_kernel_is_1, bool different_weights_for_height, const BHWC* dst_shape) { ConvParams conv_params; - conv_params.linear_hw = false; + conv_params.linear_spatial = false; conv_params.weights_data_type = DeduceDataTypeFromPrecision(definition.precision); conv_params.x_kernel_is_1 = x_kernel_is_1; @@ -700,84 +1027,84 @@ ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams( if (device_info.IsNvidia()) { if (different_weights_for_height) { work_group_size_ = int3(32, 1, 1); - conv_params.work_group_launch_order = int3(2, 0, 1); + work_group_launch_order_ = int3(2, 0, 1); conv_params.fixed_work_group_size = true; } else { - conv_params.linear_hw = true; + conv_params.linear_spatial = true; work_group_size_ = int3(32, 1, 1); - conv_params.work_group_launch_order = int3(1, 0, 2); + work_group_launch_order_ = int3(1, 0, 2); conv_params.fixed_work_group_size = true; } - conv_params.block_size = int3(2, 1, 4); + conv_params.block_size = int4(2, 1, 1, 4); conv_params.src_depth_loop_size = 1; conv_params.weights_upload_type = WeightsUploadType::LOCAL_MEM_BY_THREADS; if (dst_depth % 4 == 0 || dst_depth >= 8) { - conv_params.block_size.z = 4; + conv_params.block_size.w = 4; } else if (dst_depth % 2 == 0 || dst_depth >= 4) { - conv_params.block_size.z = 2; + conv_params.block_size.w = 2; } else { - conv_params.block_size.z = dst_depth; + conv_params.block_size.w = dst_depth; } if (dst_shape) { int task_size = dst_shape->w * dst_shape->b * dst_shape->h * dst_depth; float task_size_per_cu = static_cast(task_size) / device_info.compute_units_count; int block_size = conv_params.block_size.x * conv_params.block_size.y * - conv_params.block_size.z; + conv_params.block_size.w; float threads_per_cu = task_size_per_cu / block_size; float warps_per_cu = threads_per_cu / 32 /*warp_size*/; if (warps_per_cu < 8.0f) { conv_params.block_size.x = 1; } - if (warps_per_cu < 4.0f && conv_params.block_size.z >= 4) { - conv_params.block_size.z /= 2; + if (warps_per_cu < 4.0f && conv_params.block_size.w >= 4) { + conv_params.block_size.w /= 2; } - if (warps_per_cu < 2.0f && conv_params.block_size.z >= 2) { - conv_params.block_size.z /= 2; + if (warps_per_cu < 2.0f && conv_params.block_size.w >= 2) { + conv_params.block_size.w /= 2; } } if (src_depth % 2 == 0) { conv_params.src_depth_loop_size = 2; } - if (src_depth % 4 == 0 && conv_params.block_size.z <= 2) { + if (src_depth % 4 == 0 && conv_params.block_size.w <= 2) { conv_params.src_depth_loop_size = 4; } } else if (device_info.IsPowerVR()) { if (different_weights_for_height) { work_group_size_ = int3(32, 1, 1); - conv_params.work_group_launch_order = int3(2, 0, 1); + work_group_launch_order_ = int3(2, 0, 1); conv_params.fixed_work_group_size = true; } else { - conv_params.linear_hw = true; + conv_params.linear_spatial = true; work_group_size_ = int3(32, 1, 1); - conv_params.work_group_launch_order = int3(1, 0, 2); + work_group_launch_order_ = int3(1, 0, 2); conv_params.fixed_work_group_size = true; } conv_params.weights_data_type = definition.precision == CalculationsPrecision::F16 ? DataType::FLOAT16 : DataType::FLOAT32; - conv_params.block_size = int3(1, 1, 4); + conv_params.block_size = int4(1, 1, 1, 4); conv_params.src_depth_loop_size = 1; conv_params.weights_upload_type = WeightsUploadType::LOCAL_MEM_ASYNC_SUBGROUP; if (dst_depth % 8 == 0 || dst_depth >= 32) { - conv_params.block_size.z = 8; + conv_params.block_size.w = 8; } else if (dst_depth % 4 == 0 || dst_depth >= 8) { - conv_params.block_size.z = 4; + conv_params.block_size.w = 4; } else if (dst_depth % 2 == 0 || dst_depth >= 4) { - conv_params.block_size.z = 2; + conv_params.block_size.w = 2; } else { - conv_params.block_size.z = dst_depth; + conv_params.block_size.w = dst_depth; } if (definition.precision == CalculationsPrecision::F16) { - conv_params.block_size.z = std::min(4, conv_params.block_size.z); + conv_params.block_size.w = std::min(4, conv_params.block_size.w); if (src_depth % 2 == 0) { conv_params.src_depth_loop_size = 2; } - if (src_depth % 4 == 0 && conv_params.block_size.z <= 2) { + if (src_depth % 4 == 0 && conv_params.block_size.w <= 2) { conv_params.src_depth_loop_size = 4; } - if (conv_params.block_size.z == 1) { + if (conv_params.block_size.w == 1) { if (src_depth % 2 == 0) { conv_params.src_depth_loop_size = 2; } @@ -793,28 +1120,28 @@ ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams( } else if (device_info.IsAMD()) { if (different_weights_for_height) { work_group_size_ = int3(32, 1, 1); - conv_params.work_group_launch_order = int3(2, 0, 1); + work_group_launch_order_ = int3(2, 0, 1); conv_params.fixed_work_group_size = true; } else { work_group_size_ = int3(8, 4, 1); - conv_params.work_group_launch_order = int3(2, 0, 1); + work_group_launch_order_ = int3(2, 0, 1); conv_params.fixed_work_group_size = true; } - conv_params.block_size = int3(2, 1, 1); + conv_params.block_size = int4(2, 1, 1, 1); if (x_kernel_is_1 && y_kernel_is_1) { conv_params.block_size.y = 2; } conv_params.src_depth_loop_size = 1; conv_params.weights_upload_type = WeightsUploadType::CONSTANT_MEM; if (dst_depth % 8 == 0 || dst_depth >= 32) { - conv_params.block_size.z = 8; + conv_params.block_size.w = 8; } else if (dst_depth % 4 == 0 || dst_depth >= 8) { - conv_params.block_size.z = 4; + conv_params.block_size.w = 4; } else if (dst_depth % 2 == 0 || dst_depth >= 4) { - conv_params.block_size.z = 2; + conv_params.block_size.w = 2; } else { - conv_params.block_size.z = 1; + conv_params.block_size.w = 1; } if (src_depth % 2 == 0 && src_depth >= 16) { conv_params.src_depth_loop_size = 2; @@ -831,20 +1158,20 @@ ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams( } if (block_size == 8) { if (dst_depth == 1 || dst_depth == 3) { - conv_params.block_size = int3(2, 2, 1); + conv_params.block_size = int4(2, 2, 1, 1); } else { - conv_params.block_size = int3(2, 2, 2); + conv_params.block_size = int4(2, 2, 1, 2); } } else if (block_size == 4) { if (dst_depth == 1 || dst_depth == 3) { - conv_params.block_size = int3(2, 2, 1); + conv_params.block_size = int4(2, 2, 1, 1); } else { - conv_params.block_size = int3(2, 1, 2); + conv_params.block_size = int4(2, 1, 1, 2); } } else if (block_size == 2) { - conv_params.block_size = int3(2, 1, 1); + conv_params.block_size = int4(2, 1, 1, 1); } else { - conv_params.block_size = int3(1, 1, 1); + conv_params.block_size = int4(1, 1, 1, 1); } conv_params.src_depth_loop_size = 1; MaliInfo mali_info = device_info.mali_info; @@ -856,70 +1183,88 @@ ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams( conv_params.src_depth_loop_size = 4; } work_group_size_ = int3(4, 4, 1); - conv_params.work_group_launch_order = int3(0, 1, 2); + work_group_launch_order_ = int3(0, 1, 2); conv_params.fixed_work_group_size = false; conv_params.weights_upload_type = WeightsUploadType::GLOBAL_MEM; } else if (device_info.IsAdreno()) { - conv_params.block_size = int3(2, 2, 1); + conv_params.block_size = int4(2, 2, 1, 2); + if (device_info.IsAdreno3xx()) { + if (definition.precision == CalculationsPrecision::F16) { + conv_params.block_size = int4(2, 2, 1, 2); + } else if (definition.precision == CalculationsPrecision::F32_F16) { + conv_params.block_size = int4(2, 1, 1, 2); + } else { // F32 + conv_params.block_size = int4(2, 2, 1, 1); + } + } work_group_size_ = int3(8, 2, 1); - conv_params.work_group_launch_order = int3(0, 1, 2); + work_group_launch_order_ = int3(0, 1, 2); conv_params.fixed_work_group_size = false; conv_params.src_depth_loop_size = 1; - conv_params.weights_upload_type = WeightsUploadType::GLOBAL_MEM; + if (definition.src_tensors.size() == 2) { + // dynamic weights supported only with buffers. + conv_params.weights_upload_type = WeightsUploadType::GLOBAL_MEM; + } else { + conv_params.weights_upload_type = WeightsUploadType::TEXTURES_MEM_X4; + } } else if (device_info.IsIntel()) { if (different_weights_for_height) { work_group_size_ = int3(16, 1, 1); - conv_params.work_group_launch_order = int3(0, 1, 2); + work_group_launch_order_ = int3(0, 1, 2); conv_params.fixed_work_group_size = true; } else { - conv_params.linear_hw = true; + conv_params.linear_spatial = true; work_group_size_ = int3(16, 1, 1); - conv_params.work_group_launch_order = int3(0, 1, 2); + work_group_launch_order_ = int3(0, 1, 2); conv_params.fixed_work_group_size = true; } - conv_params.block_size = int3(1, 1, 4); + conv_params.block_size = int4(1, 1, 1, 4); conv_params.src_depth_loop_size = 1; + int sub_group_size = 16; + const bool supports_subgroups = + device_info.SupportsExtension("cl_khr_subgroups") || + device_info.SupportsExtension("cl_intel_subgroups"); if (definition.precision != CalculationsPrecision::F32_F16 && - device_info.SupportsExtension("cl_khr_subgroups") && + supports_subgroups && device_info.SupportsExtension("cl_intel_required_subgroup_size") && - device_info.IsCL20OrHigher() && - device_info.SupportsSubGroupWithSize(16)) { + device_info.SupportsSubGroupWithSize(sub_group_size)) { conv_params.weights_upload_type = - WeightsUploadType::PRIVATE_MEM_SIMD16_BROADCAST; + WeightsUploadType::PRIVATE_MEM_SIMD_BROADCAST; + conv_params.simd_size = sub_group_size; } else { conv_params.weights_upload_type = WeightsUploadType::LOCAL_MEM_BY_THREADS; } if (dst_depth % 4 == 0 || dst_depth >= 8) { - conv_params.block_size.z = 4; + conv_params.block_size.w = 4; } else if (dst_depth % 2 == 0 || dst_depth >= 4) { - conv_params.block_size.z = 2; + conv_params.block_size.w = 2; } else { - conv_params.block_size.z = dst_depth; + conv_params.block_size.w = dst_depth; } if (src_depth % 2 == 0) { conv_params.src_depth_loop_size = 2; } - if (src_depth % 4 == 0 && conv_params.block_size.z <= 2) { + if (src_depth % 4 == 0 && conv_params.block_size.w <= 2) { conv_params.src_depth_loop_size = 4; } } else { - conv_params.block_size = int3(1, 1, 4); + conv_params.block_size = int4(1, 1, 1, 4); work_group_size_ = int3(8, 2, 1); - conv_params.work_group_launch_order = int3(0, 1, 2); + work_group_launch_order_ = int3(0, 1, 2); conv_params.fixed_work_group_size = false; conv_params.src_depth_loop_size = 1; conv_params.weights_upload_type = WeightsUploadType::GLOBAL_MEM; if (dst_depth % 4 == 0 || dst_depth >= 8) { - conv_params.block_size.z = 4; + conv_params.block_size.w = 4; } else if (dst_depth % 2 == 0 || dst_depth >= 4) { - conv_params.block_size.z = 2; + conv_params.block_size.w = 2; } else { - conv_params.block_size.z = dst_depth; + conv_params.block_size.w = dst_depth; } if (src_depth % 2 == 0) { conv_params.src_depth_loop_size = 2; } - if (src_depth % 4 == 0 && conv_params.block_size.z <= 2) { + if (src_depth % 4 == 0 && conv_params.block_size.w <= 2) { conv_params.src_depth_loop_size = 4; } } @@ -944,6 +1289,41 @@ ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams( x_kernel_is_1, y_kernel_is_1, false, dst_shape); } +ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams( + const DeviceInfo& device_info, const OperationDef& definition, + const Convolution3DAttributes& attr, const BHWDC* dst_shape) { + const int dst_depth = DivideRoundUp(attr.weights.shape.o, 4); + const int src_depth = DivideRoundUp(attr.weights.shape.i, 4); + const bool x_kernel_is_1 = attr.weights.shape.w == 1 && attr.strides.w == 1 && + attr.dilations.w == 1 && + attr.padding.prepended.w == 0 && + attr.padding.appended.w == 0; + const bool y_kernel_is_1 = attr.weights.shape.h == 1 && attr.strides.h == 1 && + attr.dilations.h == 1 && + attr.padding.prepended.h == 0 && + attr.padding.appended.h == 0; + const bool z_kernel_is_1 = attr.weights.shape.d == 1 && attr.strides.d == 1 && + attr.dilations.d == 1 && + attr.padding.prepended.d == 0 && + attr.padding.appended.d == 0; + + ConvPowerVR::ConvParams result; + BHWC shape; + if (dst_shape) { + shape.b = dst_shape->b; + shape.h = dst_shape->h * dst_shape->d; + shape.w = dst_shape->w; + shape.c = dst_shape->c; + result = GuessBestParams(device_info, definition, src_depth, dst_depth, + x_kernel_is_1, y_kernel_is_1, false, &shape); + } else { + result = GuessBestParams(device_info, definition, src_depth, dst_depth, + x_kernel_is_1, y_kernel_is_1, false, nullptr); + } + result.z_kernel_is_1 = z_kernel_is_1; + return result; +} + ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams( const DeviceInfo& device_info, const OperationDef& definition, const Convolution2DAttributes& attr, const BHWC& weights_shape, @@ -1031,6 +1411,17 @@ ConvPowerVR CreateConvPowerVRWino4x4To6x6(const DeviceInfo& device_info, return result; } +ConvPowerVR CreateConvPowerVR3D(const DeviceInfo& device_info, + const OperationDef& definition, + const Convolution3DAttributes& attr, + const BHWDC* dst_shape) { + ConvPowerVR result(definition, attr, device_info, dst_shape); + result.GenerateCode(device_info); + result.UploadWeights(attr.weights); + result.UploadBias(attr.bias); + return result; +} + } // namespace cl } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.h b/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.h index bceb25044f7..30e412cd923 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_CONV_POWERVR_H_ #define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_CONV_POWERVR_H_ +#include #include #include "tensorflow/lite/delegates/gpu/cl/buffer.h" @@ -25,6 +26,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/cl/kernels/util.h" #include "tensorflow/lite/delegates/gpu/cl/linear_storage.h" #include "tensorflow/lite/delegates/gpu/cl/tensor.h" +#include "tensorflow/lite/delegates/gpu/cl/texture2d.h" #include "tensorflow/lite/delegates/gpu/cl/util.h" #include "tensorflow/lite/delegates/gpu/common/data_type.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" @@ -45,13 +47,13 @@ class ConvPowerVR : public GPUOperation { TuningType tuning_type, const DeviceInfo& device_info, const KernelInfo& kernel_info, std::vector* work_groups) const override; - absl::Status BindArguments() override; + absl::Status BindArguments(ArgumentsBinder* args) override; int3 GetGridSize() const override; ConvWeightsDescription GetConvWeightsDescription() const { ConvWeightsDescription desc; desc.layout = ConvWeightsLayout::kOHWIOGroupI4O4; - desc.output_group_size = conv_params_.block_size.z; + desc.output_group_size = conv_params_.block_size.w; return desc; } @@ -67,11 +69,8 @@ class ConvPowerVR : public GPUOperation { LOCAL_MEM_BY_THREADS, GLOBAL_MEM, CONSTANT_MEM, - PRIVATE_MEM_SIMD8_BROADCAST, - PRIVATE_MEM_SIMD16_BROADCAST, - PRIVATE_MEM_SIMD32_BROADCAST, - PRIVATE_MEM_SIMD64_BROADCAST, - PRIVATE_MEM_SIMD128_BROADCAST, + PRIVATE_MEM_SIMD_BROADCAST, + TEXTURES_MEM_X4, // 4 textures for weights }; struct ConvParams { @@ -83,47 +82,26 @@ class ConvPowerVR : public GPUOperation { // weights, so for PowerVR in this kernel we have F32 weights for // F32_F16 precision mode DataType weights_data_type; // used for weights and biases - int3 block_size; - int3 work_group_launch_order; + int4 block_size; // WHDS bool fixed_work_group_size; - bool linear_hw; + bool linear_spatial; // spatial dimensions are Width/Height/Depth bool different_weights_for_height; int src_depth_loop_size; WeightsUploadType weights_upload_type; bool x_kernel_is_1; bool y_kernel_is_1; + bool z_kernel_is_1; + + // used only with PRIVATE_MEM_SIMD_BROADCAST + int simd_size = 1; + + bool AreWeightsBuffer() const { + return weights_upload_type != WeightsUploadType::TEXTURES_MEM_X4; + } bool IsPrivateMemBroadcast() const { return weights_upload_type == - WeightsUploadType::PRIVATE_MEM_SIMD8_BROADCAST || - weights_upload_type == - WeightsUploadType::PRIVATE_MEM_SIMD16_BROADCAST || - weights_upload_type == - WeightsUploadType::PRIVATE_MEM_SIMD32_BROADCAST || - weights_upload_type == - WeightsUploadType::PRIVATE_MEM_SIMD64_BROADCAST || - weights_upload_type == - WeightsUploadType::PRIVATE_MEM_SIMD128_BROADCAST; - } - - int GetSimdSize() const { - if (weights_upload_type == - WeightsUploadType::PRIVATE_MEM_SIMD8_BROADCAST) { - return 8; - } else if (weights_upload_type == - WeightsUploadType::PRIVATE_MEM_SIMD16_BROADCAST) { - return 16; - } else if (weights_upload_type == - WeightsUploadType::PRIVATE_MEM_SIMD32_BROADCAST) { - return 32; - } else if (weights_upload_type == - WeightsUploadType::PRIVATE_MEM_SIMD64_BROADCAST) { - return 64; - } else if (weights_upload_type == - WeightsUploadType::PRIVATE_MEM_SIMD128_BROADCAST) { - return 128; - } - return 1; + WeightsUploadType::PRIVATE_MEM_SIMD_BROADCAST; } }; @@ -137,6 +115,9 @@ class ConvPowerVR : public GPUOperation { const FullyConnectedAttributes& attr, const DeviceInfo& device_info, const BHWC* dst_shape = nullptr); explicit ConvPowerVR(const OperationDef& definition); + ConvPowerVR(const OperationDef& definition, + const Convolution3DAttributes& attr, + const DeviceInfo& device_info, const BHWDC* dst_shape = nullptr); void GenerateCode(const DeviceInfo& device_info); @@ -150,6 +131,9 @@ class ConvPowerVR : public GPUOperation { template void UploadWeights(const tflite::gpu::Tensor& weights); + template + void UploadWeights(const tflite::gpu::Tensor& weights); + template void UploadBias(const tflite::gpu::Tensor& bias); @@ -172,6 +156,11 @@ class ConvPowerVR : public GPUOperation { const DeviceInfo& device_info, const OperationDef& definition, const Convolution2DAttributes& attr, const BHWC* dst_shape); + friend ConvPowerVR CreateConvPowerVR3D(const DeviceInfo& device_info, + const OperationDef& definition, + const Convolution3DAttributes& attr, + const BHWDC* dst_shape); + ConvParams GuessBestParams(const DeviceInfo& device_info, const OperationDef& definition, const Convolution2DAttributes& attr, @@ -189,6 +178,10 @@ class ConvPowerVR : public GPUOperation { const OperationDef& definition, const Convolution2DAttributes& attr, const BHWC* dst_shape = nullptr); + ConvParams GuessBestParams(const DeviceInfo& device_info, + const OperationDef& definition, + const Convolution3DAttributes& attr, + const BHWDC* dst_shape = nullptr); ConvParams GuessBestParams(const DeviceInfo& device_info, const OperationDef& definition, int src_depth, int dst_depth, bool x_kernel_is_1, @@ -200,8 +193,10 @@ class ConvPowerVR : public GPUOperation { const OperationDef& op_def, bool stride_correction, const ConvParams& conv_params); - int4 stride_padding_; - int4 kernel_dilation_; + int4 stride_; + int4 padding_; + int4 kernel_size_; + int4 dilation_; ConvParams conv_params_; }; @@ -236,7 +231,7 @@ void ConvPowerVR::UploadBias(const tflite::gpu::Tensor& bias) { const int float_size = conv_params_.weights_data_type == DataType::FLOAT32 ? sizeof(float) : sizeof(half); - int aligned_channels = AlignByN(bias.shape.v, 4 * conv_params_.block_size.z); + int aligned_channels = AlignByN(bias.shape.v, 4 * conv_params_.block_size.w); desc.size = float_size * aligned_channels; desc.data.resize(desc.size); if (conv_params_.weights_data_type == DataType::FLOAT32) { @@ -256,37 +251,125 @@ void ConvPowerVR::UploadBias(const tflite::gpu::Tensor& bias) { template void ConvPowerVR::UploadWeights(const tflite::gpu::Tensor& weights) { - const int dst_depth = DivideRoundUp(weights.shape.o, 4); - const int src_depth = DivideRoundUp(weights.shape.i, 4); + const int dst_slices = + AlignByN(DivideRoundUp(weights.shape.o, 4), conv_params_.block_size.w); + const int src_slices = DivideRoundUp(weights.shape.i, 4); const bool f32_weights = conv_params_.weights_data_type == DataType::FLOAT32; const int float4_size = f32_weights ? sizeof(float4) : sizeof(half4); - const int dst_depth_aligned = AlignByN(dst_depth, conv_params_.block_size.z); const int elements_count = - weights.shape.h * weights.shape.w * src_depth * dst_depth_aligned * 4; + weights.shape.h * weights.shape.w * src_slices * dst_slices * 4; - BufferDescriptor desc; - desc.element_type = conv_params_.weights_data_type; - desc.element_size = 4; - desc.memory_type = conv_params_.weights_upload_type == - ConvPowerVR::WeightsUploadType::CONSTANT_MEM - ? MemoryType::CONSTANT - : MemoryType::GLOBAL; - desc.size = float4_size * elements_count; - desc.data.resize(desc.size); + std::vector data(float4_size * elements_count); if (f32_weights) { - float4* ptr = reinterpret_cast(desc.data.data()); - RearrangeWeightsToOHWIOGroupI4O4(weights, conv_params_.block_size.z, - absl::MakeSpan(ptr, elements_count)); + float4* ptr = reinterpret_cast(data.data()); + if (conv_params_.AreWeightsBuffer()) { + RearrangeWeightsToOHWIOGroupI4O4(weights, conv_params_.block_size.w, + absl::MakeSpan(ptr, elements_count)); + } else { + RearrangeWeightsToI4HWIOOGroupO4(weights, conv_params_.block_size.w, + absl::MakeSpan(ptr, elements_count)); + } } else { - half4* ptr = reinterpret_cast(desc.data.data()); - RearrangeWeightsToOHWIOGroupI4O4(weights, conv_params_.block_size.z, - absl::MakeSpan(ptr, elements_count)); + half4* ptr = reinterpret_cast(data.data()); + if (conv_params_.AreWeightsBuffer()) { + RearrangeWeightsToOHWIOGroupI4O4(weights, conv_params_.block_size.w, + absl::MakeSpan(ptr, elements_count)); + } else { + RearrangeWeightsToI4HWIOOGroupO4(weights, conv_params_.block_size.w, + absl::MakeSpan(ptr, elements_count)); + } + } + if (conv_params_.AreWeightsBuffer()) { + BufferDescriptor desc; + desc.element_type = conv_params_.weights_data_type; + desc.element_size = 4; + desc.memory_type = conv_params_.weights_upload_type == + ConvPowerVR::WeightsUploadType::CONSTANT_MEM + ? MemoryType::CONSTANT + : MemoryType::GLOBAL; + desc.size = float4_size * elements_count; + desc.data = std::move(data); + args_.AddObject("weights", + absl::make_unique(std::move(desc))); + } else { + const int texture_width = dst_slices; + const int texture_height = src_slices * weights.shape.h * weights.shape.w; + const int sub_size = float4_size * texture_width * texture_height; + for (int i = 0; i < 4; ++i) { + Texture2DDescriptor desc; + desc.element_type = conv_params_.weights_data_type; + desc.size = int2(texture_width, texture_height); + desc.data.resize(sub_size); + std::memcpy(desc.data.data(), data.data() + sub_size * i, sub_size); + const std::string name = "weights" + std::to_string(i); + args_.AddObject(name, + absl::make_unique(std::move(desc))); + } + } +} + +template +void ConvPowerVR::UploadWeights(const tflite::gpu::Tensor& weights) { + const int block_size = conv_params_.block_size.w; + const int dst_slices = + AlignByN(DivideRoundUp(weights.shape.o, 4), block_size); + const int src_slices = DivideRoundUp(weights.shape.i, 4); + + const int elements_count = weights.shape.d * weights.shape.h * + weights.shape.w * src_slices * dst_slices * 4; + const bool f32_weights = definition_.precision == CalculationsPrecision::F32; + + const int float4_size = f32_weights ? 16 : 8; + + std::vector data(float4_size * elements_count); + + if (f32_weights) { + float4* ptr = reinterpret_cast(data.data()); + if (conv_params_.AreWeightsBuffer()) { + RearrangeWeightsToODHWIOGroupI4O4(weights, conv_params_.block_size.w, + absl::MakeSpan(ptr, elements_count)); + } else { + RearrangeWeightsToI4DHWIOOGroupO4(weights, conv_params_.block_size.w, + absl::MakeSpan(ptr, elements_count)); + } + } else { + half4* ptr = reinterpret_cast(data.data()); + if (conv_params_.AreWeightsBuffer()) { + RearrangeWeightsToODHWIOGroupI4O4(weights, conv_params_.block_size.w, + absl::MakeSpan(ptr, elements_count)); + } else { + RearrangeWeightsToI4DHWIOOGroupO4(weights, conv_params_.block_size.w, + absl::MakeSpan(ptr, elements_count)); + } + } + + if (conv_params_.AreWeightsBuffer()) { + BufferDescriptor desc; + desc.element_type = f32_weights ? DataType::FLOAT32 : DataType::FLOAT16; + desc.element_size = 4; + desc.size = float4_size * elements_count; + desc.data = std::move(data); + args_.AddObject("weights", + absl::make_unique(std::move(desc))); + } else { + const int texture_width = dst_slices; + const int texture_height = + src_slices * weights.shape.d * weights.shape.h * weights.shape.w; + int sub_size = float4_size * texture_width * texture_height; + for (int i = 0; i < 4; ++i) { + Texture2DDescriptor desc; + desc.element_type = f32_weights ? DataType::FLOAT32 : DataType::FLOAT16; + desc.size = int2(texture_width, texture_height); + desc.data.resize(sub_size); + memcpy(desc.data.data(), data.data() + sub_size * i, sub_size); + const std::string name = "weights" + std::to_string(i); + args_.AddObject(name, + absl::make_unique(std::move(desc))); + } } - args_.AddObject("weights", - absl::make_unique(std::move(desc))); } ConvPowerVR CreateConvPowerVR(const DeviceInfo& device_info, @@ -310,6 +393,11 @@ ConvPowerVR CreateConvPowerVRWino4x4To6x6(const DeviceInfo& device_info, const Convolution2DAttributes& attr, const BHWC* dst_shape = nullptr); +ConvPowerVR CreateConvPowerVR3D(const DeviceInfo& device_info, + const OperationDef& definition, + const Convolution3DAttributes& attr, + const BHWDC* dst_shape = nullptr); + } // namespace cl } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/conv_texture.cc b/tensorflow/lite/delegates/gpu/cl/kernels/conv_texture.cc deleted file mode 100644 index bff328772d7..00000000000 --- a/tensorflow/lite/delegates/gpu/cl/kernels/conv_texture.cc +++ /dev/null @@ -1,461 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/delegates/gpu/cl/kernels/conv_texture.h" - -#include -#include -#include - -#include "absl/strings/substitute.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h" -#include "tensorflow/lite/delegates/gpu/cl/linear_storage.h" -#include "tensorflow/lite/delegates/gpu/cl/precision.h" -#include "tensorflow/lite/delegates/gpu/cl/tensor_type.h" - -namespace tflite { -namespace gpu { -namespace cl { -namespace { -bool UseFP16SIMD(const DeviceInfo& device_info, CalculationsPrecision precision, - bool kernel1x1) { - if (!device_info.IsAdreno()) { - return false; - } - switch (precision) { - case CalculationsPrecision::F32: - case CalculationsPrecision::F32_F16: - return false; - case CalculationsPrecision::F16: - return device_info.IsAdreno3xx() && kernel1x1; - } -} -} // namespace - -ConvTexture::ConvTexture(const OperationDef& definition, - const Convolution2DAttributes& attr) - : GPUOperation(definition), - kernel_size_(attr.weights.shape.w, attr.weights.shape.h), - stride_(attr.strides.w, attr.strides.h), - padding_(-attr.padding.prepended.w, -attr.padding.prepended.h), - dilation_(attr.dilations.w, attr.dilations.h), - different_weights_for_height_(false), - block_size_(2, 2, 2) { - work_group_size_ = int3(4, 4, 2); -} - -ConvTexture::ConvTexture(const OperationDef& definition) - : GPUOperation(definition), - kernel_size_(1, 1), - stride_(1, 1), - padding_(0, 0), - dilation_(1, 1), - different_weights_for_height_(false), - block_size_(4, 1, 2) { - work_group_size_ = int3(16, 1, 2); -} - -ConvTexture::ConvTexture(ConvTexture&& operation) - : GPUOperation(std::move(operation)), - kernel_size_(operation.kernel_size_), - stride_(operation.stride_), - padding_(operation.padding_), - dilation_(operation.dilation_), - different_weights_for_height_(operation.different_weights_for_height_), - block_size_(operation.block_size_) {} - -ConvTexture& ConvTexture::operator=(ConvTexture&& operation) { - if (this != &operation) { - std::swap(kernel_size_, operation.kernel_size_); - std::swap(stride_, operation.stride_); - std::swap(padding_, operation.padding_); - std::swap(dilation_, operation.dilation_); - std::swap(different_weights_for_height_, - operation.different_weights_for_height_); - std::swap(block_size_, operation.block_size_); - GPUOperation::operator=(std::move(operation)); - } - return *this; -} - -std::string ConvTexture::GenerateConvCode(const OperationDef& op_def, - const int3& block_size, bool is1x1, - bool adreno4xx_optimization, - bool stride_correction, - bool different_weights_for_height) { - auto src_desc = op_def.src_tensors[0]; - src_desc.SetTextureAddressMode(TextureAddressMode::ZERO); - if (op_def.IsBatchSupported()) { - src_desc.SetStateVar("BatchedWidth", "true"); - } - AddSrcTensor("src_tensor", src_desc); - - auto dst_desc = op_def.dst_tensors[0]; - if (op_def.IsBatchSupported()) { - dst_desc.SetStateVar("BatchedWidth", "true"); - } - AddDstTensor("dst_tensor", dst_desc); - - if (!is1x1) { - args_.AddInt("kernel_size_x"); - args_.AddInt("kernel_size_y"); - args_.AddInt("dilation_x"); - args_.AddInt("dilation_y"); - } - args_.AddInt("stride_x"); - args_.AddInt("stride_y"); - args_.AddInt("padding_x"); - args_.AddInt("padding_y"); - - const auto src_tensor_type = op_def.src_tensors[0].storage_type; - const bool is_buffer = src_tensor_type == TensorStorageType::IMAGE_BUFFER || - src_tensor_type == TensorStorageType::BUFFER; - - std::vector xs(block_size.x); - for (int x = 0; x < block_size.x; ++x) { - xs[x] = std::to_string(x); - } - - std::vector ys(block_size.y); - for (int y = 0; y < block_size.y; ++y) { - ys[y] = std::to_string(y); - } - - std::vector zs(block_size.z); - for (int z = 0; z < block_size.z; ++z) { - zs[z] = std::to_string(z); - } - - std::string c = GetCommonDefines(op_def.precision); - for (int z = 0; z < block_size.z; ++z) { - const std::string f0 = std::to_string(z * 4 + 0); - const std::string f1 = std::to_string(z * 4 + 1); - const std::string f2 = std::to_string(z * 4 + 2); - const std::string f3 = std::to_string(z * 4 + 3); - switch (op_def.precision) { - case CalculationsPrecision::F32: - case CalculationsPrecision::F16: - c += "#define CONV" + zs[z] + "(R, S) \\\n"; - c += "R += S.x * f" + f0 + "; \\\n"; - c += "R += S.y * f" + f1 + "; \\\n"; - c += "R += S.z * f" + f2 + "; \\\n"; - c += "R += S.w * f" + f3 + "; \n"; - break; - case CalculationsPrecision::F32_F16: - c += "#define CONV" + zs[z] + "(R, S) \\\n"; - c += "R += convert_float4(S.x * f" + f0 + " + S.y * f" + f1 + - " + S.z * f" + f2 + " + S.w * f" + f3 + ");\n"; - break; - } - } - - c += "__kernel void main_function(\n"; - c += "$0) {\n"; - c += " int X = get_global_id(0) * " + std::to_string(block_size.x) + ";\n"; - c += " int Y = get_global_id(1) * " + std::to_string(block_size.y) + ";\n"; - c += " int Z = get_global_id(2) * " + std::to_string(block_size.z) + ";\n"; - c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() " - "|| Z >= args.dst_tensor.Slices()) return;\n"; - std::vector s_x(block_size.x); - std::vector s_y(block_size.y); - for (int x = 0; x < block_size.x; ++x) { - if (stride_correction) { - c += " int xc" + xs[x] + " = " + - GetXStrideCorrected("X + " + xs[x], "args.src_tensor.Batch()", - "args.stride_x", "args.padding_x") + - ";\n"; - } else { - c += " int xc" + xs[x] + " = (X +" + xs[x] + - ") * args.stride_x + args.padding_x;\n"; - } - s_x[x] = is1x1 ? "xc" + xs[x] : "cx" + xs[x]; - } - for (int y = 0; y < block_size.y; ++y) { - c += " int yc" + ys[y] + " = (Y +" + ys[y] + - ") * args.stride_y + args.padding_y;\n"; - s_y[y] = is1x1 ? "yc" + ys[y] : "cy" + ys[y]; - } - for (int i = 0; i < block_size.x * block_size.y * block_size.z; ++i) { - c += " ACCUM_FLT4 r" + std::to_string(i) + - " = (ACCUM_FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n"; - } - std::string f_y = is1x1 ? "s" : "filter_offset"; - if (different_weights_for_height) { - f_y = "Y * args.src_tensor.Slices() + s"; - } - if (!is1x1) { - for (int x = 0; x < block_size.x; ++x) { - c += " int cx" + xs[x] + ";\n"; - } - for (int y = 0; y < block_size.y; ++y) { - c += " int cy" + ys[y] + ";\n"; - } - c += " int filter_offset = 0;\n"; - c += " for (int y = 0; y < args.kernel_size_y; ++y) {\n"; - for (int y = 0; y < block_size.y; ++y) { - c += " cy" + ys[y] + " = y * args.dilation_y + yc" + ys[y] + ";\n"; - } - if (is_buffer) { - for (int y = 0; y < block_size.y; ++y) { - c += " bool in_y" + ys[y] + " = cy" + ys[y] + " >= 0 && cy" + ys[y] + - " < args.src_tensor.Height();\n"; - if (src_tensor_type == TensorStorageType::BUFFER) { - c += " cy" + ys[y] + " = clamp(cy" + ys[y] + - ", 0, args.src_tensor.Height() - 1);\n"; - } - } - } - c += " for (int x = 0; x < args.kernel_size_x; ++x) {\n"; - for (int x = 0; x < block_size.x; ++x) { - c += " cx" + xs[x] + " = x * args.dilation_x + xc" + xs[x] + ";\n"; - } - if (is_buffer) { - for (int x = 0; x < block_size.x; ++x) { - c += " bool in_x" + xs[x] + " = cx" + xs[x] + " >= 0 && cx" + xs[x] + - " < args.src_tensor.Width();\n"; - if (src_tensor_type == TensorStorageType::BUFFER) { - c += " cx" + xs[x] + " = clamp(cx" + xs[x] + - ", 0, args.src_tensor.Width() - 1);\n"; - } - } - for (int x = 0; x < block_size.x; ++x) { - for (int y = 0; y < block_size.y; ++y) { - const std::string id = std::to_string(y * block_size.x + x); - if (src_tensor_type == TensorStorageType::IMAGE_BUFFER) { - c += absl::Substitute( - " int addr_$0 = select(-1, cy$2 * args.src_tensor.Width() + " - "cx$1, (in_x$1 " - "&& " - "in_y$2));\n", - y * block_size.x + x, x, y); - c += absl::Substitute( - " int dz_$0 = select(0, args.src_tensor.Width() * " - "args.src_tensor.Height(), (in_x$1 && " - "in_y$2));\n", - y * block_size.x + x, x, y); - } else { - c += absl::Substitute( - " int addr_$0 = cy$2 * args.src_tensor.Width() + cx$1;\n", - y * block_size.x + x, x, y); - } - } - } - if (src_tensor_type == TensorStorageType::BUFFER) { - c += " int dz = args.src_tensor.Width() * args.src_tensor.Height();\n"; - } - } - } else if (is_buffer) { - for (int y = 0; y < block_size.y; ++y) { - c += " bool in_y" + ys[y] + " = yc" + ys[y] + " >= 0 && yc" + ys[y] + - " < args.src_tensor.Height();\n"; - } - for (int x = 0; x < block_size.x; ++x) { - c += " bool in_x" + xs[x] + " = xc" + xs[x] + " >= 0 && xc" + xs[x] + - " < args.src_tensor.Width();\n"; - } - for (int x = 0; x < block_size.x; ++x) { - for (int y = 0; y < block_size.y; ++y) { - const std::string id = std::to_string(y * block_size.x + x); - if (src_tensor_type == TensorStorageType::IMAGE_BUFFER) { - c += absl::Substitute( - " int addr_$0 = select(-1, yc$2 * args.src_tensor.Width() + " - "xc$1, (in_x$1 && " - "in_y$2));\n", - y * block_size.x + x, x, y); - c += absl::Substitute( - " int dz_$0 = select(0, args.src_tensor.Width() * " - "args.src_tensor.Height(), (in_x$1 && " - "in_y$2));\n", - y * block_size.x + x, x, y); - } else { - c += absl::Substitute( - " int addr_$0 = yc$2 * args.src_tensor.Width() + xc$1;\n", - y * block_size.x + x, x, y); - } - } - } - if (src_tensor_type == TensorStorageType::BUFFER) { - c += " int dz = args.src_tensor.Width() * args.src_tensor.Height();\n"; - } - } - c += " for (int s = 0; s < args.src_tensor.Slices(); ++s) {\n"; - if (is_buffer) { - if (src_tensor_type == TensorStorageType::IMAGE_BUFFER) { - for (int index = 0; index < block_size.x * block_size.y; ++index) { - const std::string id = std::to_string(index); - c += - " FLT4 src" + id + " = args.src_tensor.Read(addr_" + id + ");\n"; - } - } else { - for (int x = 0; x < block_size.x; ++x) { - for (int y = 0; y < block_size.y; ++y) { - const std::string id = std::to_string(y * block_size.x + x); - c += " FLT4 src" + id + " = args.src_tensor.Read(addr_" + id + - ") * (FLT)(in_x" + xs[x] + " && in_y" + ys[y] + "); addr_" + id + - " += dz;\n"; - } - } - } - } - for (int z = 0; z < block_size.z; ++z) { - c += absl::Substitute(R"( FLT4 f$2 = args.weights0.Read($0, $1); - FLT4 f$3 = args.weights1.Read($0, $1); - FLT4 f$4 = args.weights2.Read($0, $1); - FLT4 f$5 = args.weights3.Read($0, $1); -)", - "Z + " + zs[z], f_y, z * 4 + 0, z * 4 + 1, z * 4 + 2, - z * 4 + 3); - } - if (!is_buffer) { - for (int x = 0; x < block_size.x; ++x) { - for (int y = 0; y < block_size.y; ++y) { - const std::string id = std::to_string(y * block_size.x + x); - c += " FLT4 src" + id + " = args.src_tensor.Read(" + s_x[x] + ", " + - s_y[y] + ", s);\n"; - } - } - } - for (int z = 0; z < block_size.z; ++z) { - for (int i = 0; i < block_size.x * block_size.y; ++i) { - c += " CONV" + zs[z] + "(r" + - std::to_string(i + z * block_size.x * block_size.y) + ", src" + - std::to_string(i) + ");\n"; - } - } - if (!is1x1) { - c += " filter_offset++;\n"; - } - if (is_buffer) { - if (src_tensor_type == TensorStorageType::IMAGE_BUFFER) { - for (int index = 0; index < block_size.x * block_size.y; ++index) { - const std::string id = std::to_string(index); - c += " addr_" + id + " += dz_" + id + ";\n"; - } - } - } - c += " }\n"; // args.src_tensor.Slices() - if (!is1x1) { - c += " }\n"; // kernel_size_x - c += " }\n"; // kernel_size_y - } - // when is1x1 && adreno4xx_optimization is true, xc0 == X and yc0 == Y - std::string dst_x = is1x1 && adreno4xx_optimization ? "xc0" : "X"; - std::string dst_y = is1x1 && adreno4xx_optimization ? "yc0" : "Y"; - for (int z = 0; z < block_size.z; ++z) { - c += " if (Z < args.dst_tensor.Slices()) {\n"; - c += " FLT4 bias_val = args.biases.Read(Z);\n"; - for (int y = 0; y < block_size.y; ++y) { - for (int x = 0; x < block_size.x; ++x) { - const std::string id = - std::to_string((z * block_size.y + y) * block_size.x + x); - c += " {\n"; - c += " int xc = " + dst_x + " + " + xs[x] + ";\n"; - c += " int yc = " + dst_y + " + " + ys[y] + ";\n"; - c += " if (xc < args.dst_tensor.Width() && yc < " - "args.dst_tensor.Height()) {\n"; - c += " FLT4 res = TO_FLT4(r" + id + ") + bias_val;\n"; - c += " args.dst_tensor.Write(res, xc, yc, Z);\n"; - c += " }\n"; - c += " }\n"; - } - } - c += " }\n"; - c += " Z++;\n"; - } - c += "}\n"; - return c; -} - -void ConvTexture::GenerateCode(const DeviceInfo& device_info) { - auto storage_type = definition_.GetPrimaryStorageType(); - bool is1x1 = kernel_size_.x == 1 && kernel_size_.y == 1; - bool adreno4xx_optimization = - stride_.x == 1 && stride_.y == 1 && padding_.x == 0 && padding_.y == 0 && - device_info.IsAdreno4xx() && - storage_type == TensorStorageType::TEXTURE_ARRAY && - definition_.precision == CalculationsPrecision::F16; - const bool stride_correction = - definition_.IsBatchSupported() && stride_.x != 1; - code_ = - GenerateConvCode(definition_, block_size_, is1x1, adreno4xx_optimization, - stride_correction, different_weights_for_height_); - - if (UseFP16SIMD(device_info, definition_.precision, is1x1)) { - compiler_options_.push_back(CompilerOptions::ADRENO_FULL_SIMD_LINE); - } -} - -absl::Status ConvTexture::BindArguments() { - if (!(kernel_size_.x == 1 && kernel_size_.y == 1)) { - RETURN_IF_ERROR(args_.SetInt("kernel_size_x", kernel_size_.x)); - RETURN_IF_ERROR(args_.SetInt("kernel_size_y", kernel_size_.y)); - RETURN_IF_ERROR(args_.SetInt("dilation_x", dilation_.x * src_[0]->Batch())); - RETURN_IF_ERROR(args_.SetInt("dilation_y", dilation_.y)); - } - RETURN_IF_ERROR(args_.SetInt("stride_x", stride_.x)); - RETURN_IF_ERROR(args_.SetInt("stride_y", stride_.y)); - RETURN_IF_ERROR(args_.SetInt("padding_x", padding_.x * src_[0]->Batch())); - RETURN_IF_ERROR(args_.SetInt("padding_y", padding_.y)); - return absl::OkStatus(); -} - -int3 ConvTexture::GetGridSize() const { - const int grid_x = - DivideRoundUp(dst_[0]->Width() * dst_[0]->Batch(), block_size_.x); - const int grid_y = DivideRoundUp(dst_[0]->Height(), block_size_.y); - const int grid_z = DivideRoundUp(dst_[0]->Slices(), block_size_.z); - return int3(grid_x, grid_y, grid_z); -} - -void ConvTexture::GetPossibleKernelWorkGroups( - TuningType tuning_type, const DeviceInfo& device_info, - const KernelInfo& kernel_info, std::vector* work_groups) const { - GetPossibleWorkGroupsConv(tuning_type, device_info, kernel_info, grid_size_, - work_groups); -} - -ConvTexture CreateConvTexture(const DeviceInfo& device_info, - const OperationDef& definition, - const Convolution2DAttributes& attr) { - ConvTexture result(definition, attr); - result.GenerateCode(device_info); - result.UploadData(attr.weights, attr.bias); - return result; -} - -ConvTexture CreateConvTexture(const DeviceInfo& device_info, - const OperationDef& definition, - const FullyConnectedAttributes& attr) { - ConvTexture result(definition); - result.GenerateCode(device_info); - result.UploadData(attr.weights, attr.bias); - return result; -} - -ConvTexture CreateConvTextureWino4x4To6x6(const DeviceInfo& device_info, - const OperationDef& definition, - const Convolution2DAttributes& attr) { - ConvTexture result(definition); - result.different_weights_for_height_ = true; - result.block_size_ = {4, 1, 2}; - result.GenerateCode(device_info); - result.UploadDataForWinograd4x4To6x6(attr.weights); - return result; -} - -} // namespace cl -} // namespace gpu -} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/conv_texture.h b/tensorflow/lite/delegates/gpu/cl/kernels/conv_texture.h deleted file mode 100644 index 3ebd43bf32b..00000000000 --- a/tensorflow/lite/delegates/gpu/cl/kernels/conv_texture.h +++ /dev/null @@ -1,269 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_CONV_TEXTURE_H_ -#define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_CONV_TEXTURE_H_ - -#include - -#include "tensorflow/lite/delegates/gpu/cl/cl_command_queue.h" -#include "tensorflow/lite/delegates/gpu/cl/cl_context.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h" -#include "tensorflow/lite/delegates/gpu/cl/linear_storage.h" -#include "tensorflow/lite/delegates/gpu/cl/tensor.h" -#include "tensorflow/lite/delegates/gpu/cl/texture2d.h" -#include "tensorflow/lite/delegates/gpu/cl/util.h" -#include "tensorflow/lite/delegates/gpu/common/data_type.h" -#include "tensorflow/lite/delegates/gpu/common/operations.h" -#include "tensorflow/lite/delegates/gpu/common/shape.h" -#include "tensorflow/lite/delegates/gpu/common/status.h" -#include "tensorflow/lite/delegates/gpu/common/tensor.h" -#include "tensorflow/lite/delegates/gpu/common/types.h" -#include "tensorflow/lite/delegates/gpu/common/winograd_util.h" - -namespace tflite { -namespace gpu { -namespace cl { - -// This convolution process BLOCK_SIZE(XxYxZ) of FLT4 values per thread. -class ConvTexture : public GPUOperation { - public: - ConvTexture() = default; - void GetPossibleKernelWorkGroups( - TuningType tuning_type, const DeviceInfo& device_info, - const KernelInfo& kernel_info, - std::vector* work_groups) const override; - absl::Status BindArguments() override; - int3 GetGridSize() const override; - - // Move only - ConvTexture(ConvTexture&& operation); - ConvTexture& operator=(ConvTexture&& operation); - ConvTexture(const ConvTexture&) = delete; - ConvTexture& operator=(const ConvTexture&) = delete; - - private: - friend ConvTexture CreateConvTexture(const DeviceInfo& device_info, - const OperationDef& definition, - const Convolution2DAttributes& attr); - friend ConvTexture CreateConvTexture(const DeviceInfo& device_info, - const OperationDef& definition, - const FullyConnectedAttributes& attr); - - friend ConvTexture CreateConvTextureWino4x4To6x6( - const DeviceInfo& device_info, const OperationDef& definition, - const Convolution2DAttributes& attr); - - ConvTexture(const OperationDef& definition, - const Convolution2DAttributes& attr); - explicit ConvTexture(const OperationDef& definition); - template - void UploadData(const tflite::gpu::Tensor& weights, - const tflite::gpu::Tensor& biases); - - template - void UploadDataForWinograd4x4To6x6( - const tflite::gpu::Tensor& weights); - - template - void UploadWeights(const tflite::gpu::Tensor& weights); - - template - void RearrangeWeightsData(const tflite::gpu::Tensor& weights, - absl::Span dst_0, absl::Span dst_1, - absl::Span dst_2, absl::Span dst_3); - - void GenerateCode(const DeviceInfo& device_info); - - std::string GenerateConvCode(const OperationDef& op_def, - const int3& block_size, bool is1x1, - bool adreno4xx_optimization, - bool stride_correction, - bool different_weights_for_height); - - int2 kernel_size_; - int2 stride_; - int2 padding_; - int2 dilation_; - - // By default in 2d convolution we have the same weights for WH dims, but in - // some cases we need separate weights for H dimension and convolution kernel - // requires very small modifications to support it. - bool different_weights_for_height_; - - int3 block_size_ = int3(2, 2, 2); -}; - -template -void ConvTexture::UploadData(const tflite::gpu::Tensor& weights, - const tflite::gpu::Tensor& biases) { - UploadWeights(weights); - - TensorLinearDescriptor desc; - desc.storage_type = LinearStorageType::TEXTURE_2D; - desc.element_type = definition_.GetDataType(); - desc.UploadLinearData(biases); - args_.AddObject("biases", - absl::make_unique(std::move(desc))); -} - -template -void ConvTexture::UploadDataForWinograd4x4To6x6( - const tflite::gpu::Tensor& weights) { - tflite::gpu::Tensor wino_weights; - RearrangeWeightsToWinograd4x4To6x6Weights(weights, &wino_weights); - UploadWeights(wino_weights); - - tflite::gpu::Tensor bias; - bias.shape = Linear(1); - bias.data = {0.0f}; - TensorLinearDescriptor desc; - desc.storage_type = LinearStorageType::TEXTURE_2D; - desc.element_type = definition_.GetDataType(); - desc.UploadLinearData(bias); - args_.AddObject("biases", - absl::make_unique(std::move(desc))); -} - -template -void ConvTexture::UploadWeights(const tflite::gpu::Tensor& weights) { - int dst_depth = DivideRoundUp(weights.shape.o, 4); - dst_depth = AlignByN(dst_depth, block_size_.z); - const int src_depth = DivideRoundUp(weights.shape.i, 4); - const int kernel_x = weights.shape.w; - const int kernel_y = weights.shape.h; - - int texture_width = dst_depth; - int texture_height = src_depth * kernel_x * kernel_y; - - const bool f32_weights = definition_.precision == CalculationsPrecision::F32; - DataType data_type = f32_weights ? DataType::FLOAT32 : DataType::FLOAT16; - - const int elements_count = texture_width * texture_height; - const int float4_size = f32_weights ? sizeof(float4) : sizeof(half4); - - Texture2DDescriptor desc0; - desc0.element_type = data_type; - desc0.size = int2(texture_width, texture_height); - desc0.data.resize(elements_count * float4_size); - - Texture2DDescriptor desc1; - desc1.element_type = data_type; - desc1.size = int2(texture_width, texture_height); - desc1.data.resize(elements_count * float4_size); - - Texture2DDescriptor desc2; - desc2.element_type = data_type; - desc2.size = int2(texture_width, texture_height); - desc2.data.resize(elements_count * float4_size); - - Texture2DDescriptor desc3; - desc3.element_type = data_type; - desc3.size = int2(texture_width, texture_height); - desc3.data.resize(elements_count * float4_size); - - if (f32_weights) { - float4* ptr0 = reinterpret_cast(desc0.data.data()); - float4* ptr1 = reinterpret_cast(desc1.data.data()); - float4* ptr2 = reinterpret_cast(desc2.data.data()); - float4* ptr3 = reinterpret_cast(desc3.data.data()); - RearrangeWeightsData(weights, absl::MakeSpan(ptr0, elements_count), - absl::MakeSpan(ptr1, elements_count), - absl::MakeSpan(ptr2, elements_count), - absl::MakeSpan(ptr3, elements_count)); - } else { - half4* ptr0 = reinterpret_cast(desc0.data.data()); - half4* ptr1 = reinterpret_cast(desc1.data.data()); - half4* ptr2 = reinterpret_cast(desc2.data.data()); - half4* ptr3 = reinterpret_cast(desc3.data.data()); - RearrangeWeightsData(weights, absl::MakeSpan(ptr0, elements_count), - absl::MakeSpan(ptr1, elements_count), - absl::MakeSpan(ptr2, elements_count), - absl::MakeSpan(ptr3, elements_count)); - } - - args_.AddObject("weights0", - absl::make_unique(std::move(desc0))); - args_.AddObject("weights1", - absl::make_unique(std::move(desc1))); - args_.AddObject("weights2", - absl::make_unique(std::move(desc2))); - args_.AddObject("weights3", - absl::make_unique(std::move(desc3))); -} - -template -void ConvTexture::RearrangeWeightsData( - const tflite::gpu::Tensor& weights, absl::Span dst_0, - absl::Span dst_1, absl::Span dst_2, absl::Span dst_3) { - int dst_depth = DivideRoundUp(weights.shape.o, 4); - dst_depth = AlignByN(dst_depth, block_size_.z); - const int src_depth = DivideRoundUp(weights.shape.i, 4); - const int kernel_x = weights.shape.w; - const int kernel_y = weights.shape.h; - - int texture_width = dst_depth; - - for (int d = 0; d < dst_depth / block_size_.z; ++d) { - for (int y = 0; y < kernel_y; ++y) { - for (int x = 0; x < kernel_x; ++x) { - for (int s = 0; s < src_depth; ++s) { - for (int sub_d = 0; sub_d < block_size_.z; ++sub_d) { - T filters[4]; - for (int i = 0; i < 4; ++i) { - for (int j = 0; j < 4; ++j) { - const int s_ch = s * 4 + j; - const int d_ch = (d * block_size_.z + sub_d) * 4 + i; - if (s_ch < weights.shape.i && d_ch < weights.shape.o) { - const int f_index = - weights.shape.LinearIndex({d_ch, y, x, s_ch}); - filters[j][i] = weights.data[f_index]; - } else { - filters[j][i] = 0.0f; - } - } - } - int x_coord = d * block_size_.z + sub_d; - int y_coord = (y * kernel_x + x) * src_depth + s; - int offset = y_coord * texture_width + x_coord; - dst_0[offset] = filters[0]; - dst_1[offset] = filters[1]; - dst_2[offset] = filters[2]; - dst_3[offset] = filters[3]; - } - } - } - } - } -} - -ConvTexture CreateConvTexture(const DeviceInfo& device_info, - const OperationDef& definition, - const Convolution2DAttributes& attr); - -ConvTexture CreateConvTexture(const DeviceInfo& device_info, - const OperationDef& definition, - const FullyConnectedAttributes& attr); - -ConvTexture CreateConvTextureWino4x4To6x6(const DeviceInfo& device_info, - const OperationDef& definition, - const Convolution2DAttributes& attr); - -} // namespace cl -} // namespace gpu -} // namespace tflite - -#endif // TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_CONV_TEXTURE_H_ diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/conv_texture_test.cc b/tensorflow/lite/delegates/gpu/cl/kernels/conv_texture_test.cc deleted file mode 100644 index 2a92573b689..00000000000 --- a/tensorflow/lite/delegates/gpu/cl/kernels/conv_texture_test.cc +++ /dev/null @@ -1,107 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/delegates/gpu/cl/kernels/conv_texture.h" - -#include - -#include -#include -#include "tensorflow/lite/delegates/gpu/cl/kernels/cl_test.h" -#include "tensorflow/lite/delegates/gpu/common/operations.h" -#include "tensorflow/lite/delegates/gpu/common/status.h" - -using ::testing::FloatNear; -using ::testing::Pointwise; - -namespace tflite { -namespace gpu { -namespace cl { -namespace { - -TEST_F(OpenCLOperationTest, ConvTextureSimpleWeights) { - TensorFloat32 src_tensor; - src_tensor.shape = BHWC(1, 2, 2, 2); - src_tensor.data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}; - - Convolution2DAttributes attr; - attr.padding.prepended = HW(0, 0); - attr.padding.appended = HW(1, 1); - attr.strides = HW(1, 1); - attr.dilations = HW(1, 1); - attr.weights.shape = OHWI(1, 2, 2, 2); - attr.weights.data = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; - attr.bias.shape = Linear(1); - attr.bias.data = {0.0f}; - - for (auto storage : env_.GetSupportedStorages()) { - for (auto precision : env_.GetSupportedPrecisions()) { - const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f; - OperationDef op_def; - op_def.precision = precision; - auto data_type = DeduceDataTypeFromPrecision(precision); - op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); - op_def.dst_tensors.push_back({data_type, storage, Layout::HWC}); - TensorFloat32 dst_tensor; - ConvTexture operation = - CreateConvTexture(creation_context_.GetDeviceInfo(), op_def, attr); - ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation, - BHWC(1, 2, 2, 1), &dst_tensor)); - EXPECT_THAT(dst_tensor.data, - Pointwise(FloatNear(eps), {28.0f, 18.0f, 22.0f, 13.0f})); - } - } -} - -TEST_F(OpenCLOperationTest, ConvTexture) { - TensorFloat32 src_tensor; - src_tensor.shape = BHWC(1, 2, 2, 2); - src_tensor.data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}; - - Convolution2DAttributes attr; - attr.padding.prepended = HW(0, 0); - attr.padding.appended = HW(1, 1); - attr.strides = HW(1, 1); - attr.dilations = HW(1, 1); - attr.weights.shape = OHWI(2, 2, 2, 2); - attr.weights.data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, - 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f}; - attr.bias.shape = Linear(2); - attr.bias.data = {0.5f, -0.5f}; - - for (auto storage : env_.GetSupportedStorages()) { - for (auto precision : env_.GetSupportedPrecisions()) { - const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f; - OperationDef op_def; - op_def.precision = precision; - auto data_type = DeduceDataTypeFromPrecision(precision); - op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); - op_def.dst_tensors.push_back({data_type, storage, Layout::HWC}); - TensorFloat32 dst_tensor; - ConvTexture operation = - CreateConvTexture(creation_context_.GetDeviceInfo(), op_def, attr); - ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation, - BHWC(1, 2, 2, 2), &dst_tensor)); - EXPECT_THAT(dst_tensor.data, - Pointwise(FloatNear(eps), {168.5f, 391.5f, 80.5f, 223.5f, - 60.5f, 235.5f, 20.5f, 123.5f})); - } - } -} - -} // namespace -} // namespace cl -} // namespace gpu -} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/conv_weights_converter.cc b/tensorflow/lite/delegates/gpu/cl/kernels/conv_weights_converter.cc index d6e17ce2a86..521cbefd885 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/conv_weights_converter.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/conv_weights_converter.cc @@ -110,12 +110,12 @@ std::string ConverterToConvWeights::GetConverterToConvWeightsCode( return c; } -absl::Status ConverterToConvWeights::BindArguments() { +absl::Status ConverterToConvWeights::BindArguments(ArgumentsBinder* args) { float4 mask = GetMaskForLastPlane(src_[0]->Channels()); - RETURN_IF_ERROR(args_.SetFloat("mask_x", mask.x)); - RETURN_IF_ERROR(args_.SetFloat("mask_y", mask.y)); - RETURN_IF_ERROR(args_.SetFloat("mask_z", mask.z)); - return args_.SetFloat("mask_w", mask.w); + RETURN_IF_ERROR(args->SetFloat("mask_x", mask.x)); + RETURN_IF_ERROR(args->SetFloat("mask_y", mask.y)); + RETURN_IF_ERROR(args->SetFloat("mask_z", mask.z)); + return args->SetFloat("mask_w", mask.w); } int3 ConverterToConvWeights::GetGridSize() const { diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/conv_weights_converter.h b/tensorflow/lite/delegates/gpu/cl/kernels/conv_weights_converter.h index fe814d296fa..3c7314ea6c9 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/conv_weights_converter.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/conv_weights_converter.h @@ -31,7 +31,7 @@ class ConverterToConvWeights : public GPUOperation { public: ConverterToConvWeights(const OperationDef& definition, const ConvWeightsDescription& conv_weights_desc); - absl::Status BindArguments() override; + absl::Status BindArguments(ArgumentsBinder* args) override; int3 GetGridSize() const override; // Move only diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/converter.cc b/tensorflow/lite/delegates/gpu/cl/kernels/converter.cc index d52efb43a08..77ac946637d 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/converter.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/converter.cc @@ -46,9 +46,11 @@ class OpenClConverterImpl : public TensorObjectConverter { RETURN_IF_ERROR(kernel_.SetMemoryAuto(buffer_mem)); RETURN_IF_ERROR(args_.SetObjectRef("tensor", tensor)); RETURN_IF_ERROR(args_.Bind(kernel_.kernel(), kernel_.GetBindingCounter())); - int3 grid = int3(tensor->Width() * tensor->Batch(), tensor->Height(), - tensor->Slices()); - return queue_->DispatchImplicit(kernel_, grid, {16, 8, 1}); + const int3 grid = int3(tensor->Width() * tensor->Batch(), tensor->Height(), + tensor->Slices()); + const int3 work_group_size = {16, 8, 1}; + const int3 work_groups_count = GetWorkGroupsCount(grid, work_group_size); + return queue_->Dispatch(kernel_, work_groups_count, work_group_size); } Arguments args_; @@ -63,47 +65,168 @@ bool IsSupportedDataType(DataType type) { return type == DataType::FLOAT16 || type == DataType::FLOAT32; } -// Implements conversion from OpenCL-specific tensor layout to BHWC. -class FromTensorConverter : public OpenClConverterImpl { +bool IsBHWCOpenCLBuffer(const ObjectDef& def) { + return IsSupportedDataType(def.data_type) && + def.object_type == ObjectType::OPENCL_BUFFER && + def.data_layout == DataLayout::BHWC; +} + +bool IsOpenCLTensor(const ObjectDef& def) { + const bool is_buffer_tensor = def.object_type == ObjectType::OPENCL_BUFFER && + def.data_layout == DataLayout::DHWC4; + const bool is_image2d_tensor = + def.object_type == ObjectType::OPENCL_TEXTURE && + def.data_layout == DataLayout::HDWC4; + const bool is_image2d_array_tensor = + def.object_type == ObjectType::OPENCL_TEXTURE && + def.data_layout == DataLayout::DHWC4; + const bool is_single_image_tensor = + def.object_type == ObjectType::OPENCL_TEXTURE && + def.data_layout == DataLayout::BHWC; + return IsSupportedDataType(def.data_type) && + (is_buffer_tensor || is_image2d_tensor || is_image2d_array_tensor || + is_single_image_tensor); +} + +absl::Status GetOpenCLMemory(const TensorObject& obj, cl_mem* memory) { + auto texture = absl::get_if(&obj); + auto buffer = absl::get_if(&obj); + if (texture && texture->memobj) { + *memory = texture->memobj; + } else if (buffer && buffer->memobj) { + *memory = buffer->memobj; + } else { + return absl::InvalidArgumentError("Missing OpenCL object."); + } + return absl::OkStatus(); +} + +// Implements conversion from OpenCL tensor to another OpenCL tensor. +class TensorToTensorConverter : public OpenClConverterImpl { public: static bool IsSupported(const ObjectDef& input, const ObjectDef& output) { - return IsSupportedDataType(input.data_type) && - IsSupportedDataType(output.data_type) && - // Output is always Buffer/(BHWC|DHWC4) - output.object_type == ObjectType::OPENCL_BUFFER && - (output.data_layout == DataLayout::BHWC || - output.data_layout == DataLayout::DHWC4) && - // Texture2D/HDWC4 -> - ((input.object_type == ObjectType::OPENCL_TEXTURE && - input.data_layout == DataLayout::HDWC4) || - // SingleTextureArray/BHWC -> - (input.object_type == ObjectType::OPENCL_TEXTURE && - input.data_layout == DataLayout::BHWC) || - // TextureArray/DHWC4 -> - (input.object_type == ObjectType::OPENCL_TEXTURE && - input.data_layout == DataLayout::DHWC4) || - // Buffer/DHWC4 -> - (input.object_type == ObjectType::OPENCL_BUFFER && - input.data_layout == DataLayout::DHWC4)); + return IsOpenCLTensor(input) && IsOpenCLTensor(output); } - std::pair GetToDhwc4Kernel( - const TensorObjectDef& input_def, - const TensorObjectDef& output_def) const { - return std::make_pair("__global " + - ToCLDataType(output_def.object_def.data_type, 4) + - "* dst", - "dst[(d * args.tensor.Height() + y) * " - "args.tensor.Width() + x] = input;"); + absl::Status Init(const TensorObjectDef& input_def, + const TensorObjectDef& output_def, + Environment* environment) final { + src_tensor_descriptor_.layout = Layout::BHWC; + src_tensor_descriptor_.storage_type = ToTensorStorageType( + input_def.object_def.object_type, input_def.object_def.data_layout); + src_tensor_descriptor_.data_type = input_def.object_def.data_type; + args_.AddObjectRef( + "src_tensor", AccessType::READ, + absl::make_unique(src_tensor_descriptor_)); + + dst_tensor_descriptor_.layout = Layout::BHWC; + dst_tensor_descriptor_.storage_type = ToTensorStorageType( + output_def.object_def.object_type, output_def.object_def.data_layout); + dst_tensor_descriptor_.data_type = output_def.object_def.data_type; + args_.AddObjectRef( + "dst_tensor", AccessType::WRITE, + absl::make_unique(dst_tensor_descriptor_)); + + const bool need_fp16_support = + input_def.object_def.data_type == DataType::FLOAT16 || + output_def.object_def.data_type == DataType::FLOAT16; + const std::string out_data_type = + ToCLDataType(output_def.object_def.data_type); + std::string shader_src; + if (need_fp16_support) { + shader_src += "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"; + } + shader_src += + R"(__kernel void tensor_to_tensor($0) { + int linear_id = get_global_id(0); + int x = linear_id / args.dst_tensor.Batch(); + int b = linear_id % args.dst_tensor.Batch(); + int y = get_global_id(1); + int d = get_global_id(2); + if (x >= args.dst_tensor.Width() || y >= args.dst_tensor.Height() || d >= args.dst_tensor.Slices()) return; +)"; + shader_src += " " + out_data_type + "4 input = args.src_tensor.Read<" + + out_data_type + ">(x, y, d, b);\n"; + shader_src += " args.dst_tensor.Write(input, x, y, d, b);\n}"; + queue_ = environment->queue(); + context_ = &environment->context(); + shape_ = BHWC(input_def.dimensions.b, input_def.dimensions.h, + input_def.dimensions.w, input_def.dimensions.c); + RETURN_IF_ERROR( + args_.TransformToCLCode(environment->device().info_, {}, &shader_src)); + return environment->program_cache()->GetOrCreateCLKernel( + shader_src, "tensor_to_tensor", environment->context(), + environment->device(), &kernel_); } - std::pair GetToBhwcKernel( - const TensorObjectDef& input_def, - const TensorObjectDef& output_def) const { - return std::make_pair( - "__global " + ToCLDataType(output_def.object_def.data_type) + "* dst", - R"( - int c = d * 4; + absl::Status Convert(const TensorObject& input_obj, + const TensorObject& output_obj) override { + cl_mem in_memory; + RETURN_IF_ERROR(GetOpenCLMemory(input_obj, &in_memory)); + cl_mem out_memory; + RETURN_IF_ERROR(GetOpenCLMemory(output_obj, &out_memory)); + + Tensor src_tensor; + RETURN_IF_ERROR(CreateSharedTensor(*context_, in_memory, shape_, + src_tensor_descriptor_, &src_tensor)); + Tensor dst_tensor; + RETURN_IF_ERROR(CreateSharedTensor(*context_, out_memory, shape_, + dst_tensor_descriptor_, &dst_tensor)); + RETURN_IF_ERROR(args_.SetObjectRef("src_tensor", &src_tensor)); + RETURN_IF_ERROR(args_.SetObjectRef("dst_tensor", &dst_tensor)); + RETURN_IF_ERROR(args_.Bind(kernel_.kernel())); + const int3 grid = int3(dst_tensor.Width() * dst_tensor.Batch(), + dst_tensor.Height(), dst_tensor.Slices()); + const int3 work_group_size = {16, 8, 1}; + const int3 work_groups_count = GetWorkGroupsCount(grid, work_group_size); + return queue_->Dispatch(kernel_, work_groups_count, work_group_size); + } + + private: + TensorDescriptor src_tensor_descriptor_; + TensorDescriptor dst_tensor_descriptor_; +}; + +// Implements conversion from OpenCL-specific tensor layout to BHWC OpenCL +// buffer. +class TensorToBHWCBufferConverter : public OpenClConverterImpl { + public: + static bool IsSupported(const ObjectDef& input, const ObjectDef& output) { + return IsOpenCLTensor(input) && IsBHWCOpenCLBuffer(output); + } + + absl::Status Init(const TensorObjectDef& input_def, + const TensorObjectDef& output_def, + Environment* environment) final { + TensorStorageType src_tensor_type = ToTensorStorageType( + input_def.object_def.object_type, input_def.object_def.data_layout); + tensor_descriptor_.layout = Layout::BHWC; + tensor_descriptor_.storage_type = src_tensor_type; + tensor_descriptor_.data_type = input_def.object_def.data_type; + args_.AddObjectRef("tensor", AccessType::READ, + absl::make_unique(tensor_descriptor_)); + + const bool need_fp16_support = + input_def.object_def.data_type == DataType::FLOAT16 || + output_def.object_def.data_type == DataType::FLOAT16; + std::string shader_src; + if (need_fp16_support) { + shader_src += "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"; + } + const std::string out_data_type = + ToCLDataType(output_def.object_def.data_type); + shader_src += "__kernel void tensor_to_bhwc("; + shader_src += "__global " + out_data_type + "* dst, $0) {\n"; + shader_src += R"( int linear_id = get_global_id(0); + int x = linear_id / args.tensor.Batch(); + int b = linear_id % args.tensor.Batch(); + int y = get_global_id(1); + int d = get_global_id(2); + if (x >= args.tensor.Width() || y >= args.tensor.Height() || d >= args.tensor.Slices()) return; +)"; + shader_src += " " + out_data_type + "4 input = args.tensor.Read<" + + out_data_type + ">(x, y, d, b);\n"; + shader_src += R"( int c = d * 4; int index = ((b * args.tensor.Height() + y) * args.tensor.Width() + x) * args.tensor.Channels() + c; dst[index] = input.x; @@ -115,39 +238,8 @@ class FromTensorConverter : public OpenClConverterImpl { } if (c + 3 < args.tensor.Channels()) { dst[index + 3] = input.w; - })"); } - - absl::Status Init(const TensorObjectDef& input_def, - const TensorObjectDef& output_def, - Environment* environment) final { - auto params_kernel = output_def.object_def.data_layout == DataLayout::BHWC - ? GetToBhwcKernel(input_def, output_def) - : GetToDhwc4Kernel(input_def, output_def); - - TensorStorageType src_tensor_type = ToTensorStorageType( - input_def.object_def.object_type, input_def.object_def.data_layout); - tensor_descriptor_.layout = Layout::BHWC; - tensor_descriptor_.storage_type = src_tensor_type; - tensor_descriptor_.data_type = input_def.object_def.data_type; - args_.AddObjectRef("tensor", AccessType::READ, - absl::make_unique(tensor_descriptor_)); - std::string shader_src = - R"( -#pragma OPENCL EXTENSION cl_khr_fp16 : enable - -__kernel void from_tensor()" + - params_kernel.first + R"(, $0) { - int linear_id = get_global_id(0); - int x = linear_id / args.tensor.Batch(); - int b = linear_id % args.tensor.Batch(); - int y = get_global_id(1); - int d = get_global_id(2); - if (x >= args.tensor.Width() || y >= args.tensor.Height() || d >= args.tensor.Slices()) return; - )" + ToCLDataType(output_def.object_def.data_type, 4) + - " input = args.tensor.Read<" + - ToCLDataType(output_def.object_def.data_type) + ">(x, y, d, b);\n" + - params_kernel.second + "\n}"; +})"; queue_ = environment->queue(); context_ = &environment->context(); shape_ = BHWC(input_def.dimensions.b, input_def.dimensions.h, @@ -155,7 +247,7 @@ __kernel void from_tensor()" + RETURN_IF_ERROR( args_.TransformToCLCode(environment->device().info_, {}, &shader_src)); return environment->program_cache()->GetOrCreateCLKernel( - shader_src, "from_tensor", environment->context(), + shader_src, "tensor_to_bhwc", environment->context(), environment->device(), &kernel_); } @@ -164,64 +256,24 @@ __kernel void from_tensor()" + auto output = absl::get_if(&output_obj); if (!output || !output->memobj) { return absl::InvalidArgumentError( - "Missing output in from_tensor converter"); - } - cl_mem memory = nullptr; - auto input_texture = absl::get_if(&input_obj); - if (input_texture && input_texture->memobj) { - memory = input_texture->memobj; - } - auto input_buffer = absl::get_if(&input_obj); - if (input_buffer && input_buffer->memobj) { - memory = input_buffer->memobj; - } - if (!memory) { - return absl::InvalidArgumentError( - "Missing input in from_tensor converter"); + "Missing output in tensor_to_bhwc converter"); } + + cl_mem in_memory; + RETURN_IF_ERROR(GetOpenCLMemory(input_obj, &in_memory)); Tensor tensor; - RETURN_IF_ERROR(CreateSharedTensor(*context_, memory, shape_, + RETURN_IF_ERROR(CreateSharedTensor(*context_, in_memory, shape_, tensor_descriptor_, &tensor)); return DispatchKernel(output->memobj, &tensor); } }; -// Implements conversion from BHWC to OpenCL-specific tensor layout. -class ToTensorConverter : public OpenClConverterImpl { +// Implements conversion from BHWC OpenCL buffer to OpenCL-specific tensor +// layout. +class BHWCBufferToTensorConverter : public OpenClConverterImpl { public: static bool IsSupported(const ObjectDef& input, const ObjectDef& output) { - return IsSupportedDataType(input.data_type) && - IsSupportedDataType(output.data_type) && - // Input is always Buffer/BHWC - input.object_type == ObjectType::OPENCL_BUFFER && - (input.data_layout == DataLayout::BHWC || - input.data_layout == DataLayout::DHWC4) && - // -> Texture2D/HDWC4 - ((output.object_type == ObjectType::OPENCL_TEXTURE && - output.data_layout == DataLayout::HDWC4) || - // -> TextureArray/DHWC4 - (output.object_type == ObjectType::OPENCL_TEXTURE && - output.data_layout == DataLayout::DHWC4) || - // -> SingleTextureArray/BHWC - (output.object_type == ObjectType::OPENCL_TEXTURE && - output.data_layout == DataLayout::BHWC) || - // -> Buffer/DHWC4 - (output.object_type == ObjectType::OPENCL_BUFFER && - output.data_layout == DataLayout::DHWC4)); - } - - std::pair GetFromDhwc4Kernel( - const TensorObjectDef& input_def, - const TensorObjectDef& output_def) const { - return std::make_pair( - "__global " + ToCLDataType(input_def.object_def.data_type, 4) + "* src", - output_def.object_def.data_type == input_def.object_def.data_type - ? "result = src[(d * args.tensor.Height() + y) * " - "args.tensor.Width() + x];" - : "result = convert_" + - ToCLDataType(output_def.object_def.data_type, 4) + - "(src[(d * args.tensor.Height() + y) * args.tensor.Width() + " - "x]);"); + return IsBHWCOpenCLBuffer(input) && IsOpenCLTensor(output); } std::pair GetFromBhwcKernel( @@ -241,9 +293,8 @@ class ToTensorConverter : public OpenClConverterImpl { absl::Status Init(const TensorObjectDef& input_def, const TensorObjectDef& output_def, Environment* environment) final { - auto params_kernel = input_def.object_def.data_layout == DataLayout::BHWC - ? GetFromBhwcKernel(input_def, output_def) - : GetFromDhwc4Kernel(input_def, output_def); + auto params_kernel = GetFromBhwcKernel(input_def, output_def); + TensorStorageType dst_tensor_type = ToTensorStorageType( output_def.object_def.object_type, output_def.object_def.data_layout); tensor_descriptor_.layout = Layout::BHWC; @@ -251,23 +302,38 @@ class ToTensorConverter : public OpenClConverterImpl { tensor_descriptor_.data_type = output_def.object_def.data_type; args_.AddObjectRef("tensor", AccessType::WRITE, absl::make_unique(tensor_descriptor_)); - std::string shader_src = - R"( -#pragma OPENCL EXTENSION cl_khr_3d_image_writes : enable -#pragma OPENCL EXTENSION cl_khr_fp16 : enable -__kernel void to_tensor()" + - params_kernel.first + R"(, $0) { - int linear_id = get_global_id(0); + const bool need_fp16_support = + input_def.object_def.data_type == DataType::FLOAT16 || + output_def.object_def.data_type == DataType::FLOAT16; + std::string shader_src; + if (need_fp16_support) { + shader_src += "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"; + } + const std::string in_data_type = + ToCLDataType(input_def.object_def.data_type); + const std::string out_data_type = + ToCLDataType(output_def.object_def.data_type); + shader_src += "__kernel void bhwc_to_tensor("; + shader_src += "__global " + in_data_type + "* src, $0) {\n"; + + shader_src += R"( int linear_id = get_global_id(0); int x = linear_id / args.tensor.Batch(); int b = linear_id % args.tensor.Batch(); int y = get_global_id(1); int d = get_global_id(2); if (x >= args.tensor.Width() || y >= args.tensor.Height() || d >= args.tensor.Slices()) return; - )" + ToCLDataType(output_def.object_def.data_type, 4) + - " result;\n" + params_kernel.second + "\n " + - "args.tensor.Write(result, x, y, d, b);\n}"; +)"; + shader_src += " " + out_data_type + "4 result;\n"; + shader_src += R"( int c = d * 4; + int index = ((b * args.tensor.Height() + y) * args.tensor.Width() + x) * args.tensor.Channels() + c; + result.x = src[index]; + result.y = c + 1 < args.tensor.Channels() ? src[index + 1] : 1; + result.z = c + 2 < args.tensor.Channels() ? src[index + 2] : 2; + result.w = c + 3 < args.tensor.Channels() ? src[index + 3] : 3; +)"; + shader_src += " args.tensor.Write(result, x, y, d, b);\n}"; queue_ = environment->queue(); context_ = &environment->context(); shape_ = BHWC(output_def.dimensions.b, output_def.dimensions.h, @@ -275,31 +341,21 @@ __kernel void to_tensor()" + RETURN_IF_ERROR( args_.TransformToCLCode(environment->device().info_, {}, &shader_src)); return environment->program_cache()->GetOrCreateCLKernel( - shader_src, "to_tensor", environment->context(), environment->device(), - &kernel_); + shader_src, "bhwc_to_tensor", environment->context(), + environment->device(), &kernel_); } absl::Status Convert(const TensorObject& input_obj, const TensorObject& output_obj) override { auto input = absl::get_if(&input_obj); if (!input || !input->memobj) { - return absl::InvalidArgumentError("Missing input in to_tensor converter"); - } - cl_mem memory = nullptr; - auto output_texture = absl::get_if(&output_obj); - if (output_texture && output_texture->memobj) { - memory = output_texture->memobj; - } - auto output_buffer = absl::get_if(&output_obj); - if (output_buffer && output_buffer->memobj) { - memory = output_buffer->memobj; - } - if (!memory) { return absl::InvalidArgumentError( - "Missing output in to_tensor converter"); + "Missing input in bhwc_to_tensor converter"); } + cl_mem out_memory; + RETURN_IF_ERROR(GetOpenCLMemory(output_obj, &out_memory)); Tensor tensor; - RETURN_IF_ERROR(CreateSharedTensor(*context_, memory, shape_, + RETURN_IF_ERROR(CreateSharedTensor(*context_, out_memory, shape_, tensor_descriptor_, &tensor)); return DispatchKernel(input->memobj, &tensor); } @@ -465,9 +521,10 @@ class OpenClTensorConverterBuilder : public TensorObjectConverterBuilder { const auto& output_def = output.object_def; return input.dimensions == output.dimensions && (TrivialCopier::IsSupported(input_def, output_def) || + TensorToTensorConverter::IsSupported(input_def, output_def) || CpuCopier::IsSupported(input_def, output_def) || - FromTensorConverter::IsSupported(input_def, output_def) || - ToTensorConverter::IsSupported(input_def, output_def)); + TensorToBHWCBufferConverter::IsSupported(input_def, output_def) || + BHWCBufferToTensorConverter::IsSupported(input_def, output_def)); } absl::Status MakeConverter( @@ -478,12 +535,16 @@ class OpenClTensorConverterBuilder : public TensorObjectConverterBuilder { const auto& output_def = output.object_def; if (TrivialCopier::IsSupported(input_def, output_def)) { impl = absl::make_unique(); + } else if (TensorToTensorConverter::IsSupported(input_def, output_def)) { + impl = absl::make_unique(); } else if (CpuCopier::IsSupported(input_def, output_def)) { impl = absl::make_unique(); - } else if (FromTensorConverter::IsSupported(input_def, output_def)) { - impl = absl::make_unique(); - } else if (ToTensorConverter::IsSupported(input_def, output_def)) { - impl = absl::make_unique(); + } else if (TensorToBHWCBufferConverter::IsSupported(input_def, + output_def)) { + impl = absl::make_unique(); + } else if (BHWCBufferToTensorConverter::IsSupported(input_def, + output_def)) { + impl = absl::make_unique(); } else { return absl::UnimplementedError("Unsupported conversion"); } diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed.cc b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed.cc index 18522239a47..b2bf5216f8e 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed.cc @@ -17,12 +17,14 @@ limitations under the License. #include #include +#include #include "absl/strings/substitute.h" #include "tensorflow/lite/delegates/gpu/cl/kernels/util.h" #include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h" #include "tensorflow/lite/delegates/gpu/cl/precision.h" #include "tensorflow/lite/delegates/gpu/cl/tensor_type.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/status.h" namespace tflite { @@ -33,23 +35,23 @@ ConvolutionTransposed::ConvolutionTransposed( const OperationDef& definition, const ConvolutionTransposedAttributes& attr, const DeviceInfo& device_info) : GPUOperation(definition), - stride_(attr.stride.w, attr.stride.h), - block_size_(2, 2, 2) { + stride_(attr.stride.w, attr.stride.h, 1, 1), + block_size_(2, 2, 1, 2) { const bool weights_are_buffer = device_info.IsMali(); const bool is_f16 = definition.precision == CalculationsPrecision::F16; if (device_info.IsMali()) { if (device_info.mali_info.IsMidgard()) { - block_size_ = is_f16 ? int3(2, 1, 2) : int3(2, 1, 1); + block_size_ = is_f16 ? int4(2, 1, 1, 2) : int4(2, 1, 1, 1); } else { - block_size_ = is_f16 ? int3(2, 2, 2) : int3(2, 2, 1); + block_size_ = is_f16 ? int4(2, 2, 1, 2) : int4(2, 2, 1, 1); } } const int dst_depth = DivideRoundUp(attr.weights.shape.o, 4); if (dst_depth == 1 || dst_depth == 3) { if (!device_info.IsMali()) { - block_size_.y *= block_size_.z; + block_size_.y *= block_size_.w; } - block_size_.z = 1; + block_size_.w = 1; } args_.AddInt("stride_x", stride_.x); @@ -63,6 +65,45 @@ ConvolutionTransposed::ConvolutionTransposed( UploadWeights(attr.weights, weights_are_buffer); } +ConvolutionTransposed::ConvolutionTransposed( + const OperationDef& definition, + const ConvolutionTransposed3DAttributes& attr, + const DeviceInfo& device_info) + : GPUOperation(definition), + stride_(attr.stride.w, attr.stride.h, attr.stride.d, 1), + block_size_(2, 2, 1, 2) { + const bool weights_are_buffer = device_info.IsMali(); + const bool is_f16 = definition.precision == CalculationsPrecision::F16; + if (device_info.IsMali()) { + if (device_info.mali_info.IsMidgard()) { + block_size_ = is_f16 ? int4(2, 1, 1, 2) : int4(2, 1, 1, 1); + } else { + block_size_ = is_f16 ? int4(2, 2, 1, 2) : int4(2, 2, 1, 1); + } + } + const int dst_depth = DivideRoundUp(attr.weights.shape.o, 4); + if (dst_depth == 1 || dst_depth == 3) { + if (!device_info.IsMali()) { + block_size_.y *= block_size_.w; + } + block_size_.w = 1; + } + + args_.AddInt("stride_x", stride_.x); + args_.AddInt("stride_y", stride_.y); + args_.AddInt("stride_z", stride_.z); + args_.AddInt("padding_x", attr.padding.prepended.w); + args_.AddInt("padding_y", attr.padding.prepended.h); + args_.AddInt("padding_z", attr.padding.prepended.d); + args_.AddInt("kernel_size_x", attr.weights.shape.w); + args_.AddInt("kernel_size_y", attr.weights.shape.h); + args_.AddInt("kernel_size_z", attr.weights.shape.d); + args_.AddInt("grid_size_y"); + code_ = GenerateConvolutionTransposedCode(definition_, device_info, + weights_are_buffer, block_size_); + UploadWeights(attr.weights, weights_are_buffer); +} + ConvolutionTransposed::ConvolutionTransposed(ConvolutionTransposed&& operation) : GPUOperation(std::move(operation)), stride_(operation.stride_), @@ -80,50 +121,85 @@ ConvolutionTransposed& ConvolutionTransposed::operator=( std::string ConvolutionTransposed::GenerateConvolutionTransposedCode( const OperationDef& op_def, const DeviceInfo& device_info, - bool weights_are_buffer, const int3& block_size) { + bool weights_are_buffer, const int4& block_size) { auto src_desc = op_def.src_tensors[0]; src_desc.SetTextureAddressMode(TextureAddressMode::ZERO); AddSrcTensor("src_tensor", src_desc); - AddDstTensor("dst_tensor", op_def.dst_tensors[0]); - const auto src_tensor_type = op_def.src_tensors[0].storage_type; - bool image_buffer = src_tensor_type == TensorStorageType::IMAGE_BUFFER; - bool manual_clamp = - image_buffer || src_tensor_type == TensorStorageType::BUFFER; + const auto& src_def = op_def.src_tensors[0]; std::string c = GetCommonDefines(op_def.precision); - for (int z = 0; z < block_size.z; ++z) { + for (int s = 0; s < block_size.w; ++s) { const std::string f0 = - weights_are_buffer ? "weights_cache[" + std::to_string(z) + "].s0123" - : "f" + std::to_string(z * 4 + 0); + weights_are_buffer ? "weights_cache[" + std::to_string(s) + "].s0123" + : "f" + std::to_string(s * 4 + 0); const std::string f1 = - weights_are_buffer ? "weights_cache[" + std::to_string(z) + "].s4567" - : "f" + std::to_string(z * 4 + 1); + weights_are_buffer ? "weights_cache[" + std::to_string(s) + "].s4567" + : "f" + std::to_string(s * 4 + 1); const std::string f2 = - weights_are_buffer ? "weights_cache[" + std::to_string(z) + "].s89ab" - : "f" + std::to_string(z * 4 + 2); + weights_are_buffer ? "weights_cache[" + std::to_string(s) + "].s89ab" + : "f" + std::to_string(s * 4 + 2); const std::string f3 = - weights_are_buffer ? "weights_cache[" + std::to_string(z) + "].scdef" - : "f" + std::to_string(z * 4 + 3); + weights_are_buffer ? "weights_cache[" + std::to_string(s) + "].scdef" + : "f" + std::to_string(s * 4 + 3); switch (op_def.precision) { case CalculationsPrecision::F32: case CalculationsPrecision::F16: - c += "#define CONV" + std::to_string(z) + "(R, S) \\\n"; + c += "#define CONV" + std::to_string(s) + "(R, S) \\\n"; c += "R += S.x * " + f0 + "; \\\n"; c += "R += S.y * " + f1 + "; \\\n"; c += "R += S.z * " + f2 + "; \\\n"; c += "R += S.w * " + f3 + "; \n"; break; case CalculationsPrecision::F32_F16: - c += "#define CONV" + std::to_string(z) + "(R, S) \\\n"; + c += "#define CONV" + std::to_string(s) + "(R, S) \\\n"; c += "R += convert_float4(S.x * " + f0 + " + S.y * " + f1 + " + S.z * " + f2 + " + S.w * " + f3 + ");\n"; break; } } + auto generate_id = [&](const std::string& x, const std::string& y, + const std::string& z) { + std::string id; + if (src_def.HasAxis(Axis::WIDTH)) { + id += "_w" + x; + } + if (src_def.HasAxis(Axis::HEIGHT)) { + id += "_h" + y; + } + if (src_def.HasAxis(Axis::DEPTH)) { + id += "_d" + z; + } + return id; + }; + + auto generate_id_full = [&](const std::string& x, const std::string& y, + const std::string& z, const std::string& s) { + return generate_id(x, y, z) + "_s" + s; + }; + + auto generate_check = [&](const std::string& x, const std::string& y, + const std::string& z) { + std::string check; + const std::vector axes{Axis::WIDTH, Axis::HEIGHT, Axis::DEPTH}; + const std::vector names{"in_x", "in_y", "in_z"}; + const std::vector coords{x, y, z}; + for (int i = 0; i < axes.size(); ++i) { + const auto& axis = axes[i]; + if (src_def.HasAxis(axis) && !src_def.SupportsZeroClamp(axis) && + block_size[i] != 1) { + if (!check.empty()) { + check += " && "; + } + check += names[i] + coords[i]; + } + } + return check; + }; + switch (op_def.precision) { case CalculationsPrecision::F32: c += "#define FLT16 float16\n"; @@ -149,23 +225,48 @@ std::string ConvolutionTransposed::GenerateConvolutionTransposedCode( c += " int ceil_x = dst_x / args.stride_x;\n"; c += " dst_x = ceil_x * args.stride_x * " + std::to_string(block_size.x) + " + rem_x;\n"; - c += " int dst_y = get_global_id(1);\n"; + if (src_def.HasAxis(Axis::DEPTH)) { + c += " int linear_id_y = get_global_id(1);\n"; + c += " int dst_y = linear_id_y % args.grid_size_y;\n"; + c += " int dst_z = linear_id_y / args.grid_size_y;\n"; + c += " int rem_z = dst_z % args.stride_z;\n"; + c += " int ceil_z = dst_z / args.stride_z;\n"; + c += " dst_z = ceil_z * args.stride_z * " + std::to_string(block_size.z) + + " + rem_z;\n"; + c += " if (dst_z >= args.dst_tensor.Depth()) return;\n"; + } else { + c += " int dst_y = get_global_id(1);\n"; + } c += " int rem_y = dst_y % args.stride_y;\n"; c += " int ceil_y = dst_y / args.stride_y;\n"; c += " dst_y = ceil_y * args.stride_y * " + std::to_string(block_size.y) + " + rem_y;\n"; - c += " int dst_z = get_global_id(2) * " + std::to_string(block_size.z) + + c += " int dst_s = get_global_id(2) * " + std::to_string(block_size.w) + ";\n"; c += " if (dst_x >= args.dst_tensor.Width() || dst_y >= " - "args.dst_tensor.Height() || dst_z >= " + "args.dst_tensor.Height() || dst_s >= " "args.dst_tensor.Slices()) return;\n"; if (weights_are_buffer) { - c += " int f_base = dst_z * args.src_tensor.Slices() * args.kernel_size_x " - "* args.kernel_size_y;\n"; + c += " int f_base = dst_s * args.src_tensor.Slices() * args.kernel_size_x " + "* args.kernel_size_y"; + if (src_def.HasAxis(Axis::DEPTH)) { + c += " * args.kernel_size_z"; + } + c += ";\n"; } - for (int i = 0; i < block_size.x * block_size.y * block_size.z; ++i) { - c += " ACCUM_FLT4 r" + std::to_string(i) + - " = (ACCUM_FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n"; + for (int s = 0; s < block_size.w; ++s) { + const std::string sind = std::to_string(s); + for (int z = 0; z < block_size.z; ++z) { + const std::string zind = std::to_string(z); + for (int y = 0; y < block_size.y; ++y) { + const std::string yind = std::to_string(y); + for (int x = 0; x < block_size.x; ++x) { + const std::string xind = std::to_string(x); + c += " ACCUM_FLT4 r" + generate_id_full(xind, yind, zind, sind) + + " = (ACCUM_FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n"; + } + } + } } c += " int kernel_first_dst_x = dst_x + args.padding_x;\n"; c += " int kernel_first_dst_y = dst_y + args.padding_y;\n"; @@ -181,21 +282,59 @@ std::string ConvolutionTransposed::GenerateConvolutionTransposedCode( c += " int src_y = (kernel_first_dst_y + offset_y_strided) / args.stride_y - " "offset_y;\n"; - c += " int src_as_dst_y = src_y * args.stride_y;\n"; - c += " for (;src_as_dst_y > kernel_last_dst_y; src_y -= 1, src_as_dst_y -= " - "args.stride_y) {\n"; + if (src_def.HasAxis(Axis::DEPTH)) { + c += " int kernel_first_dst_z = dst_z + args.padding_z;\n"; + c += " int kernel_last_dst_z = kernel_first_dst_z - args.kernel_size_z;\n"; + c += " int offset_z = abs(args.padding_z);\n"; + c += " int offset_z_strided = offset_z * args.stride_z;\n"; + c += " int src_z = (kernel_first_dst_z + offset_z_strided) / " + "args.stride_z - offset_z;\n"; + c += " int src_as_dst_z = src_z * args.stride_z;\n"; + c += + " for (;src_as_dst_z > kernel_last_dst_z; src_z -= 1, src_as_dst_z -= " + "args.stride_z) {\n"; + for (int z = 0; z < block_size.z; ++z) { + const std::string zindex = std::to_string(z); + c += " int sz" + zindex + " = src_z + " + zindex + ";\n"; + if (!src_def.SupportsZeroClamp(Axis::DEPTH)) { + c += " bool in_z" + zindex + " = sz" + zindex + " >= 0 && sz" + + zindex + " < args.src_tensor.Depth();\n"; + if (!src_def.CanReadOutOfBorder(Axis::DEPTH)) { + c += " sz" + zindex + " = clamp(sz" + zindex + + ", 0, args.src_tensor.Depth() - 1);\n"; + } + } + } + if (block_size.z == 1 && !src_def.SupportsZeroClamp(Axis::DEPTH)) { + c += " if (!in_z0) continue;\n"; + } + c += " int kernel_z = kernel_first_dst_z - src_as_dst_z;\n"; + c += " int src_as_dst_y = src_y * args.stride_y;\n"; + c += " int src_y_copy = src_y;\n"; + c += " for (;src_as_dst_y > kernel_last_dst_y; src_y_copy -= 1, " + "src_as_dst_y -= args.stride_y) {\n"; + } else { + c += " int src_as_dst_y = src_y * args.stride_y;\n"; + c += " for (;src_as_dst_y > kernel_last_dst_y; src_y -= 1, src_as_dst_y " + "-= args.stride_y) {\n"; + } for (int y = 0; y < block_size.y; ++y) { const std::string yindex = std::to_string(y); - c += " int sy" + yindex + " = src_y + " + yindex + ";\n"; - if (manual_clamp) { + const std::string src_y = + src_def.HasAxis(Axis::DEPTH) ? "src_y_copy" : "src_y"; + c += " int sy" + yindex + " = " + src_y + " + " + yindex + ";\n"; + if (!src_def.SupportsZeroClamp(Axis::HEIGHT)) { c += " bool in_y" + yindex + " = sy" + yindex + " >= 0 && sy" + yindex + " < args.src_tensor.Height();\n"; - if (!image_buffer) { + if (!src_def.CanReadOutOfBorder(Axis::HEIGHT)) { c += " sy" + yindex + " = clamp(sy" + yindex + ", 0, args.src_tensor.Height() - 1);\n"; } } } + if (block_size.y == 1 && !src_def.SupportsZeroClamp(Axis::HEIGHT)) { + c += " if (!in_y0) continue;\n"; + } c += " int kernel_y = kernel_first_dst_y - src_as_dst_y;\n"; c += " int src_as_dst_x = src_x * args.stride_x;\n"; c += " int src_x_copy = src_x;\n"; @@ -205,132 +344,196 @@ std::string ConvolutionTransposed::GenerateConvolutionTransposedCode( for (int x = 0; x < block_size.x; ++x) { const std::string xindex = std::to_string(x); c += " int sx" + xindex + " = src_x_copy + " + xindex + ";\n"; - if (manual_clamp) { + if (!src_def.SupportsZeroClamp(Axis::WIDTH)) { c += " bool in_x" + xindex + " = sx" + xindex + " >= 0 && sx" + xindex + " < args.src_tensor.Width();\n"; - if (!image_buffer) { + if (!src_def.CanReadOutOfBorder(Axis::WIDTH)) { c += " sx" + xindex + " = clamp(sx" + xindex + ", 0, args.src_tensor.Width() - 1);\n"; } } } - for (int y = 0; y < block_size.y; ++y) { - const std::string yindex = std::to_string(y); - for (int x = 0; x < block_size.x; ++x) { - const std::string xindex = std::to_string(x); - const std::string id = std::to_string(y * block_size.x + x); - c += " args.src_tensor.GetAddress(addr_" + id + ", sx" + xindex + - ", sy" + yindex + ", 0);\n"; - if (image_buffer) { - c += " addr_" + id + " = select(-1, addr_" + id + ", (in_x" + - xindex + " && in_y" + yindex + "));\n"; - c += absl::Substitute( - " int dz_$0 = select(0, args.src_tensor.SliceStride(), " - "(in_x$1 && in_y$2));\n", - y * block_size.x + x, x, y); + if (block_size.x == 1 && !src_def.SupportsZeroClamp(Axis::WIDTH)) { + c += " if (!in_x0) continue;\n"; + } + for (int z = 0; z < block_size.z; ++z) { + const std::string zind = std::to_string(z); + for (int y = 0; y < block_size.y; ++y) { + const std::string yind = std::to_string(y); + for (int x = 0; x < block_size.x; ++x) { + const std::string xind = std::to_string(x); + const std::string id = generate_id(xind, yind, zind); + const std::string check = generate_check(xind, yind, zind); + std::string coords = "sx" + xind + ", sy" + yind; + if (src_def.HasAxis(Axis::DEPTH)) { + coords += ", sz" + zind; + } + if (src_def.IsLinear()) { + c += " args.src_tensor.GetAddress(addr" + id + ", " + coords + + ", 0);\n"; + } + if (src_def.ReturnsZeroForNegOneRead()) { + c += " addr" + id + " = select(-1, addr" + id + ", (" + check + + "));\n"; + c += " int ds" + id + + " = select(0, args.src_tensor.SliceStride(), (" + check + + "));\n"; + } } } } - if (src_tensor_type == TensorStorageType::BUFFER) { - c += " int dz = args.src_tensor.SliceStride();\n"; - } - if (block_size.x == 1 && block_size.y == 1 && manual_clamp) { - c += " if (!in_x0 || !in_y0) continue;\n"; + if (src_def.storage_type == TensorStorageType::BUFFER) { + c += " int ds = args.src_tensor.SliceStride();\n"; } c += " int kernel_x = kernel_first_dst_x - src_as_dst_x;\n"; - c += " int kernel_index = kernel_y * args.kernel_size_x + kernel_x;\n"; + if (src_def.HasAxis(Axis::DEPTH)) { + c += " int kernel_index = (kernel_z * args.kernel_size_y + kernel_y) " + "* args.kernel_size_x + kernel_x;\n"; + } else { + c += " int kernel_index = kernel_y * args.kernel_size_x + kernel_x;\n"; + } if (weights_are_buffer) { c += " int f_offset = f_base + kernel_index * " "args.src_tensor.Slices() * " + - std::to_string(block_size.z) + ";\n"; + std::to_string(block_size.w) + ";\n"; } else { c += " int x_c = kernel_index * args.src_tensor.Slices();\n"; } c += " for (int s = 0; s < args.src_tensor.Slices(); ++s) {\n"; const bool conditional_read = device_info.IsMali(); - for (int y = 0; y < block_size.y; ++y) { - const std::string yindex = std::to_string(y); - for (int x = 0; x < block_size.x; ++x) { - const std::string xindex = std::to_string(x); - const std::string id = std::to_string(y * block_size.x + x); - if (image_buffer) { - c += " FLT4 src" + id + " = args.src_tensor.Read(addr_" + id + - "); addr_" + id + " += dz_" + id + ";\n"; - } else if (manual_clamp) { - if (conditional_read) { - c += " FLT4 src" + id + " = in_x" + xindex + " && in_y" + - yindex + " ? args.src_tensor.Read(addr_" + id + - ") : (FLT4)(0.0f); addr_" + id + " += dz;\n"; + for (int z = 0; z < block_size.z; ++z) { + const std::string zind = std::to_string(z); + for (int y = 0; y < block_size.y; ++y) { + const std::string yind = std::to_string(y); + for (int x = 0; x < block_size.x; ++x) { + const std::string xind = std::to_string(x); + const std::string id = generate_id(xind, yind, zind); + std::string address; + if (src_def.IsLinear()) { + address = "addr" + id; } else { - c += " FLT4 src" + id + " = args.src_tensor.Read(addr_" + id + - ") * (FLT)(in_x" + xindex + " && in_y" + yindex + "); addr_" + - id + " += dz;\n"; + address = "sx" + xind + ", sy" + yind; + if (src_def.HasAxis(Axis::DEPTH)) { + address += ", sz" + zind; + } + address += ", s"; + } + if (src_def.ReturnsZeroForNegOneRead()) { + c += " FLT4 src" + id + " = args.src_tensor.Read(" + address + + "); " + address + " += ds" + id + ";\n"; + } else { + const std::string check = generate_check(xind, yind, zind); + if (!check.empty()) { + if (conditional_read) { + c += " FLT4 src" + id + " = " + check + + " ? args.src_tensor.Read(" + address + ") : (FLT4)(0.0f);\n"; + } else { + c += " FLT4 src" + id + " = args.src_tensor.Read(" + + address + ") * (FLT)(" + check + ");\n"; + } + } else { + c += " FLT4 src" + id + " = args.src_tensor.Read(" + + address + ");\n"; + } + if (src_def.IsLinear()) { + c += " addr" + id + " += ds;\n"; + } } - } else { - c += " FLT4 src" + id + " = args.src_tensor.Read(sx" + xindex + - ", sy" + yindex + ", s);\n"; } } } if (weights_are_buffer) { c += " __global FLT16* weights_cache = " "args.weights.GetPtr(f_offset);\n"; - c += " f_offset += " + std::to_string(block_size.z) + ";\n"; + c += " f_offset += " + std::to_string(block_size.w) + ";\n"; } else { - for (int z = 0; z < block_size.z; ++z) { + for (int s = 0; s < block_size.w; ++s) { c += absl::Substitute( - R"( FLT4 f$1 = args.weights0.Read(dst_z + $0, x_c); - FLT4 f$2 = args.weights1.Read(dst_z + $0, x_c); - FLT4 f$3 = args.weights2.Read(dst_z + $0, x_c); - FLT4 f$4 = args.weights3.Read(dst_z + $0, x_c); + R"( FLT4 f$1 = args.weights0.Read(dst_s + $0, x_c); + FLT4 f$2 = args.weights1.Read(dst_s + $0, x_c); + FLT4 f$3 = args.weights2.Read(dst_s + $0, x_c); + FLT4 f$4 = args.weights3.Read(dst_s + $0, x_c); )", - z, z * 4 + 0, z * 4 + 1, z * 4 + 2, z * 4 + 3); + s, s * 4 + 0, s * 4 + 1, s * 4 + 2, s * 4 + 3); } c += " x_c++;\n"; } - for (int z = 0; z < block_size.z; ++z) { - for (int i = 0; i < block_size.x * block_size.y; ++i) { - c += " CONV" + std::to_string(z) + "(r" + - std::to_string(i + z * block_size.x * block_size.y) + ", src" + - std::to_string(i) + ");\n"; + for (int s = 0; s < block_size.w; ++s) { + const std::string sind = std::to_string(s); + for (int z = 0; z < block_size.z; ++z) { + const std::string zind = std::to_string(z); + for (int y = 0; y < block_size.y; ++y) { + const std::string yind = std::to_string(y); + for (int x = 0; x < block_size.x; ++x) { + const std::string xind = std::to_string(x); + const std::string id = generate_id(xind, yind, zind); + const std::string full_id = generate_id_full(xind, yind, zind, sind); + c += " CONV" + sind + "(r" + full_id + ", src" + id + ");\n"; + } + } } } c += " }\n"; c += " }\n"; c += " }\n"; - for (int z = 0; z < block_size.z; ++z) { - c += " if (dst_z < args.dst_tensor.Slices()) {\n"; - c += " FLT4 bias_val = args.biases.Read(dst_z);\n"; - for (int y = 0; y < block_size.y; ++y) { - for (int x = 0; x < block_size.x; ++x) { - const std::string id = - std::to_string((z * block_size.y + y) * block_size.x + x); - c += " {\n"; - c += " int xc = dst_x + args.stride_x * " + std::to_string(x) + - ";\n"; - c += " int yc = dst_y + args.stride_y * " + std::to_string(y) + - ";\n"; - c += " if (xc < args.dst_tensor.Width() && yc < " - "args.dst_tensor.Height()) {\n"; - c += " FLT4 res = TO_FLT4(r" + id + ") + bias_val;\n"; - c += " args.dst_tensor.Write(res, xc, yc, dst_z);\n"; - c += " }\n"; - c += " }\n"; + if (src_def.HasAxis(Axis::DEPTH)) { + c += " }\n"; + } + for (int s = 0; s < block_size.w; ++s) { + const std::string sind = std::to_string(s); + c += " if (dst_s < args.dst_tensor.Slices()) {\n"; + c += " FLT4 bias_val = args.biases.Read(dst_s);\n"; + for (int z = 0; z < block_size.z; ++z) { + const std::string zind = std::to_string(z); + for (int y = 0; y < block_size.y; ++y) { + const std::string yind = std::to_string(y); + for (int x = 0; x < block_size.x; ++x) { + const std::string xind = std::to_string(x); + const std::string id = generate_id_full(xind, yind, zind, sind); + std::string checks = + "xc < args.dst_tensor.Width() && yc < args.dst_tensor.Height()"; + std::string coords = "xc, yc"; + c += " {\n"; + c += " int xc = dst_x + args.stride_x * " + xind + ";\n"; + c += " int yc = dst_y + args.stride_y * " + yind + ";\n"; + if (src_def.HasAxis(Axis::DEPTH)) { + c += " int zc = dst_z + args.stride_z * " + zind + ";\n"; + checks += " && zc < args.dst_tensor.Depth()"; + coords += ", zc"; + } + c += " if (" + checks + ") {\n"; + c += " FLT4 res = TO_FLT4(r" + id + ") + bias_val;\n"; + c += " args.dst_tensor.Write(res, " + coords + ", dst_s);\n"; + c += " }\n"; + c += " }\n"; + } } } c += " }\n"; - c += " dst_z++;\n"; + c += " dst_s++;\n"; } c += "}\n"; return c; } +absl::Status ConvolutionTransposed::BindArguments(ArgumentsBinder* args) { + if (definition_.src_tensors[0].HasAxis(Axis::DEPTH)) { + const int aligned_h = + AlignByN(dst_[0]->Height(), stride_.y * block_size_.y); + RETURN_IF_ERROR( + args->SetInt("grid_size_y", DivideRoundUp(aligned_h, block_size_.y))); + } + return absl::OkStatus(); +} + int3 ConvolutionTransposed::GetGridSize() const { const int aligned_w = AlignByN(dst_[0]->Width(), stride_.x * block_size_.x); const int aligned_h = AlignByN(dst_[0]->Height(), stride_.y * block_size_.y); + const int aligned_d = AlignByN(dst_[0]->Depth(), stride_.z * block_size_.z); const int grid_x = DivideRoundUp(aligned_w, block_size_.x) * dst_[0]->Batch(); - const int grid_y = DivideRoundUp(aligned_h, block_size_.y); - const int grid_z = DivideRoundUp(dst_[0]->Slices(), block_size_.z); + const int grid_y = DivideRoundUp(aligned_h, block_size_.y) * + DivideRoundUp(aligned_d, block_size_.z); + const int grid_z = DivideRoundUp(dst_[0]->Slices(), block_size_.w); return int3(grid_x, grid_y, grid_z); } @@ -356,6 +559,21 @@ ConvolutionTransposed CreateConvolutionTransposed( return result; } +ConvolutionTransposed CreateConvolutionTransposed3D( + const DeviceInfo& device_info, const OperationDef& definition, + const ConvolutionTransposed3DAttributes& attr) { + ConvolutionTransposed result(definition, attr, device_info); + + TensorLinearDescriptor desc; + desc.storage_type = + DeduceLinearStorageType(definition.GetPrimaryStorageType()); + desc.element_type = definition.GetDataType(); + desc.UploadLinearData(attr.bias); + result.args_.AddObject( + "biases", absl::make_unique(std::move(desc))); + return result; +} + } // namespace cl } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed.h b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed.h index 7939236409e..5aa86f33e5a 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed.h @@ -16,10 +16,12 @@ limitations under the License. #ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_CONVOLUTION_TRANSPOSED_H_ #define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_CONVOLUTION_TRANSPOSED_H_ +#include #include #include "tensorflow/lite/delegates/gpu/cl/buffer.h" #include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" +#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h" #include "tensorflow/lite/delegates/gpu/cl/linear_storage.h" #include "tensorflow/lite/delegates/gpu/cl/tensor.h" #include "tensorflow/lite/delegates/gpu/cl/texture2d.h" @@ -42,6 +44,7 @@ class ConvolutionTransposed : public GPUOperation { TuningType tuning_type, const DeviceInfo& device_info, const KernelInfo& kernel_info, std::vector* work_groups) const override; + absl::Status BindArguments(ArgumentsBinder* args) override; int3 GetGridSize() const override; // Move only @@ -54,30 +57,37 @@ class ConvolutionTransposed : public GPUOperation { friend ConvolutionTransposed CreateConvolutionTransposed( const DeviceInfo& device_info, const OperationDef& definition, const ConvolutionTransposedAttributes& attr); - explicit ConvolutionTransposed(const OperationDef& definition, - const ConvolutionTransposedAttributes& attr, - const DeviceInfo& device_info); + friend ConvolutionTransposed CreateConvolutionTransposed3D( + const DeviceInfo& device_info, const OperationDef& definition, + const ConvolutionTransposed3DAttributes& attr); + ConvolutionTransposed(const OperationDef& definition, + const ConvolutionTransposedAttributes& attr, + const DeviceInfo& device_info); + ConvolutionTransposed(const OperationDef& definition, + const ConvolutionTransposed3DAttributes& attr, + const DeviceInfo& device_info); + template void UploadWeights(const tflite::gpu::Tensor& weights, bool weights_are_buffer); - template - void RearrangeWeightsData(const tflite::gpu::Tensor& weights, - absl::Span dst, bool weights_are_buffer); + template + void UploadWeights(const tflite::gpu::Tensor& weights, + bool weights_are_buffer); std::string GenerateConvolutionTransposedCode(const OperationDef& op_def, const DeviceInfo& device_info, bool weights_are_buffer, - const int3& block_size); - int2 stride_; - int3 block_size_ = int3(1, 1, 1); + const int4& block_size); + int4 stride_; + int4 block_size_ = int4(1, 1, 1, 1); // WHDS }; template void ConvolutionTransposed::UploadWeights( const tflite::gpu::Tensor& weights, bool weights_are_buffer) { const int dst_depth = - AlignByN(DivideRoundUp(weights.shape.o, 4), block_size_.z); + AlignByN(DivideRoundUp(weights.shape.o, 4), block_size_.w); const int src_depth = DivideRoundUp(weights.shape.i, 4); const int kernel_x = weights.shape.w; const int kernel_y = weights.shape.h; @@ -90,12 +100,22 @@ void ConvolutionTransposed::UploadWeights( if (f32_weights) { float4* ptr = reinterpret_cast(data.data()); - RearrangeWeightsData(weights, absl::MakeSpan(ptr, elements_count), - weights_are_buffer); + if (weights_are_buffer) { + RearrangeWeightsToOHWIOGroupI4O4(weights, block_size_.w, + absl::MakeSpan(ptr, elements_count)); + } else { + RearrangeWeightsToI4HWIOOGroupO4(weights, block_size_.w, + absl::MakeSpan(ptr, elements_count)); + } } else { half4* ptr = reinterpret_cast(data.data()); - RearrangeWeightsData(weights, absl::MakeSpan(ptr, elements_count), - weights_are_buffer); + if (weights_are_buffer) { + RearrangeWeightsToOHWIOGroupI4O4(weights, block_size_.w, + absl::MakeSpan(ptr, elements_count)); + } else { + RearrangeWeightsToI4HWIOOGroupO4(weights, block_size_.w, + absl::MakeSpan(ptr, elements_count)); + } } if (weights_are_buffer) { @@ -107,90 +127,80 @@ void ConvolutionTransposed::UploadWeights( args_.AddObject("weights", absl::make_unique(std::move(desc))); } else { - int sub_size = float4_size * elements_count / 4; - Texture2DDescriptor desc0; - desc0.element_type = f32_weights ? DataType::FLOAT32 : DataType::FLOAT16; - desc0.size = int2(dst_depth, src_depth * kernel_x * kernel_y); - desc0.data.resize(sub_size); - memcpy(desc0.data.data(), data.data(), sub_size); - args_.AddObject("weights0", - absl::make_unique(std::move(desc0))); - - Texture2DDescriptor desc1; - desc1.element_type = f32_weights ? DataType::FLOAT32 : DataType::FLOAT16; - desc1.size = int2(dst_depth, src_depth * kernel_x * kernel_y); - desc1.data.resize(sub_size); - memcpy(desc1.data.data(), data.data() + sub_size, sub_size); - args_.AddObject("weights1", - absl::make_unique(std::move(desc1))); - - Texture2DDescriptor desc2; - desc2.element_type = f32_weights ? DataType::FLOAT32 : DataType::FLOAT16; - desc2.size = int2(dst_depth, src_depth * kernel_x * kernel_y); - desc2.data.resize(sub_size); - memcpy(desc2.data.data(), data.data() + sub_size * 2, sub_size); - args_.AddObject("weights2", - absl::make_unique(std::move(desc2))); - - Texture2DDescriptor desc3; - desc3.element_type = f32_weights ? DataType::FLOAT32 : DataType::FLOAT16; - desc3.size = int2(dst_depth, src_depth * kernel_x * kernel_y); - desc3.data.resize(sub_size); - memcpy(desc3.data.data(), data.data() + sub_size * 3, sub_size); - args_.AddObject("weights3", - absl::make_unique(std::move(desc3))); + int texture_width = dst_depth; + int texture_height = src_depth * kernel_x * kernel_y; + int sub_size = float4_size * texture_width * texture_height; + for (int i = 0; i < 4; ++i) { + Texture2DDescriptor desc; + desc.element_type = f32_weights ? DataType::FLOAT32 : DataType::FLOAT16; + desc.size = int2(texture_width, texture_height); + desc.data.resize(sub_size); + memcpy(desc.data.data(), data.data() + sub_size * i, sub_size); + const std::string name = "weights" + std::to_string(i); + args_.AddObject(name, + absl::make_unique(std::move(desc))); + } } } -template -void ConvolutionTransposed::RearrangeWeightsData( - const tflite::gpu::Tensor& weights, absl::Span dst, - bool weights_are_buffer) { +template +void ConvolutionTransposed::UploadWeights( + const tflite::gpu::Tensor& weights, bool weights_are_buffer) { const int dst_depth = - AlignByN(DivideRoundUp(weights.shape.o, 4), block_size_.z); + AlignByN(DivideRoundUp(weights.shape.o, 4), block_size_.w); const int src_depth = DivideRoundUp(weights.shape.i, 4); const int kernel_x = weights.shape.w; const int kernel_y = weights.shape.h; - int texture_width = dst_depth; - int texture_height = src_depth * kernel_x * kernel_y; + const int kernel_z = weights.shape.d; - int counter = 0; - for (int d = 0; d < dst_depth / block_size_.z; ++d) { - for (int y = 0; y < kernel_y; ++y) { - for (int x = 0; x < kernel_x; ++x) { - for (int s = 0; s < src_depth; ++s) { - for (int sub_d = 0; sub_d < block_size_.z; ++sub_d) { - T filters[4]; - for (int i = 0; i < 4; ++i) { - for (int j = 0; j < 4; ++j) { - const int s_ch = s * 4 + j; - const int d_ch = (d * block_size_.z + sub_d) * 4 + i; - if (s_ch < weights.shape.i && d_ch < weights.shape.o) { - const int f_index = - weights.shape.LinearIndex({d_ch, y, x, s_ch}); - filters[j][i] = weights.data[f_index]; - } else { - filters[j][i] = 0.0f; - } - } - } - if (weights_are_buffer) { - dst[counter++] = filters[0]; - dst[counter++] = filters[1]; - dst[counter++] = filters[2]; - dst[counter++] = filters[3]; - } else { - int x_coord = d * block_size_.z + sub_d; - int y_coord = (y * kernel_x + x) * src_depth + s; - int offset = y_coord * dst_depth + x_coord; - dst[offset + texture_width * texture_height * 0] = filters[0]; - dst[offset + texture_width * texture_height * 1] = filters[1]; - dst[offset + texture_width * texture_height * 2] = filters[2]; - dst[offset + texture_width * texture_height * 3] = filters[3]; - } - } - } - } + const int elements_count = + kernel_x * kernel_y * kernel_z * src_depth * dst_depth * 4; + const bool f32_weights = definition_.precision == CalculationsPrecision::F32; + + const int float4_size = f32_weights ? 16 : 8; + std::vector data(float4_size * elements_count); + + if (f32_weights) { + float4* ptr = reinterpret_cast(data.data()); + if (weights_are_buffer) { + RearrangeWeightsToODHWIOGroupI4O4(weights, block_size_.w, + absl::MakeSpan(ptr, elements_count)); + } else { + RearrangeWeightsToI4DHWIOOGroupO4(weights, block_size_.w, + absl::MakeSpan(ptr, elements_count)); + } + } else { + half4* ptr = reinterpret_cast(data.data()); + if (weights_are_buffer) { + RearrangeWeightsToODHWIOGroupI4O4(weights, block_size_.w, + absl::MakeSpan(ptr, elements_count)); + } else { + RearrangeWeightsToI4DHWIOOGroupO4(weights, block_size_.w, + absl::MakeSpan(ptr, elements_count)); + } + } + + if (weights_are_buffer) { + BufferDescriptor desc; + desc.element_type = f32_weights ? DataType::FLOAT32 : DataType::FLOAT16; + desc.element_size = 16; + desc.size = float4_size * elements_count; + desc.data = std::move(data); + args_.AddObject("weights", + absl::make_unique(std::move(desc))); + } else { + int texture_width = dst_depth; + int texture_height = src_depth * kernel_x * kernel_y * kernel_z; + int sub_size = float4_size * texture_width * texture_height; + for (int i = 0; i < 4; ++i) { + Texture2DDescriptor desc; + desc.element_type = f32_weights ? DataType::FLOAT32 : DataType::FLOAT16; + desc.size = int2(texture_width, texture_height); + desc.data.resize(sub_size); + memcpy(desc.data.data(), data.data() + sub_size * i, sub_size); + const std::string name = "weights" + std::to_string(i); + args_.AddObject(name, + absl::make_unique(std::move(desc))); } } } @@ -199,6 +209,10 @@ ConvolutionTransposed CreateConvolutionTransposed( const DeviceInfo& device_info, const OperationDef& definition, const ConvolutionTransposedAttributes& attr); +ConvolutionTransposed CreateConvolutionTransposed3D( + const DeviceInfo& device_info, const OperationDef& definition, + const ConvolutionTransposed3DAttributes& attr); + } // namespace cl } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3d.cc b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3d.cc deleted file mode 100644 index b2a85a89ef0..00000000000 --- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3d.cc +++ /dev/null @@ -1,402 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3d.h" - -#include -#include - -#include "absl/strings/substitute.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h" -#include "tensorflow/lite/delegates/gpu/cl/tensor_type.h" -#include "tensorflow/lite/delegates/gpu/common/status.h" - -namespace tflite { -namespace gpu { -namespace cl { - -ConvolutionTransposed3D::ConvolutionTransposed3D( - const OperationDef& definition, - const ConvolutionTransposed3DAttributes& attr, - const DeviceInfo& device_info) - : GPUOperation(definition), - stride_(attr.stride.w, attr.stride.h, attr.stride.d), - block_size_(2, 2, 1, 2) { - bool weights_are_buffer = device_info.IsMali(); - args_.AddInt("stride_x", stride_.x); - args_.AddInt("stride_y", stride_.y); - args_.AddInt("stride_z", stride_.z); - args_.AddInt("padding_x", attr.padding.prepended.w); - args_.AddInt("padding_y", attr.padding.prepended.h); - args_.AddInt("padding_z", attr.padding.prepended.d); - args_.AddInt("kernel_size_x", attr.weights.shape.w); - args_.AddInt("kernel_size_y", attr.weights.shape.h); - args_.AddInt("kernel_size_z", attr.weights.shape.d); - args_.AddInt("grid_size_s"); - code_ = GenerateConvolutionTransposed3DCode(definition_, weights_are_buffer, - block_size_); - UploadWeights(attr.weights, weights_are_buffer); - if (device_info.IsPowerVR() && block_size_.y != 1) { - bool is_texture3d = definition_.src_tensors[0].storage_type == - TensorStorageType::TEXTURE_3D; - bool is_texture_array = definition_.src_tensors[0].storage_type == - TensorStorageType::TEXTURE_ARRAY; - if (is_texture3d || is_texture_array) { - compiler_options_.push_back(CompilerOptions::CL_OPT_DISABLE); - } - } -} - -ConvolutionTransposed3D::ConvolutionTransposed3D( - ConvolutionTransposed3D&& operation) - : GPUOperation(std::move(operation)), - stride_(operation.stride_), - block_size_(operation.block_size_) {} - -ConvolutionTransposed3D& ConvolutionTransposed3D::operator=( - ConvolutionTransposed3D&& operation) { - if (this != &operation) { - std::swap(stride_, operation.stride_); - std::swap(block_size_, operation.block_size_); - GPUOperation::operator=(std::move(operation)); - } - return *this; -} - -std::string ConvolutionTransposed3D::GenerateConvolutionTransposed3DCode( - const OperationDef& op_def, bool weights_are_buffer, - const int4& block_size) { - auto src_desc = op_def.src_tensors[0]; - src_desc.SetTextureAddressMode(TextureAddressMode::ZERO); - AddSrcTensor("src_tensor", src_desc); - - AddDstTensor("dst_tensor", op_def.dst_tensors[0]); - - const auto src_tensor_type = op_def.src_tensors[0].storage_type; - bool image_buffer = src_tensor_type == TensorStorageType::IMAGE_BUFFER; - bool manual_clamp = - image_buffer || src_tensor_type == TensorStorageType::BUFFER; - - std::string c = GetCommonDefines(op_def.precision); - - for (int s = 0; s < block_size.w; ++s) { - const std::string f0 = - weights_are_buffer ? "weights_cache[" + std::to_string(s) + "].s0123" - : "f" + std::to_string(s * 4 + 0); - const std::string f1 = - weights_are_buffer ? "weights_cache[" + std::to_string(s) + "].s4567" - : "f" + std::to_string(s * 4 + 1); - const std::string f2 = - weights_are_buffer ? "weights_cache[" + std::to_string(s) + "].s89ab" - : "f" + std::to_string(s * 4 + 2); - const std::string f3 = - weights_are_buffer ? "weights_cache[" + std::to_string(s) + "].scdef" - : "f" + std::to_string(s * 4 + 3); - switch (op_def.precision) { - case CalculationsPrecision::F32: - case CalculationsPrecision::F16: - c += "#define CONV" + std::to_string(s) + "(R, S) \\\n"; - c += "R += S.x * " + f0 + "; \\\n"; - c += "R += S.y * " + f1 + "; \\\n"; - c += "R += S.z * " + f2 + "; \\\n"; - c += "R += S.w * " + f3 + "; \n"; - break; - case CalculationsPrecision::F32_F16: - c += "#define CONV" + std::to_string(s) + "(R, S) \\\n"; - c += "R += convert_float4(S.x * " + f0 + " + S.y * " + f1 + - " + S.z * " + f2 + " + S.w * " + f3 + ");\n"; - break; - } - } - - switch (op_def.precision) { - case CalculationsPrecision::F32: - c += "#define FLT16 float16\n"; - break; - case CalculationsPrecision::F32_F16: - case CalculationsPrecision::F16: - c += "#define FLT16 half16\n"; - break; - } - - c += "__kernel void main_function(\n"; - c += "$0) {\n"; - if (op_def.IsBatchSupported()) { - c += " int linear_id = get_global_id(0);\n"; - c += " int dst_x = (linear_id / args.dst_tensor.Batch());\n"; - c += " int B = linear_id % args.dst_tensor.Batch();\n"; - c += " args.dst_tensor.SetBatchRef(B);\n"; - c += " args.src_tensor.SetBatchRef(B);\n"; - } else { - c += " int dst_x = get_global_id(0);\n"; - } - c += " int rem_x = dst_x % args.stride_x;\n"; - c += " int ceil_x = dst_x / args.stride_x;\n"; - c += " dst_x = ceil_x * args.stride_x * " + std::to_string(block_size.x) + - " + rem_x;\n"; - c += " int dst_y = get_global_id(1);\n"; - c += " int rem_y = dst_y % args.stride_y;\n"; - c += " int ceil_y = dst_y / args.stride_y;\n"; - c += " dst_y = ceil_y * args.stride_y * " + std::to_string(block_size.y) + - " + rem_y;\n"; - c += " int linear_id_z = get_global_id(2);\n"; - c += " int S = (linear_id_z % args.grid_size_s) * " + - std::to_string(block_size.w) + ";\n"; - c += " int dst_z = linear_id_z / args.grid_size_s;\n"; - c += " int rem_z = dst_z % args.stride_z;\n"; - c += " int ceil_z = dst_z / args.stride_z;\n"; - c += " dst_z = ceil_z * args.stride_z * " + std::to_string(block_size.z) + - " + rem_z;\n"; - c += " if (dst_x >= args.dst_tensor.Width() || dst_y >= " - "args.dst_tensor.Height() || dst_z >= " - "args.dst_tensor.Depth()) return;\n"; - if (weights_are_buffer) { - c += " int f_base = S * args.src_tensor.Slices() * args.kernel_size_x * " - "args.kernel_size_y * " - "args.kernel_size_z;\n"; - } - for (int i = 0; i < block_size.x * block_size.y * block_size.z * block_size.w; - ++i) { - c += " ACCUM_FLT4 r" + std::to_string(i) + - " = (ACCUM_FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n"; - } - c += " int kernel_first_dst_x = dst_x + args.padding_x;\n"; - c += " int kernel_first_dst_y = dst_y + args.padding_y;\n"; - c += " int kernel_first_dst_z = dst_z + args.padding_z;\n"; - c += " int kernel_last_dst_x = kernel_first_dst_x - args.kernel_size_x;\n"; - c += " int kernel_last_dst_y = kernel_first_dst_y - args.kernel_size_y;\n"; - c += " int kernel_last_dst_z = kernel_first_dst_z - args.kernel_size_z;\n"; - c += " int offset_x = abs(args.padding_x);\n"; - c += " int offset_x_strided = offset_x * args.stride_x;\n"; - c += - " int src_x = (kernel_first_dst_x + offset_x_strided) / args.stride_x - " - "offset_x;\n"; - c += " int offset_y = abs(args.padding_y);\n"; - c += " int offset_y_strided = offset_y * args.stride_y;\n"; - c += - " int src_y = (kernel_first_dst_y + offset_y_strided) / args.stride_y - " - "offset_y;\n"; - c += " int offset_z = abs(args.padding_z);\n"; - c += " int offset_z_strided = offset_z * args.stride_z;\n"; - c += - " int src_z = (kernel_first_dst_z + offset_z_strided) / args.stride_z - " - "offset_z;\n"; - c += " int src_as_dst_z = src_z * args.stride_z;\n"; - c += " for (;src_as_dst_z > kernel_last_dst_z; src_z -= 1, src_as_dst_z -= " - "args.stride_z) {\n"; - for (int z = 0; z < block_size.z; ++z) { - const std::string zindex = std::to_string(z); - c += " int sz" + zindex + " = src_z + " + zindex + ";\n"; - if (src_tensor_type != TensorStorageType::TEXTURE_3D) { - c += " bool in_z" + zindex + " = sz" + zindex + " >= 0 && sz" + - zindex + " < args.src_tensor.Depth();\n"; - } - } - if (block_size.z == 1 && (src_tensor_type != TensorStorageType::TEXTURE_3D)) { - c += " if (!in_z0) continue;\n"; - } - c += " int kernel_z = kernel_first_dst_z - src_as_dst_z;\n"; - c += " int src_as_dst_y = src_y * args.stride_y;\n"; - c += " int src_y_copy = src_y;\n"; - c += " for (;src_as_dst_y > kernel_last_dst_y; src_y_copy -= 1, " - "src_as_dst_y -= " - "args.stride_y) {\n"; - for (int y = 0; y < block_size.y; ++y) { - const std::string yindex = std::to_string(y); - c += " int sy" + yindex + " = src_y_copy + " + yindex + ";\n"; - if (manual_clamp) { - c += " bool in_y" + yindex + " = sy" + yindex + " >= 0 && sy" + - yindex + " < args.src_tensor.Height();\n"; - if (!image_buffer) { - c += " sy" + yindex + " = clamp(sy" + yindex + - ", 0, args.src_tensor.Height() - 1);\n"; - } - } - } - c += " int kernel_y = kernel_first_dst_y - src_as_dst_y;\n"; - c += " int src_as_dst_x = src_x * args.stride_x;\n"; - c += " int src_x_copy = src_x;\n"; - c += " for (;src_as_dst_x > kernel_last_dst_x; src_x_copy -= 1, " - "src_as_dst_x " - "-= args.stride_x) {\n"; - for (int x = 0; x < block_size.x; ++x) { - const std::string xindex = std::to_string(x); - c += " int sx" + xindex + " = src_x_copy + " + xindex + ";\n"; - if (manual_clamp) { - c += " bool in_x" + xindex + " = sx" + xindex + " >= 0 && sx" + - xindex + " < args.src_tensor.Width();\n"; - if (!image_buffer) { - c += " sx" + xindex + " = clamp(sx" + xindex + - ", 0, args.src_tensor.Width() - 1);\n"; - } - } - } - const std::string layer_offset = "args.src_tensor.SliceStride()"; - for (int z = 0; z < block_size.z; ++z) { - const std::string zindex = std::to_string(z); - for (int y = 0; y < block_size.y; ++y) { - const std::string yindex = std::to_string(y); - for (int x = 0; x < block_size.x; ++x) { - const std::string xindex = std::to_string(x); - const std::string id = - std::to_string((z * block_size.y + y) * block_size.x + x); - c += " args.src_tensor.GetAddress(addr_" + id + ", sx" + xindex + - ", sy" + yindex + ", sz" + zindex + ", 0);"; - if (image_buffer) { - c += " addr_" + id + " = select(-1, addr_" + id + ", (in_x" + - xindex + " && in_y" + yindex + "));\n"; - c += absl::Substitute( - " int dz_$0 = select(0, $3, (in_x$1 && " - "in_y$2));\n", - id, x, y, layer_offset); - } - } - } - } - if (src_tensor_type == TensorStorageType::BUFFER) { - c += " int dz = " + layer_offset + ";\n"; - } - if (block_size.x == 1 && block_size.y == 1 && manual_clamp) { - c += " if (!in_x0 || !in_y0) continue;\n"; - } - c += " int kernel_x = kernel_first_dst_x - src_as_dst_x;\n"; - c += " int kernel_index =(kernel_z * args.kernel_size_y + kernel_y) * " - "args.kernel_size_x + kernel_x;\n"; - if (weights_are_buffer) { - c += " int f_offset = f_base + kernel_index * " - "args.src_tensor.Slices() * " + - std::to_string(block_size.w) + ";\n"; - } else { - c += " int x_c = kernel_index * args.src_tensor.Slices();\n"; - } - c += " for (int s = 0; s < args.src_tensor.Slices(); ++s) {\n"; - for (int y = 0; y < block_size.y; ++y) { - const std::string yindex = std::to_string(y); - for (int x = 0; x < block_size.x; ++x) { - const std::string xindex = std::to_string(x); - const std::string id = std::to_string(y * block_size.x + x); - if (image_buffer) { - c += " FLT4 src" + id + " = args.src_tensor.Read(addr_" + id + - "); addr_" + id + " += dz_" + id + ";\n"; - } else if (manual_clamp) { - c += " FLT4 src" + id + " = args.src_tensor.Read(addr_" + id + - ") * (FLT)(in_x" + xindex + " && in_y" + yindex + "); addr_" + id + - " += dz;\n"; - } else { - c += " FLT4 src" + id + " = args.src_tensor.Read(sx" + xindex + - ", sy" + yindex + ", sz0, s);\n"; - } - } - } - if (weights_are_buffer) { - c += " __global FLT16* weights_cache = " - "args.weights.GetPtr(f_offset);\n"; - c += " f_offset += " + std::to_string(block_size.w) + ";\n"; - } else { - for (int z = 0; z < block_size.w; ++z) { - c += absl::Substitute( - R"( FLT4 f$1 = args.weights0.Read(S + $0, x_c); - FLT4 f$2 = args.weights1.Read(S + $0, x_c); - FLT4 f$3 = args.weights2.Read(S + $0, x_c); - FLT4 f$4 = args.weights3.Read(S + $0, x_c); -)", - z, z * 4 + 0, z * 4 + 1, z * 4 + 2, z * 4 + 3); - } - c += " x_c++;\n"; - } - for (int z = 0; z < block_size.w; ++z) { - for (int i = 0; i < block_size.x * block_size.y * block_size.z; ++i) { - c += " CONV" + std::to_string(z) + "(r" + - std::to_string(i + z * block_size.x * block_size.y * block_size.z) + - ", src" + std::to_string(i) + ");\n"; - } - } - c += " }\n"; - c += " }\n"; - c += " }\n"; - c += " }\n"; - for (int s = 0; s < block_size.w; ++s) { - c += " if (S < args.dst_tensor.Slices()) {\n"; - c += " FLT4 bias_val = args.biases.Read(S);\n"; - for (int z = 0; z < block_size.z; ++z) { - for (int y = 0; y < block_size.y; ++y) { - for (int x = 0; x < block_size.x; ++x) { - const std::string id = std::to_string( - ((s * block_size.z + z) * block_size.y + y) * block_size.x + x); - c += " {\n"; - c += " int xc = dst_x + args.stride_x * " + std::to_string(x) + - ";\n"; - c += " int yc = dst_y + args.stride_y * " + std::to_string(y) + - ";\n"; - c += " int zc = dst_z + args.stride_z * " + std::to_string(z) + - ";\n"; - c += " if (xc < args.dst_tensor.Width() && yc < " - "args.dst_tensor.Height() && zc < args.dst_tensor.Depth()) {\n"; - c += " FLT4 res = TO_FLT4(r" + id + ") + bias_val;\n"; - c += " args.dst_tensor.Write(res, xc, yc, zc, S)\n"; - c += " }\n"; - c += " }\n"; - } - } - } - c += " }\n"; - c += " S++;\n"; - } - c += "}\n"; - return c; -} - -absl::Status ConvolutionTransposed3D::BindArguments() { - return args_.SetInt("grid_size_s", - DivideRoundUp(dst_[0]->Slices(), block_size_.w)); -} - -int3 ConvolutionTransposed3D::GetGridSize() const { - const int aligned_w = AlignByN(dst_[0]->Width(), stride_.x * block_size_.x); - const int aligned_h = AlignByN(dst_[0]->Height(), stride_.y * block_size_.y); - const int aligned_d = AlignByN(dst_[0]->Depth(), stride_.z * block_size_.z); - const int grid_x = DivideRoundUp(aligned_w, block_size_.x) * dst_[0]->Batch(); - const int grid_y = DivideRoundUp(aligned_h, block_size_.y); - const int grid_z = DivideRoundUp(dst_[0]->Slices(), block_size_.w) * - DivideRoundUp(aligned_d, block_size_.z); - return int3(grid_x, grid_y, grid_z); -} - -void ConvolutionTransposed3D::GetPossibleKernelWorkGroups( - TuningType tuning_type, const DeviceInfo& device_info, - const KernelInfo& kernel_info, std::vector* work_groups) const { - GetPossibleWorkGroupsConv(tuning_type, device_info, kernel_info, grid_size_, - work_groups); -} - -ConvolutionTransposed3D CreateConvolutionTransposed3D( - const DeviceInfo& device_info, const OperationDef& definition, - const ConvolutionTransposed3DAttributes& attr) { - ConvolutionTransposed3D result(definition, attr, device_info); - - TensorLinearDescriptor desc; - desc.storage_type = - DeduceLinearStorageType(definition.GetPrimaryStorageType()); - desc.element_type = definition.GetDataType(); - desc.UploadLinearData(attr.bias); - result.args_.AddObject( - "biases", absl::make_unique(std::move(desc))); - return result; -} - -} // namespace cl -} // namespace gpu -} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3d.h b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3d.h deleted file mode 100644 index ebd674d612b..00000000000 --- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3d.h +++ /dev/null @@ -1,215 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_CONVOLUTION_TRANSPOSED_3D_H_ -#define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_CONVOLUTION_TRANSPOSED_3D_H_ - -#include - -#include "tensorflow/lite/delegates/gpu/cl/buffer.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" -#include "tensorflow/lite/delegates/gpu/cl/linear_storage.h" -#include "tensorflow/lite/delegates/gpu/cl/tensor.h" -#include "tensorflow/lite/delegates/gpu/cl/texture2d.h" -#include "tensorflow/lite/delegates/gpu/cl/util.h" -#include "tensorflow/lite/delegates/gpu/common/data_type.h" -#include "tensorflow/lite/delegates/gpu/common/operations.h" -#include "tensorflow/lite/delegates/gpu/common/shape.h" -#include "tensorflow/lite/delegates/gpu/common/status.h" -#include "tensorflow/lite/delegates/gpu/common/tensor.h" -#include "tensorflow/lite/delegates/gpu/common/types.h" - -namespace tflite { -namespace gpu { -namespace cl { - -class ConvolutionTransposed3D : public GPUOperation { - public: - ConvolutionTransposed3D() = default; - void GetPossibleKernelWorkGroups( - TuningType tuning_type, const DeviceInfo& device_info, - const KernelInfo& kernel_info, - std::vector* work_groups) const override; - absl::Status BindArguments() override; - int3 GetGridSize() const override; - - // Move only - ConvolutionTransposed3D(ConvolutionTransposed3D&& operation); - ConvolutionTransposed3D& operator=(ConvolutionTransposed3D&& operation); - ConvolutionTransposed3D(const ConvolutionTransposed3D&) = delete; - ConvolutionTransposed3D& operator=(const ConvolutionTransposed3D&) = delete; - - private: - friend ConvolutionTransposed3D CreateConvolutionTransposed3D( - const DeviceInfo& device_info, const OperationDef& definition, - const ConvolutionTransposed3DAttributes& attr); - ConvolutionTransposed3D(const OperationDef& definition, - const ConvolutionTransposed3DAttributes& attr, - const DeviceInfo& device_info); - template - void UploadWeights(const tflite::gpu::Tensor& weights, - bool weights_are_buffer); - - template - void RearrangeWeightsData(const tflite::gpu::Tensor& weights, - absl::Span dst, bool weights_are_buffer); - - std::string GenerateConvolutionTransposed3DCode(const OperationDef& op_def, - bool weights_are_buffer, - const int4& block_size); - - int3 stride_; - int4 block_size_ = int4(1, 1, 1, 1); // WHDS -}; - -template -void ConvolutionTransposed3D::UploadWeights( - const tflite::gpu::Tensor& weights, bool weights_are_buffer) { - const int dst_depth = - AlignByN(DivideRoundUp(weights.shape.o, 4), block_size_.z); - const int src_depth = DivideRoundUp(weights.shape.i, 4); - const int kernel_x = weights.shape.w; - const int kernel_y = weights.shape.h; - const int kernel_z = weights.shape.d; - int texture_width = dst_depth; - int texture_height = src_depth * kernel_x * kernel_y * kernel_z; - - const int elements_count = - kernel_x * kernel_y * kernel_z * src_depth * dst_depth * 4; - const bool f32_weights = definition_.precision == CalculationsPrecision::F32; - - const int float4_size = f32_weights ? 16 : 8; - std::vector data(float4_size * elements_count); - - if (f32_weights) { - float4* ptr = reinterpret_cast(data.data()); - RearrangeWeightsData(weights, absl::MakeSpan(ptr, elements_count), - weights_are_buffer); - } else { - half4* ptr = reinterpret_cast(data.data()); - RearrangeWeightsData(weights, absl::MakeSpan(ptr, elements_count), - weights_are_buffer); - } - - if (weights_are_buffer) { - BufferDescriptor desc; - desc.element_type = f32_weights ? DataType::FLOAT32 : DataType::FLOAT16; - desc.element_size = 16; - desc.size = float4_size * elements_count; - desc.data = std::move(data); - args_.AddObject("weights", - absl::make_unique(std::move(desc))); - } else { - int sub_size = float4_size * elements_count / 4; - Texture2DDescriptor desc0; - desc0.element_type = f32_weights ? DataType::FLOAT32 : DataType::FLOAT16; - desc0.size = int2(texture_width, texture_height); - desc0.data.resize(sub_size); - memcpy(desc0.data.data(), data.data(), sub_size); - args_.AddObject("weights0", - absl::make_unique(std::move(desc0))); - - Texture2DDescriptor desc1; - desc1.element_type = f32_weights ? DataType::FLOAT32 : DataType::FLOAT16; - desc1.size = int2(texture_width, texture_height); - desc1.data.resize(sub_size); - memcpy(desc1.data.data(), data.data() + sub_size, sub_size); - args_.AddObject("weights1", - absl::make_unique(std::move(desc1))); - - Texture2DDescriptor desc2; - desc2.element_type = f32_weights ? DataType::FLOAT32 : DataType::FLOAT16; - desc2.size = int2(texture_width, texture_height); - desc2.data.resize(sub_size); - memcpy(desc2.data.data(), data.data() + sub_size * 2, sub_size); - args_.AddObject("weights2", - absl::make_unique(std::move(desc2))); - - Texture2DDescriptor desc3; - desc3.element_type = f32_weights ? DataType::FLOAT32 : DataType::FLOAT16; - desc3.size = int2(texture_width, texture_height); - desc3.data.resize(sub_size); - memcpy(desc3.data.data(), data.data() + sub_size * 3, sub_size); - args_.AddObject("weights3", - absl::make_unique(std::move(desc3))); - } -} - -template -void ConvolutionTransposed3D::RearrangeWeightsData( - const tflite::gpu::Tensor& weights, absl::Span dst, - bool weights_are_buffer) { - const int dst_depth = - AlignByN(DivideRoundUp(weights.shape.o, 4), block_size_.w); - const int src_depth = DivideRoundUp(weights.shape.i, 4); - const int kernel_x = weights.shape.w; - const int kernel_y = weights.shape.h; - const int kernel_z = weights.shape.d; - int texture_width = dst_depth; - int texture_height = src_depth * kernel_x * kernel_y * kernel_z; - - int counter = 0; - for (int d = 0; d < dst_depth / block_size_.w; ++d) { - for (int z = 0; z < kernel_z; ++z) { - for (int y = 0; y < kernel_y; ++y) { - for (int x = 0; x < kernel_x; ++x) { - for (int s = 0; s < src_depth; ++s) { - for (int sub_d = 0; sub_d < block_size_.w; ++sub_d) { - T filters[4]; - for (int i = 0; i < 4; ++i) { - for (int j = 0; j < 4; ++j) { - const int s_ch = s * 4 + j; - const int d_ch = (d * block_size_.w + sub_d) * 4 + i; - if (s_ch < weights.shape.i && d_ch < weights.shape.o) { - const int f_index = - weights.shape.LinearIndex({d_ch, y, x, z, s_ch}); - filters[j][i] = weights.data[f_index]; - } else { - filters[j][i] = 0.0f; - } - } - } - if (weights_are_buffer) { - dst[counter++] = filters[0]; - dst[counter++] = filters[1]; - dst[counter++] = filters[2]; - dst[counter++] = filters[3]; - } else { - int x_coord = d * block_size_.w + sub_d; - int y_coord = - ((z * kernel_y + y) * kernel_x + x) * src_depth + s; - int offset = y_coord * dst_depth + x_coord; - dst[offset + texture_width * texture_height * 0] = filters[0]; - dst[offset + texture_width * texture_height * 1] = filters[1]; - dst[offset + texture_width * texture_height * 2] = filters[2]; - dst[offset + texture_width * texture_height * 3] = filters[3]; - } - } - } - } - } - } - } -} - -ConvolutionTransposed3D CreateConvolutionTransposed3D( - const DeviceInfo& device_info, const OperationDef& definition, - const ConvolutionTransposed3DAttributes& attr); - -} // namespace cl -} // namespace gpu -} // namespace tflite - -#endif // TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_CONVOLUTION_TRANSPOSED_3D_H_ diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3.cc b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3.cc index af952dd3f78..7880f31013a 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3.cc @@ -29,10 +29,9 @@ namespace gpu { namespace cl { ConvolutionTransposed3x3::ConvolutionTransposed3x3( const OperationDef& definition, const DeviceInfo& device_info, int2 padding) - : GPUOperation(definition), - padding_(padding), - work_group_launch_order_(2, 0, 1) { + : GPUOperation(definition), padding_(padding) { work_group_size_ = int3(8, 4, 1); + work_group_launch_order_ = int3(2, 0, 1); if (device_info.IsPowerVR()) { weights_upload_type_ = WeightsUploadType::LOCAL_MEM_ASYNC; } else if (device_info.IsNvidia() || device_info.IsIntel()) { @@ -54,14 +53,12 @@ ConvolutionTransposed3x3::ConvolutionTransposed3x3( ConvolutionTransposed3x3&& operation) : GPUOperation(std::move(operation)), padding_(operation.padding_), - work_group_launch_order_(operation.work_group_launch_order_), weights_upload_type_(operation.weights_upload_type_) {} ConvolutionTransposed3x3& ConvolutionTransposed3x3::operator=( ConvolutionTransposed3x3&& operation) { if (this != &operation) { std::swap(padding_, operation.padding_); - std::swap(work_group_launch_order_, operation.work_group_launch_order_); std::swap(weights_upload_type_, operation.weights_upload_type_); GPUOperation::operator=(std::move(operation)); } @@ -305,27 +302,33 @@ std::string ConvolutionTransposed3x3::GenerateConvolutionTransposedCode( return c; } -absl::Status ConvolutionTransposed3x3::BindArguments() { - RETURN_IF_ERROR(args_.SetInt("filter_offset", 4 * 9 * src_[0]->Slices())); +absl::Status ConvolutionTransposed3x3::BindArguments(ArgumentsBinder* args) { + RETURN_IF_ERROR(args->SetInt("filter_offset", 4 * 9 * src_[0]->Slices())); const int padding_x = padding_.x >= 1 ? (padding_.x - 1) / 2 : (padding_.x - 2) / 2; const int padding_y = padding_.y >= 1 ? (padding_.y - 1) / 2 : (padding_.y - 2) / 2; - RETURN_IF_ERROR(args_.SetInt("padding_x", padding_x * src_[0]->Batch())); - return args_.SetInt("padding_y", padding_y); + RETURN_IF_ERROR(args->SetInt("padding_x", padding_x * src_[0]->Batch())); + return args->SetInt("padding_y", padding_y); +} + +void ConvolutionTransposed3x3::GetPossibleKernelWorkGroups( + TuningType tuning_type, const DeviceInfo& device_info, + const KernelInfo& kernel_info, std::vector* work_groups) const { + if (weights_upload_type_ == WeightsUploadType::LOCAL_MEM_ASYNC || + weights_upload_type_ == WeightsUploadType::LOCAL_MEM_BY_THREADS) { + work_groups->push_back(work_group_size_); + return; + } + GetPossibleWorkGroupsConv(tuning_type, device_info, kernel_info, grid_size_, + work_groups); } int3 ConvolutionTransposed3x3::GetGridSize() const { const int grid_x = DivideRoundUp(dst_[0]->Width(), 2) * dst_[0]->Batch(); const int grid_y = DivideRoundUp(dst_[0]->Height(), 2); const int grid_z = dst_[0]->Slices(); - int3 wg; - wg.x = DivideRoundUp(grid_x, work_group_size_.x); - wg.y = DivideRoundUp(grid_y, work_group_size_.y); - wg.z = DivideRoundUp(grid_z, work_group_size_.z); - return int3(wg[work_group_launch_order_[0]] * work_group_size_.x, - wg[work_group_launch_order_[1]] * work_group_size_.y, - wg[work_group_launch_order_[2]] * work_group_size_.z); + return int3(grid_x, grid_y, grid_z); } bool IsConvolutionTransposed3x3Supported( diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3.h b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3.h index ad3e459da3e..074fc23b0e7 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3.h @@ -40,10 +40,8 @@ class ConvolutionTransposed3x3 : public GPUOperation { void GetPossibleKernelWorkGroups( TuningType tuning_type, const DeviceInfo& device_info, const KernelInfo& kernel_info, - std::vector* work_groups) const override { - work_groups->push_back(work_group_size_); - } - absl::Status BindArguments() override; + std::vector* work_groups) const override; + absl::Status BindArguments(ArgumentsBinder* args) override; int3 GetGridSize() const override; // Move only @@ -78,7 +76,6 @@ class ConvolutionTransposed3x3 : public GPUOperation { int2 padding, int3 work_group_launch_order); int2 padding_; - int3 work_group_launch_order_; WeightsUploadType weights_upload_type_; }; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_4x4.cc b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_4x4.cc index d606a822d7e..0f389361724 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_4x4.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_4x4.cc @@ -296,8 +296,8 @@ std::string ConvolutionTransposed4x4::GenerateConvolutionTransposedCode( return c; } -absl::Status ConvolutionTransposed4x4::BindArguments() { - return args_.SetInt("filter_offset", 4 * 16 * src_[0]->Slices()); +absl::Status ConvolutionTransposed4x4::BindArguments(ArgumentsBinder* args) { + return args->SetInt("filter_offset", 4 * 16 * src_[0]->Slices()); } int3 ConvolutionTransposed4x4::GetGridSize() const { diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_4x4.h b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_4x4.h index 2577eb47513..17d63233864 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_4x4.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_4x4.h @@ -43,7 +43,7 @@ class ConvolutionTransposed4x4 : public GPUOperation { std::vector* work_groups) const override { work_groups->push_back(work_group_size_); } - absl::Status BindArguments() override; + absl::Status BindArguments(ArgumentsBinder* args) override; int3 GetGridSize() const override; // Move only diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv.cc b/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv.cc index 91e26b27cdf..05d5d086bc7 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv.cc @@ -66,100 +66,25 @@ std::string GetSrcValue(int channel_multiplier, const std::string coords) { return c; } -} // namespace -DepthwiseConvolution::DepthwiseConvolution( - const OperationDef& definition, - const DepthwiseConvolution2DAttributes& attr, bool weights_are_buffer) - : GPUOperation(definition), - weights_are_buffer_(weights_are_buffer), - kernel_size_(attr.weights.shape.w, attr.weights.shape.h, 0, 0), - stride_(attr.strides.w, attr.strides.h, 0, 0), - padding_(-attr.padding.prepended.w, -attr.padding.prepended.h, 0, 0), - dilation_(attr.dilations.w, attr.dilations.h, 0, 0), - channel_multiplier_(attr.weights.shape.o) { - work_group_size_ = int3(8, 8, 1); - const bool stride_correction = - definition_.IsBatchSupported() && stride_.x != 1; - code_ = GenerateDepthwiseConvolutionCode( - definition_, stride_correction, channel_multiplier_, weights_are_buffer_); -} - -DepthwiseConvolution::DepthwiseConvolution( - const OperationDef& definition, - const DepthwiseConvolution3DAttributes& attr, bool weights_are_buffer) - : GPUOperation(definition), - weights_are_buffer_(weights_are_buffer), - kernel_size_(attr.weights.shape.w, attr.weights.shape.h, - attr.weights.shape.d, 0), - stride_(attr.strides.w, attr.strides.h, attr.strides.d, 0), - padding_(-attr.padding.prepended.w, -attr.padding.prepended.h, - -attr.padding.prepended.d, 0), - dilation_(attr.dilations.w, attr.dilations.h, attr.dilations.d, 0), - channel_multiplier_(attr.weights.shape.o) { - work_group_size_ = int3(8, 8, 1); - const bool stride_correction = - definition_.IsBatchSupported() && stride_.x != 1; - code_ = GenerateDepthwiseConvolutionCode( - definition_, stride_correction, channel_multiplier_, weights_are_buffer_); -} - -DepthwiseConvolution::DepthwiseConvolution(DepthwiseConvolution&& operation) - : GPUOperation(std::move(operation)), - weights_are_buffer_(operation.weights_are_buffer_), - kernel_size_(operation.kernel_size_), - stride_(operation.stride_), - padding_(operation.padding_), - dilation_(operation.dilation_), - channel_multiplier_(operation.channel_multiplier_) {} - -DepthwiseConvolution& DepthwiseConvolution::operator=( - DepthwiseConvolution&& operation) { - if (this != &operation) { - std::swap(weights_are_buffer_, operation.weights_are_buffer_); - std::swap(kernel_size_, operation.kernel_size_); - std::swap(stride_, operation.stride_); - std::swap(padding_, operation.padding_); - std::swap(dilation_, operation.dilation_); - std::swap(channel_multiplier_, operation.channel_multiplier_); - GPUOperation::operator=(std::move(operation)); - } - return *this; -} - -std::string DepthwiseConvolution::GenerateDepthwiseConvolutionCode( +std::string GenerateDepthwiseConvolutionCode( const OperationDef& op_def, bool stride_correction, int channel_multiplier, - bool weights_are_buffer) { + bool weights_are_buffer, bool dynamic_weights, GPUOperation* op) { auto src_desc = op_def.src_tensors[0]; src_desc.SetTextureAddressMode(TextureAddressMode::ZERO); if (op_def.IsBatchSupported()) { src_desc.SetStateVar("BatchedWidth", "true"); } - AddSrcTensor("src_tensor", src_desc); + op->AddSrcTensor("src_tensor", src_desc); + if (dynamic_weights) { + op->AddSrcTensor("weights", op_def.src_tensors[1]); + } auto dst_desc = op_def.dst_tensors[0]; if (op_def.IsBatchSupported()) { dst_desc.SetStateVar("BatchedWidth", "true"); } - AddDstTensor("dst_tensor", dst_desc); - - args_.AddInt("kernel_size_x"); - args_.AddInt("stride_x"); - args_.AddInt("padding_x"); - args_.AddInt("dilation_x"); - args_.AddInt("kernel_size_y"); - args_.AddInt("stride_y"); - args_.AddInt("padding_y"); - args_.AddInt("dilation_y"); - if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) { - args_.AddInt("kernel_size_z"); - args_.AddInt("stride_z"); - args_.AddInt("padding_z"); - args_.AddInt("dilation_z"); - } - if (!IsSpecializedCase(channel_multiplier)) { - args_.AddInt("ch_multiplier"); - } + op->AddDstTensor("dst_tensor", dst_desc); const auto src_tensor_type = op_def.src_tensors[0].storage_type; @@ -171,14 +96,14 @@ std::string DepthwiseConvolution::GenerateDepthwiseConvolutionCode( c += "__kernel void main_function(\n"; c += "$0) {\n"; c += " int X = get_global_id(0);\n"; - c += " int Y = get_global_id(1);\n"; if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) { - c += " int linear_id_2 = get_global_id(2);\n"; - c += " int S = linear_id_2 / args.dst_tensor.Depth();\n"; - c += " int Z = linear_id_2 % args.dst_tensor.Depth();\n"; + c += " int linear_id_1 = get_global_id(1);\n"; + c += " int Y = linear_id_1 / args.dst_tensor.Depth();\n"; + c += " int Z = linear_id_1 % args.dst_tensor.Depth();\n"; } else { - c += " int S = get_global_id(2);\n"; + c += " int Y = get_global_id(1);\n"; } + c += " int S = get_global_id(2);\n"; c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || " "S >= args.dst_tensor.Slices()) { \n"; c += " return; \n"; @@ -186,23 +111,36 @@ std::string DepthwiseConvolution::GenerateDepthwiseConvolutionCode( c += " ACCUM_FLT4 r = (ACCUM_FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n"; if (stride_correction) { c += " int x_offseted = " + - GetXStrideCorrected("X", "args.src_tensor.Batch()", "args.stride_x", - "args.padding_x") + + GetXStrideCorrectedV2("X", "args.src_tensor.Batch()", "args.stride_x", + "args.padding_x") + ";\n"; } else { - c += " int x_offseted = X * args.stride_x + args.padding_x;\n"; + if (op_def.IsBatchSupported()) { + c += " int x_offseted = X * args.stride_x + args.padding_x * " + "args.src_tensor.Batch();\n"; + } else { + c += " int x_offseted = X * args.stride_x + args.padding_x;\n"; + } } c += " int y_offseted = Y * args.stride_y + args.padding_y;\n"; - std::string weights_offset = "args.kernel_size_x * args.kernel_size_y"; - if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) { - c += " int z_offseted = Z * args.stride_z + args.padding_z;\n"; - weights_offset += " * args.kernel_size_z"; - } - if (weights_are_buffer) { - c += " int fx_c = S * " + weights_offset + ";\n"; - } else { - c += " int fx_c = 0;\n"; + if (!dynamic_weights) { + std::string weights_offset = "args.kernel_size_x * args.kernel_size_y"; + if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) { + c += " int z_offseted = Z * args.stride_z + args.padding_z;\n"; + weights_offset += " * args.kernel_size_z"; + } + if (weights_are_buffer) { + c += " int fx_c = S * " + weights_offset + ";\n"; + } else { + c += " int fx_c = 0;\n"; + } } + std::string kernel_size_x = + dynamic_weights ? "args.weights.Width()" : "args.kernel_size_x"; + std::string kernel_size_y = + dynamic_weights ? "args.weights.Height()" : "args.kernel_size_y"; + std::string kernel_size_z = + dynamic_weights ? "args.weights.Depth()" : "args.kernel_size_z"; std::string flat_coords = "x_c, y_c"; if (manual_clamp) { @@ -210,26 +148,35 @@ std::string DepthwiseConvolution::GenerateDepthwiseConvolutionCode( if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) { check += " && !outside_z"; flat_coords += ", z_c"; - c += " for (int kz = 0; kz < args.kernel_size_z; ++kz) {\n"; + c += " for (int kz = 0; kz < " + kernel_size_z + "; ++kz) {\n"; c += " int z_c = z_offseted + kz * args.dilation_z;\n"; c += " bool outside_z = z_c < 0 || z_c >= args.src_tensor.Depth();\n"; } - c += " for (int ky = 0; ky < args.kernel_size_y; ++ky) {\n"; + c += " for (int ky = 0; ky < " + kernel_size_y + "; ++ky) {\n"; c += " int y_c = y_offseted + ky * args.dilation_y;\n"; c += " bool outside_y = y_c < 0 || y_c >= args.src_tensor.Height();\n"; - c += " for (int kx = 0; kx < args.kernel_size_x; ++kx) {\n"; - c += " int x_c = x_offseted + kx * args.dilation_x;\n"; + c += " for (int kx = 0; kx < " + kernel_size_x + "; ++kx) {\n"; + const std::string dilation_x = + op_def.IsBatchSupported() ? "args.dilation_x * args.src_tensor.Batch()" + : "args.dilation_x"; + c += " int x_c = x_offseted + kx * " + dilation_x + ";\n"; c += " bool outside_x = x_c < 0 || x_c >= args.src_tensor.Width();\n"; c += " if (" + check + ") {\n"; - if (weights_are_buffer) { - c += " FLT4 f = args.weights.Read(fx_c);\n"; + if (dynamic_weights) { + c += " FLT4 f = args.weights.Read(kx, ky, S);\n"; } else { - c += " FLT4 f = args.weights.Read(fx_c, S);\n"; + if (weights_are_buffer) { + c += " FLT4 f = args.weights.Read(fx_c);\n"; + } else { + c += " FLT4 f = args.weights.Read(fx_c, S);\n"; + } } c += GetSrcValue(channel_multiplier, flat_coords); c += " r += TO_ACCUM_TYPE(src_final * f);\n"; c += " };\n"; - c += " fx_c++;\n"; + if (!dynamic_weights) { + c += " fx_c++;\n"; + } c += " }\n"; c += " }\n"; if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) { @@ -238,7 +185,7 @@ std::string DepthwiseConvolution::GenerateDepthwiseConvolutionCode( } else { // Texture types with ZERO clamping if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) { flat_coords += ", z_c"; - c += " for (int kz = 0; kz < args.kernel_size_z; ++kz) {\n"; + c += " for (int kz = 0; kz < " + kernel_size_z + "; ++kz) {\n"; c += " int z_c = z_offseted + kz * args.dilation_z;\n"; if (src_tensor_type != TensorStorageType::TEXTURE_3D) { // Only TEXTURE_3D supports clamping @@ -249,17 +196,24 @@ std::string DepthwiseConvolution::GenerateDepthwiseConvolutionCode( c += " }\n"; } } - c += " for (int ky = 0; ky < args.kernel_size_y; ++ky) {\n"; + c += " for (int ky = 0; ky < " + kernel_size_y + "; ++ky) {\n"; c += " int y_c = y_offseted + ky * args.dilation_y;\n"; - c += " for (int kx = 0; kx < args.kernel_size_x; ++kx) {\n"; - c += " int x_c = x_offseted + kx * args.dilation_x;\n"; + c += " for (int kx = 0; kx < " + kernel_size_x + "; ++kx) {\n"; + const std::string dilation_x = + op_def.IsBatchSupported() ? "args.dilation_x * args.src_tensor.Batch()" + : "args.dilation_x"; + c += " int x_c = x_offseted + kx * " + dilation_x + ";\n"; c += GetSrcValue(channel_multiplier, flat_coords); - if (weights_are_buffer) { - c += " FLT4 f = args.weights.Read(fx_c);\n"; + if (dynamic_weights) { + c += " FLT4 f = args.weights.Read(kx, ky, S);\n"; } else { - c += " FLT4 f = args.weights.Read(fx_c, S);\n"; + if (weights_are_buffer) { + c += " FLT4 f = args.weights.Read(fx_c);\n"; + } else { + c += " FLT4 f = args.weights.Read(fx_c, S);\n"; + } + c += " fx_c++;\n"; } - c += " fx_c++;\n"; c += " r += TO_ACCUM_TYPE(src_final * f);\n"; c += " }\n"; c += " }\n"; @@ -277,67 +231,106 @@ std::string DepthwiseConvolution::GenerateDepthwiseConvolutionCode( return c; } +} // namespace -absl::Status DepthwiseConvolution::BindArguments() { - RETURN_IF_ERROR(args_.SetInt("kernel_size_x", kernel_size_.x)); - RETURN_IF_ERROR(args_.SetInt("stride_x", stride_.x)); - RETURN_IF_ERROR(args_.SetInt("padding_x", padding_.x * src_[0]->Batch())); - RETURN_IF_ERROR(args_.SetInt("dilation_x", dilation_.x * src_[0]->Batch())); - RETURN_IF_ERROR(args_.SetInt("kernel_size_y", kernel_size_.y)); - RETURN_IF_ERROR(args_.SetInt("stride_y", stride_.y)); - RETURN_IF_ERROR(args_.SetInt("padding_y", padding_.y)); - RETURN_IF_ERROR(args_.SetInt("dilation_y", dilation_.y)); - if (definition_.dst_tensors[0].HasAxis(Axis::DEPTH)) { - RETURN_IF_ERROR(args_.SetInt("kernel_size_z", kernel_size_.z)); - RETURN_IF_ERROR(args_.SetInt("stride_z", stride_.z)); - RETURN_IF_ERROR(args_.SetInt("padding_z", padding_.z)); - RETURN_IF_ERROR(args_.SetInt("dilation_z", dilation_.z)); - } - if (!IsSpecializedCase(channel_multiplier_)) { - RETURN_IF_ERROR(args_.SetInt("ch_multiplier", channel_multiplier_)); - } - return absl::OkStatus(); -} - -int3 DepthwiseConvolution::GetGridSize() const { - const int grid_x = dst_[0]->Width() * dst_[0]->Batch(); - const int grid_y = dst_[0]->Height(); - const int grid_z = dst_[0]->Slices() * dst_[0]->Depth(); - return int3(grid_x, grid_y, grid_z); -} - -DepthwiseConvolution CreateDepthwiseConvolution( +GPUOperation CreateDepthwiseConvolution2D( const DeviceInfo& device_info, const OperationDef& definition, const DepthwiseConvolution2DAttributes& attr) { bool weights_are_buffer = device_info.IsMali(); - DepthwiseConvolution result(definition, attr, weights_are_buffer); - result.UploadWeights(attr.weights); + GPUOperation op(definition); + op.args_.AddInt("kernel_size_x", attr.weights.shape.w); + op.args_.AddInt("stride_x", attr.strides.w); + op.args_.AddInt("padding_x", -attr.padding.prepended.w); + op.args_.AddInt("dilation_x", attr.dilations.w); + op.args_.AddInt("kernel_size_y", attr.weights.shape.h); + op.args_.AddInt("stride_y", attr.strides.h); + op.args_.AddInt("padding_y", -attr.padding.prepended.h); + op.args_.AddInt("dilation_y", attr.dilations.h); + if (!IsSpecializedCase(attr.weights.shape.o)) { + op.args_.AddInt("ch_multiplier", attr.weights.shape.o); + } + const bool stride_correction = + definition.IsBatchSupported() && attr.strides.w != 1; + op.code_ = GenerateDepthwiseConvolutionCode(definition, stride_correction, + attr.weights.shape.o, + weights_are_buffer, false, &op); + UploadWeightsForDWConv2D(attr.weights, weights_are_buffer, + definition.precision, &op); + op.tensor_to_grid_ = TensorToGrid::kWBToX_HDToY_SToZ; TensorLinearDescriptor desc; desc.storage_type = weights_are_buffer ? LinearStorageType::BUFFER : LinearStorageType::TEXTURE_2D; desc.element_type = definition.GetDataType(); desc.UploadLinearData(attr.bias); - result.args_.AddObject( + op.args_.AddObject( "biases", absl::make_unique(std::move(desc))); - return result; + return op; } -DepthwiseConvolution CreateDepthwiseConvolution( +GPUOperation CreateDepthwiseConvolution2DDynamicWeights( + const DeviceInfo& device_info, const OperationDef& definition, + const DepthwiseConvolution2DAttributes& attr) { + GPUOperation op(definition); + op.args_.AddInt("stride_x", attr.strides.w); + op.args_.AddInt("padding_x", -attr.padding.prepended.w); + op.args_.AddInt("dilation_x", attr.dilations.w); + op.args_.AddInt("stride_y", attr.strides.h); + op.args_.AddInt("padding_y", -attr.padding.prepended.h); + op.args_.AddInt("dilation_y", attr.dilations.h); + const bool stride_correction = + definition.IsBatchSupported() && attr.strides.w != 1; + op.code_ = GenerateDepthwiseConvolutionCode(definition, stride_correction, 1, + false, true, &op); + op.tensor_to_grid_ = TensorToGrid::kWBToX_HDToY_SToZ; + + TensorLinearDescriptor desc; + desc.storage_type = device_info.IsMali() ? LinearStorageType::BUFFER + : LinearStorageType::TEXTURE_2D; + desc.element_type = definition.GetDataType(); + desc.UploadLinearData(attr.bias); + op.args_.AddObject( + "biases", absl::make_unique(std::move(desc))); + return op; +} + +GPUOperation CreateDepthwiseConvolution3D( const DeviceInfo& device_info, const OperationDef& definition, const DepthwiseConvolution3DAttributes& attr) { bool weights_are_buffer = device_info.IsMali(); - DepthwiseConvolution result(definition, attr, weights_are_buffer); - result.UploadWeights(attr.weights); + GPUOperation op(definition); + op.args_.AddInt("kernel_size_x", attr.weights.shape.w); + op.args_.AddInt("stride_x", attr.strides.w); + op.args_.AddInt("padding_x", -attr.padding.prepended.w); + op.args_.AddInt("dilation_x", attr.dilations.w); + op.args_.AddInt("kernel_size_y", attr.weights.shape.h); + op.args_.AddInt("stride_y", attr.strides.h); + op.args_.AddInt("padding_y", -attr.padding.prepended.h); + op.args_.AddInt("dilation_y", attr.dilations.h); + op.args_.AddInt("kernel_size_z", attr.weights.shape.d); + op.args_.AddInt("stride_z", attr.strides.d); + op.args_.AddInt("padding_z", -attr.padding.prepended.d); + op.args_.AddInt("dilation_z", attr.dilations.d); + if (!IsSpecializedCase(attr.weights.shape.o)) { + op.args_.AddInt("ch_multiplier", attr.weights.shape.o); + } + const bool stride_correction = + definition.IsBatchSupported() && attr.strides.w != 1; + op.code_ = GenerateDepthwiseConvolutionCode(definition, stride_correction, + attr.weights.shape.o, + weights_are_buffer, false, &op); + UploadWeightsForDWConv3D(attr.weights, weights_are_buffer, + definition.precision, &op); + op.tensor_to_grid_ = TensorToGrid::kWBToX_HDToY_SToZ; TensorLinearDescriptor desc; desc.storage_type = weights_are_buffer ? LinearStorageType::BUFFER : LinearStorageType::TEXTURE_2D; desc.element_type = definition.GetDataType(); desc.UploadLinearData(attr.bias); - result.args_.AddObject( + op.args_.AddObject( "biases", absl::make_unique(std::move(desc))); - return result; + return op; } } // namespace cl diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv.h b/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv.h index afa6375eb83..3bb034849bc 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv.h @@ -35,102 +35,9 @@ namespace tflite { namespace gpu { namespace cl { -class DepthwiseConvolution : public GPUOperation { - public: - DepthwiseConvolution() = default; - absl::Status BindArguments() override; - int3 GetGridSize() const override; - - // Move only - DepthwiseConvolution(DepthwiseConvolution&& operation); - DepthwiseConvolution& operator=(DepthwiseConvolution&& operation); - DepthwiseConvolution(const DepthwiseConvolution&) = delete; - DepthwiseConvolution& operator=(const DepthwiseConvolution&) = delete; - - private: - friend DepthwiseConvolution CreateDepthwiseConvolution( - const DeviceInfo& device_info, const OperationDef& definition, - const DepthwiseConvolution2DAttributes& attr); - friend DepthwiseConvolution CreateDepthwiseConvolution( - const DeviceInfo& device_info, const OperationDef& definition, - const DepthwiseConvolution3DAttributes& attr); - DepthwiseConvolution(const OperationDef& definition, - const DepthwiseConvolution2DAttributes& attr, - bool weights_are_buffer); - DepthwiseConvolution(const OperationDef& definition, - const DepthwiseConvolution3DAttributes& attr, - bool weights_are_buffer); - - template - void UploadWeights(const tflite::gpu::Tensor& weights); - - template - void RearrangeWeightsData(const tflite::gpu::Tensor& weights, - absl::Span dst); - - template - void UploadWeights(const tflite::gpu::Tensor& weights); - - template - void RearrangeWeightsData(const tflite::gpu::Tensor& weights, - absl::Span dst); - - std::string GenerateDepthwiseConvolutionCode(const OperationDef& op_def, - bool stride_correction, - int channel_multiplier, - bool weights_are_buffer); - - bool weights_are_buffer_; - - int4 kernel_size_; - int4 stride_; - int4 padding_; - int4 dilation_; - int channel_multiplier_; -}; - -template -void DepthwiseConvolution::UploadWeights( - const tflite::gpu::Tensor& weights) { - const int dst_channels = weights.shape.i * weights.shape.o; - const int dst_slices = DivideRoundUp(dst_channels, 4); - const int kernel_x = weights.shape.w; - const int kernel_y = weights.shape.h; - - const int elements_count = kernel_x * kernel_y * dst_slices; - - const bool fp32_weights = definition_.precision == CalculationsPrecision::F32; - const int float4_size = fp32_weights ? 16 : 8; - - std::vector data(float4_size * elements_count); - - if (fp32_weights) { - float4* ptr = reinterpret_cast(data.data()); - RearrangeWeightsData(weights, absl::MakeSpan(ptr, elements_count)); - } else { - half4* ptr = reinterpret_cast(data.data()); - RearrangeWeightsData(weights, absl::MakeSpan(ptr, elements_count)); - } - - if (weights_are_buffer_) { - BufferDescriptor desc; - desc.element_type = fp32_weights ? DataType::FLOAT32 : DataType::FLOAT16; - desc.element_size = 4; - desc.size = float4_size * elements_count; - desc.data = std::move(data); - args_.AddObject("weights", absl::make_unique(desc)); - } else { - Texture2DDescriptor desc; - desc.element_type = fp32_weights ? DataType::FLOAT32 : DataType::FLOAT16; - desc.size = int2(kernel_x * kernel_y, dst_slices); - desc.data = std::move(data); - args_.AddObject("weights", absl::make_unique(desc)); - } -} - template -void DepthwiseConvolution::RearrangeWeightsData( - const tflite::gpu::Tensor& weights, absl::Span dst) { +void RearrangeWeightsForDWConv2D(const tflite::gpu::Tensor& weights, + absl::Span dst) { const int dst_channels = weights.shape.i * weights.shape.o; const int dst_depth = DivideRoundUp(dst_channels, 4); const int kernel_x = weights.shape.w; @@ -158,50 +65,50 @@ void DepthwiseConvolution::RearrangeWeightsData( } template -void DepthwiseConvolution::UploadWeights( - const tflite::gpu::Tensor& weights) { +void UploadWeightsForDWConv2D(const tflite::gpu::Tensor& weights, + bool weights_are_buffer, + CalculationsPrecision precision, + GPUOperation* op) { const int dst_channels = weights.shape.i * weights.shape.o; const int dst_slices = DivideRoundUp(dst_channels, 4); const int kernel_x = weights.shape.w; const int kernel_y = weights.shape.h; - const int kernel_z = weights.shape.d; - const int elements_count = kernel_x * kernel_y * kernel_z * dst_slices; + const int elements_count = kernel_x * kernel_y * dst_slices; - const bool fp32_weights = definition_.precision == CalculationsPrecision::F32; + const bool fp32_weights = precision == CalculationsPrecision::F32; const int float4_size = fp32_weights ? 16 : 8; std::vector data(float4_size * elements_count); if (fp32_weights) { float4* ptr = reinterpret_cast(data.data()); - RearrangeWeightsData(weights, absl::MakeSpan(ptr, elements_count)); + RearrangeWeightsForDWConv2D(weights, absl::MakeSpan(ptr, elements_count)); } else { half4* ptr = reinterpret_cast(data.data()); - RearrangeWeightsData(weights, absl::MakeSpan(ptr, elements_count)); + RearrangeWeightsForDWConv2D(weights, absl::MakeSpan(ptr, elements_count)); } - if (weights_are_buffer_) { + if (weights_are_buffer) { BufferDescriptor desc; desc.element_type = fp32_weights ? DataType::FLOAT32 : DataType::FLOAT16; desc.element_size = 4; desc.size = float4_size * elements_count; desc.data = std::move(data); - args_.AddObject("weights", - absl::make_unique(std::move(desc))); + op->args_.AddObject("weights", absl::make_unique(desc)); } else { Texture2DDescriptor desc; desc.element_type = fp32_weights ? DataType::FLOAT32 : DataType::FLOAT16; - desc.size = int2(kernel_x * kernel_y * kernel_z, dst_slices); + desc.size = int2(kernel_x * kernel_y, dst_slices); desc.data = std::move(data); - args_.AddObject("weights", - absl::make_unique(std::move(desc))); + op->args_.AddObject("weights", + absl::make_unique(desc)); } } template -void DepthwiseConvolution::RearrangeWeightsData( - const tflite::gpu::Tensor& weights, absl::Span dst) { +void RearrangeWeightsForDWConv3D(const tflite::gpu::Tensor& weights, + absl::Span dst) { const int dst_channels = weights.shape.i * weights.shape.o; const int dst_slices = DivideRoundUp(dst_channels, 4); const int kernel_x = weights.shape.w; @@ -231,11 +138,59 @@ void DepthwiseConvolution::RearrangeWeightsData( } } -DepthwiseConvolution CreateDepthwiseConvolution( +template +void UploadWeightsForDWConv3D(const tflite::gpu::Tensor& weights, + bool weights_are_buffer, + CalculationsPrecision precision, + GPUOperation* op) { + const int dst_channels = weights.shape.i * weights.shape.o; + const int dst_slices = DivideRoundUp(dst_channels, 4); + const int kernel_x = weights.shape.w; + const int kernel_y = weights.shape.h; + const int kernel_z = weights.shape.d; + + const int elements_count = kernel_x * kernel_y * kernel_z * dst_slices; + + const bool fp32_weights = precision == CalculationsPrecision::F32; + const int float4_size = fp32_weights ? 16 : 8; + + std::vector data(float4_size * elements_count); + + if (fp32_weights) { + float4* ptr = reinterpret_cast(data.data()); + RearrangeWeightsForDWConv3D(weights, absl::MakeSpan(ptr, elements_count)); + } else { + half4* ptr = reinterpret_cast(data.data()); + RearrangeWeightsForDWConv3D(weights, absl::MakeSpan(ptr, elements_count)); + } + + if (weights_are_buffer) { + BufferDescriptor desc; + desc.element_type = fp32_weights ? DataType::FLOAT32 : DataType::FLOAT16; + desc.element_size = 4; + desc.size = float4_size * elements_count; + desc.data = std::move(data); + op->args_.AddObject("weights", + absl::make_unique(std::move(desc))); + } else { + Texture2DDescriptor desc; + desc.element_type = fp32_weights ? DataType::FLOAT32 : DataType::FLOAT16; + desc.size = int2(kernel_x * kernel_y * kernel_z, dst_slices); + desc.data = std::move(data); + op->args_.AddObject( + "weights", absl::make_unique(std::move(desc))); + } +} + +GPUOperation CreateDepthwiseConvolution2D( const DeviceInfo& device_info, const OperationDef& definition, const DepthwiseConvolution2DAttributes& attr); -DepthwiseConvolution CreateDepthwiseConvolution( +GPUOperation CreateDepthwiseConvolution2DDynamicWeights( + const DeviceInfo& device_info, const OperationDef& definition, + const DepthwiseConvolution2DAttributes& attr); + +GPUOperation CreateDepthwiseConvolution3D( const DeviceInfo& device_info, const OperationDef& definition, const DepthwiseConvolution3DAttributes& attr); diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_test.cc b/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_test.cc index 5c3e596a2e5..eb43c0c30e3 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_test.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_test.cc @@ -55,7 +55,7 @@ TEST_F(OpenCLOperationTest, DepthwiseConvSimpleWeights) { op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); op_def.dst_tensors.push_back({data_type, storage, Layout::HWC}); TensorFloat32 dst_tensor; - DepthwiseConvolution operation = CreateDepthwiseConvolution( + GPUOperation operation = CreateDepthwiseConvolution2D( creation_context_.GetDeviceInfo(), op_def, attr); ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation, BHWC(1, 2, 2, 2), &dst_tensor)); @@ -90,7 +90,7 @@ TEST_F(OpenCLOperationTest, DepthwiseConvNoMultiplier) { op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); op_def.dst_tensors.push_back({data_type, storage, Layout::HWC}); TensorFloat32 dst_tensor; - DepthwiseConvolution operation = CreateDepthwiseConvolution( + GPUOperation operation = CreateDepthwiseConvolution2D( creation_context_.GetDeviceInfo(), op_def, attr); ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation, BHWC(1, 2, 2, 2), &dst_tensor)); @@ -126,7 +126,7 @@ TEST_F(OpenCLOperationTest, DepthwiseConvMultiplier2) { op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); op_def.dst_tensors.push_back({data_type, storage, Layout::HWC}); TensorFloat32 dst_tensor; - DepthwiseConvolution operation = CreateDepthwiseConvolution( + GPUOperation operation = CreateDepthwiseConvolution2D( creation_context_.GetDeviceInfo(), op_def, attr); ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation, BHWC(1, 2, 2, 4), &dst_tensor)); diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/elementwise.cc b/tensorflow/lite/delegates/gpu/cl/kernels/elementwise.cc index afec0ab8a56..f50045131c2 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/elementwise.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/elementwise.cc @@ -42,10 +42,10 @@ std::string GetOneInputCode(const OperationType& op_type, result = "\n"; break; case OperationType::ELU: - result = "$0.x = $0.x < (FLT)(0.0f) ? exp($0.x) - (FLT)(1.0f) : $0.x;\n"; - result += "$0.y = $0.y < (FLT)(0.0f) ? exp($0.y) - (FLT)(1.0f) : $0.y;\n"; - result += "$0.z = $0.z < (FLT)(0.0f) ? exp($0.z) - (FLT)(1.0f) : $0.z;\n"; - result += "$0.w = $0.w < (FLT)(0.0f) ? exp($0.w) - (FLT)(1.0f) : $0.w;\n"; + result = "$0.x = $0.x < (FLT)(0.0f) ? expm1($0.x) : $0.x;\n"; + result += "$0.y = $0.y < (FLT)(0.0f) ? expm1($0.y) : $0.y;\n"; + result += "$0.z = $0.z < (FLT)(0.0f) ? expm1($0.z) : $0.z;\n"; + result += "$0.w = $0.w < (FLT)(0.0f) ? expm1($0.w) : $0.w;\n"; break; case OperationType::EXP: result = "$0 = exp($0);\n"; @@ -58,23 +58,17 @@ std::string GetOneInputCode(const OperationType& op_type, case OperationType::LOG: result = "$0 = log($0);\n"; break; + case OperationType::NEG: + result = "$0 = -($0);\n"; + break; case OperationType::RSQRT: result = "$0 = rsqrt($0);\n"; break; case OperationType::SIGMOID: if (precision != CalculationsPrecision::F32) { result = - "$0.x = convert_half(native_recip(1.0f + " - "native_exp(convert_float(-$0.x))));\n"; - result += - "$0.y = convert_half(native_recip(1.0f + " - "native_exp(convert_float(-$0.y))));\n"; - result += - "$0.z = convert_half(native_recip(1.0f + " - "native_exp(convert_float(-$0.z))));\n"; - result += - "$0.w = convert_half(native_recip(1.0f + " - "native_exp(convert_float(-$0.w))));\n"; + "$0 = convert_half4(native_recip(1.0f + " + "native_exp(convert_float4(-$0))));\n"; } else { result = "$0 = (FLT4)(1.0f) / ((FLT4)(1.0f) + exp(-($0)));\n"; } @@ -89,7 +83,12 @@ std::string GetOneInputCode(const OperationType& op_type, result = "$0 *= $0;\n"; break; case OperationType::TANH: - result = "$0 = tanh($0);\n"; + if (precision != CalculationsPrecision::F32) { + result = "float4 t = native_exp(convert_float4($0 * 2.0h));\n"; + result += "$0 = convert_half4(native_divide(t - 1.0f, t + 1.0f));\n"; + } else { + result = "$0 = tanh($0);\n"; + } break; default: return "Unknown operation type;\n"; @@ -128,6 +127,43 @@ std::string GetTwoInputCode(const OperationType& op_type, case OperationType::SUB: result += "$0 = $1 - $2;\n"; break; + // Comparison operators + case OperationType::LESS: + result = "$0.x = $1.x < $2.x ? (FLT)(1.0f) : (FLT)(0.0f);\n"; + result += "$0.y = $1.y < $2.y ? (FLT)(1.0f) : (FLT)(0.0f);\n"; + result += "$0.z = $1.z < $2.z ? (FLT)(1.0f) : (FLT)(0.0f);\n"; + result += "$0.w = $1.w < $2.w ? (FLT)(1.0f) : (FLT)(0.0f);\n"; + break; + case OperationType::LESS_EQUAL: + result = "$0.x = $1.x <= $2.x ? (FLT)(1.0f) : (FLT)(0.0f);\n"; + result += "$0.y = $1.y <= $2.y ? (FLT)(1.0f) : (FLT)(0.0f);\n"; + result += "$0.z = $1.z <= $2.z ? (FLT)(1.0f) : (FLT)(0.0f);\n"; + result += "$0.w = $1.w <= $2.w ? (FLT)(1.0f) : (FLT)(0.0f);\n"; + break; + case OperationType::GREATER: + result = "$0.x = $1.x > $2.x ? (FLT)(1.0f) : (FLT)(0.0f);\n"; + result += "$0.y = $1.y > $2.y ? (FLT)(1.0f) : (FLT)(0.0f);\n"; + result += "$0.z = $1.z > $2.z ? (FLT)(1.0f) : (FLT)(0.0f);\n"; + result += "$0.w = $1.w > $2.w ? (FLT)(1.0f) : (FLT)(0.0f);\n"; + break; + case OperationType::GREATER_EQUAL: + result = "$0.x = $1.x >= $2.x ? (FLT)(1.0f) : (FLT)(0.0f);\n"; + result += "$0.y = $1.y >= $2.y ? (FLT)(1.0f) : (FLT)(0.0f);\n"; + result += "$0.z = $1.z >= $2.z ? (FLT)(1.0f) : (FLT)(0.0f);\n"; + result += "$0.w = $1.w >= $2.w ? (FLT)(1.0f) : (FLT)(0.0f);\n"; + break; + case OperationType::EQUAL: + result = "$0.x = $1.x == $2.x ? (FLT)(1.0f) : (FLT)(0.0f);\n"; + result += "$0.y = $1.y == $2.y ? (FLT)(1.0f) : (FLT)(0.0f);\n"; + result += "$0.z = $1.z == $2.z ? (FLT)(1.0f) : (FLT)(0.0f);\n"; + result += "$0.w = $1.w == $2.w ? (FLT)(1.0f) : (FLT)(0.0f);\n"; + break; + case OperationType::NOT_EQUAL: + result = "$0.x = $1.x != $2.x ? (FLT)(1.0f) : (FLT)(0.0f);\n"; + result += "$0.y = $1.y != $2.y ? (FLT)(1.0f) : (FLT)(0.0f);\n"; + result += "$0.z = $1.z != $2.z ? (FLT)(1.0f) : (FLT)(0.0f);\n"; + result += "$0.w = $1.w != $2.w ? (FLT)(1.0f) : (FLT)(0.0f);\n"; + break; default: return "Unknown operation type;\n"; } diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/elementwise_test.cc b/tensorflow/lite/delegates/gpu/cl/kernels/elementwise_test.cc index d883a734214..b48f66ce600 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/elementwise_test.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/elementwise_test.cc @@ -208,6 +208,30 @@ TEST_F(OpenCLOperationTest, Log) { } } +TEST_F(OpenCLOperationTest, Neg) { + TensorFloat32 src_tensor; + src_tensor.shape = BHWC(1, 2, 1, 2); + src_tensor.data = {1.0f, -2.0f, 0.0f, 4.0f}; + + for (auto storage : env_.GetSupportedStorages()) { + for (auto precision : env_.GetSupportedPrecisions()) { + const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f; + OperationDef op_def; + op_def.precision = precision; + auto data_type = DeduceDataTypeFromPrecision(precision); + op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); + op_def.dst_tensors.push_back({data_type, storage, Layout::HWC}); + TensorFloat32 dst_tensor; + GPUOperation operation = + CreateElementwiseOneInput(op_def, OperationType::NEG); + ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation, + BHWC(1, 2, 1, 2), &dst_tensor)); + EXPECT_THAT(dst_tensor.data, + Pointwise(FloatNear(eps), {-1.0f, 2.0f, 0.0f, -4.0f})); + } + } +} + TEST_F(OpenCLOperationTest, Rsqrt) { TensorFloat32 src_tensor; src_tensor.shape = BHWC(1, 2, 1, 2); @@ -817,6 +841,174 @@ TEST_F(OpenCLOperationTest, SubWithScalarAtFirstPosition) { } } +TEST_F(OpenCLOperationTest, Less) { + TensorFloat32 src_tensor_0, src_tensor_1; + src_tensor_0.shape = BHWC(1, 2, 1, 2); + src_tensor_1.shape = BHWC(1, 2, 1, 2); + src_tensor_0.data = {0.0f, 1.0f, 2.0f, 3.0f}; + src_tensor_1.data = {1.0f, 0.0f, 2.0f, -4.0f}; + + for (auto storage : env_.GetSupportedStorages()) { + for (auto precision : env_.GetSupportedPrecisions()) { + const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-2f; + OperationDef op_def; + op_def.precision = precision; + auto data_type = DeduceDataTypeFromPrecision(precision); + op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); + op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); + op_def.dst_tensors.push_back({data_type, storage, Layout::HWC}); + TensorFloat32 dst_tensor; + GPUOperation operation = CreateElementwiseTwoInput( + op_def, OperationType::LESS, src_tensor_1.shape); + ASSERT_OK(ExecuteGPUOperation({src_tensor_0, src_tensor_1}, + creation_context_, &operation, + BHWC(1, 2, 1, 2), &dst_tensor)); + EXPECT_THAT(dst_tensor.data, + Pointwise(FloatNear(eps), {1.0f, 0.0f, 0.0f, 0.0f})); + } + } +} + +TEST_F(OpenCLOperationTest, LessEqual) { + TensorFloat32 src_tensor_0; + src_tensor_0.shape = BHWC(1, 2, 1, 2); + src_tensor_0.data = {0.0f, 1.0f, 2.0f, 3.0f}; + + ElementwiseAttributes attr; + attr.param = 2.0f; + + for (auto storage : env_.GetSupportedStorages()) { + for (auto precision : env_.GetSupportedPrecisions()) { + const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-2f; + OperationDef op_def; + op_def.precision = precision; + auto data_type = DeduceDataTypeFromPrecision(precision); + op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); + op_def.dst_tensors.push_back({data_type, storage, Layout::HWC}); + TensorFloat32 dst_tensor; + GPUOperation operation = + CreateElementwise(creation_context_.GetDeviceInfo(), op_def, + OperationType::LESS_EQUAL, attr); + ASSERT_OK(ExecuteGPUOperation(src_tensor_0, creation_context_, &operation, + BHWC(1, 2, 1, 2), &dst_tensor)); + EXPECT_THAT(dst_tensor.data, + Pointwise(FloatNear(eps), {1.0f, 1.0f, 1.0f, 0.0f})); + } + } +} + +TEST_F(OpenCLOperationTest, Greater) { + TensorFloat32 src_tensor_0; + src_tensor_0.shape = BHWC(1, 2, 1, 2); + src_tensor_0.data = {0.0f, 1.0f, 2.0f, 3.0f}; + + ElementwiseAttributes attr; + attr.param = 2.0f; + + for (auto storage : env_.GetSupportedStorages()) { + for (auto precision : env_.GetSupportedPrecisions()) { + const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-2f; + OperationDef op_def; + op_def.precision = precision; + auto data_type = DeduceDataTypeFromPrecision(precision); + op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); + op_def.dst_tensors.push_back({data_type, storage, Layout::HWC}); + TensorFloat32 dst_tensor; + GPUOperation operation = + CreateElementwise(creation_context_.GetDeviceInfo(), op_def, + OperationType::GREATER, attr); + ASSERT_OK(ExecuteGPUOperation(src_tensor_0, creation_context_, &operation, + BHWC(1, 2, 1, 2), &dst_tensor)); + EXPECT_THAT(dst_tensor.data, + Pointwise(FloatNear(eps), {0.0f, 0.0f, 0.0f, 1.0f})); + } + } +} + +TEST_F(OpenCLOperationTest, GreaterEqual) { + TensorFloat32 src_tensor_0; + src_tensor_0.shape = BHWC(1, 2, 1, 2); + src_tensor_0.data = {0.0f, 1.0f, 2.0f, 3.0f}; + + ElementwiseAttributes attr; + attr.param = 2.0f; + + for (auto storage : env_.GetSupportedStorages()) { + for (auto precision : env_.GetSupportedPrecisions()) { + const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-2f; + OperationDef op_def; + op_def.precision = precision; + auto data_type = DeduceDataTypeFromPrecision(precision); + op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); + op_def.dst_tensors.push_back({data_type, storage, Layout::HWC}); + TensorFloat32 dst_tensor; + GPUOperation operation = + CreateElementwise(creation_context_.GetDeviceInfo(), op_def, + OperationType::GREATER_EQUAL, attr); + ASSERT_OK(ExecuteGPUOperation(src_tensor_0, creation_context_, &operation, + BHWC(1, 2, 1, 2), &dst_tensor)); + EXPECT_THAT(dst_tensor.data, + Pointwise(FloatNear(eps), {0.0f, 0.0f, 1.0f, 1.0f})); + } + } +} + +TEST_F(OpenCLOperationTest, Equal) { + TensorFloat32 src_tensor_0; + src_tensor_0.shape = BHWC(1, 2, 1, 2); + src_tensor_0.data = {0.0f, 1.0f, 2.0f, 3.0f}; + + ElementwiseAttributes attr; + attr.param = 2.0f; + + for (auto storage : env_.GetSupportedStorages()) { + for (auto precision : env_.GetSupportedPrecisions()) { + const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-2f; + OperationDef op_def; + op_def.precision = precision; + auto data_type = DeduceDataTypeFromPrecision(precision); + op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); + op_def.dst_tensors.push_back({data_type, storage, Layout::HWC}); + TensorFloat32 dst_tensor; + GPUOperation operation = + CreateElementwise(creation_context_.GetDeviceInfo(), op_def, + OperationType::EQUAL, attr); + ASSERT_OK(ExecuteGPUOperation(src_tensor_0, creation_context_, &operation, + BHWC(1, 2, 1, 2), &dst_tensor)); + EXPECT_THAT(dst_tensor.data, + Pointwise(FloatNear(eps), {0.0f, 0.0f, 1.0f, 0.0f})); + } + } +} + +TEST_F(OpenCLOperationTest, NotEqual) { + TensorFloat32 src_tensor_0; + src_tensor_0.shape = BHWC(1, 2, 1, 2); + src_tensor_0.data = {0.0f, 1.0f, 2.0f, 3.0f}; + + ElementwiseAttributes attr; + attr.param = 2.0f; + + for (auto storage : env_.GetSupportedStorages()) { + for (auto precision : env_.GetSupportedPrecisions()) { + const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-2f; + OperationDef op_def; + op_def.precision = precision; + auto data_type = DeduceDataTypeFromPrecision(precision); + op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); + op_def.dst_tensors.push_back({data_type, storage, Layout::HWC}); + TensorFloat32 dst_tensor; + GPUOperation operation = + CreateElementwise(creation_context_.GetDeviceInfo(), op_def, + OperationType::NOT_EQUAL, attr); + ASSERT_OK(ExecuteGPUOperation(src_tensor_0, creation_context_, &operation, + BHWC(1, 2, 1, 2), &dst_tensor)); + EXPECT_THAT(dst_tensor.data, + Pointwise(FloatNear(eps), {1.0f, 1.0f, 0.0f, 1.0f})); + } + } +} + } // namespace } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/fully_connected.cc b/tensorflow/lite/delegates/gpu/cl/kernels/fully_connected.cc index 999344384aa..1940a1a020c 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/fully_connected.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/fully_connected.cc @@ -17,28 +17,50 @@ limitations under the License. #include #include +#include +#include "absl/memory/memory.h" +#include "tensorflow/lite/delegates/gpu/cl/arguments.h" +#include "tensorflow/lite/delegates/gpu/cl/device_info.h" +#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/cl/kernels/util.h" +#include "tensorflow/lite/delegates/gpu/cl/linear_storage.h" +#include "tensorflow/lite/delegates/gpu/cl/precision.h" +#include "tensorflow/lite/delegates/gpu/cl/tensor.h" +#include "tensorflow/lite/delegates/gpu/cl/tensor_type.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" namespace tflite { namespace gpu { namespace cl { +namespace { +bool UseBufferForWeights(const DeviceInfo& device_info) { + return device_info.IsAdreno() || device_info.IsAMD() || device_info.IsMali(); +} +} // namespace FullyConnected::FullyConnected(const OperationDef& definition, const DeviceInfo& device_info) : GPUOperation(definition) { if (device_info.IsAdreno()) { if (device_info.IsAdreno3xx()) { - work_group_size_ = int3(8, 4, 1); - } else if (device_info.IsAdreno4xx()) { work_group_size_ = int3(16, 4, 1); + } else if (device_info.IsAdreno4xx()) { + work_group_size_ = int3(32, 4, 1); } else { work_group_size_ = int3(32, 4, 1); } + } else if (device_info.IsIntel()) { + work_group_size_ = int3(8, 4, 1); + } else if (device_info.IsNvidia()) { + work_group_size_ = int3(8, 4, 1); + } else if (device_info.IsPowerVR()) { + work_group_size_ = int3(8, 4, 1); } else { work_group_size_ = int3(16, 4, 1); } - code_ = GetFullyConnectedKernelCode(definition_, work_group_size_); + code_ = GetFullyConnectedKernelCode(definition_, device_info); } FullyConnected::FullyConnected(FullyConnected&& kernel) @@ -58,10 +80,12 @@ FullyConnected& FullyConnected::operator=(FullyConnected&& kernel) { // optimized shaders std::string FullyConnected::GetFullyConnectedKernelCode( - const OperationDef& op_def, const int3& work_group_size) { + const OperationDef& op_def, const DeviceInfo& device_info) { AddSrcTensor("src_tensor", op_def.src_tensors[0]); AddDstTensor("dst_tensor", op_def.dst_tensors[0]); + const bool weights_are_buffer = UseBufferForWeights(device_info); + std::string c = GetCommonDefines(op_def.precision); switch (op_def.precision) { case CalculationsPrecision::F32: @@ -73,35 +97,54 @@ std::string FullyConnected::GetFullyConnectedKernelCode( break; } - const std::string wg_x = std::to_string(work_group_size.x); - const std::string wg_y = std::to_string(work_group_size.y); - c += "__kernel void main_function(\n"; - c += "$0) {\n"; - c += " int gid = get_global_id(0);\n"; - c += " bool inside = gid < args.dst_tensor.Slices();\n"; - c += " gid = min(gid, args.dst_tensor.Slices() - 1);\n"; - c += " int2 tid = (int2)(get_local_id(0), get_local_id(1));\n"; - c += " ACCUM_FLT4 s = (ACCUM_FLT4)(0.0f);\n"; - c += " for (uint c = tid.y; c < args.src_tensor.Slices(); c += " + wg_y + - ") {\n"; - c += " FLT4 v = args.src_tensor.Read(0, 0, c);\n"; - c += " FLT16 w = args.weights.Read(c * args.dst_tensor.Slices() + gid);\n"; - c += " s.x += dot(v, w.s0123);\n"; - c += " s.y += dot(v, w.s4567);\n"; - c += " s.z += dot(v, w.s89ab);\n"; - c += " s.w += dot(v, w.scdef);\n"; - c += " }\n"; - c += " __local ACCUM_FLT4 temp[" + wg_x + "][" + wg_y + "];\n"; - c += " temp[tid.x][tid.y] = s;\n"; - c += " barrier(CLK_LOCAL_MEM_FENCE);\n"; - c += " if (tid.y == 0 && inside) {\n"; - for (int i = 1; i < work_group_size.y; ++i) { + c += "#define WG_X " + std::to_string(work_group_size_.x) + "\n"; + c += "#define WG_Y " + std::to_string(work_group_size_.y) + "\n"; + + c += R"(__kernel void main_function($0) { + int gid = get_global_id(0); + int2 tid = (int2)(get_local_id(0), get_local_id(1)); + ACCUM_FLT4 s = (ACCUM_FLT4)(0.0f); + if (gid < args.dst_tensor.Slices()) { + for (int c = tid.y; c < args.src_tensor.Slices(); c += WG_Y) { + FLT4 v = args.src_tensor.Read(0, 0, c); +)"; + if (weights_are_buffer) { + c += R"(FLT16 w = args.weights.Read(c * args.dst_tensor.Slices() + gid); + FLT4 partial = v.s0 * w.s0123; + partial = mad(v.s1, w.s4567, partial); + partial = mad(v.s2, w.s89ab, partial); + partial = mad(v.s3, w.scdef, partial); + s += TO_ACCUM_TYPE(partial); +)"; + } else { + c += R"(FLT4 w0 = args.weights.Read(c * 4 + 0, gid); + FLT4 w1 = args.weights.Read(c * 4 + 1, gid); + FLT4 w2 = args.weights.Read(c * 4 + 2, gid); + FLT4 w3 = args.weights.Read(c * 4 + 3, gid); + FLT4 partial = v.s0 * w0; + partial = mad(v.s1, w1, partial); + partial = mad(v.s2, w2, partial); + partial = mad(v.s3, w3, partial); + s += TO_ACCUM_TYPE(partial); +)"; + } + c += R"( } + } + __local ACCUM_FLT4 temp[WG_X][WG_Y]; + temp[tid.x][tid.y] = s; + barrier(CLK_LOCAL_MEM_FENCE); + if (gid >= args.dst_tensor.Slices()) { + return; + } + if (tid.y == 0) { +)"; + for (int i = 1; i < work_group_size_.y; ++i) { c += " s += temp[tid.x][" + std::to_string(i) + "];\n"; } - c += " FLT4 r0 = TO_FLT4(s) + args.biases.Read(gid);\n"; - c += " args.dst_tensor.Write(r0, 0, 0, gid);\n"; - c += " }\n"; - c += "}\n"; + c += R"( FLT4 r0 = TO_FLT4(s) + args.biases.Read(gid); + args.dst_tensor.Write(r0, 0, 0, gid); + } +})"; return c; } @@ -114,7 +157,7 @@ FullyConnected CreateFullyConnected(const DeviceInfo& device_info, const OperationDef& definition, const FullyConnectedAttributes& attr) { FullyConnected result(definition, device_info); - result.UploadWeights(attr.weights); + result.UploadWeights(attr.weights, UseBufferForWeights(device_info)); TensorLinearDescriptor desc; desc.storage_type = LinearStorageType::TEXTURE_2D; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/fully_connected.h b/tensorflow/lite/delegates/gpu/cl/kernels/fully_connected.h index f1fc7dc199f..ec572b24fb5 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/fully_connected.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/fully_connected.h @@ -16,19 +16,27 @@ limitations under the License. #ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_FULLY_CONNECTED_H_ #define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_FULLY_CONNECTED_H_ +#include + +#include +#include #include +#include "absl/memory/memory.h" +#include "tensorflow/lite/delegates/gpu/cl/arguments.h" #include "tensorflow/lite/delegates/gpu/cl/buffer.h" +#include "tensorflow/lite/delegates/gpu/cl/cl_kernel.h" +#include "tensorflow/lite/delegates/gpu/cl/device_info.h" #include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" -#include "tensorflow/lite/delegates/gpu/cl/linear_storage.h" -#include "tensorflow/lite/delegates/gpu/cl/tensor.h" -#include "tensorflow/lite/delegates/gpu/cl/util.h" +#include "tensorflow/lite/delegates/gpu/cl/kernels/tuning_parameters.h" +#include "tensorflow/lite/delegates/gpu/cl/precision.h" +#include "tensorflow/lite/delegates/gpu/cl/texture2d.h" #include "tensorflow/lite/delegates/gpu/common/data_type.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" -#include "tensorflow/lite/delegates/gpu/common/status.h" #include "tensorflow/lite/delegates/gpu/common/tensor.h" #include "tensorflow/lite/delegates/gpu/common/types.h" +#include "tensorflow/lite/delegates/gpu/common/util.h" namespace tflite { namespace gpu { @@ -36,52 +44,77 @@ namespace cl { template void RearrangeFCWeightsToIOO4I4(const tflite::gpu::Tensor& weights, - absl::Span dst) { + S* dst) { const int src_channels = weights.shape.i; const int padded_src_channels = AlignByN(src_channels, 4); const int dst_channels = weights.shape.o; const int padded_dst_channels = AlignByN(dst_channels, 4); - // The weights are to be rearranged in such a way that the first 4 elements of - // each row, starting from row_0, are copied onto the destination buffer. The - // next set of 4 elements are then copied and so on. As an example, an 8x8 - // matrix would be rearranged as below. + // Change the travelsal order of the weight matrix in the following way: + // The matrix is segmented to blocks of 4x4. If (any) dimension of the matrix + // size is not divisible by 4, then pad with zeros. Each block is stored + // contigously. The 16 elements within a block are ordered as 4 elements of + // the first column, 4 elems of the second, etc. Blocks then traversed as + // columns first, rows last. As an example, an 8x8 matrix would be traversed + // as below. // - // | a0 a1 a2 a3 a4 a5 a6 a7 | | a0 a1 a2 a3 b0 b1 b2 b3 | - // | b0 b1 b2 b3 b4 b5 b6 b7 | | c0 c1 c2 c3 d0 d1 d2 d3 | - // | c0 c1 c2 c3 c4 c5 c6 c7 | | e0 e1 e2 e3 f0 f1 f2 f3 | - // | d0 d1 d2 d3 d4 d5 d6 d7 | ---------> | g0 g1 g2 g3 h0 h1 h2 h3 | - // | e0 e1 e2 e3 e4 e5 e6 e7 | | a4 a5 a6 a7 b4 b5 b6 b7 | - // | f0 f1 f2 f3 f4 f5 f6 f7 | | c4 c5 c6 c7 d4 d5 d6 d7 | - // | g0 g1 g2 g3 g4 g5 g6 g7 | | e4 e5 e6 e7 f4 f5 f6 f7 | - // | h0 h1 h2 h3 h4 h5 h6 h7 | | g4 g5 g6 g7 h4 h5 h6 h7 | + // | 0 4 8 12 32 36 40 44 | + // | 1 5 9 13 33 37 41 45 | + // | 2 6 10 14 34 38 42 46 | + // | 3 7 11 15 35 39 43 47 | + // | 16 20 24 28 48 52 56 60 | + // | 17 21 25 29 49 53 57 61 | + // | 18 22 26 30 50 54 58 62 | + // | 19 23 27 31 51 55 59 63 | + // + // The benefit of doing this is that reading contigous 16 elements gives a 4x4 + // block of the matrix, where the first 4 elements is the first row of the + // block, second 4 elements is the second row of the block, etc. Subsequent + // blocks contain elements of the same 4 columns. - for (int y = 0; y < dst_channels; y++) { - int x = 0; - for (; x + 4 <= src_channels; x += 4) { - const int idx_data_0 = src_channels * y + x; - S filter = S(weights.data[idx_data_0], weights.data[idx_data_0 + 1], - weights.data[idx_data_0 + 2], weights.data[idx_data_0 + 3]); - dst[y + padded_dst_channels * x / 4] = filter; - } - - // If the width is not a multiple of 4, padding is required and the padded - // region is filled with zeros. - if (src_channels != padded_src_channels) { - const int idx_data_0 = src_channels * y + x; - - S filter = S(x < src_channels ? weights.data[idx_data_0] : 0.0, - x + 1 < src_channels ? weights.data[idx_data_0 + 1] : 0.0, - x + 2 < src_channels ? weights.data[idx_data_0 + 2] : 0.0, - x + 3 < src_channels ? weights.data[idx_data_0 + 3] : 0.0); - dst[y + padded_dst_channels * x / 4] = filter; + for (int block_y = 0; 4 * block_y < padded_dst_channels; block_y++) { + for (int y_in_block = 0; y_in_block < 4; y_in_block++) { + for (int block_x = 0; 4 * block_x < padded_src_channels; block_x++) { + for (int x_in_block = 0; x_in_block < 4; x_in_block++) { + int y = 4 * block_y + y_in_block; + int x = 4 * block_x + x_in_block; + // Consider destination as an array with extents + // [padded_src_channels/4][padded_dst_channels/4][4][4] + int dst_index = block_x * padded_dst_channels * 4 + block_y * 16 + + x_in_block * 4 + y_in_block; + if (x < src_channels && y < dst_channels) { + dst[dst_index] = weights.data[src_channels * y + x]; + } else { + dst[dst_index] = 0.0f; + } + } + } } } +} - // Fill the padded columns with zeros. - for (int y = dst_channels; y < padded_dst_channels; y++) { - for (int x = 0; x < padded_src_channels; x += 4) { - dst[y + padded_dst_channels * x / 4] = S(0.0); +template +void RearrangeFCWeightsToOIO4I4(const tflite::gpu::Tensor& weights, + S* dst) { + const int src_channels = weights.shape.i; + const int src_depth = DivideRoundUp(src_channels, 4); + const int dst_channels = weights.shape.o; + const int dst_depth = DivideRoundUp(dst_channels, 4); + + int counter = 0; + for (int d = 0; d < dst_depth; ++d) { + for (int s = 0; s < src_depth; ++s) { + for (int i = 0; i < 4; ++i) { + const int src_ch = s * 4 + i; + for (int j = 0; j < 4; ++j) { + const int dst_ch = d * 4 + j; + if (src_ch < src_channels && dst_ch < dst_channels) { + dst[counter++] = weights.data[dst_ch * src_channels + src_ch]; + } else { + dst[counter++] = 0.0f; + } + } + } } } } @@ -110,15 +143,16 @@ class FullyConnected : public GPUOperation { const FullyConnectedAttributes& attr); template - void UploadWeights(const tflite::gpu::Tensor& weights); + void UploadWeights(const tflite::gpu::Tensor& weights, + bool weights_are_buffer); std::string GetFullyConnectedKernelCode(const OperationDef& op_def, - const int3& work_group_size); + const DeviceInfo& device_info); }; template -void FullyConnected::UploadWeights( - const tflite::gpu::Tensor& weights) { +void FullyConnected::UploadWeights(const tflite::gpu::Tensor& weights, + bool weights_are_buffer) { const int src_depth = DivideRoundUp(weights.shape.i, 4); const int dst_depth = DivideRoundUp(weights.shape.o, 4); @@ -127,22 +161,40 @@ void FullyConnected::UploadWeights( const int float4_size = f32_weights ? 16 : 8; - BufferDescriptor desc; - desc.element_type = f32_weights ? DataType::FLOAT32 : DataType::FLOAT16; - desc.element_size = 16; - desc.size = float4_size * elements_count; - desc.data.resize(desc.size); + if (weights_are_buffer) { + BufferDescriptor desc; + desc.element_type = f32_weights ? DataType::FLOAT32 : DataType::FLOAT16; + desc.element_size = 16; + desc.size = float4_size * elements_count; + desc.data.resize(desc.size); - if (f32_weights) { - float4* ptr = reinterpret_cast(desc.data.data()); - RearrangeFCWeightsToIOO4I4(weights, absl::MakeSpan(ptr, elements_count)); + if (f32_weights) { + float* ptr = reinterpret_cast(desc.data.data()); + RearrangeFCWeightsToIOO4I4(weights, ptr); + } else { + half* ptr = reinterpret_cast(desc.data.data()); + RearrangeFCWeightsToIOO4I4(weights, ptr); + } + + args_.AddObject("weights", + absl::make_unique(std::move(desc))); } else { - half4* ptr = reinterpret_cast(desc.data.data()); - RearrangeFCWeightsToIOO4I4(weights, absl::MakeSpan(ptr, elements_count)); - } + Texture2DDescriptor desc; + desc.element_type = f32_weights ? DataType::FLOAT32 : DataType::FLOAT16; + desc.size = int2(src_depth * 4, dst_depth); + desc.data.resize(float4_size * elements_count); - args_.AddObject("weights", - absl::make_unique(std::move(desc))); + if (f32_weights) { + float* ptr = reinterpret_cast(desc.data.data()); + RearrangeFCWeightsToOIO4I4(weights, ptr); + } else { + half* ptr = reinterpret_cast(desc.data.data()); + RearrangeFCWeightsToOIO4I4(weights, ptr); + } + + args_.AddObject("weights", + absl::make_unique(std::move(desc))); + } } FullyConnected CreateFullyConnected(const DeviceInfo& device_info, diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/fully_connected_test.cc b/tensorflow/lite/delegates/gpu/cl/kernels/fully_connected_test.cc index f58487c1941..c9853187b3c 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/fully_connected_test.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/fully_connected_test.cc @@ -19,9 +19,15 @@ limitations under the License. #include #include +#include "tensorflow/lite/delegates/gpu/cl/environment.h" #include "tensorflow/lite/delegates/gpu/cl/kernels/cl_test.h" +#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" +#include "tensorflow/lite/delegates/gpu/cl/precision.h" +#include "tensorflow/lite/delegates/gpu/cl/tensor_type.h" +#include "tensorflow/lite/delegates/gpu/common/data_type.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" -#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/tensor.h" using ::testing::ElementsAreArray; using ::testing::FloatNear; @@ -39,7 +45,8 @@ TEST_F(OpenCLOperationTest, FullyConnected) { FullyConnectedAttributes attr; attr.weights.shape = OHWI(2, 1, 1, 4); - attr.weights.data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}; + attr.weights.data = {0.0f, 1.0f, 2.0f, 3.0f, // + 4.0f, 5.0f, 6.0f, 7.0f}; attr.bias.shape = Linear(2); attr.bias.data = {0.5f, -0.5f}; @@ -56,7 +63,101 @@ TEST_F(OpenCLOperationTest, FullyConnected) { CreateFullyConnected(creation_context_.GetDeviceInfo(), op_def, attr); ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation, BHWC(1, 1, 1, 2), &dst_tensor)); - EXPECT_THAT(dst_tensor.data, Pointwise(FloatNear(eps), {14.5f, 37.5f})); + EXPECT_THAT(dst_tensor.data, Pointwise(FloatNear(eps), {14.5f, 37.5f})) + << "Failed using precision " << ToString(precision); + } + } +} + +TEST_F(OpenCLOperationTest, FullyConnectedLarge) { + TensorFloat32 src_tensor; + src_tensor.shape = BHWC(1, 1, 1, 8); + src_tensor.data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}; + + FullyConnectedAttributes attr; + attr.weights.shape = OHWI(12, 1, 1, 8); + attr.weights.data = { + 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, // + 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, // + 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, // + 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, // + 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, 38.0f, 39.0f, // + 40.0f, 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, // + 48.0f, 49.0f, 50.0f, 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, // + 56.0f, 57.0f, 58.0f, 59.0f, 60.0f, 61.0f, 62.0f, 63.0f, // + 64.0f, 65.0f, 66.0f, 67.0f, 68.0f, 69.0f, 70.0f, 71.0f, // + 72.0f, 73.0f, 74.0f, 75.0f, 76.0f, 77.0f, 78.0f, 79.0f, // + 80.0f, 81.0f, 82.0f, 83.0f, 84.0f, 85.0f, 86.0f, 87.0f, // + 88.0f, 89.0f, 90.0f, 91.0f, 92.0f, 93.0f, 94.0f, 95.0f, // + }; + attr.bias.shape = Linear(12); + attr.bias.data = {-0.6f, -0.5f, -0.4f, -0.3f, -0.2f, -0.1f, + 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f}; + + for (auto storage : env_.GetSupportedStorages()) { + for (auto precision : env_.GetSupportedPrecisions()) { + const float eps = precision == CalculationsPrecision::F32 ? 0.0f : 0.601f; + OperationDef op_def; + op_def.precision = precision; + auto data_type = DeduceDataTypeFromPrecision(precision); + op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); + op_def.dst_tensors.push_back({data_type, storage, Layout::HWC}); + TensorFloat32 dst_tensor; + FullyConnected operation = + CreateFullyConnected(creation_context_.GetDeviceInfo(), op_def, attr); + ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation, + BHWC(1, 1, 1, 12), &dst_tensor)); + EXPECT_THAT( + dst_tensor.data, + Pointwise(FloatNear(eps), + {139.4f, 363.5f, 587.6f, 811.7f, 1035.8f, 1259.9f, 1484.1f, + 1708.2f, 1932.3f, 2156.4f, 2380.5f, 2604.6f})) + << "Failed using precision " << ToString(precision); + } + } +} + +TEST_F(OpenCLOperationTest, FullyConnectedExtraLarge) { + static const int kInputSize = 1024; + static const int kOutputSize = 1024; + TensorFloat32 src_tensor; + src_tensor.shape = BHWC(1, 1, 1, kInputSize); + src_tensor.data.assign(kInputSize, 1.1f); + + FullyConnectedAttributes attr; + attr.weights.shape = OHWI(1024, 1, 1, kInputSize); + attr.weights.data.assign(kOutputSize * kInputSize, 2.2f); + attr.bias.shape = Linear(kOutputSize); + attr.bias.data.assign(kOutputSize, 3.3f); + + std::vector expected(kOutputSize, 2481.38f); + + for (auto storage : env_.GetSupportedStorages()) { + for (auto precision : env_.GetSupportedPrecisions()) { + float eps; + switch (precision) { + case CalculationsPrecision::F32: + eps = 2.45e-3f; + break; + case CalculationsPrecision::F32_F16: + eps = 1.38f; + break; + case CalculationsPrecision::F16: + eps = 38.7f; + break; + } + OperationDef op_def; + op_def.precision = precision; + auto data_type = DeduceDataTypeFromPrecision(precision); + op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); + op_def.dst_tensors.push_back({data_type, storage, Layout::HWC}); + TensorFloat32 dst_tensor; + FullyConnected operation = + CreateFullyConnected(creation_context_.GetDeviceInfo(), op_def, attr); + ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation, + BHWC(1, 1, 1, kOutputSize), &dst_tensor)); + EXPECT_THAT(dst_tensor.data, Pointwise(FloatNear(eps), expected)) + << "Failed using precision " << ToString(precision); } } } @@ -64,53 +165,74 @@ TEST_F(OpenCLOperationTest, FullyConnected) { TEST_F(OpenCLOperationTest, RearrageWeights) { tflite::gpu::Tensor weights; weights.shape = OHWI(8, 1, 1, 8); - weights.data = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 10.0, 11.0, - 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 20.0, 21.0, 22.0, 23.0, - 24.0, 25.0, 26.0, 27.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, - 36.0, 37.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, - 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 60.0, 61.0, - 62.0, 63.0, 64.0, 65.0, 66.0, 67.0, 70.0, 71.0, 72.0, 73.0, - 74.0, 75.0, 76.0, 77.0}; - - std::vector expected_rearranged_data = { - 0.0, 1.0, 2.0, 3.0, 10.0, 11.0, 12.0, 13.0, 20.0, 21.0, 22.0, - 23.0, 30.0, 31.0, 32.0, 33.0, 40.0, 41.0, 42.0, 43.0, 50.0, 51.0, - 52.0, 53.0, 60.0, 61.0, 62.0, 63.0, 70.0, 71.0, 72.0, 73.0, 4.0, - 5.0, 6.0, 7.0, 14.0, 15.0, 16.0, 17.0, 24.0, 25.0, 26.0, 27.0, - 34.0, 35.0, 36.0, 37.0, 44.0, 45.0, 46.0, 47.0, 54.0, 55.0, 56.0, - 57.0, 64.0, 65.0, 66.0, 67.0, 74.0, 75.0, 76.0, 77.0, + weights.data = { + 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, // + 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, // + 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, // + 30.0f, 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, // + 40.0f, 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, // + 50.0f, 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 56.0f, 57.0f, // + 60.0f, 61.0f, 62.0f, 63.0f, 64.0f, 65.0f, 66.0f, 67.0f, // + 70.0f, 71.0f, 72.0f, 73.0f, 74.0f, 75.0f, 76.0f, 77.0f // }; + std::vector expected_rearranged_data = { + // Top-left block + 0.0f, 10.0f, 20.0f, 30.0f, 1.0f, 11.0f, 21.0f, 31.0f, 2.0f, 12.0f, 22.0f, + 32.0f, 3.0f, 13.0f, 23.0f, 33.0f, + // Bottom-left block + 40.0f, 50.0f, 60.0f, 70.0f, 41.0f, 51.0f, 61.0f, 71.0f, 42.0f, 52.0f, + 62.0f, 72.0f, 43.0f, 53.0f, 63.0f, 73.0f, + // Top-right block + 4.0f, 14.0f, 24.0f, 34.0f, 5.0f, 15.0f, 25.0f, 35.0f, 6.0f, 16.0f, 26.0f, + 36.0f, 7.0f, 17.0f, 27.0f, 37.0f, + // Bottom-right block + 44.0f, 54.0f, 64.0f, 74.0f, 45.0f, 55.0f, 65.0f, 75.0f, 46.0f, 56.0f, + 66.0f, 76.0f, 47.0f, 57.0f, 67.0f, 77.0f}; + std::vector data(8 * 8); - float4* data_ptr = static_cast(static_cast(data.data())); - RearrangeFCWeightsToIOO4I4(weights, absl::MakeSpan(data_ptr, 8 * 8 / 4)); + RearrangeFCWeightsToIOO4I4(weights, data.data()); EXPECT_THAT(data, ElementsAreArray(expected_rearranged_data)); } TEST_F(OpenCLOperationTest, RearrageWeightsWhenPaddingIsRequired) { tflite::gpu::Tensor weights; - weights.shape = OHWI(7, 1, 1, 7); + weights.shape = OHWI(9, 1, 1, 7); weights.data = { - 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 10.0, 11.0, 12.0, - 13.0, 14.0, 15.0, 16.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, - 26.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 40.0, 41.0, - 42.0, 43.0, 44.0, 45.0, 46.0, 50.0, 51.0, 52.0, 53.0, 54.0, - 55.0, 56.0, 60.0, 61.0, 62.0, 63.0, 64.0, 65.0, 66.0, + 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, // + 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, // + 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, // + 30.0f, 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, // + 40.0f, 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, // + 50.0f, 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 56.0f, // + 60.0f, 61.0f, 62.0f, 63.0f, 64.0f, 65.0f, 66.0f, // + 70.0f, 71.0f, 72.0f, 73.0f, 74.0f, 75.0f, 76.0f, // + 80.0f, 81.0f, 82.0f, 83.0f, 84.0f, 85.0f, 86.0f, // }; std::vector expected_rearranged_data = { - 0.0, 1.0, 2.0, 3.0, 10.0, 11.0, 12.0, 13.0, 20.0, 21.0, 22.0, - 23.0, 30.0, 31.0, 32.0, 33.0, 40.0, 41.0, 42.0, 43.0, 50.0, 51.0, - 52.0, 53.0, 60.0, 61.0, 62.0, 63.0, 0.0, 0.0, 0.0, 0.0, 4.0, - 5.0, 6.0, 0.0, 14.0, 15.0, 16.0, 0.0, 24.0, 25.0, 26.0, 0.0, - 34.0, 35.0, 36.0, 0.0, 44.0, 45.0, 46.0, 0.0, 54.0, 55.0, 56.0, - 0.0, 64.0, 65.0, 66.0, 0.0, 0.0, 0.0, 0.0, 0.0, - }; + // Top-left block + 0.0f, 10.0f, 20.0f, 30.0f, 1.0f, 11.0f, 21.0f, 31.0f, 2.0f, 12.0f, 22.0f, + 32.0f, 3.0f, 13.0f, 23.0f, 33.0f, + // Mid-left block + 40.0f, 50.0f, 60.0f, 70.0f, 41.0f, 51.0f, 61.0f, 71.0f, 42.0f, 52.0f, + 62.0f, 72.0f, 43.0f, 53.0f, 63.0f, 73.0f, + // Bottom-left block + 80.0f, 0.0f, 0.0f, 0.0f, 81.0f, 0.0f, 0.0f, 0.0f, 82.0f, 0.0f, 0.0f, 0.0f, + 83.0f, 0.0f, 0.0f, 0.0f, + // Top-right block + 4.0f, 14.0f, 24.0f, 34.0f, 5.0f, 15.0f, 25.0f, 35.0f, 6.0f, 16.0f, 26.0f, + 36.0f, 0.0f, 0.0f, 0.0f, 0.0f, + // Mid-left block + 44.0f, 54.0f, 64.0f, 74.0f, 45.0f, 55.0f, 65.0f, 75.0f, 46.0f, 56.0f, + 66.0f, 76.0f, 0.0f, 0.0f, 0.0f, 0.0f, + // Bottom-right block + 84.0f, 0.0f, 0.0f, 0.0f, 85.0f, 0.0f, 0.0f, 0.0f, 86.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 0.0f}; - std::vector data(8 * 8); - float4* data_ptr = static_cast(static_cast(data.data())); - RearrangeFCWeightsToIOO4I4(weights, absl::MakeSpan(data_ptr, 8 * 8 / 4)); + std::vector data(12 * 8); + RearrangeFCWeightsToIOO4I4(weights, data.data()); EXPECT_THAT(data, ElementsAreArray(expected_rearranged_data)); } diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.cc b/tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.cc index f9d6ec762ec..b39f03af846 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.cc @@ -49,6 +49,33 @@ std::string GetElementWiseCode(const OperationDef& op_def, return c; } +int3 GetWorkGroupsCount(int grid_dimension, const int3& grid_size, + const int3& work_group_size, + const int3& work_group_launch_order) { + int3 work_groups_count; + if (grid_dimension == 1) { + work_groups_count.x = DivideRoundUp(grid_size.x, work_group_size.x); + work_groups_count.y = 1; + work_groups_count.z = 1; + } else if (grid_dimension == 2) { + int3 wgs; + wgs.x = DivideRoundUp(grid_size.x, work_group_size.x); + wgs.y = DivideRoundUp(grid_size.y, work_group_size.y); + work_groups_count.x = wgs[work_group_launch_order[0]]; + work_groups_count.y = wgs[work_group_launch_order[1]]; + work_groups_count.z = 1; + } else { // grid_dimension == 3 + int3 wgs; + wgs.x = DivideRoundUp(grid_size.x, work_group_size.x); + wgs.y = DivideRoundUp(grid_size.y, work_group_size.y); + wgs.z = DivideRoundUp(grid_size.z, work_group_size.z); + work_groups_count.x = wgs[work_group_launch_order[0]]; + work_groups_count.y = wgs[work_group_launch_order[1]]; + work_groups_count.z = wgs[work_group_launch_order[2]]; + } + return work_groups_count; +} + } // namespace DataType OperationDef::GetDataType() const { @@ -106,9 +133,12 @@ GPUOperation::GPUOperation(GPUOperation&& operation) src_(std::move(operation.src_)), dst_(std::move(operation.dst_)), kernel_(std::move(operation.kernel_)), + grid_dimension_(operation.grid_dimension_), + work_group_launch_order_(operation.work_group_launch_order_), grid_size_(operation.grid_size_), src_tensors_names_(std::move(operation.src_tensors_names_)), dst_tensors_names_(std::move(operation.dst_tensors_names_)), + work_groups_count_(operation.work_groups_count_), linkable_count_(operation.linkable_count_), elementwise_code_(std::move(operation.elementwise_code_)) {} @@ -126,9 +156,12 @@ GPUOperation& GPUOperation::operator=(GPUOperation&& operation) { src_ = std::move(operation.src_); dst_ = std::move(operation.dst_); kernel_ = std::move(operation.kernel_); + std::swap(grid_dimension_, operation.grid_dimension_); + std::swap(work_group_launch_order_, operation.work_group_launch_order_); std::swap(grid_size_, operation.grid_size_); src_tensors_names_ = std::move(operation.src_tensors_names_); dst_tensors_names_ = std::move(operation.dst_tensors_names_); + std::swap(work_groups_count_, operation.work_groups_count_); std::swap(linkable_count_, operation.linkable_count_); elementwise_code_ = std::move(operation.elementwise_code_); } @@ -183,12 +216,15 @@ absl::Status GPUOperation::UpdateParams() { for (int i = 0; i < dst_tensors_names_.size(); ++i) { RETURN_IF_ERROR(args_.SetObjectRef(dst_tensors_names_[i], dst_[i])); } - RETURN_IF_ERROR(BindArguments()); + RETURN_IF_ERROR(BindArguments(&args_)); grid_size_ = GetGridSize(); + work_groups_count_ = GetWorkGroupsCount( + grid_dimension_, grid_size_, work_group_size_, work_group_launch_order_); return absl::OkStatus(); } -absl::Status GPUOperation::Compile(const CreationContext& creation_context) { +absl::Status GPUOperation::AssembleCode(const DeviceInfo& device_info, + CLContext* context) { if (elementwise_) { auto src_desc = absl::make_unique(definition_.src_tensors[0]); @@ -206,29 +242,35 @@ absl::Status GPUOperation::Compile(const CreationContext& creation_context) { dst_tensors_names_.insert(dst_tensors_names_.begin(), "dst_tensor"); args_.AddObjectRef("dst_tensor", AccessType::WRITE, std::move(dst_desc)); - std::string code = - GetElementWiseCode(definition_, check_src_channels_size_); elementwise_code_ = "{\n" + code_ + "\n}\n" + elementwise_code_; - RETURN_IF_ERROR(args_.AllocateObjects(creation_context.context)); + code_ = GetElementWiseCode(definition_, check_src_channels_size_); + RETURN_IF_ERROR(args_.AllocateObjects(context)); RETURN_IF_ERROR(args_.TransformToCLCode( - creation_context.device->info_, - {{dst_tensors_names_[0], elementwise_code_}}, &code)); - code = absl::Substitute(code, args_.GetListOfArgs()); - RETURN_IF_ERROR(creation_context.cache->GetOrCreateCLKernel( - code, "main_function", *creation_context.context, - *creation_context.device, &kernel_)); + device_info, {{dst_tensors_names_[0], elementwise_code_}}, &code_)); } else { - RETURN_IF_ERROR(args_.AllocateObjects(creation_context.context)); + RETURN_IF_ERROR(args_.AllocateObjects(context)); RETURN_IF_ERROR(args_.TransformToCLCode( - creation_context.device->info_, - {{dst_tensors_names_[0], elementwise_code_}}, &code_)); - RETURN_IF_ERROR(creation_context.cache->GetOrCreateCLKernel( - code_, "main_function", compiler_options_, *creation_context.context, - *creation_context.device, &kernel_)); + device_info, {{dst_tensors_names_[0], elementwise_code_}}, &code_)); } + return absl::OkStatus(); +} + +absl::Status GPUOperation::Compile(const CreationContext& creation_context) { + RETURN_IF_ERROR( + AssembleCode(creation_context.GetDeviceInfo(), creation_context.context)); + RETURN_IF_ERROR(creation_context.cache->GetOrCreateCLKernel( + code_, "main_function", compiler_options_, *creation_context.context, + *creation_context.device, &kernel_)); return PostCompileCheck(creation_context.device->info_, kernel_.info_); } +absl::Status GPUOperation::CompileDeserialized( + const CreationContext& creation_context) { + return creation_context.cache->GetOrCreateCLKernel( + code_, "main_function", compiler_options_, *creation_context.context, + *creation_context.device, &kernel_); +} + void GPUOperation::GetPossibleKernelWorkGroups( TuningType tuning_type, const DeviceInfo& device_info, const KernelInfo& kernel_info, std::vector* work_groups) const { @@ -246,14 +288,26 @@ absl::Status GPUOperation::Tune(const TuningParameters& params) { } if (possible_work_groups.size() == 1) { work_group_size_ = possible_work_groups[0]; + work_groups_count_ = + GetWorkGroupsCount(grid_dimension_, grid_size_, work_group_size_, + work_group_launch_order_); return absl::OkStatus(); } else { + std::vector work_groups_count(possible_work_groups.size()); + for (int i = 0; i < work_groups_count.size(); ++i) { + work_groups_count[i] = + GetWorkGroupsCount(grid_dimension_, grid_size_, + possible_work_groups[i], work_group_launch_order_); + } RETURN_IF_ERROR(args_.Bind(kernel_.kernel())); int best_work_group_index; RETURN_IF_ERROR(params.queue->GetBestWorkGroupIndex( - kernel_, *params.info, grid_size_, possible_work_groups, + kernel_, *params.info, work_groups_count, possible_work_groups, &best_work_group_index)); work_group_size_ = possible_work_groups[best_work_group_index]; + work_groups_count_ = + GetWorkGroupsCount(grid_dimension_, grid_size_, work_group_size_, + work_group_launch_order_); return absl::OkStatus(); } } @@ -283,7 +337,7 @@ int3 GPUOperation::GetGridSize() const { const int grid_z = 1; return int3(grid_x, grid_y, grid_z); } - return int3(0, 0, 0); + return grid_size_; } void GPUOperation::AddUniquePostfix(const std::string& unique_postfix) { diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h b/tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h index 2fa8c90c1da..57d8690c54e 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h @@ -16,20 +16,24 @@ limitations under the License. #ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_GPU_OPERATION_H_ #define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_GPU_OPERATION_H_ -#include #include #include #include "tensorflow/lite/delegates/gpu/cl/arguments.h" #include "tensorflow/lite/delegates/gpu/cl/buffer.h" +#include "tensorflow/lite/delegates/gpu/cl/cl_command_queue.h" #include "tensorflow/lite/delegates/gpu/cl/cl_context.h" #include "tensorflow/lite/delegates/gpu/cl/cl_device.h" +#include "tensorflow/lite/delegates/gpu/cl/cl_kernel.h" +#include "tensorflow/lite/delegates/gpu/cl/cl_program.h" +#include "tensorflow/lite/delegates/gpu/cl/device_info.h" #include "tensorflow/lite/delegates/gpu/cl/kernels/tuning_parameters.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h" #include "tensorflow/lite/delegates/gpu/cl/precision.h" #include "tensorflow/lite/delegates/gpu/cl/program_cache.h" +#include "tensorflow/lite/delegates/gpu/cl/serialization_generated.h" #include "tensorflow/lite/delegates/gpu/cl/tensor.h" #include "tensorflow/lite/delegates/gpu/cl/tensor_type.h" +#include "tensorflow/lite/delegates/gpu/common/data_type.h" #include "tensorflow/lite/delegates/gpu/common/status.h" #include "tensorflow/lite/delegates/gpu/common/types.h" @@ -117,7 +121,7 @@ class GPUOperation { absl::Status AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(args_.Bind(kernel_.kernel())); - return queue->DispatchImplicit(kernel_, grid_size_, work_group_size_); + return queue->Dispatch(kernel_, work_groups_count_, work_group_size_); } virtual void GetPossibleKernelWorkGroups( @@ -126,8 +130,12 @@ class GPUOperation { absl::Status Tune(const TuningParameters& params); + absl::Status AssembleCode(const DeviceInfo& device_info, CLContext* context); + absl::Status Compile(const CreationContext& creation_context); + absl::Status CompileDeserialized(const CreationContext& creation_context); + virtual absl::Status PostCompileCheck(const DeviceInfo& device_info, const KernelInfo& kernel_info) { return absl::OkStatus(); @@ -161,7 +169,14 @@ class GPUOperation { bool check_src_channels_size_ = false; protected: - virtual absl::Status BindArguments() { return absl::OkStatus(); } + friend flatbuffers::Offset Encode( + const GPUOperation& op, flatbuffers::FlatBufferBuilder* builder); + friend absl::Status Decode(CLContext* context, + const data::GPUOperation* fb_op, GPUOperation* op); + + virtual absl::Status BindArguments(ArgumentsBinder* args) { + return absl::OkStatus(); + } virtual int3 GetGridSize() const; // Defines operation calculation precision and format of src/dst tensors. @@ -169,11 +184,14 @@ class GPUOperation { std::vector src_; std::vector dst_; CLKernel kernel_; + int grid_dimension_ = 3; // can be 1, 2 or 3 + int3 work_group_launch_order_ = int3(0, 1, 2); int3 grid_size_ = int3(0, 0, 0); std::vector src_tensors_names_; std::vector dst_tensors_names_; private: + int3 work_groups_count_ = int3(0, 0, 0); int linkable_count_ = 0; std::string elementwise_code_; // temporary, used during op construction }; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/lstm_full_test.cc b/tensorflow/lite/delegates/gpu/cl/kernels/lstm_full_test.cc new file mode 100644 index 00000000000..08cb622ff91 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/cl/kernels/lstm_full_test.cc @@ -0,0 +1,1181 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Unit test for TFLite LSTM op. + +#include +#include + +#include +#include +#include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "tensorflow/lite/delegates/gpu/delegate.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/schema/schema_generated.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class LSTMOpModel : public SingleOpModel { + public: + LSTMOpModel(int n_batch, int n_input, int n_cell, int n_output, bool use_cifg, + bool use_peephole, bool use_projection_weights, + bool use_projection_bias, const TensorType weight_type, + bool model_has_legacy_20_inputs, bool is_layer_norm, + bool asymmetric_quantize_inputs, + std::initializer_list input_to_input_weights, + std::initializer_list input_to_forget_weights, + std::initializer_list input_to_cell_weights, + std::initializer_list input_to_output_weights, + std::initializer_list recurrent_to_input_weights, + std::initializer_list recurrent_to_forget_weights, + std::initializer_list recurrent_to_cell_weights, + std::initializer_list recurrent_to_output_weights, + std::initializer_list cell_to_input_weights, + std::initializer_list cell_to_forget_weights, + std::initializer_list cell_to_output_weights, + std::initializer_list input_gate_bias, + std::initializer_list forget_gate_bias, + std::initializer_list cell_gate_bias, + std::initializer_list output_gate_bias, + std::initializer_list projection_weights, + std::initializer_list projection_bias, + std::initializer_list input_layer_norm_coefficients, + std::initializer_list forget_layer_norm_coefficients, + std::initializer_list cell_layer_norm_coefficients, + std::initializer_list output_layer_norm_coefficients) + : n_input_(n_input), + n_output_(n_output), + n_batch_(n_batch), + weight_type_(weight_type) { + input_ = AddInput({TensorType_FLOAT32, {n_batch, n_input}}); + + if (use_cifg) { + AddNullInput(); + } else { + AddConstInput({weight_type, {n_cell, n_input}}, input_to_input_weights); + } + AddConstInput({weight_type, {n_cell, n_input}}, input_to_forget_weights); + AddConstInput({weight_type, {n_cell, n_input}}, input_to_cell_weights); + AddConstInput({weight_type, {n_cell, n_input}}, input_to_output_weights); + + if (use_cifg) { + AddNullInput(); + } else { + AddConstInput({weight_type, {n_cell, n_output}}, + recurrent_to_input_weights); + } + AddConstInput({weight_type, {n_cell, n_output}}, + recurrent_to_forget_weights); + AddConstInput({weight_type, {n_cell, n_output}}, recurrent_to_cell_weights); + AddConstInput({weight_type, {n_cell, n_output}}, + recurrent_to_output_weights); + + if (use_peephole) { + if (use_cifg) { + AddNullInput(); + } else { + AddConstInput({weight_type, {n_cell}}, cell_to_input_weights); + } + AddConstInput({weight_type, {n_cell}}, cell_to_forget_weights); + AddConstInput({weight_type, {n_cell}}, cell_to_output_weights); + } else { + AddNullInput(); + AddNullInput(); + AddNullInput(); + } + + if (use_cifg) { + AddNullInput(); + } else { + AddConstInput({TensorType_FLOAT32, {n_cell}}, input_gate_bias); + } + AddConstInput({TensorType_FLOAT32, {n_cell}}, forget_gate_bias); + AddConstInput({TensorType_FLOAT32, {n_cell}}, cell_gate_bias); + AddConstInput({TensorType_FLOAT32, {n_cell}}, output_gate_bias); + + if (use_projection_weights) { + AddConstInput({weight_type, {n_output, n_cell}}, projection_weights); + } else { + AddNullInput(); + } + if (use_projection_bias) { + CHECK(use_projection_weights); + AddConstInput({TensorType_FLOAT32, {n_output}}, projection_bias); + } else { + AddNullInput(); + } + + // Adding the 2 state tensors. + AddVariableInput({TensorType_FLOAT32, {n_batch, n_output}}); + AddVariableInput({TensorType_FLOAT32, {n_batch, n_cell}}); + + // Layer norm weights. + if (!model_has_legacy_20_inputs) { + if (is_layer_norm) { + if (use_cifg) { + AddNullInput(); + } else { + AddConstInput({TensorType_FLOAT32, {n_cell}}, + input_layer_norm_coefficients); + } + AddConstInput({TensorType_FLOAT32, {n_cell}}, + forget_layer_norm_coefficients); + AddConstInput({TensorType_FLOAT32, {n_cell}}, + cell_layer_norm_coefficients); + AddConstInput({TensorType_FLOAT32, {n_cell}}, + output_layer_norm_coefficients); + } else { + AddNullInput(); + AddNullInput(); + AddNullInput(); + AddNullInput(); + } + } + + output_ = AddOutput({TensorType_FLOAT32, {n_batch, n_output}}); + + // TODO(b/161825581): Add tests where cell_clip and/or proj_clip is not the + // default 0. + SetBuiltinOp( + BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions, + CreateLSTMOptions(builder_, ActivationFunctionType_TANH, + /*cell_clip=*/0.0f, /*proj_clip=*/0.0f, + LSTMKernelType_FULL, asymmetric_quantize_inputs) + .Union()); + + // Input shapes are already set up, no need to pass them again. + BuildInterpreter(/*input_shapes=*/{}, /*num_threads=*/-1, + /*allow_fp32_relax_to_fp16=*/false, + /*apply_delegate=*/false); + + auto options = TfLiteGpuDelegateOptionsV2Default(); + // MeanStddevNormalization is only implemented in OpenCL now. + options.experimental_flags |= TFLITE_GPU_EXPERIMENTAL_FLAGS_CL_ONLY; + SetDelegate(TfLiteGpuDelegateV2Create(&options)); + } + + ~LSTMOpModel() { EXPECT_EQ(CountOpsExecutedByCpuKernel(), 0); } + + void SetInput(int offset, const float* begin, const float* end) { + SingleOpModel::PopulateTensor(input_, offset, const_cast(begin), + const_cast(end)); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + int num_inputs() { return n_input_; } + int num_outputs() { return n_output_; } + int num_batches() { return n_batch_; } + + protected: + int input_; + int output_; + + int n_input_; + int n_output_; + int n_batch_; + + private: + const TensorType weight_type_; +}; + +// GetParam() => model_has_legacy_20_inputs +class LstmOpTest : public ::testing::TestWithParam { + protected: + // Weights of the LSTM model. Some are optional. + std::initializer_list input_to_input_weights_; + std::initializer_list input_to_forget_weights_; + std::initializer_list input_to_cell_weights_; + std::initializer_list input_to_output_weights_; + std::initializer_list recurrent_to_input_weights_; + std::initializer_list recurrent_to_forget_weights_; + std::initializer_list recurrent_to_cell_weights_; + std::initializer_list recurrent_to_output_weights_; + std::initializer_list cell_to_input_weights_; + std::initializer_list cell_to_forget_weights_; + std::initializer_list cell_to_output_weights_; + std::initializer_list input_gate_bias_; + std::initializer_list forget_gate_bias_; + std::initializer_list cell_gate_bias_; + std::initializer_list output_gate_bias_; + std::initializer_list projection_weights_; + std::initializer_list input_layer_norm_coefficients_; + std::initializer_list forget_layer_norm_coefficients_; + std::initializer_list cell_layer_norm_coefficients_; + std::initializer_list output_layer_norm_coefficients_; + + // LSTM input is stored as num_steps * num_batch * num_inputs vector. + std::vector>> lstm_input_; + // LSTM output is stored as num_steps * num_batch * num_outputs vector. + std::vector>> lstm_golden_output_; + + // Compares output up to tolerance to the result of the lstm given the input. + void VerifyGoldens(LSTMOpModel* lstm, float tolerance) { + EXPECT_EQ(lstm->ApplyDelegate(), kTfLiteOk); + + const int num_inputs = lstm->num_inputs(); + const int num_outputs = lstm->num_outputs(); + const int num_batches = lstm->num_batches(); + + ASSERT_EQ(lstm_input_.size(), lstm_golden_output_.size()); + const int num_steps = lstm_input_.size(); + + for (int i = 0; i < num_steps; ++i) { + ASSERT_EQ(num_batches, lstm_input_[i].size()); + for (int b = 0; b < num_batches; ++b) { + ASSERT_EQ(num_inputs, lstm_input_[i][b].size()); + const float* batch_start = lstm_input_[i][b].data(); + const float* batch_end = batch_start + num_inputs; + lstm->SetInput(b * num_inputs, batch_start, batch_end); + } + + lstm->Invoke(); + + std::vector expected; + ASSERT_EQ(num_batches, lstm_golden_output_[i].size()); + for (int b = 0; b < num_batches; ++b) { + ASSERT_EQ(num_outputs, lstm_golden_output_[i][b].size()); + const float* batch_start = lstm_golden_output_[i][b].data(); + const float* batch_end = batch_start + num_outputs; + expected.insert(expected.end(), batch_start, batch_end); + } + + EXPECT_THAT(lstm->GetOutput(), + ElementsAreArray(ArrayFloatNear(expected, tolerance))); + } + } +}; + +TEST_P(LstmOpTest, NoCifg_NoPeephole_NoProjection_NoLayerNorm) { + const int n_batch = 1; + const int n_input = 2; + // n_cell and n_output have the same size when there is no projection. + const int n_cell = 4; + const int n_output = 4; + + bool model_has_legacy_20_inputs = GetParam(); + + input_to_input_weights_ = {-0.45018822, -0.02338299, -0.0870589, -0.34550029, + 0.04266912, -0.15680569, -0.34856534, 0.43890524}; + input_to_cell_weights_ = {-0.50013041, 0.1370284, 0.11810488, 0.2013163, + -0.20583314, 0.44344562, 0.22077113, -0.29909778}; + input_to_forget_weights_ = {0.09701663, 0.20334584, -0.50592935, + -0.31343272, -0.40032279, 0.44781327, + 0.01387155, -0.35593212}; + input_to_output_weights_ = {-0.25065863, -0.28290087, 0.04613829, 0.40525138, + 0.44272184, 0.03897077, -0.1556896, 0.19487578}; + input_gate_bias_ = {0., 0., 0., 0.}; + cell_gate_bias_ = {0., 0., 0., 0.}; + forget_gate_bias_ = {1., 1., 1., 1.}; + output_gate_bias_ = {0., 0., 0., 0.}; + + recurrent_to_input_weights_ = { + -0.0063535, -0.2042388, 0.31454784, -0.35746509, + 0.28902304, 0.08183324, -0.16555229, 0.02286911, + -0.13566875, 0.03034258, 0.48091322, -0.12528998, + 0.24077177, -0.51332325, -0.33502164, 0.10629296}; + + recurrent_to_cell_weights_ = { + -0.3407414, 0.24443203, -0.2078532, 0.26320225, + 0.05695659, -0.00123841, -0.4744786, -0.35869038, + -0.06418842, -0.13502428, -0.501764, 0.22830659, + -0.46367589, 0.26016325, -0.03894562, -0.16368064}; + + recurrent_to_forget_weights_ = { + -0.48684245, -0.06655136, 0.42224967, 0.2112639, + 0.27654213, 0.20864892, -0.07646349, 0.45877004, + 0.00141793, -0.14609534, 0.36447752, 0.09196436, + 0.28053468, 0.01560611, -0.20127171, -0.01140004}; + + recurrent_to_output_weights_ = { + 0.43385774, -0.17194885, 0.2718237, 0.09215671, + 0.24107647, -0.39835793, 0.18212086, 0.01301402, + 0.48572797, -0.50656658, 0.20047462, -0.20607421, + -0.51818722, -0.15390486, 0.0468148, 0.39922136}; + + // num_steps * num_batch * num_inputs + lstm_input_ = {{{2., 3.}}, {{3., 4.}}, {{1., 1.}}}; + // num_steps * num_batch * num_outputs + lstm_golden_output_ = {{{-0.02973187, 0.1229473, 0.20885126, -0.15358765}}, + {{-0.03716109, 0.12507336, 0.41193449, -0.20860538}}, + {{-0.15053082, 0.09120187, 0.24278517, -0.12222792}}}; + + LSTMOpModel lstm( + n_batch, n_input, n_cell, n_output, + /*use_cifg=*/false, /*use_peephole=*/false, + /*use_projection_weights=*/false, + /*use_projection_bias=*/false, /*weight_type=*/TensorType_FLOAT32, + model_has_legacy_20_inputs, + /*is_layer_norm=*/false, /*asymmetric_quantize_inputs=*/false, + input_to_input_weights_, input_to_forget_weights_, input_to_cell_weights_, + input_to_output_weights_, recurrent_to_input_weights_, + recurrent_to_forget_weights_, recurrent_to_cell_weights_, + recurrent_to_output_weights_, cell_to_input_weights_, + cell_to_forget_weights_, cell_to_output_weights_, input_gate_bias_, + forget_gate_bias_, cell_gate_bias_, output_gate_bias_, + projection_weights_, {}, input_layer_norm_coefficients_, + forget_layer_norm_coefficients_, cell_layer_norm_coefficients_, + output_layer_norm_coefficients_); + + VerifyGoldens(&lstm, 0.00001f); +} + +TEST_P(LstmOpTest, Cifg_Peephole_NoProjection_NoLayerNorm) { + const int n_batch = 1; + const int n_input = 2; + // n_cell and n_output have the same size when there is no projection. + const int n_cell = 4; + const int n_output = 4; + + bool model_has_legacy_20_inputs = GetParam(); + + input_to_cell_weights_ = {-0.49770179, -0.27711356, -0.09624726, 0.05100781, + 0.04717243, 0.48944736, -0.38535351, -0.17212132}; + + input_to_forget_weights_ = {-0.55291498, -0.42866567, 0.13056988, -0.3633365, + -0.22755712, 0.28253698, 0.24407166, 0.33826375}; + + input_to_output_weights_ = {0.10725588, -0.02335852, -0.55932593, + -0.09426838, -0.44257352, 0.54939759, + 0.01533556, 0.42751634}; + cell_gate_bias_ = {0., 0., 0., 0.}; + forget_gate_bias_ = {1., 1., 1., 1.}; + output_gate_bias_ = {0., 0., 0., 0.}; + + recurrent_to_cell_weights_ = { + 0.54066205, -0.32668582, -0.43562764, -0.56094903, + 0.42957711, 0.01841056, -0.32764608, -0.33027974, + -0.10826075, 0.20675004, 0.19069612, -0.03026325, + -0.54532051, 0.33003211, 0.44901288, 0.21193194}; + + recurrent_to_forget_weights_ = { + -0.13832897, -0.0515101, -0.2359007, -0.16661474, + -0.14340827, 0.36986142, 0.23414481, 0.55899, + 0.10798943, -0.41174671, 0.17751795, -0.34484994, + -0.35874045, -0.11352962, 0.27268326, 0.54058349}; + + recurrent_to_output_weights_ = { + 0.41613156, 0.42610586, -0.16495961, -0.5663873, + 0.30579174, -0.05115908, -0.33941799, 0.23364776, + 0.11178309, 0.09481031, -0.26424935, 0.46261835, + 0.50248802, 0.26114327, -0.43736315, 0.33149987}; + + cell_to_forget_weights_ = {0.47485286, -0.51955009, -0.24458408, 0.31544167}; + cell_to_output_weights_ = {-0.17135078, 0.82760304, 0.85573703, -0.77109635}; + + lstm_input_ = {{{2., 3.}}, {{3., 4.}}, {{1., 1.}}}; + lstm_golden_output_ = {{{-0.36444446, -0.00352185, 0.12886585, -0.05163646}}, + {{-0.42312205, -0.01218222, 0.24201041, -0.08124574}}, + {{-0.358325, -0.04621704, 0.21641694, -0.06471302}}}; + + LSTMOpModel lstm( + n_batch, n_input, n_cell, n_output, + /*use_cifg=*/true, /*use_peephole=*/true, + /*use_projection_weights=*/false, + /*use_projection_bias=*/false, /*weight_type=*/TensorType_FLOAT32, + model_has_legacy_20_inputs, + /*is_layer_norm=*/false, /*asymmetric_quantize_inputs=*/false, + input_to_input_weights_, input_to_forget_weights_, input_to_cell_weights_, + input_to_output_weights_, recurrent_to_input_weights_, + recurrent_to_forget_weights_, recurrent_to_cell_weights_, + recurrent_to_output_weights_, cell_to_input_weights_, + cell_to_forget_weights_, cell_to_output_weights_, input_gate_bias_, + forget_gate_bias_, cell_gate_bias_, output_gate_bias_, + projection_weights_, {}, input_layer_norm_coefficients_, + forget_layer_norm_coefficients_, cell_layer_norm_coefficients_, + output_layer_norm_coefficients_); + + VerifyGoldens(&lstm, 0.00001f); +} + +TEST_P(LstmOpTest, NoCifg_Peephole_Projection_NoLayerNorm) { + const int n_batch = 1; + const int n_input = 5; + const int n_cell = 20; + const int n_output = 16; + + bool model_has_legacy_20_inputs = GetParam(); + + input_to_input_weights_ = { + 0.021393683, 0.06124551, 0.046905167, -0.014657677, -0.03149463, + 0.09171803, 0.14647801, 0.10797193, -0.0057968358, 0.0019193048, + -0.2726754, 0.10154029, -0.018539885, 0.080349885, -0.10262385, + -0.022599787, -0.09121155, -0.008675967, -0.045206103, -0.0821282, + -0.008045952, 0.015478081, 0.055217247, 0.038719587, 0.044153627, + -0.06453243, 0.05031825, -0.046935108, -0.008164439, 0.014574226, + -0.1671009, -0.15519552, -0.16819797, -0.13971269, -0.11953059, + 0.25005487, -0.22790983, 0.009855087, -0.028140958, -0.11200698, + 0.11295408, -0.0035217577, 0.054485075, 0.05184695, 0.064711206, + 0.10989193, 0.11674786, 0.03490607, 0.07727357, 0.11390585, + -0.1863375, -0.1034451, -0.13945189, -0.049401227, -0.18767063, + 0.042483903, 0.14233552, 0.13832581, 0.18350165, 0.14545603, + -0.028545704, 0.024939531, 0.050929718, 0.0076203286, -0.0029723682, + -0.042484224, -0.11827596, -0.09171104, -0.10808628, -0.16327988, + -0.2273378, -0.0993647, -0.017155107, 0.0023917493, 0.049272764, + 0.0038534778, 0.054764505, 0.089753784, 0.06947234, 0.08014476, + -0.04544234, -0.0497073, -0.07135631, -0.048929106, -0.004042012, + -0.009284026, 0.018042054, 0.0036860977, -0.07427302, -0.11434604, + -0.018995456, 0.031487543, 0.012834908, 0.019977754, 0.044256654, + -0.39292613, -0.18519334, -0.11651281, -0.06809892, 0.011373677}; + + input_to_forget_weights_ = { + -0.0018401089, -0.004852237, 0.03698424, 0.014181704, 0.028273236, + -0.016726194, -0.05249759, -0.10204261, 0.00861066, -0.040979505, + -0.009899187, 0.01923892, -0.028177269, -0.08535103, -0.14585495, + 0.10662567, -0.01909731, -0.017883534, -0.0047269356, -0.045103323, + 0.0030784295, 0.076784775, 0.07463696, 0.094531395, 0.0814421, + -0.12257899, -0.033945758, -0.031303465, 0.045630626, 0.06843887, + -0.13492945, -0.012480007, -0.0811829, -0.07224499, -0.09628791, + 0.045100946, 0.0012300825, 0.013964662, 0.099372394, 0.02543059, + 0.06958324, 0.034257296, 0.0482646, 0.06267997, 0.052625068, + 0.12784666, 0.07077897, 0.025725935, 0.04165009, 0.07241905, + 0.018668644, -0.037377294, -0.06277783, -0.08833636, -0.040120605, + -0.011405586, -0.007808335, -0.010301386, -0.005102167, 0.027717464, + 0.05483423, 0.11449111, 0.11289652, 0.10939839, 0.13396506, + -0.08402166, -0.01901462, -0.044678304, -0.07720565, 0.014350063, + -0.11757958, -0.0652038, -0.08185733, -0.076754324, -0.092614375, + 0.10405491, 0.052960336, 0.035755895, 0.035839386, -0.012540553, + 0.036881298, 0.02913376, 0.03420159, 0.05448447, -0.054523353, + 0.02582715, 0.02327355, -0.011857179, -0.0011980024, -0.034641717, + -0.026125094, -0.17582615, -0.15923657, -0.27486774, -0.0006143371, + 0.0001771948, -8.470171e-05, 0.02651807, 0.045790765, 0.06956496}; + + input_to_cell_weights_ = { + -0.04580283, -0.09549462, -0.032418985, -0.06454633, -0.043528453, + 0.043018587, -0.049152344, -0.12418144, -0.078985475, -0.07596889, + 0.019484362, -0.11434962, -0.0074034138, -0.06314844, -0.092981495, + 0.0062155537, -0.025034338, -0.0028890965, 0.048929527, 0.06235075, + 0.10665918, -0.032036792, -0.08505916, -0.10843358, -0.13002433, + -0.036816437, -0.02130134, -0.016518239, 0.0047691227, -0.0025825808, + 0.066017866, 0.029991534, -0.10652836, -0.1037554, -0.13056071, + -0.03266643, -0.033702414, -0.006473424, -0.04611692, 0.014419339, + -0.025174323, 0.0396852, 0.081777506, 0.06157468, 0.10210095, + -0.009658194, 0.046511717, 0.03603906, 0.0069369148, 0.015960095, + -0.06507666, 0.09551598, 0.053568836, 0.06408714, 0.12835667, + -0.008714329, -0.20211966, -0.12093674, 0.029450472, 0.2849013, + -0.029227901, 0.1164364, -0.08560263, 0.09941786, -0.036999565, + -0.028842626, -0.0033637602, -0.017012902, -0.09720865, -0.11193351, + -0.029155117, -0.017936034, -0.009768936, -0.04223324, -0.036159635, + 0.06505112, -0.021742892, -0.023377212, -0.07221364, -0.06430552, + 0.05453865, 0.091149814, 0.06387331, 0.007518393, 0.055960953, + 0.069779344, 0.046411168, 0.10509911, 0.07463894, 0.0075130584, + 0.012850982, 0.04555431, 0.056955688, 0.06555285, 0.050801456, + -0.009862683, 0.00826772, -0.026555609, -0.0073611983, -0.0014897042}; + + input_to_output_weights_ = { + -0.0998932, -0.07201956, -0.052803773, -0.15629593, -0.15001918, + -0.07650751, 0.02359855, -0.075155355, -0.08037709, -0.15093534, + 0.029517552, -0.04751393, 0.010350531, -0.02664851, -0.016839722, + -0.023121163, 0.0077019283, 0.012851257, -0.05040649, -0.0129761, + -0.021737747, -0.038305793, -0.06870586, -0.01481247, -0.001285394, + 0.10124236, 0.083122835, 0.053313006, -0.062235646, -0.075637154, + -0.027833903, 0.029774971, 0.1130802, 0.09218906, 0.09506135, + -0.086665764, -0.037162706, -0.038880914, -0.035832845, -0.014481564, + -0.09825003, -0.12048569, -0.097665586, -0.05287633, -0.0964047, + -0.11366429, 0.035777505, 0.13568819, 0.052451383, 0.050649304, + 0.05798951, -0.021852335, -0.099848844, 0.014740475, -0.078897946, + 0.04974699, 0.014160473, 0.06973932, 0.04964942, 0.033364646, + 0.08190124, 0.025535367, 0.050893165, 0.048514254, 0.06945813, + -0.078907564, -0.06707616, -0.11844508, -0.09986688, -0.07509403, + 0.06263226, 0.14925587, 0.20188436, 0.12098451, 0.14639415, + 0.0015017595, -0.014267382, -0.03417257, 0.012711468, 0.0028300495, + -0.024758482, -0.05098548, -0.0821182, 0.014225672, 0.021544158, + 0.08949725, 0.07505268, -0.0020780868, 0.04908258, 0.06476295, + -0.022907063, 0.027562456, 0.040185735, 0.019567577, -0.015598739, + -0.049097303, -0.017121866, -0.083368234, -0.02332002, -0.0840956}; + + input_gate_bias_ = {0.02234832, 0.14757581, 0.18176508, 0.10380666, + 0.053110216, -0.06928846, -0.13942584, -0.11816189, + 0.19483899, 0.03652339, -0.10250295, 0.036714908, + -0.18426876, 0.036065217, 0.21810818, 0.02383196, + -0.043370757, 0.08690144, -0.04444982, 0.00030581196}; + + forget_gate_bias_ = {0.035185695, -0.042891346, -0.03032477, 0.23027696, + 0.11098921, 0.15378423, 0.09263801, 0.09790885, + 0.09508917, 0.061199076, 0.07665568, -0.015443159, + -0.03499149, 0.046190713, 0.08895977, 0.10899629, + 0.40694186, 0.06030037, 0.012413437, -0.06108739}; + + cell_gate_bias_ = {-0.024379363, 0.0055531194, 0.23377132, 0.033463873, + -0.1483596, -0.10639995, -0.091433935, 0.058573797, + -0.06809782, -0.07889636, -0.043246906, -0.09829136, + -0.4279842, 0.034901652, 0.18797937, 0.0075234566, + 0.016178843, 0.1749513, 0.13975595, 0.92058027}; + + output_gate_bias_ = {0.046159424, -0.0012809046, 0.03563469, 0.12648113, + 0.027195795, 0.35373217, -0.018957434, 0.008907322, + -0.0762701, 0.12018895, 0.04216877, 0.0022856654, + 0.040952638, 0.3147856, 0.08225149, -0.057416286, + -0.14995944, -0.008040261, 0.13208859, 0.029760877}; + + recurrent_to_input_weights_ = { + -0.001374326, -0.078856036, 0.10672688, 0.029162422, + -0.11585556, 0.02557986, -0.13446963, -0.035785314, + -0.01244275, 0.025961924, -0.02337298, -0.044228926, + -0.055839065, -0.046598054, -0.010546039, -0.06900766, + 0.027239809, 0.022582639, -0.013296484, -0.05459212, + 0.08981, -0.045407712, 0.08682226, -0.06867011, + -0.14390695, -0.02916037, 0.000996957, 0.091420636, + 0.14283475, -0.07390571, -0.06402044, 0.062524505, + -0.093129106, 0.04860203, -0.08364217, -0.08119002, + 0.009352075, 0.22920375, 0.0016303885, 0.11583097, + -0.13732095, 0.012405723, -0.07551853, 0.06343048, + 0.12162708, -0.031923793, -0.014335606, 0.01790974, + -0.10650317, -0.0724401, 0.08554849, -0.05727212, + 0.06556731, -0.042729504, -0.043227166, 0.011683251, + -0.013082158, -0.029302018, -0.010899579, -0.062036745, + -0.022509435, -0.00964907, -0.01567329, 0.04260106, + -0.07787477, -0.11576462, 0.017356863, 0.048673786, + -0.017577527, -0.05527947, -0.082487635, -0.040137455, + -0.10820036, -0.04666372, 0.022746278, -0.07851417, + 0.01068115, 0.032956902, 0.022433773, 0.0026891115, + 0.08944216, -0.0685835, 0.010513544, 0.07228705, + 0.02032331, -0.059686817, -0.0005566496, -0.086984694, + 0.040414046, -0.1380399, 0.094208956, -0.05722982, + 0.012092817, -0.04989123, -0.086576, -0.003399834, + -0.04696032, -0.045747425, 0.10091314, 0.048676282, + -0.029037097, 0.031399418, -0.0040285117, 0.047237843, + 0.09504992, 0.041799378, -0.049185462, -0.031518843, + -0.10516937, 0.026374253, 0.10058866, -0.0033195973, + -0.041975245, 0.0073591834, 0.0033782164, -0.004325073, + -0.10167381, 0.042500053, -0.01447153, 0.06464186, + -0.017142897, 0.03312627, 0.009205989, 0.024138335, + -0.011337001, 0.035530265, -0.010912711, 0.0706555, + -0.005894094, 0.051841937, -0.1401738, -0.02351249, + 0.0365468, 0.07590991, 0.08838724, 0.021681072, + -0.10086113, 0.019608743, -0.06195883, 0.077335775, + 0.023646897, -0.095322326, 0.02233014, 0.09756986, + -0.048691444, -0.009579111, 0.07595467, 0.11480546, + -0.09801813, 0.019894179, 0.08502348, 0.004032281, + 0.037211012, 0.068537936, -0.048005626, -0.091520436, + -0.028379958, -0.01556313, 0.06554592, -0.045599163, + -0.01672207, -0.020169014, -0.011877351, -0.20212261, + 0.010889619, 0.0047078193, 0.038385306, 0.08540671, + -0.017140968, -0.0035865551, 0.016678626, 0.005633034, + 0.015963363, 0.00871737, 0.060130805, 0.028611384, + 0.10109069, -0.015060172, -0.07894427, 0.06401885, + 0.011584063, -0.024466386, 0.0047652307, -0.09041358, + 0.030737216, -0.0046374933, 0.14215417, -0.11823516, + 0.019899689, 0.006106124, -0.027092824, 0.0786356, + 0.05052217, -0.058925, -0.011402121, -0.024987547, + -0.0013661642, -0.06832946, -0.015667673, -0.1083353, + -0.00096863037, -0.06988685, -0.053350925, -0.027275559, + -0.033664223, -0.07978348, -0.025200296, -0.017207067, + -0.058403496, -0.055697463, 0.005798788, 0.12965427, + -0.062582195, 0.0013350133, -0.10482091, 0.0379771, + 0.072521195, -0.0029455067, -0.13797039, -0.03628521, + 0.013806405, -0.017858358, -0.01008298, -0.07700066, + -0.017081132, 0.019358726, 0.0027079724, 0.004635139, + 0.062634714, -0.02338735, -0.039547626, -0.02050681, + 0.03385117, -0.083611414, 0.002862572, -0.09421313, + 0.058618143, -0.08598433, 0.00972939, 0.023867095, + -0.053934585, -0.023203006, 0.07452513, -0.048767887, + -0.07314807, -0.056307215, -0.10433547, -0.06440842, + 0.04328182, 0.04389765, -0.020006588, -0.09076438, + -0.11652589, -0.021705797, 0.03345259, -0.010329105, + -0.025767034, 0.013057034, -0.07316461, -0.10145612, + 0.06358255, 0.18531723, 0.07759293, 0.12006465, + 0.1305557, 0.058638252, -0.03393652, 0.09622831, + -0.16253184, -2.4580743e-06, 0.079869635, -0.070196845, + -0.005644518, 0.06857898, -0.12598175, -0.035084512, + 0.03156317, -0.12794146, -0.031963028, 0.04692781, + 0.030070418, 0.0071660685, -0.095516115, -0.004643372, + 0.040170413, -0.062104587, -0.0037324072, 0.0554317, + 0.08184801, -0.019164372, 0.06791302, 0.034257166, + -0.10307039, 0.021943003, 0.046745934, 0.0790918, + -0.0265588, -0.007824208, 0.042546265, -0.00977924, + -0.0002440307, -0.017384544, -0.017990116, 0.12252321, + -0.014512694, -0.08251313, 0.08861942, 0.13589665, + 0.026351685, 0.012641483, 0.07466548, 0.044301085, + -0.045414884, -0.051112458, 0.03444247, -0.08502782, + -0.04106223, -0.028126027, 0.028473156, 0.10467447}; + + recurrent_to_cell_weights_ = { + -0.037322544, 0.018592842, 0.0056175636, -0.06253426, + 0.055647098, -0.05713207, -0.05626563, 0.005559383, + 0.03375411, -0.025757805, -0.088049285, 0.06017052, + -0.06570978, 0.007384076, 0.035123326, -0.07920549, + 0.053676967, 0.044480428, -0.07663568, 0.0071805613, + 0.08089997, 0.05143358, 0.038261272, 0.03339287, + -0.027673481, 0.044746667, 0.028349208, 0.020090483, + -0.019443132, -0.030755889, -0.0040000007, 0.04465846, + -0.021585021, 0.0031670958, 0.0053199246, -0.056117613, + -0.10893326, 0.076739706, -0.08509834, -0.027997585, + 0.037871376, 0.01449768, -0.09002357, -0.06111149, + -0.046195522, 0.0422062, -0.005683705, -0.1253618, + -0.012925729, -0.04890792, 0.06985068, 0.037654128, + 0.03398274, -0.004781977, 0.007032333, -0.031787455, + 0.010868644, -0.031489216, 0.09525667, 0.013939797, + 0.0058680447, 0.0167067, 0.02668468, -0.04797466, + -0.048885044, -0.12722108, 0.035304096, 0.06554885, + 0.00972396, -0.039238118, -0.05159735, -0.11329045, + 0.1613692, -0.03750952, 0.06529313, -0.071974665, + -0.11769596, 0.015524369, -0.0013754242, -0.12446318, + 0.02786344, -0.014179351, 0.005264273, 0.14376344, + 0.015983658, 0.03406988, -0.06939408, 0.040699873, + 0.02111075, 0.09669095, 0.041345075, -0.08316494, + -0.07684199, -0.045768797, 0.032298047, -0.041805092, + 0.0119405, 0.0061010392, 0.12652606, 0.0064572375, + -0.024950314, 0.11574242, 0.04508852, -0.04335324, + 0.06760663, -0.027437469, 0.07216407, 0.06977076, + -0.05438599, 0.034033038, -0.028602652, 0.05346137, + 0.043184172, -0.037189785, 0.10420091, 0.00882477, + -0.054019816, -0.074273005, -0.030617684, -0.0028467078, + 0.024302477, -0.0038869337, 0.005332455, 0.0013399826, + 0.04361412, -0.007001822, 0.09631092, -0.06702025, + -0.042049985, -0.035070654, -0.04103342, -0.10273396, + 0.0544271, 0.037184782, -0.13150354, -0.0058036847, + -0.008264958, 0.042035464, 0.05891794, 0.029673764, + 0.0063542654, 0.044788733, 0.054816857, 0.062257513, + -0.00093483756, 0.048938446, -0.004952862, -0.007730018, + -0.04043371, -0.017094059, 0.07229206, -0.023670016, + -0.052195564, -0.025616996, -0.01520939, 0.045104615, + -0.007376126, 0.003533447, 0.006570588, 0.056037236, + 0.12436656, 0.051817212, 0.028532185, -0.08686856, + 0.11868599, 0.07663395, -0.07323171, 0.03463402, + -0.050708205, -0.04458982, -0.11590894, 0.021273347, + 0.1251325, -0.15313013, -0.12224372, 0.17228661, + 0.023029093, 0.086124025, 0.006445803, -0.03496501, + 0.028332196, 0.04449512, -0.042436164, -0.026587414, + -0.006041347, -0.09292539, -0.05678812, 0.03897832, + 0.09465633, 0.008115513, -0.02171956, 0.08304309, + 0.071401566, 0.019622514, 0.032163795, -0.004167056, + 0.02295182, 0.030739572, 0.056506045, 0.004612461, + 0.06524936, 0.059999723, 0.046395954, -0.0045512207, + -0.1335546, -0.030136576, 0.11584653, -0.014678886, + 0.0020118146, -0.09688814, -0.0790206, 0.039770417, + -0.0329582, 0.07922767, 0.029322514, 0.026405897, + 0.04207835, -0.07073373, 0.063781224, 0.0859677, + -0.10925287, -0.07011058, 0.048005477, 0.03438226, + -0.09606514, -0.006669445, -0.043381985, 0.04240257, + -0.06955775, -0.06769346, 0.043903265, -0.026784198, + -0.017840602, 0.024307009, -0.040079936, -0.019946516, + 0.045318738, -0.12233574, 0.026170589, 0.0074471775, + 0.15978073, 0.10185836, 0.10298046, -0.015476589, + -0.039390966, -0.072174534, 0.0739445, -0.1211869, + -0.0347889, -0.07943156, 0.014809798, -0.12412325, + -0.0030663363, 0.039695457, 0.0647603, -0.08291318, + -0.018529687, -0.004423833, 0.0037507233, 0.084633216, + -0.01514876, -0.056505352, -0.012800942, -0.06994386, + 0.012962922, -0.031234352, 0.07029052, 0.016418684, + 0.03618972, 0.055686004, -0.08663945, -0.017404709, + -0.054761406, 0.029065743, 0.052404847, 0.020238016, + 0.0048197987, -0.0214882, 0.07078733, 0.013016777, + 0.06262858, 0.009184685, 0.020785125, -0.043904778, + -0.0270329, -0.03299152, -0.060088247, -0.015162964, + -0.001828936, 0.12642565, -0.056757294, 0.013586685, + 0.09232601, -0.035886683, 0.06000002, 0.05229691, + -0.052580316, -0.082029596, -0.010794592, 0.012947712, + -0.036429964, -0.085508935, -0.13127148, -0.017744139, + 0.031502828, 0.036232427, -0.031581745, 0.023051167, + -0.05325106, -0.03421577, 0.028793324, -0.034633752, + -0.009881397, -0.043551125, -0.018609839, 0.0019097115, + -0.008799762, 0.056595087, 0.0022273948, 0.055752404}; + + recurrent_to_forget_weights_ = { + -0.057784554, -0.026057621, -0.068447545, -0.022581743, + 0.14811787, 0.10826372, 0.09471067, 0.03987225, + -0.0039523416, 0.00030638507, 0.053185795, 0.10572994, + 0.08414449, -0.022036452, -0.00066928595, -0.09203576, + 0.032950465, -0.10985798, -0.023809856, 0.0021431844, + -0.02196096, -0.00326074, 0.00058621005, -0.074678116, + -0.06193199, 0.055729095, 0.03736828, 0.020123724, + 0.061878487, -0.04729229, 0.034919553, -0.07585433, + -0.04421272, -0.044019096, 0.085488975, 0.04058006, + -0.06890133, -0.030951202, -0.024628663, -0.07672815, + 0.034293607, 0.08556707, -0.05293577, -0.033561368, + -0.04899627, 0.0241671, 0.015736353, -0.095442444, + -0.029564252, 0.016493602, -0.035026584, 0.022337519, + -0.026871363, 0.004780428, 0.0077918363, -0.03601621, + 0.016435321, -0.03263031, -0.09543275, -0.047392778, + 0.013454138, 0.028934088, 0.01685226, -0.086110644, + -0.046250615, -0.01847454, 0.047608484, 0.07339695, + 0.034546845, -0.04881143, 0.009128804, -0.08802852, + 0.03761666, 0.008096139, -0.014454086, 0.014361001, + -0.023502491, -0.0011840804, -0.07607001, 0.001856849, + -0.06509276, -0.006021153, -0.08570962, -0.1451793, + 0.060212336, 0.055259194, 0.06974018, 0.049454916, + -0.027794661, -0.08077226, -0.016179763, 0.1169753, + 0.17213494, -0.0056326236, -0.053934924, -0.0124349, + -0.11520337, 0.05409887, 0.088759385, 0.0019655675, + 0.0042065294, 0.03881498, 0.019844765, 0.041858196, + -0.05695512, 0.047233116, 0.038937137, -0.06542224, + 0.014429736, -0.09719407, 0.13908425, -0.05379757, + 0.012321099, 0.082840554, -0.029899208, 0.044217527, + 0.059855383, 0.07711018, -0.045319796, 0.0948846, + -0.011724666, -0.0033288454, -0.033542685, -0.04764985, + -0.13873616, 0.040668588, 0.034832682, -0.015319203, + -0.018715994, 0.046002675, 0.0599172, -0.043107376, + 0.0294216, -0.002314414, -0.022424703, 0.0030315618, + 0.0014641669, 0.0029166266, -0.11878115, 0.013738511, + 0.12375372, -0.0006038222, 0.029104086, 0.087442465, + 0.052958444, 0.07558703, 0.04817258, 0.044462286, + -0.015213451, -0.08783778, -0.0561384, -0.003008196, + 0.047060397, -0.002058388, 0.03429439, -0.018839769, + 0.024734668, 0.024614193, -0.042046934, 0.09597743, + -0.0043254104, 0.04320769, 0.0064070094, -0.0019131786, + -0.02558259, -0.022822596, -0.023273505, -0.02464396, + -0.10991725, -0.006240552, 0.0074488563, 0.024044557, + 0.04383914, -0.046476185, 0.028658995, 0.060410924, + 0.050786525, 0.009452605, -0.0073054377, -0.024810238, + 0.0052906186, 0.0066939713, -0.0020913032, 0.014515517, + 0.015898481, 0.021362653, -0.030262267, 0.016587038, + -0.011442813, 0.041154444, -0.007631438, -0.03423484, + -0.010977775, 0.036152758, 0.0066366293, 0.11915515, + 0.02318443, -0.041350313, 0.021485701, -0.10906167, + -0.028218046, -0.00954771, 0.020531068, -0.11995105, + -0.03672871, 0.024019798, 0.014255957, -0.05221243, + -0.00661567, -0.04630967, 0.033188973, 0.10107534, + -0.014027541, 0.030796422, -0.10270911, -0.035999842, + 0.15443139, 0.07684145, 0.036571592, -0.035900835, + -0.0034699554, 0.06209149, 0.015920248, -0.031122351, + -0.03858649, 0.01849943, 0.13872518, 0.01503974, + 0.069941424, -0.06948533, -0.0088794185, 0.061282158, + -0.047401894, 0.03100163, -0.041533746, -0.10430945, + 0.044574402, -0.01425562, -0.024290353, 0.034563623, + 0.05866852, 0.023947537, -0.09445152, 0.035450947, + 0.02247216, -0.0042998926, 0.061146557, -0.10250651, + 0.020881841, -0.06747029, 0.10062043, -0.0023941975, + 0.03532124, -0.016341697, 0.09685456, -0.016764693, + 0.051808182, 0.05875331, -0.04536488, 0.001626336, + -0.028892258, -0.01048663, -0.009793449, -0.017093895, + 0.010987891, 0.02357273, -0.00010856845, 0.0099760275, + -0.001845119, -0.03551521, 0.0018358806, 0.05763657, + -0.01769146, 0.040995963, 0.02235177, -0.060430344, + 0.11475477, -0.023854522, 0.10071741, 0.0686208, + -0.014250481, 0.034261297, 0.047418304, 0.08562733, + -0.030519066, 0.0060542435, 0.014653856, -0.038836084, + 0.04096551, 0.032249358, -0.08355519, -0.026823482, + 0.056386515, -0.010401743, -0.028396193, 0.08507674, + 0.014410365, 0.020995233, 0.17040324, 0.11511526, + 0.02459721, 0.0066619175, 0.025853224, -0.023133837, + -0.081302024, 0.017264642, -0.009585969, 0.09491168, + -0.051313367, 0.054532815, -0.014298593, 0.10657464, + 0.007076659, 0.10964551, 0.0409152, 0.008275321, + -0.07283536, 0.07937492, 0.04192024, -0.1075027}; + + recurrent_to_output_weights_ = { + 0.025825322, -0.05813119, 0.09495884, -0.045984812, -0.01255415, + -0.0026479573, -0.08196161, -0.054914974, -0.0046604523, -0.029587349, + -0.044576716, -0.07480124, -0.082868785, 0.023254942, 0.027502948, + -0.0039728214, -0.08683098, -0.08116779, -0.014675607, -0.037924774, + -0.023314456, -0.007401714, -0.09255757, 0.029460307, -0.08829125, + -0.005139627, -0.08989442, -0.0555066, 0.13596267, -0.025062224, + -0.048351806, -0.03850004, 0.07266485, -0.022414139, 0.05940088, + 0.075114764, 0.09597592, -0.010211725, -0.0049794707, -0.011523867, + -0.025980417, 0.072999895, 0.11091378, -0.081685916, 0.014416728, + 0.043229222, 0.034178585, -0.07530371, 0.035837382, -0.085607, + -0.007721233, -0.03287832, -0.043848954, -0.06404588, -0.06632928, + -0.073643476, 0.008214239, -0.045984086, 0.039764922, 0.03474462, + 0.060612556, -0.080590084, 0.049127717, 0.04151091, -0.030063879, + 0.008801774, -0.023021035, -0.019558564, 0.05158114, -0.010947698, + -0.011825728, 0.0075720972, 0.0699727, -0.0039981045, 0.069350146, + 0.08799282, 0.016156472, 0.035502106, 0.11695009, 0.006217345, + 0.13392477, -0.037875112, 0.025745004, 0.08940699, -0.00924166, + 0.0046702605, -0.036598757, -0.08811812, 0.10522024, -0.032441203, + 0.008176899, -0.04454919, 0.07058152, 0.0067963637, 0.039206743, + 0.03259838, 0.03725492, -0.09515802, 0.013326398, -0.052055415, + -0.025676316, 0.03198509, -0.015951829, -0.058556724, 0.036879618, + 0.043357447, 0.028362012, -0.05908629, 0.0059240665, -0.04995891, + -0.019187413, 0.0276265, -0.01628143, 0.0025863599, 0.08800015, + 0.035250366, -0.022165963, -0.07328642, -0.009415526, -0.07455109, + 0.11690406, 0.0363299, 0.07411125, 0.042103454, -0.009660886, + 0.019076364, 0.018299393, -0.046004917, 0.08891175, 0.0431396, + -0.026327137, -0.051502608, 0.08979574, -0.051670972, 0.04940282, + -0.07491107, -0.021240504, 0.022596184, -0.034280192, 0.060163025, + -0.058211457, -0.051837247, -0.01349775, -0.04639988, -0.035936575, + -0.011681591, 0.064818054, 0.0073146066, -0.021745546, -0.043124277, + -0.06471268, -0.07053354, -0.029321948, -0.05330136, 0.016933719, + -0.053782392, 0.13747959, -0.1361751, -0.11569455, 0.0033329215, + 0.05693899, -0.053219706, 0.063698, 0.07977434, -0.07924483, + 0.06936997, 0.0034815092, -0.007305279, -0.037325785, -0.07251102, + -0.033633437, -0.08677009, 0.091591336, -0.14165086, 0.021752775, + 0.019683983, 0.0011612234, -0.058154266, 0.049996935, 0.0288841, + -0.0024567875, -0.14345716, 0.010955264, -0.10234828, 0.1183656, + -0.0010731248, -0.023590032, -0.072285876, -0.0724771, -0.026382286, + -0.0014920527, 0.042667855, 0.0018776858, 0.02986552, 0.009814309, + 0.0733756, 0.12289186, 0.018043943, -0.0458958, 0.049412545, + 0.033632483, 0.05495232, 0.036686596, -0.013781798, -0.010036754, + 0.02576849, -0.08307328, 0.010112348, 0.042521734, -0.05869831, + -0.071689695, 0.03876447, -0.13275425, -0.0352966, -0.023077697, + 0.10285965, 0.084736146, 0.15568255, -0.00040734606, 0.027835453, + -0.10292561, -0.032401145, 0.10053256, -0.026142767, -0.08271222, + -0.0030240538, -0.016368777, 0.1070414, 0.042672627, 0.013456989, + -0.0437609, -0.022309763, 0.11576483, 0.04108048, 0.061026827, + -0.0190714, -0.0869359, 0.037901703, 0.0610107, 0.07202949, + 0.01675338, 0.086139716, -0.08795751, -0.014898893, -0.023771819, + -0.01965048, 0.007955471, -0.043740474, 0.03346837, -0.10549954, + 0.090567775, 0.042013682, -0.03176985, 0.12569028, -0.02421228, + -0.029526481, 0.023851605, 0.031539805, 0.05292009, -0.02344001, + -0.07811758, -0.08834428, 0.10094801, 0.16594367, -0.06861939, + -0.021256343, -0.041093912, -0.06669611, 0.035498552, 0.021757556, + -0.09302526, -0.015403468, -0.06614931, -0.051798206, -0.013874718, + 0.03630673, 0.010412845, -0.08077351, 0.046185967, 0.0035662893, + 0.03541868, -0.094149634, -0.034814864, 0.003128424, -0.020674974, + -0.03944324, -0.008110165, -0.11113267, 0.08484226, 0.043586485, + 0.040582247, 0.0968012, -0.065249965, -0.028036479, 0.0050708856, + 0.0017462453, 0.0326779, 0.041296225, 0.09164146, -0.047743853, + -0.015952192, -0.034451712, 0.084197424, -0.05347844, -0.11768019, + 0.085926116, -0.08251791, -0.045081906, 0.0948852, 0.068401024, + 0.024856757, 0.06978981, -0.057309967, -0.012775832, -0.0032452994, + 0.01977615, -0.041040014, -0.024264973, 0.063464895, 0.05431621, + }; + + cell_to_input_weights_ = { + 0.040369894, 0.030746894, 0.24704495, 0.018586371, -0.037586458, + -0.15312155, -0.11812848, -0.11465643, 0.20259799, 0.11418174, + -0.10116027, -0.011334949, 0.12411352, -0.076769054, -0.052169047, + 0.21198851, -0.38871562, -0.09061183, -0.09683246, -0.21929175}; + + cell_to_forget_weights_ = { + -0.01998659, -0.15568835, -0.24248174, -0.012770197, 0.041331276, + -0.072311886, -0.052123554, -0.0066330447, -0.043891653, 0.036225766, + -0.047248036, 0.021479502, 0.033189066, 0.11952997, -0.020432774, + 0.64658105, -0.06650122, -0.03467612, 0.095340036, 0.23647355}; + + cell_to_output_weights_ = {0.08286371, -0.08261836, -0.51210177, 0.002913762, + 0.17764764, -0.5495371, -0.08460716, -0.24552552, + 0.030037103, 0.04123544, -0.11940523, 0.007358328, + 0.1890978, 0.4833202, -0.34441817, 0.36312827, + -0.26375428, 0.1457655, -0.19724406, 0.15548733}; + + projection_weights_ = { + -0.009802181, 0.09401916, 0.0717386, -0.13895074, 0.09641832, + 0.060420845, 0.08539281, 0.054285463, 0.061395317, 0.034448683, + -0.042991187, 0.019801661, -0.16840284, -0.015726732, -0.23041931, + -0.024478018, -0.10959692, -0.013875541, 0.18600968, -0.061274476, + 0.0138165, -0.08160894, -0.07661644, 0.032372914, 0.16169067, + 0.22465782, -0.03993472, -0.004017731, 0.08633481, -0.28869787, + 0.08682067, 0.17240396, 0.014975425, 0.056431185, 0.031037588, + 0.16702051, 0.0077946745, 0.15140012, 0.29405436, 0.120285, + -0.188994, -0.027265169, 0.043389652, -0.022061434, 0.014777949, + -0.20203483, 0.094781205, 0.19100232, 0.13987629, -0.036132768, + -0.06426278, -0.05108664, 0.13221376, 0.009441198, -0.16715929, + 0.15859416, -0.040437475, 0.050779544, -0.022187516, 0.012166504, + 0.027685808, -0.07675938, -0.0055694645, -0.09444123, 0.0046453946, + 0.050794356, 0.10770313, -0.20790008, -0.07149004, -0.11425117, + 0.008225835, -0.035802525, 0.14374903, 0.15262283, 0.048710253, + 0.1847461, -0.007487823, 0.11000021, -0.09542012, 0.22619456, + -0.029149994, 0.08527916, 0.009043713, 0.0042746216, 0.016261552, + 0.022461696, 0.12689082, -0.043589946, -0.12035478, -0.08361797, + -0.050666027, -0.1248618, -0.1275799, -0.071875185, 0.07377272, + 0.09944291, -0.18897448, -0.1593054, -0.06526116, -0.040107165, + -0.004618631, -0.067624845, -0.007576253, 0.10727444, 0.041546922, + -0.20424393, 0.06907816, 0.050412357, 0.00724631, 0.039827548, + 0.12449835, 0.10747581, 0.13708383, 0.09134148, -0.12617786, + -0.06428341, 0.09956831, 0.1208086, -0.14676677, -0.0727722, + 0.1126304, 0.010139365, 0.015571211, -0.038128063, 0.022913318, + -0.042050496, 0.16842307, -0.060597885, 0.10531834, -0.06411776, + -0.07451711, -0.03410368, -0.13393489, 0.06534304, 0.003620307, + 0.04490757, 0.05970546, 0.05197996, 0.02839995, 0.10434969, + -0.013699693, -0.028353551, -0.07260381, 0.047201227, -0.024575593, + -0.036445823, 0.07155557, 0.009672501, -0.02328883, 0.009533515, + -0.03606021, -0.07421458, -0.028082801, -0.2678904, -0.13221288, + 0.18419984, -0.13012612, -0.014588381, -0.035059117, -0.04824723, + 0.07830115, -0.056184657, 0.03277091, 0.025466874, 0.14494097, + -0.12522776, -0.098633975, -0.10766018, -0.08317623, 0.08594209, + 0.07749552, 0.039474737, 0.1776665, -0.07409566, -0.0477268, + 0.29323658, 0.10801441, 0.1154011, 0.013952499, 0.10739139, + 0.10708251, -0.051456142, 0.0074137426, -0.10430189, 0.10034707, + 0.045594677, 0.0635285, -0.0715442, -0.089667566, -0.10811871, + 0.00026344223, 0.08298446, -0.009525053, 0.006585689, -0.24567553, + -0.09450807, 0.09648481, 0.026996298, -0.06419476, -0.04752702, + -0.11063944, -0.23441927, -0.17608605, -0.052156363, 0.067035615, + 0.19271925, -0.0032889997, -0.043264326, 0.09663576, -0.057112187, + -0.10100678, 0.0628376, 0.04447668, 0.017961001, -0.10094388, + -0.10190601, 0.18335468, 0.10494553, -0.052095775, -0.0026118709, + 0.10539724, -0.04383912, -0.042349473, 0.08438151, -0.1947263, + 0.02251204, 0.11216432, -0.10307853, 0.17351969, -0.039091777, + 0.08066188, -0.00561982, 0.12633002, 0.11335965, -0.0088127935, + -0.019777594, 0.06864014, -0.059751723, 0.016233567, -0.06894641, + -0.28651384, -0.004228674, 0.019708522, -0.16305895, -0.07468996, + -0.0855457, 0.099339016, -0.07580735, -0.13775392, 0.08434318, + 0.08330512, -0.12131499, 0.031935584, 0.09180414, -0.08876437, + -0.08049874, 0.008753825, 0.03498998, 0.030215185, 0.03907079, + 0.089751154, 0.029194152, -0.03337423, -0.019092513, 0.04331237, + 0.04299654, -0.036394123, -0.12915532, 0.09793732, 0.07512415, + -0.11319543, -0.032502122, 0.15661901, 0.07671967, -0.005491124, + -0.19379048, -0.218606, 0.21448623, 0.017840758, 0.1416943, + -0.07051762, 0.19488361, 0.02664691, -0.18104725, -0.09334311, + 0.15026465, -0.15493552, -0.057762887, -0.11604192, -0.262013, + -0.01391798, 0.012185008, 0.11156489, -0.07483202, 0.06693364, + -0.26151478, 0.046425626, 0.036540434, -0.16435726, 0.17338543, + -0.21401681, -0.11385144, -0.08283257, -0.069031075, 0.030635102, + 0.010969227, 0.11109743, 0.010919218, 0.027526086, 0.13519906, + 0.01891392, -0.046839405, -0.040167913, 0.017953383, -0.09700955, + 0.0061885654, -0.07000971, 0.026893595, -0.038844477, 0.14543656}; + + lstm_input_ = {// Step 1 + {{0.787926, 0.151646, 0.071352, 0.118426, 0.458058}}, + // Step 2 + {{0.596268, 0.998386, 0.568695, 0.864524, 0.571277}}, + // Step 3 + {{0.073204, 0.296072, 0.743333, 0.069199, 0.045348}}, + // Step 4 + {{0.867394, 0.291279, 0.013714, 0.482521, 0.626339}}}; + + lstm_golden_output_ = { + {{-0.00396806, 0.029352, -0.00279226, 0.0159977, -0.00835576, -0.0211779, + 0.0283512, -0.0114597, 0.00907307, -0.0244004, -0.0152191, -0.0259063, + 0.00914318, 0.00415118, 0.017147, 0.0134203}}, + + {{-0.0166936, 0.0381209, 0.000889694, 0.0143363, -0.0328911, -0.0234288, + 0.0333051, -0.012229, 0.0110322, -0.0457725, -0.000832209, -0.0202817, + 0.0327257, 0.0121308, 0.0155969, 0.0312091}}, + + {{-0.0213783, 0.0350169, 0.000324794, 0.0276012, -0.0263374, -0.0371449, + 0.0446149, -0.0205474, 0.0103729, -0.0576349, -0.0150052, -0.0292043, + 0.0376827, 0.0136115, 0.0243435, 0.0354492}}, + + {{-0.0189322, 0.0464512, -0.00251373, 0.0225745, -0.0308346, -0.0317124, + 0.0460407, -0.0189395, 0.0149363, -0.0530162, -0.0150767, -0.0340193, + 0.0286833, 0.00824207, 0.0264887, 0.0305169}}}; + + LSTMOpModel lstm( + n_batch, n_input, n_cell, n_output, + /*use_cifg=*/false, /*use_peephole=*/true, + /*use_projection_weights=*/true, + /*use_projection_bias=*/false, /*weight_type=*/TensorType_FLOAT32, + model_has_legacy_20_inputs, + /*is_layer_norm=*/false, /*asymmetric_quantize_inputs=*/false, + input_to_input_weights_, input_to_forget_weights_, input_to_cell_weights_, + input_to_output_weights_, recurrent_to_input_weights_, + recurrent_to_forget_weights_, recurrent_to_cell_weights_, + recurrent_to_output_weights_, cell_to_input_weights_, + cell_to_forget_weights_, cell_to_output_weights_, input_gate_bias_, + forget_gate_bias_, cell_gate_bias_, output_gate_bias_, + projection_weights_, {}, input_layer_norm_coefficients_, + forget_layer_norm_coefficients_, cell_layer_norm_coefficients_, + output_layer_norm_coefficients_); + + VerifyGoldens(&lstm, 0.00001f); +} + +TEST_F(LstmOpTest, NoCifg_Peephole_Projection_LayerNorm) { + const int n_batch = 1; + const int n_input = 5; + const int n_cell = 4; + const int n_output = 3; + + input_to_input_weights_ = {0.5, 0.6, 0.7, -0.8, -0.9, 0.1, 0.2, + 0.3, -0.4, 0.5, -0.8, 0.7, -0.6, 0.5, + -0.4, -0.5, -0.4, -0.3, -0.2, -0.1}; + + input_to_forget_weights_ = {-0.6, -0.1, 0.3, 0.2, 0.9, -0.5, -0.2, + -0.4, 0.3, -0.8, -0.4, 0.3, -0.5, -0.4, + -0.6, 0.3, -0.4, -0.6, -0.5, -0.5}; + + input_to_cell_weights_ = {-0.4, -0.3, -0.2, -0.1, -0.5, 0.5, -0.2, + -0.3, -0.2, -0.6, 0.6, -0.1, -0.4, -0.3, + -0.7, 0.7, -0.9, -0.5, 0.8, 0.6}; + + input_to_output_weights_ = {-0.8, -0.4, -0.2, -0.9, -0.1, -0.7, 0.3, + -0.3, -0.8, -0.2, 0.6, -0.2, 0.4, -0.7, + -0.3, -0.5, 0.1, 0.5, -0.6, -0.4}; + + input_gate_bias_ = {0.03, 0.15, 0.22, 0.38}; + + forget_gate_bias_ = {0.1, -0.3, -0.2, 0.1}; + + cell_gate_bias_ = {-0.05, 0.72, 0.25, 0.08}; + + output_gate_bias_ = {0.05, -0.01, 0.2, 0.1}; + + recurrent_to_input_weights_ = {-0.2, -0.3, 0.4, 0.1, -0.5, 0.9, + -0.2, -0.3, -0.7, 0.05, -0.2, -0.6}; + + recurrent_to_cell_weights_ = {-0.3, 0.2, 0.1, -0.3, 0.8, -0.08, + -0.2, 0.3, 0.8, -0.6, -0.1, 0.2}; + + recurrent_to_forget_weights_ = {-0.5, -0.3, -0.5, -0.2, 0.6, 0.4, + 0.9, 0.3, -0.1, 0.2, 0.5, 0.2}; + + recurrent_to_output_weights_ = {0.3, -0.1, 0.1, -0.2, -0.5, -0.7, + -0.2, -0.6, -0.1, -0.4, -0.7, -0.2}; + + cell_to_input_weights_ = {0.05, 0.1, 0.25, 0.15}; + + cell_to_forget_weights_ = {-0.02, -0.15, -0.25, -0.03}; + + cell_to_output_weights_ = {0.1, -0.1, -0.5, 0.05}; + + input_layer_norm_coefficients_ = {0.1, 0.2, 0.3, 0.5}; + forget_layer_norm_coefficients_ = {0.2, 0.2, 0.4, 0.3}; + cell_layer_norm_coefficients_ = {0.7, 0.2, 0.3, 0.8}; + output_layer_norm_coefficients_ = {0.6, 0.2, 0.2, 0.5}; + + projection_weights_ = {-0.1, 0.2, 0.01, -0.2, 0.1, 0.5, + 0.3, 0.08, 0.07, 0.2, -0.4, 0.2}; + + lstm_input_ = { + {{0.7, 0.8, 0.1, 0.2, 0.3}}, + {{0.8, 0.1, 0.2, 0.4, 0.5}}, + {{0.2, 0.7, 0.7, 0.1, 0.7}}, + }; + + lstm_golden_output_ = {{{0.0244077, 0.128027, -0.00170918}}, + {{0.0137642, 0.140751, 0.0395835}}, + {{-0.00459231, 0.155278, 0.0837377}}}; + + LSTMOpModel lstm( + n_batch, n_input, n_cell, n_output, + /*use_cifg=*/false, /*use_peephole=*/true, + /*use_projection_weights=*/true, + /*use_projection_bias=*/false, /*weight_type=*/TensorType_FLOAT32, + /*model_has_legacy_20_inputs=*/false, + /*is_layer_norm=*/true, /*asymmetric_quantize_inputs=*/false, + input_to_input_weights_, input_to_forget_weights_, input_to_cell_weights_, + input_to_output_weights_, recurrent_to_input_weights_, + recurrent_to_forget_weights_, recurrent_to_cell_weights_, + recurrent_to_output_weights_, cell_to_input_weights_, + cell_to_forget_weights_, cell_to_output_weights_, input_gate_bias_, + forget_gate_bias_, cell_gate_bias_, output_gate_bias_, + projection_weights_, {}, input_layer_norm_coefficients_, + forget_layer_norm_coefficients_, cell_layer_norm_coefficients_, + output_layer_norm_coefficients_); + + VerifyGoldens(&lstm, 0.00001f); +} + +TEST_F(LstmOpTest, Cifg_Peephole_Projection_LayerNorm) { + const int n_batch = 1; + const int n_input = 5; + const int n_cell = 4; + const int n_output = 3; + + input_to_forget_weights_ = {-0.6, -0.1, 0.3, 0.2, 0.9, -0.5, -0.2, + -0.4, 0.3, -0.8, -0.4, 0.3, -0.5, -0.4, + -0.6, 0.3, -0.4, -0.6, -0.5, -0.5}; + input_to_cell_weights_ = {-0.4, -0.3, -0.2, -0.1, -0.5, 0.5, -0.2, + -0.3, -0.2, -0.6, 0.6, -0.1, -0.4, -0.3, + -0.7, 0.7, -0.9, -0.5, 0.8, 0.6}; + input_to_output_weights_ = {-0.8, -0.4, -0.2, -0.9, -0.1, -0.7, 0.3, + -0.3, -0.8, -0.2, 0.6, -0.2, 0.4, -0.7, + -0.3, -0.5, 0.1, 0.5, -0.6, -0.4}; + + forget_gate_bias_ = {0.1, -0.3, -0.2, 0.1}; + cell_gate_bias_ = {-0.05, 0.72, 0.25, 0.08}; + output_gate_bias_ = {0.05, -0.01, 0.2, 0.1}; + + recurrent_to_cell_weights_ = {-0.3, 0.2, 0.1, -0.3, 0.8, -0.08, + -0.2, 0.3, 0.8, -0.6, -0.1, 0.2}; + recurrent_to_forget_weights_ = {-0.5, -0.3, -0.5, -0.2, 0.6, 0.4, + 0.9, 0.3, -0.1, 0.2, 0.5, 0.2}; + recurrent_to_output_weights_ = {0.3, -0.1, 0.1, -0.2, -0.5, -0.7, + -0.2, -0.6, -0.1, -0.4, -0.7, -0.2}; + + cell_to_forget_weights_ = {-0.02, -0.15, -0.25, -0.03}; + cell_to_output_weights_ = {0.1, -0.1, -0.5, 0.05}; + + forget_layer_norm_coefficients_ = {0.2, 0.2, 0.4, 0.3}; + cell_layer_norm_coefficients_ = {0.7, 0.2, 0.3, 0.8}; + output_layer_norm_coefficients_ = {0.6, 0.2, 0.2, 0.5}; + projection_weights_ = {-0.1, 0.2, 0.01, -0.2, 0.1, 0.5, + 0.3, 0.08, 0.07, 0.2, -0.4, 0.2}; + + lstm_input_ = {{{0.7, 0.8, 0.1, 0.2, 0.3}}, + {{0.8, 0.1, 0.2, 0.4, 0.5}}, + {{0.2, 0.7, 0.7, 0.1, 0.7}}}; + lstm_golden_output_ = {{{0.02129706, 0.140816242, 0.0112733059}}, + {{0.0132302344, 0.152308047, 0.0346313119}}, + {{-0.0123688057, 0.165790111, 0.0893077999}}}; + + LSTMOpModel lstm( + n_batch, n_input, n_cell, n_output, + /*use_cifg=*/true, /*use_peephole=*/true, + /*use_projection_weights=*/true, + /*use_projection_bias=*/false, /*weight_type=*/TensorType_FLOAT32, + /*model_has_legacy_20_inputs=*/false, + /*is_layer_norm=*/true, /*asymmetric_quantize_inputs=*/false, + input_to_input_weights_, input_to_forget_weights_, input_to_cell_weights_, + input_to_output_weights_, recurrent_to_input_weights_, + recurrent_to_forget_weights_, recurrent_to_cell_weights_, + recurrent_to_output_weights_, cell_to_input_weights_, + cell_to_forget_weights_, cell_to_output_weights_, input_gate_bias_, + forget_gate_bias_, cell_gate_bias_, output_gate_bias_, + projection_weights_, {}, input_layer_norm_coefficients_, + forget_layer_norm_coefficients_, cell_layer_norm_coefficients_, + output_layer_norm_coefficients_); + + VerifyGoldens(&lstm, 0.00001f); +} + +#ifdef GTEST_HAS_DEATH_TEST +TEST_F(LstmOpTest, InvalidTypes) { + const int n_batch = 1; + const int n_input = 2; + const int n_cell = 4; + const int n_output = 4; + + EXPECT_DEATH( + LSTMOpModel lstm( + n_batch, n_input, n_cell, n_output, + /*use_cifg=*/false, /*use_peephole=*/false, + /*use_projection_weights=*/false, + /*use_projection_bias=*/false, + /*weight_type=*/TensorType_INT32, + /*model_has_legacy_20_inputs=*/true, + /*is_layer_norm=*/false, + /*asymmetric_quantize_inputs=*/false, + /*input_to_input_weights=*/{}, /*input_to_forget_weights=*/{}, + /*input_to_cell_weights=*/{}, /*input_to_output_weights=*/{}, + /*recurrent_to_input_weights=*/{}, + /*recurrent_to_forget_weights=*/{}, /*recurrent_to_cell_weights=*/{}, + /*recurrent_to_output_weights=*/{}, /*cell_to_input_weights=*/{}, + /*cell_to_forget_weights=*/{}, /*cell_to_output_weights=*/{}, + /*input_gate_bias=*/{}, /*forget_gate_bias=*/{}, + /*cell_gate_bias=*/{}, /*output_gate_bias=*/{}, + /*projection_weights=*/{}, /*projection_bias=*/{}, + /*input_layer_norm_coefficients=*/{}, + /*forget_layer_norm_coefficients=*/{}, + /*cell_layer_norm_coefficients=*/{}, + /*output_layer_norm_coefficients=*/{}), + ""); + + EXPECT_DEATH( + LSTMOpModel lstm( + n_batch, n_input, n_cell, n_output, + /*use_cifg=*/false, /*use_peephole=*/false, + /*use_projection_weights=*/false, + /*use_projection_bias=*/false, + /*weight_type=*/TensorType_COMPLEX64, + /*model_has_legacy_20_inputs=*/true, + /*is_layer_norm=*/false, + /*asymmetric_quantize_inputs=*/false, + /*input_to_input_weights=*/{}, /*input_to_forget_weights=*/{}, + /*input_to_cell_weights=*/{}, /*input_to_output_weights=*/{}, + /*recurrent_to_input_weights=*/{}, + /*recurrent_to_forget_weights=*/{}, /*recurrent_to_cell_weights=*/{}, + /*recurrent_to_output_weights=*/{}, /*cell_to_input_weights=*/{}, + /*cell_to_forget_weights=*/{}, /*cell_to_output_weights=*/{}, + /*input_gate_bias=*/{}, /*forget_gate_bias=*/{}, + /*cell_gate_bias=*/{}, /*output_gate_bias=*/{}, + /*projection_weights=*/{}, /*projection_bias=*/{}, + /*input_layer_norm_coefficients=*/{}, + /*forget_layer_norm_coefficients=*/{}, + /*cell_layer_norm_coefficients=*/{}, + /*output_layer_norm_coefficients=*/{}), + ""); +} +#endif + +// Test parameter controls model_has_legacy_20_inputs in LSTMOpModel. +INSTANTIATE_TEST_SUITE_P(Parameterized, LstmOpTest, ::testing::Bool()); + +} // namespace +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/mean.cc b/tensorflow/lite/delegates/gpu/cl/kernels/mean.cc index e1628a7e9a7..c5b659463ea 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/mean.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/mean.cc @@ -35,6 +35,15 @@ Mean::Mean(const OperationDef& definition, const DeviceInfo& device_info) if (device_info.IsAdreno3xx()) { work_group_size_ = int3(16, 8, 1); } + if (device_info.IsMali()) { + const MaliInfo& mali_info = device_info.mali_info; + if (mali_info.IsMaliT6xx() || mali_info.IsMaliT7xx() || + mali_info.IsMaliT8xx()) { + work_group_size_ = int3(8, 4, 1); + } else { + work_group_size_ = int3(8, 8, 1); + } + } code_ = GetMeanKernelCode(definition_, work_group_size_); } @@ -108,12 +117,12 @@ std::string Mean::GetMeanKernelCode(const OperationDef& op_def, return c; } -absl::Status Mean::BindArguments() { +absl::Status Mean::BindArguments(ArgumentsBinder* args) { const double total_size = src_[0]->Width() * src_[0]->Height(); const double size_0 = work_group_size_.x * work_group_size_.y; const double size_1 = total_size / size_0; - RETURN_IF_ERROR(args_.SetFloat("inv_multiplier_1", 1.0 / size_1)); - RETURN_IF_ERROR(args_.SetFloat("inv_multiplier_2", 1.0 / size_0)); + RETURN_IF_ERROR(args->SetFloat("inv_multiplier_1", 1.0 / size_1)); + RETURN_IF_ERROR(args->SetFloat("inv_multiplier_2", 1.0 / size_0)); return absl::OkStatus(); } diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/mean.h b/tensorflow/lite/delegates/gpu/cl/kernels/mean.h index 12735c0b916..3bf2061d329 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/mean.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/mean.h @@ -37,7 +37,7 @@ class Mean : public GPUOperation { std::vector* work_groups) const override { work_groups->push_back(work_group_size_); } - absl::Status BindArguments() override; + absl::Status BindArguments(ArgumentsBinder* args) override; int3 GetGridSize() const override; // Move only diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/mean_stddev_normalization.cc b/tensorflow/lite/delegates/gpu/cl/kernels/mean_stddev_normalization.cc index c36dacdaafc..dabf71066f6 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/mean_stddev_normalization.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/mean_stddev_normalization.cc @@ -29,15 +29,23 @@ namespace cl { namespace { std::string GetVectorReduceCode() { - return R"(static inline float reduce_vector(float4 v) { + return R"(float reduce_vector(float4 v) { return dot(v, (float4)(1.0f)); })"; } std::string GetReduceCode() { // If it is supported, use the built-in work_group_reduce_add function. - // Otherwise, implement a reduction using __local memory. Note this only works - // with power-of-two work group sizes. + // Otherwise, implement a reduction using __local memory. + + // In the reduction step add upper half of the still-to-be-summed vector to + // the lower half, while taking care of odd sizes and rounding. E.g.: + // Number of items still to be summed before: 5 + // Local memory before: [a, b, c, d, e]; + // Local memory after: [a+d, b+e, c, d, e]; + // Threads doing work: id < 2 = floor(5/2) + // Offset to the added items: 3 = ceil(5/2) + // Number of items still to be summed after: 3 = ceil(5/2) return R"( #if (__OPENCL_C_VERSION__ >= 200) && (__OPENCL_C_VERSION__ < 300) && \ !defined(__opencl_c_work_group_collective_functions) @@ -45,35 +53,85 @@ std::string GetReduceCode() { #endif #ifdef __opencl_c_work_group_collective_functions -#define local_reduce(input, tmp) work_group_reduce_add(input) +#define local_reduce(item, tmp) work_group_reduce_add(item) #else // !defined(__opencl_c_work_group_collective_functions) -static inline float local_reduce(float input, __local float* tmp) { +float local_reduce(float item, __local float* tmp) { const int local_id = get_local_id(0); - tmp[local_id] = input; + tmp[local_id] = item; barrier(CLK_LOCAL_MEM_FENCE); - int reduction_size = get_local_size(0) / 2; - while (reduction_size > 0) { - if (local_id < reduction_size) { - tmp[local_id] += tmp[local_id + reduction_size]; + // The number of items still need to be summed + int reduction_size = get_local_size(0); + while (reduction_size > 1) { + const int active_thread_limit = reduction_size / 2; + const int offset = (reduction_size + 1) / 2; + if (local_id < active_thread_limit) { + item += tmp[local_id + offset]; + tmp[local_id] = item; } barrier(CLK_LOCAL_MEM_FENCE); - reduction_size /= 2; + reduction_size = offset; } return tmp[0]; } #endif // defined(__opencl_c_work_group_collective_functions) )"; } + +std::string GetFilterCode() { + return R"( +float4 filter_outside_tensor(float4 x, int num_channels, int slice) { + return select(x, (float4)(0.0f), slice * 4 + (int4)(0, 1, 2, 3) >= num_channels); +} +)"; +} } // namespace MeanStdDevNormalization::MeanStdDevNormalization(const OperationDef& definition, - const DeviceInfo& device_info) + const DeviceInfo& device_info, + const int tensor_slices) : GPUOperation(definition) { // The kernel code does not inherently need a fixed size, but in order to not // hardcode the __local array's size for the reductions, we would need to pass // that size to the kernel at runtime, and that is currently not supported. - // For now, fix workgroup size to 128 threads. - work_group_size_.x = 128; + // For now, fix workgroup size to the biggest supported by the device, but not + // larger than the number of tensor slices. + int desired_work_group_size = + std::min(tensor_slices, device_info.max_work_group_size_x); + if (device_info.IsMali()) { + // Don't use more than 64 work items per work group on ARM Mali. They + // implement local memory using the global memory, larger workgroups have + // severe performance penalty. + desired_work_group_size = 64; + } + if (device_info.IsAdreno()) { + AdrenoInfo info = device_info.adreno_info; + if (device_info.IsAdreno3xx()) { + if (info.gpu_version < 320) { + desired_work_group_size = 64; + } else { + desired_work_group_size = 128; + } + } else if (device_info.IsAdreno4xx()) { + if (info.gpu_version < 430) { + desired_work_group_size = 128; + } else { + desired_work_group_size = 256; + } + } else if (device_info.IsAdreno5xx()) { + if (info.gpu_version < 530) { + desired_work_group_size = 128; + } else { + desired_work_group_size = 256; + } + } + } + if (device_info.IsPowerVR()) { + desired_work_group_size = 64; + } + while (desired_work_group_size >= tensor_slices * 2) { + desired_work_group_size /= 2; + } + work_group_size_.x = desired_work_group_size; work_group_size_.y = 1; // Required work_group_size_.z = 1; // Required code_ = GetNormalizationCode(); @@ -91,6 +149,7 @@ std::string MeanStdDevNormalization::GetNormalizationCode() { std::string c = GetCommonDefines(definition_.precision); c += GetVectorReduceCode(); c += GetReduceCode(); + c += GetFilterCode(); c += "__attribute__((reqd_work_group_size(" + std::to_string(work_group_size_.x) + ", 1, 1)))\n"; c += R"(__kernel void main_function($0) { @@ -99,17 +158,12 @@ std::string MeanStdDevNormalization::GetNormalizationCode() { std::to_string(work_group_size_.x) + R"(]; #endif const int B = get_global_id(1); - if (get_global_id(2) > 0) { return; } - if (B >= args.src_tensor.Batch()) { return; } // Calculate the total sum of the input tensor. // First, get a local sum of input[local_id_x + N*local_size_x] for all N. float4 private_sum4 = (float4)(0.0f); for (int S = get_local_id(0); S < args.src_tensor.Slices(); S += get_local_size(0)) { const float4 t = args.src_tensor.Read(0, 0, S, B); - // Filter out reads beyond the end of the tensor. - const int4 is_after_end_of_tensor = (int4)(0, 1, 2, 3) >= (args.src_tensor.Channels() - S * 4); - const float4 filtered_t = select(t, (float4)(0.0f), is_after_end_of_tensor); - private_sum4 += filtered_t; + private_sum4 += filter_outside_tensor(t, args.src_tensor.Channels(), S); } // Reduce the vector to a single float and do a workgroup reduce. const float private_sum = reduce_vector(private_sum4); @@ -120,19 +174,16 @@ std::string MeanStdDevNormalization::GetNormalizationCode() { float4 private_sum_diff_sq4 = (float4)(0.0f); for (int S = get_local_id(0); S < args.src_tensor.Slices(); S += get_local_size(0)) { const float4 t = args.src_tensor.Read(0, 0, S, B); - const float4 diff = t - mean; - // Filter out reads beyond the end of the tensor. - const int4 is_after_end_of_tensor = (int4)(0, 1, 2, 3) >= (args.src_tensor.Channels() - S * 4); - const float4 filtered_diff = select(diff, (float4)(0.0f), is_after_end_of_tensor); + const float4 diff = filter_outside_tensor(t - mean, args.src_tensor.Channels(), S); // sum_diff_sq += diff² - private_sum_diff_sq4 = mad(filtered_diff, filtered_diff, private_sum_diff_sq4); + private_sum_diff_sq4 = mad(diff, diff, private_sum_diff_sq4); } // Reduce const float private_sum_diff_sq = reduce_vector(private_sum_diff_sq4); const float sum_diff_sq = local_reduce(private_sum_diff_sq, tmp); // Calculate 1/stddev (with the 'regulazing constant' as in tensor_utils.cc) const float variance = sum_diff_sq / args.src_tensor.Channels(); - const float stddev_inv = rsqrt(variance + 1.0e-8f); + const float stddev_inv = native_rsqrt(variance + 1.0e-8f); // Calculate (t-mean)/stddev for each element for (int S = get_local_id(0); S < args.src_tensor.Slices(); S += get_local_size(0)) { const float4 t = args.src_tensor.Read(0, 0, S, B); @@ -153,8 +204,9 @@ int3 MeanStdDevNormalization::GetGridSize() const { } MeanStdDevNormalization CreateMeanStdDevNormalization( - const OperationDef& definition, const DeviceInfo& device_info) { - return MeanStdDevNormalization(definition, device_info); + const OperationDef& definition, const DeviceInfo& device_info, + const int tensor_slices) { + return MeanStdDevNormalization(definition, device_info, tensor_slices); } } // namespace cl diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/mean_stddev_normalization.h b/tensorflow/lite/delegates/gpu/cl/kernels/mean_stddev_normalization.h index e898803e377..3312d23122f 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/mean_stddev_normalization.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/mean_stddev_normalization.h @@ -30,7 +30,8 @@ namespace cl { class MeanStdDevNormalization : public GPUOperation { public: explicit MeanStdDevNormalization(const OperationDef& definition, - const DeviceInfo& device_info); + const DeviceInfo& device_info, + const int tensor_slices); void GetPossibleKernelWorkGroups( TuningType tuning_type, const DeviceInfo& device_info, @@ -52,7 +53,8 @@ class MeanStdDevNormalization : public GPUOperation { }; MeanStdDevNormalization CreateMeanStdDevNormalization( - const OperationDef& definition, const DeviceInfo& device_info); + const OperationDef& definition, const DeviceInfo& device_info, + const int tensor_slices); } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/mean_stddev_normalization_test.cc b/tensorflow/lite/delegates/gpu/cl/kernels/mean_stddev_normalization_test.cc index 8ff34be17d8..7ceaf964edd 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/mean_stddev_normalization_test.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/mean_stddev_normalization_test.cc @@ -55,7 +55,7 @@ TEST_P(MeanStddevNormalizationTest, SeparateBatches) { op_def.dst_tensors.push_back({data_type, storage, Layout::BHWC}); TensorFloat32 dst_tensor; auto operation = - CreateMeanStdDevNormalization(op_def, env_.GetDevicePtr()->info_); + CreateMeanStdDevNormalization(op_def, env_.GetDevicePtr()->info_, 1); ASSERT_OK(ExecuteGPUOperation({src_tensor}, creation_context_, &operation, BHWC(1, 1, 1, 4), &dst_tensor)); @@ -88,8 +88,6 @@ INSTANTIATE_TEST_SUITE_P( std::make_tuple(100.0f, 100.0f, 2.63e-4f) // large mean, large variance )); -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(MeanStddevNormalizationTest); - TEST_F(OpenCLOperationTest, MeanStddevNormalizationAllBatches) { TensorFloat32 src_tensor; src_tensor.shape = BHWC(9, 1, 1, 4); @@ -106,6 +104,8 @@ TEST_F(OpenCLOperationTest, MeanStddevNormalizationAllBatches) { }; for (auto storage : env_.GetSupportedStorages()) { for (auto precision : env_.GetSupportedPrecisions()) { + const float eps = + precision == CalculationsPrecision::F32 ? 2.53e-05f : 3.57e-4f; OperationDef op_def; op_def.precision = precision; auto data_type = DeduceDataTypeFromPrecision(precision); @@ -113,7 +113,7 @@ TEST_F(OpenCLOperationTest, MeanStddevNormalizationAllBatches) { op_def.dst_tensors.push_back({data_type, storage, Layout::BHWC}); TensorFloat32 dst_tensor; auto operation = - CreateMeanStdDevNormalization(op_def, env_.GetDevicePtr()->info_); + CreateMeanStdDevNormalization(op_def, env_.GetDevicePtr()->info_, 1); ASSERT_OK(ExecuteGPUOperation({src_tensor}, creation_context_, &operation, BHWC(9, 1, 1, 4), &dst_tensor)); @@ -130,8 +130,57 @@ TEST_F(OpenCLOperationTest, MeanStddevNormalizationAllBatches) { -ksqrt16, -ksqrt04, ksqrt04, ksqrt16, // large mean, small variance -ksqrt16, -ksqrt04, ksqrt04, ksqrt16, // large mean, large variance }; - EXPECT_THAT(dst_tensor.data, - Pointwise(FloatNear(3.57e-4f), expected_output)); + EXPECT_THAT(dst_tensor.data, Pointwise(FloatNear(eps), expected_output)) + << "Failed using precision " << ToString(precision); + } + } +} + +TEST_F(OpenCLOperationTest, MeanStddevNormalizationLargeVector) { + const float mean = 100.0f; + const float diff = 1.0f; + // Some large vector that is not a round multiple of any SIMD vector sizes. + constexpr int kVectorSize = 16 * 16 + 16 + 1; + + TensorFloat32 src_tensor; + src_tensor.shape = BHWC(1, 1, 1, kVectorSize); + src_tensor.data.resize(kVectorSize); + // First input is mean. + src_tensor.data[0] = mean; + // Rest is alternating between mean + diff and mean - diff. + for (int i = 1; i < kVectorSize - 1; i += 2) { + src_tensor.data[i + 0] = mean + diff; + src_tensor.data[i + 1] = mean - diff; + } + + for (auto storage : env_.GetSupportedStorages()) { + for (auto precision : env_.GetSupportedPrecisions()) { + const float eps = + precision == CalculationsPrecision::F32 ? 0.0f : 8.60e-4f; + OperationDef op_def; + op_def.precision = precision; + auto data_type = DeduceDataTypeFromPrecision(precision); + op_def.src_tensors.push_back({data_type, storage, Layout::BHWC}); + op_def.dst_tensors.push_back({data_type, storage, Layout::BHWC}); + TensorFloat32 dst_tensor; + auto operation = CreateMeanStdDevNormalization( + op_def, env_.GetDevicePtr()->info_, (kVectorSize + 3) / 4); + ASSERT_OK(ExecuteGPUOperation({src_tensor}, creation_context_, &operation, + BHWC(1, 1, 1, kVectorSize), &dst_tensor)); + + float expected_output[kVectorSize]; + // First output should be 0. + expected_output[0] = 0.0; + // Rest should be alternating between ±√(N/(N-1)). + const float expected_elem = + std::sqrt(static_cast(kVectorSize) / + static_cast(kVectorSize - 1)); + for (int i = 1; i < kVectorSize - 1; i += 2) { + expected_output[i + 0] = +expected_elem; + expected_output[i + 1] = -expected_elem; + } + EXPECT_THAT(dst_tensor.data, Pointwise(FloatNear(eps), expected_output)) + << "Failed using precision " << ToString(precision); } } } diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/reduce.cc b/tensorflow/lite/delegates/gpu/cl/kernels/reduce.cc new file mode 100644 index 00000000000..b24d54abbfc --- /dev/null +++ b/tensorflow/lite/delegates/gpu/cl/kernels/reduce.cc @@ -0,0 +1,103 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/cl/kernels/reduce.h" + +#include + +#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h" +#include "tensorflow/lite/delegates/gpu/cl/precision.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" + +namespace tflite { +namespace gpu { +namespace cl { +namespace { +std::string GetReduceChannelsKernelCode(const OperationDef& op_def, + const OperationType& op_type) { + std::string c = GetCommonDefines(op_def.precision); + if (op_type == OperationType::REDUCE_SUM) { + c += "#define OP(a, b) ((a) + (b))\n"; + } else if (op_type == OperationType::REDUCE_PRODUCT) { + c += "#define OP(a, b) ((a) * (b))\n"; + } else if (op_type == OperationType::REDUCE_MAXIMUM) { + c += "#define OP(a, b) max(a, b)\n"; + } else if (op_type == OperationType::REDUCE_MINIMUM) { + c += "#define OP(a, b) min(a, b)\n"; + } + c += "__kernel void main_function($0) {\n"; + c += " int X = get_global_id(0);\n"; + c += " int Y = get_global_id(1);\n"; + c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height()) " + "return;\n"; + if (op_type == OperationType::REDUCE_SUM) { + c += " FLT4 reduced = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n"; + } else if (op_type == OperationType::REDUCE_PRODUCT) { + c += " FLT4 reduced = (FLT4)(1.0f, 1.0f, 1.0f, 1.0f);\n"; + } else { + c += " FLT4 V0 = args.src_tensor.Read(X, Y, 0);\n"; + c += " FLT4 reduced = (FLT4)(V0.x, V0.x, V0.x, V0.x);\n"; + } + c += " int s = 0;\n"; + c += " for (; s < args.src_tensor.Slices() - 1; ++s) {\n"; + c += " FLT4 V = args.src_tensor.Read(X, Y, s);\n"; + c += " reduced = OP(reduced, V);\n"; + c += " }\n"; + c += " FLT reduced_final = OP(OP(reduced.x, reduced.y), OP(reduced.z, " + "reduced.w));\n"; + c += " FLT last_reduce;\n"; + c += " FLT4 last_val = args.src_tensor.Read(X, Y, s);\n"; + c += " int ch_rem = args.src_tensor.Channels() % 4;\n"; + c += " if (ch_rem == 0) {\n"; + c += " last_reduce = OP(OP(last_val.x, last_val.y), OP(last_val.z, " + "last_val.w));\n"; + c += " } else if (ch_rem == 1) {\n"; + c += " last_reduce = OP(OP(last_val.x, last_val.y), last_val.z);\n"; + c += " } else if (ch_rem == 2) {\n"; + c += " last_reduce = OP(last_val.x, last_val.y);\n"; + c += " } else {\n"; + c += " last_reduce = last_val.x;\n"; + c += " }\n"; + c += " reduced_final = OP(reduced_final, last_reduce);\n"; + c += " FLT4 result = (FLT4)(reduced_final, 0.0f, 0.0f, 0.0f);\n"; + c += " args.dst_tensor.Write(result, X, Y, 0);\n"; + c += "}\n"; + return c; +} +} // namespace + +GPUOperation CreateReduce(const OperationDef& definition, + const ReduceAttributes& attr, + const OperationType& op_type) { + GPUOperation op(definition); + auto src_desc = definition.src_tensors[0]; + if (definition.IsBatchSupported()) { + src_desc.SetStateVar("BatchedWidth", "true"); + } + op.AddSrcTensor("src_tensor", src_desc); + auto dst_desc = definition.dst_tensors[0]; + if (definition.IsBatchSupported()) { + dst_desc.SetStateVar("BatchedWidth", "true"); + } + op.AddDstTensor("dst_tensor", dst_desc); + op.code_ = GetReduceChannelsKernelCode(definition, op_type); + op.tensor_to_grid_ = TensorToGrid::kWBToX_HDToY_ZIs1; + return op; +} + +} // namespace cl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/reduce.h b/tensorflow/lite/delegates/gpu/cl/kernels/reduce.h new file mode 100644 index 00000000000..def7ced4871 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/cl/kernels/reduce.h @@ -0,0 +1,34 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_REDUCE_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_REDUCE_H_ + +#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" + +namespace tflite { +namespace gpu { +namespace cl { + +GPUOperation CreateReduce(const OperationDef& definition, + const ReduceAttributes& attr, + const OperationType& op_type); + +} // namespace cl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_REDUCE_H_ diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/reduce_test.cc b/tensorflow/lite/delegates/gpu/cl/kernels/reduce_test.cc new file mode 100644 index 00000000000..7f100410d3c --- /dev/null +++ b/tensorflow/lite/delegates/gpu/cl/kernels/reduce_test.cc @@ -0,0 +1,141 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/cl/kernels/reduce.h" + +#include +#include +#include + +#include +#include +#include "tensorflow/lite/delegates/gpu/cl/kernels/cl_test.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" + +using ::testing::FloatNear; +using ::testing::Pointwise; + +namespace tflite { +namespace gpu { +namespace cl { +namespace { + +TEST_F(OpenCLOperationTest, ReduceSumChannels) { + TensorFloat32 src_tensor; + src_tensor.shape = BHWC(1, 2, 1, 5); + src_tensor.data = {1.1, 2.1, 0.7, 0.3, 1.2, 3.1, 4.1, 0.0, 1.0, 4.4}; + ReduceAttributes attr; + attr.axis = Axis::CHANNELS; + + for (auto storage : env_.GetSupportedStorages()) { + for (auto precision : env_.GetSupportedPrecisions()) { + const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f; + OperationDef op_def; + op_def.precision = precision; + auto data_type = DeduceDataTypeFromPrecision(precision); + op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); + op_def.dst_tensors.push_back({data_type, storage, Layout::HWC}); + TensorFloat32 dst_tensor; + GPUOperation operation = + CreateReduce(op_def, attr, OperationType::REDUCE_SUM); + ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation, + BHWC(1, 2, 1, 1), &dst_tensor)); + EXPECT_THAT(dst_tensor.data, Pointwise(FloatNear(eps), {5.4f, 12.6f})); + } + } +} + +TEST_F(OpenCLOperationTest, ReduceProductChannels) { + TensorFloat32 src_tensor; + src_tensor.shape = BHWC(1, 2, 1, 2); + src_tensor.data = {1.1, 2.0, 3.1, 4.0}; + ReduceAttributes attr; + attr.axis = Axis::CHANNELS; + + for (auto storage : env_.GetSupportedStorages()) { + for (auto precision : env_.GetSupportedPrecisions()) { + const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f; + OperationDef op_def; + op_def.precision = precision; + auto data_type = DeduceDataTypeFromPrecision(precision); + op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); + op_def.dst_tensors.push_back({data_type, storage, Layout::HWC}); + TensorFloat32 dst_tensor; + GPUOperation operation = + CreateReduce(op_def, attr, OperationType::REDUCE_PRODUCT); + ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation, + BHWC(1, 2, 1, 1), &dst_tensor)); + EXPECT_THAT(dst_tensor.data, Pointwise(FloatNear(eps), {2.2f, 12.4f})); + } + } +} + +TEST_F(OpenCLOperationTest, ReduceMaxChannels) { + TensorFloat32 src_tensor; + src_tensor.shape = BHWC(1, 2, 1, 6); + src_tensor.data = {1.1, 2.0, -0.3, -100.0, 32.6, 1.1, + -3.1, -4.0, -5.0, -7.0, -2.0, -100.0}; + ReduceAttributes attr; + attr.axis = Axis::CHANNELS; + + for (auto storage : env_.GetSupportedStorages()) { + for (auto precision : env_.GetSupportedPrecisions()) { + const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f; + OperationDef op_def; + op_def.precision = precision; + auto data_type = DeduceDataTypeFromPrecision(precision); + op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); + op_def.dst_tensors.push_back({data_type, storage, Layout::HWC}); + TensorFloat32 dst_tensor; + GPUOperation operation = + CreateReduce(op_def, attr, OperationType::REDUCE_MAXIMUM); + ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation, + BHWC(1, 2, 1, 1), &dst_tensor)); + EXPECT_THAT(dst_tensor.data, Pointwise(FloatNear(eps), {32.6f, -2.0f})); + } + } +} + +TEST_F(OpenCLOperationTest, ReduceMinChannels) { + TensorFloat32 src_tensor; + src_tensor.shape = BHWC(1, 2, 1, 6); + src_tensor.data = {1.1, 2.0, -0.3, -100.0, 32.6, 1.1, + -3.1, -4.0, -5.0, -7.0, -2.0, 100.0}; + ReduceAttributes attr; + attr.axis = Axis::CHANNELS; + + for (auto storage : env_.GetSupportedStorages()) { + for (auto precision : env_.GetSupportedPrecisions()) { + const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f; + OperationDef op_def; + op_def.precision = precision; + auto data_type = DeduceDataTypeFromPrecision(precision); + op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); + op_def.dst_tensors.push_back({data_type, storage, Layout::HWC}); + TensorFloat32 dst_tensor; + GPUOperation operation = + CreateReduce(op_def, attr, OperationType::REDUCE_MINIMUM); + ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation, + BHWC(1, 2, 1, 1), &dst_tensor)); + EXPECT_THAT(dst_tensor.data, Pointwise(FloatNear(eps), {-100.0f, -7.0f})); + } + } +} + +} // namespace +} // namespace cl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/resize.cc b/tensorflow/lite/delegates/gpu/cl/kernels/resize.cc index a0fd699062c..91266ef29a6 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/resize.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/resize.cc @@ -132,13 +132,13 @@ std::string Resize::GetResizeCode(const OperationDef& op_def, return c; } -absl::Status Resize::BindArguments() { - RETURN_IF_ERROR(args_.SetInt("border_x", src_[0]->Width() - 1)); - RETURN_IF_ERROR(args_.SetInt("border_y", src_[0]->Height() - 1)); - RETURN_IF_ERROR(args_.SetFloat( +absl::Status Resize::BindArguments(ArgumentsBinder* args) { + RETURN_IF_ERROR(args->SetInt("border_x", src_[0]->Width() - 1)); + RETURN_IF_ERROR(args->SetInt("border_y", src_[0]->Height() - 1)); + RETURN_IF_ERROR(args->SetFloat( "scale_factor_x", CalculateResizeScale(src_[0]->Width(), dst_[0]->Width(), attr_))); - RETURN_IF_ERROR(args_.SetFloat( + RETURN_IF_ERROR(args->SetFloat( "scale_factor_y", CalculateResizeScale(src_[0]->Height(), dst_[0]->Height(), attr_))); return absl::OkStatus(); @@ -286,17 +286,17 @@ std::string Resize3D::GetResize3DCode(const OperationDef& op_def, return c; } -absl::Status Resize3D::BindArguments() { - RETURN_IF_ERROR(args_.SetInt("border_x", src_[0]->Width() - 1)); - RETURN_IF_ERROR(args_.SetInt("border_y", src_[0]->Height() - 1)); - RETURN_IF_ERROR(args_.SetInt("border_z", src_[0]->Depth() - 1)); - RETURN_IF_ERROR(args_.SetFloat( +absl::Status Resize3D::BindArguments(ArgumentsBinder* args) { + RETURN_IF_ERROR(args->SetInt("border_x", src_[0]->Width() - 1)); + RETURN_IF_ERROR(args->SetInt("border_y", src_[0]->Height() - 1)); + RETURN_IF_ERROR(args->SetInt("border_z", src_[0]->Depth() - 1)); + RETURN_IF_ERROR(args->SetFloat( "scale_factor_x", CalculateResizeScale(src_[0]->Width(), dst_[0]->Width(), attr_))); - RETURN_IF_ERROR(args_.SetFloat( + RETURN_IF_ERROR(args->SetFloat( "scale_factor_y", CalculateResizeScale(src_[0]->Height(), dst_[0]->Height(), attr_))); - RETURN_IF_ERROR(args_.SetFloat( + RETURN_IF_ERROR(args->SetFloat( "scale_factor_z", CalculateResizeScale(src_[0]->Depth(), dst_[0]->Depth(), attr_))); return absl::OkStatus(); diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/resize.h b/tensorflow/lite/delegates/gpu/cl/kernels/resize.h index 0349afe5664..859d750b7e0 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/resize.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/resize.h @@ -27,7 +27,7 @@ namespace cl { class Resize : public GPUOperation { public: - absl::Status BindArguments() override; + absl::Status BindArguments(ArgumentsBinder* args) override; int3 GetGridSize() const override; // Move only @@ -53,7 +53,7 @@ Resize CreateResize(const OperationDef& definition, class Resize3D : public GPUOperation { public: - absl::Status BindArguments() override; + absl::Status BindArguments(ArgumentsBinder* args) override; int3 GetGridSize() const override; // Move only diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.cc b/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.cc index e7cf72aa72a..d4d0442e61d 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.cc @@ -109,14 +109,14 @@ std::string Softmax1x1::GetSoftmaxKernelCode(const OperationDef& op_def) { return c; } -absl::Status Softmax1x1::BindArguments() { +absl::Status Softmax1x1::BindArguments(ArgumentsBinder* args) { float4 mask = GetMaskForLastPlane(src_[0]->Channels()); - RETURN_IF_ERROR(args_.SetFloat("mask_x", mask.x)); - RETURN_IF_ERROR(args_.SetFloat("mask_y", mask.y)); - RETURN_IF_ERROR(args_.SetFloat("mask_z", mask.z)); - RETURN_IF_ERROR(args_.SetFloat("mask_w", mask.w)); + RETURN_IF_ERROR(args->SetFloat("mask_x", mask.x)); + RETURN_IF_ERROR(args->SetFloat("mask_y", mask.y)); + RETURN_IF_ERROR(args->SetFloat("mask_z", mask.z)); + RETURN_IF_ERROR(args->SetFloat("mask_w", mask.w)); RETURN_IF_ERROR( - args_.SetInt("slices_x32", DivideRoundUp(src_[0]->Slices(), 32))); + args->SetInt("slices_x32", DivideRoundUp(src_[0]->Slices(), 32))); return absl::OkStatus(); } diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.h b/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.h index 5bc9278d612..202f46d2a51 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.h @@ -35,7 +35,7 @@ class Softmax1x1 : public GPUOperation { std::vector* work_groups) const override { work_groups->push_back(work_group_size_); } - absl::Status BindArguments() override; + absl::Status BindArguments(ArgumentsBinder* args) override; int3 GetGridSize() const override; // Move only diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/special/BUILD b/tensorflow/lite/delegates/gpu/cl/kernels/special/BUILD index d5ff93e6845..f601556900c 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/special/BUILD +++ b/tensorflow/lite/delegates/gpu/cl/kernels/special/BUILD @@ -23,3 +23,30 @@ cc_library( "//tensorflow/lite/delegates/gpu/common:types", ], ) + +cc_library( + name = "fc_fc_add", + srcs = ["fc_fc_add.cc"], + hdrs = ["fc_fc_add.h"], + deps = [ + "//tensorflow/lite/delegates/gpu/cl:arguments", + "//tensorflow/lite/delegates/gpu/cl:buffer", + "//tensorflow/lite/delegates/gpu/cl:cl_kernel", + "//tensorflow/lite/delegates/gpu/cl:device_info", + "//tensorflow/lite/delegates/gpu/cl:linear_storage", + "//tensorflow/lite/delegates/gpu/cl:precision", + "//tensorflow/lite/delegates/gpu/cl:tensor", + "//tensorflow/lite/delegates/gpu/cl:tensor_type", + "//tensorflow/lite/delegates/gpu/cl:texture2d", + "//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation", + "//tensorflow/lite/delegates/gpu/cl/kernels:tuning_parameters", + "//tensorflow/lite/delegates/gpu/cl/kernels:util", + "//tensorflow/lite/delegates/gpu/common:data_type", + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:shape", + "//tensorflow/lite/delegates/gpu/common:tensor", + "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/common:util", + "@com_google_absl//absl/memory", + ], +) diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/special/fc_fc_add.cc b/tensorflow/lite/delegates/gpu/cl/kernels/special/fc_fc_add.cc new file mode 100644 index 00000000000..a8d3d434bd9 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/cl/kernels/special/fc_fc_add.cc @@ -0,0 +1,207 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/cl/kernels/special/fc_fc_add.h" + +#include +#include +#include + +#include "absl/memory/memory.h" +#include "tensorflow/lite/delegates/gpu/cl/arguments.h" +#include "tensorflow/lite/delegates/gpu/cl/device_info.h" +#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" +#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h" +#include "tensorflow/lite/delegates/gpu/cl/linear_storage.h" +#include "tensorflow/lite/delegates/gpu/cl/precision.h" +#include "tensorflow/lite/delegates/gpu/cl/tensor.h" +#include "tensorflow/lite/delegates/gpu/cl/tensor_type.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" + +namespace tflite { +namespace gpu { +namespace cl { +namespace { +bool UseBufferForWeights(const DeviceInfo& device_info) { + return device_info.IsAdreno() || device_info.IsAMD() || device_info.IsMali(); +} +} // namespace + +FCFCAdd::FCFCAdd(const OperationDef& definition, const DeviceInfo& device_info) + : GPUOperation(definition) { + if (device_info.IsAdreno()) { + if (device_info.IsAdreno3xx()) { + work_group_size_ = int3(16, 4, 1); + } else if (device_info.IsAdreno4xx()) { + work_group_size_ = int3(32, 4, 1); + } else { + work_group_size_ = int3(32, 4, 1); + } + } else if (device_info.IsIntel()) { + work_group_size_ = int3(8, 4, 1); + } else if (device_info.IsNvidia()) { + work_group_size_ = int3(8, 4, 1); + } else if (device_info.IsPowerVR()) { + work_group_size_ = int3(8, 4, 1); + } else { + work_group_size_ = int3(16, 4, 1); + } + code_ = GetFCFCAddKernelCode(definition_, device_info); +} + +FCFCAdd::FCFCAdd(FCFCAdd&& kernel) : GPUOperation(std::move(kernel)) {} + +FCFCAdd& FCFCAdd::operator=(FCFCAdd&& kernel) { + if (this != &kernel) { + GPUOperation::operator=(std::move(kernel)); + } + return *this; +} + +// We split vec vec dot (every thread do vec vec dot product in basic +// vec mat mult) on 4 parts to create more threads +// tid.y thread process every 4-th element in vec vec dot +// Good results for ~1024 x 1024 sizes, for other can be written more +// optimized shaders + +std::string FCFCAdd::GetFCFCAddKernelCode(const OperationDef& op_def, + const DeviceInfo& device_info) { + AddSrcTensor("src_tensor_0", op_def.src_tensors[0]); + AddSrcTensor("src_tensor_1", op_def.src_tensors[1]); + AddDstTensor("dst_tensor", op_def.dst_tensors[0]); + + const bool weights_are_buffer = UseBufferForWeights(device_info); + + std::string c = GetCommonDefines(op_def.precision); + switch (op_def.precision) { + case CalculationsPrecision::F32: + c += "#define FLT16 float16\n"; + break; + case CalculationsPrecision::F32_F16: + case CalculationsPrecision::F16: + c += "#define FLT16 half16\n"; + break; + } + + c += "#define WG_X " + std::to_string(work_group_size_.x) + "\n"; + c += "#define WG_Y " + std::to_string(work_group_size_.y) + "\n"; + + c += R"(__kernel void main_function($0) { + int gid = get_global_id(0); + int2 tid = (int2)(get_local_id(0), get_local_id(1)); + ACCUM_FLT4 s = (ACCUM_FLT4)(0.0f); + if (gid < args.dst_tensor.Slices()) { + for (int c = tid.y; c < args.src_tensor_0.Slices(); c += WG_Y) { + FLT4 v = args.src_tensor_0.Read(0, 0, c); +)"; + if (weights_are_buffer) { + c += R"(FLT16 w = args.weights0.Read(c * args.dst_tensor.Slices() + gid); + FLT4 partial = v.s0 * w.s0123; + partial = mad(v.s1, w.s4567, partial); + partial = mad(v.s2, w.s89ab, partial); + partial = mad(v.s3, w.scdef, partial); + s += TO_ACCUM_TYPE(partial); +)"; + } else { + c += R"(FLT4 w0 = args.weights0.Read(c * 4 + 0, gid); + FLT4 w1 = args.weights0.Read(c * 4 + 1, gid); + FLT4 w2 = args.weights0.Read(c * 4 + 2, gid); + FLT4 w3 = args.weights0.Read(c * 4 + 3, gid); + FLT4 partial = v.s0 * w0; + partial = mad(v.s1, w1, partial); + partial = mad(v.s2, w2, partial); + partial = mad(v.s3, w3, partial); + s += TO_ACCUM_TYPE(partial); +)"; + } + c += R"( } + for (int c = tid.y; c < args.src_tensor_1.Slices(); c += WG_Y) { + FLT4 v = args.src_tensor_1.Read(0, 0, c); + )"; + if (weights_are_buffer) { + c += R"(FLT16 w = args.weights1.Read(c * args.dst_tensor.Slices() + gid); + FLT4 partial = v.s0 * w.s0123; + partial = mad(v.s1, w.s4567, partial); + partial = mad(v.s2, w.s89ab, partial); + partial = mad(v.s3, w.scdef, partial); + s += TO_ACCUM_TYPE(partial); +)"; + } else { + c += R"(FLT4 w0 = args.weights1.Read(c * 4 + 0, gid); + FLT4 w1 = args.weights1.Read(c * 4 + 1, gid); + FLT4 w2 = args.weights1.Read(c * 4 + 2, gid); + FLT4 w3 = args.weights1.Read(c * 4 + 3, gid); + FLT4 partial = v.s0 * w0; + partial = mad(v.s1, w1, partial); + partial = mad(v.s2, w2, partial); + partial = mad(v.s3, w3, partial); + s += TO_ACCUM_TYPE(partial); +)"; + } + c += R"( } + } + __local ACCUM_FLT4 temp[WG_X][WG_Y]; + temp[tid.x][tid.y] = s; + barrier(CLK_LOCAL_MEM_FENCE); + if (gid >= args.dst_tensor.Slices()) { + return; + } + if (tid.y == 0) { +)"; + for (int i = 1; i < work_group_size_.y; ++i) { + c += " s += temp[tid.x][" + std::to_string(i) + "];\n"; + } + c += + R"( FLT4 r0 = TO_FLT4(s) + args.biases0.Read(gid) + args.biases1.Read(gid); + args.dst_tensor.Write(r0, 0, 0, gid); + } +})"; + + return c; +} + +int3 FCFCAdd::GetGridSize() const { return int3(dst_[0]->Slices(), 1, 1); } + +FCFCAdd CreateFCFCAdd(const DeviceInfo& device_info, + const OperationDef& definition, + const FullyConnectedAttributes& attr0, + const FullyConnectedAttributes& attr1) { + FCFCAdd result(definition, device_info); + result.UploadWeights(attr0.weights, "weights0", + UseBufferForWeights(device_info)); + result.UploadWeights(attr1.weights, "weights1", + UseBufferForWeights(device_info)); + + TensorLinearDescriptor desc0; + desc0.storage_type = LinearStorageType::TEXTURE_2D; + desc0.element_type = definition.GetDataType(); + desc0.UploadLinearData(attr0.bias); + result.args_.AddObject( + "biases0", absl::make_unique(std::move(desc0))); + + TensorLinearDescriptor desc1; + desc1.storage_type = LinearStorageType::TEXTURE_2D; + desc1.element_type = definition.GetDataType(); + desc1.UploadLinearData(attr1.bias); + result.args_.AddObject( + "biases1", absl::make_unique(std::move(desc1))); + + return result; +} + +} // namespace cl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/special/fc_fc_add.h b/tensorflow/lite/delegates/gpu/cl/kernels/special/fc_fc_add.h new file mode 100644 index 00000000000..fea9d1a4990 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/cl/kernels/special/fc_fc_add.h @@ -0,0 +1,189 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_SPECIAL_FC_FC_ADD_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_SPECIAL_FC_FC_ADD_H_ + +#include + +#include +#include +#include + +#include "absl/memory/memory.h" +#include "tensorflow/lite/delegates/gpu/cl/arguments.h" +#include "tensorflow/lite/delegates/gpu/cl/buffer.h" +#include "tensorflow/lite/delegates/gpu/cl/cl_kernel.h" +#include "tensorflow/lite/delegates/gpu/cl/device_info.h" +#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" +#include "tensorflow/lite/delegates/gpu/cl/kernels/tuning_parameters.h" +#include "tensorflow/lite/delegates/gpu/cl/precision.h" +#include "tensorflow/lite/delegates/gpu/cl/texture2d.h" +#include "tensorflow/lite/delegates/gpu/common/data_type.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/tensor.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" +#include "tensorflow/lite/delegates/gpu/common/util.h" + +namespace tflite { +namespace gpu { +namespace cl { + +template +void RearrangeFCWeightsToIOO4I4(const tflite::gpu::Tensor& weights, + S* dst) { + const int src_channels = weights.shape.i; + const int padded_src_channels = AlignByN(src_channels, 4); + const int dst_channels = weights.shape.o; + const int padded_dst_channels = AlignByN(dst_channels, 4); + + for (int block_y = 0; 4 * block_y < padded_dst_channels; block_y++) { + for (int y_in_block = 0; y_in_block < 4; y_in_block++) { + for (int block_x = 0; 4 * block_x < padded_src_channels; block_x++) { + for (int x_in_block = 0; x_in_block < 4; x_in_block++) { + int y = 4 * block_y + y_in_block; + int x = 4 * block_x + x_in_block; + int dst_index = block_x * padded_dst_channels * 4 + block_y * 16 + + x_in_block * 4 + y_in_block; + if (x < src_channels && y < dst_channels) { + dst[dst_index] = weights.data[src_channels * y + x]; + } else { + dst[dst_index] = 0.0f; + } + } + } + } + } +} + +template +void RearrangeFCWeightsToOIO4I4(const tflite::gpu::Tensor& weights, + S* dst) { + const int src_channels = weights.shape.i; + const int src_depth = DivideRoundUp(src_channels, 4); + const int dst_channels = weights.shape.o; + const int dst_depth = DivideRoundUp(dst_channels, 4); + + int counter = 0; + for (int d = 0; d < dst_depth; ++d) { + for (int s = 0; s < src_depth; ++s) { + for (int i = 0; i < 4; ++i) { + const int src_ch = s * 4 + i; + for (int j = 0; j < 4; ++j) { + const int dst_ch = d * 4 + j; + if (src_ch < src_channels && dst_ch < dst_channels) { + dst[counter++] = weights.data[dst_ch * src_channels + src_ch]; + } else { + dst[counter++] = 0.0f; + } + } + } + } + } +} + +class FCFCAdd : public GPUOperation { + public: + FCFCAdd() = default; + void GetPossibleKernelWorkGroups( + TuningType tuning_type, const DeviceInfo& device_info, + const KernelInfo& kernel_info, + std::vector* work_groups) const override { + work_groups->push_back(work_group_size_); + } + int3 GetGridSize() const override; + + // Move only + FCFCAdd(FCFCAdd&& kernel); + FCFCAdd& operator=(FCFCAdd&& kernel); + FCFCAdd(const FCFCAdd&) = delete; + FCFCAdd& operator=(const FCFCAdd&) = delete; + + private: + FCFCAdd(const OperationDef& definition, const DeviceInfo& device_info); + friend FCFCAdd CreateFCFCAdd(const DeviceInfo& device_info, + const OperationDef& definition, + const FullyConnectedAttributes& attr0, + const FullyConnectedAttributes& attr1); + + template + void UploadWeights(const tflite::gpu::Tensor& weights, + const std::string& name, bool weights_are_buffer); + + std::string GetFCFCAddKernelCode(const OperationDef& op_def, + const DeviceInfo& device_info); +}; + +template +void FCFCAdd::UploadWeights(const tflite::gpu::Tensor& weights, + const std::string& name, bool weights_are_buffer) { + const int src_depth = DivideRoundUp(weights.shape.i, 4); + const int dst_depth = DivideRoundUp(weights.shape.o, 4); + + const int elements_count = src_depth * dst_depth * 4; + const bool f32_weights = definition_.precision == CalculationsPrecision::F32; + + const int float4_size = f32_weights ? 16 : 8; + + if (weights_are_buffer) { + BufferDescriptor desc; + desc.element_type = f32_weights ? DataType::FLOAT32 : DataType::FLOAT16; + desc.element_size = 16; + desc.size = float4_size * elements_count; + desc.data.resize(desc.size); + + if (f32_weights) { + float* ptr = reinterpret_cast(desc.data.data()); + RearrangeFCWeightsToIOO4I4(weights, ptr); + } else { + half* ptr = reinterpret_cast(desc.data.data()); + RearrangeFCWeightsToIOO4I4(weights, ptr); + } + + args_.AddObject(name, absl::make_unique(std::move(desc))); + } else { + Texture2DDescriptor desc; + desc.element_type = f32_weights ? DataType::FLOAT32 : DataType::FLOAT16; + // desc.element_type = DataType::UINT8; + // desc.normalized = true; + // desc.normalized_type = f32_weights ? DataType::FLOAT32 : + // DataType::FLOAT16; + desc.size = int2(src_depth * 4, dst_depth); + desc.data.resize(float4_size * elements_count); + + if (f32_weights) { + float* ptr = reinterpret_cast(desc.data.data()); + RearrangeFCWeightsToOIO4I4(weights, ptr); + } else { + half* ptr = reinterpret_cast(desc.data.data()); + RearrangeFCWeightsToOIO4I4(weights, ptr); + } + + args_.AddObject(name, + absl::make_unique(std::move(desc))); + } +} + +FCFCAdd CreateFCFCAdd(const DeviceInfo& device_info, + const OperationDef& definition, + const FullyConnectedAttributes& attr0, + const FullyConnectedAttributes& attr1); + +} // namespace cl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_SPECIAL_FC_FC_ADD_H_ diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/strided_slice.cc b/tensorflow/lite/delegates/gpu/cl/kernels/strided_slice.cc index b2ce0690a9c..1f8f985f3ee 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/strided_slice.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/strided_slice.cc @@ -154,17 +154,17 @@ std::string StridedSlice::GetStridedSliceCode(const OperationDef& op_def, return c; } -absl::Status StridedSlice::BindArguments() { +absl::Status StridedSlice::BindArguments(ArgumentsBinder* args) { int4 offset = GetOffset(attributes_, src_[0]->Width(), src_[0]->Height(), src_[0]->Channels(), src_[0]->Batch()); - RETURN_IF_ERROR(args_.SetInt("offset_x", offset.x)); - RETURN_IF_ERROR(args_.SetInt("offset_y", offset.y)); - RETURN_IF_ERROR(args_.SetInt("offset_z", offset.z)); - RETURN_IF_ERROR(args_.SetInt("offset_b", offset.w)); - RETURN_IF_ERROR(args_.SetInt("stride_x", attributes_.strides.w)); - RETURN_IF_ERROR(args_.SetInt("stride_y", attributes_.strides.h)); - RETURN_IF_ERROR(args_.SetInt("stride_z", attributes_.strides.c)); - RETURN_IF_ERROR(args_.SetInt("stride_b", attributes_.strides.b)); + RETURN_IF_ERROR(args->SetInt("offset_x", offset.x)); + RETURN_IF_ERROR(args->SetInt("offset_y", offset.y)); + RETURN_IF_ERROR(args->SetInt("offset_z", offset.z)); + RETURN_IF_ERROR(args->SetInt("offset_b", offset.w)); + RETURN_IF_ERROR(args->SetInt("stride_x", attributes_.strides.w)); + RETURN_IF_ERROR(args->SetInt("stride_y", attributes_.strides.h)); + RETURN_IF_ERROR(args->SetInt("stride_z", attributes_.strides.c)); + RETURN_IF_ERROR(args->SetInt("stride_b", attributes_.strides.b)); return absl::OkStatus(); } diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/strided_slice.h b/tensorflow/lite/delegates/gpu/cl/kernels/strided_slice.h index 5a6d8ad6047..dddff2faf35 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/strided_slice.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/strided_slice.h @@ -27,7 +27,7 @@ namespace cl { class StridedSlice : public GPUOperation { public: StridedSlice(const OperationDef& definition, const SliceAttributes& attr); - absl::Status BindArguments() override; + absl::Status BindArguments(ArgumentsBinder* args) override; int3 GetGridSize() const override; // Move only diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/tuning_parameters.h b/tensorflow/lite/delegates/gpu/cl/kernels/tuning_parameters.h index d6098b0cb81..c57ccade4b2 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/tuning_parameters.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/tuning_parameters.h @@ -17,7 +17,7 @@ limitations under the License. #define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_TUNING_PARAMETERS_H_ #include "tensorflow/lite/delegates/gpu/cl/cl_command_queue.h" -#include "tensorflow/lite/delegates/gpu/cl/cl_device.h" +#include "tensorflow/lite/delegates/gpu/cl/device_info.h" namespace tflite { namespace gpu { diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/util.cc b/tensorflow/lite/delegates/gpu/cl/kernels/util.cc index f0e0c412b7e..25fa60c776a 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/util.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/util.cc @@ -188,6 +188,14 @@ int GetRecommendedBlockSizeForConv(const DeviceInfo& device_info, return block_size; } +int3 GetWorkGroupsCount(const int3& grid_size, const int3& work_group_size) { + int3 work_groups_count; + work_groups_count.x = DivideRoundUp(grid_size.x, work_group_size.x); + work_groups_count.y = DivideRoundUp(grid_size.y, work_group_size.y); + work_groups_count.z = DivideRoundUp(grid_size.z, work_group_size.z); + return work_groups_count; +} + } // namespace cl } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/util.h b/tensorflow/lite/delegates/gpu/cl/kernels/util.h index aa9f599e4d8..69f6808146c 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/util.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/util.h @@ -17,15 +17,13 @@ limitations under the License. #define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_UTIL_H_ #include +#include #include "absl/types/span.h" #include "tensorflow/lite/delegates/gpu/cl/device_info.h" #include "tensorflow/lite/delegates/gpu/cl/precision.h" -#include "tensorflow/lite/delegates/gpu/cl/tensor_type.h" -#include "tensorflow/lite/delegates/gpu/common/access_type.h" #include "tensorflow/lite/delegates/gpu/common/data_type.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" -#include "tensorflow/lite/delegates/gpu/common/status.h" #include "tensorflow/lite/delegates/gpu/common/tensor.h" #include "tensorflow/lite/delegates/gpu/common/types.h" #include "tensorflow/lite/delegates/gpu/common/util.h" @@ -58,15 +56,12 @@ void RearrangeWeightsToOHWIOGroupI4O4( absl::Span dst) { const int dst_slices = DivideRoundUp(weights.shape.o, 4); const int src_slices = DivideRoundUp(weights.shape.i, 4); - const int kernel_x = weights.shape.w; - const int kernel_y = weights.shape.h; - const int dst_groups = DivideRoundUp(dst_slices, out_group_size); int counter = 0; for (int d = 0; d < dst_groups; ++d) { - for (int y = 0; y < kernel_y; ++y) { - for (int x = 0; x < kernel_x; ++x) { + for (int y = 0; y < weights.shape.h; ++y) { + for (int x = 0; x < weights.shape.w; ++x) { for (int s = 0; s < src_slices; ++s) { for (int d_group = 0; d_group < out_group_size; ++d_group) { for (int j = 0; j < 4; ++j) { @@ -91,6 +86,118 @@ void RearrangeWeightsToOHWIOGroupI4O4( } } +template +void RearrangeWeightsToODHWIOGroupI4O4( + const tflite::gpu::Tensor& weights, int out_group_size, + absl::Span dst) { + const int dst_slices = DivideRoundUp(weights.shape.o, 4); + const int src_slices = DivideRoundUp(weights.shape.i, 4); + const int dst_groups = DivideRoundUp(dst_slices, out_group_size); + + int counter = 0; + for (int d = 0; d < dst_groups; ++d) { + for (int z = 0; z < weights.shape.d; ++z) { + for (int y = 0; y < weights.shape.h; ++y) { + for (int x = 0; x < weights.shape.w; ++x) { + for (int s = 0; s < src_slices; ++s) { + for (int d_group = 0; d_group < out_group_size; ++d_group) { + for (int j = 0; j < 4; ++j) { + T filter; + for (int i = 0; i < 4; ++i) { + const int s_ch = s * 4 + j; + const int d_ch = (d * out_group_size + d_group) * 4 + i; + if (s_ch < weights.shape.i && d_ch < weights.shape.o) { + const int f_index = + weights.shape.LinearIndex({d_ch, y, x, z, s_ch}); + filter[i] = weights.data[f_index]; + } else { + filter[i] = 0.0f; + } + } + dst[counter++] = filter; + } + } + } + } + } + } + } +} + +template +void RearrangeWeightsToI4HWIOOGroupO4( + const tflite::gpu::Tensor& weights, int out_group_size, + absl::Span dst) { + const int dst_slices = DivideRoundUp(weights.shape.o, 4); + const int src_slices = DivideRoundUp(weights.shape.i, 4); + const int dst_groups = DivideRoundUp(dst_slices, out_group_size); + + int counter = 0; + for (int j = 0; j < 4; ++j) { + for (int y = 0; y < weights.shape.h; ++y) { + for (int x = 0; x < weights.shape.w; ++x) { + for (int s = 0; s < src_slices; ++s) { + for (int d = 0; d < dst_groups; ++d) { + for (int d_group = 0; d_group < out_group_size; ++d_group) { + T filter; + for (int i = 0; i < 4; ++i) { + const int s_ch = s * 4 + j; + const int d_ch = (d * out_group_size + d_group) * 4 + i; + if (s_ch < weights.shape.i && d_ch < weights.shape.o) { + const int f_index = + weights.shape.LinearIndex({d_ch, y, x, s_ch}); + filter[i] = weights.data[f_index]; + } else { + filter[i] = 0.0f; + } + } + dst[counter++] = filter; + } + } + } + } + } + } +} + +template +void RearrangeWeightsToI4DHWIOOGroupO4( + const tflite::gpu::Tensor& weights, int out_group_size, + absl::Span dst) { + const int dst_slices = DivideRoundUp(weights.shape.o, 4); + const int src_slices = DivideRoundUp(weights.shape.i, 4); + const int dst_groups = DivideRoundUp(dst_slices, out_group_size); + + int counter = 0; + for (int j = 0; j < 4; ++j) { + for (int z = 0; z < weights.shape.d; ++z) { + for (int y = 0; y < weights.shape.h; ++y) { + for (int x = 0; x < weights.shape.w; ++x) { + for (int s = 0; s < src_slices; ++s) { + for (int d = 0; d < dst_groups; ++d) { + for (int d_group = 0; d_group < out_group_size; ++d_group) { + T filter; + for (int i = 0; i < 4; ++i) { + const int s_ch = s * 4 + j; + const int d_ch = (d * out_group_size + d_group) * 4 + i; + if (s_ch < weights.shape.i && d_ch < weights.shape.o) { + const int f_index = + weights.shape.LinearIndex({d_ch, y, x, z, s_ch}); + filter[i] = weights.data[f_index]; + } else { + filter[i] = 0.0f; + } + } + dst[counter++] = filter; + } + } + } + } + } + } + } +} + // Returns float4 mask for last plane(batch of 4 channels) // assumes that plane size is 4; // for example we have 7 channels, in our data structures we align it to 8 @@ -106,6 +213,8 @@ int3 GetFirstSuitableWorkGroup(const std::vector& wgs, int max_wg_size); int GetRecommendedBlockSizeForConv(const DeviceInfo& device, CalculationsPrecision precision, int task_size); + +int3 GetWorkGroupsCount(const int3& grid_size, const int3& work_group_size); } // namespace cl } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/winograd.cc b/tensorflow/lite/delegates/gpu/cl/kernels/winograd.cc index 0f94847f08a..1244f769b48 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/winograd.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/winograd.cc @@ -262,16 +262,16 @@ int3 Winograd4x4To36::SelectBestWorkGroup(const KernelInfo& kernel_info) const { return GetFirstSuitableWorkGroup(wgs, kernel_info.max_work_group_size); } -absl::Status Winograd4x4To36::BindArguments() { +absl::Status Winograd4x4To36::BindArguments(ArgumentsBinder* args) { const int tiles_x = DivideRoundUp( src_[0]->Width() + padding_.prepended.w + padding_.appended.w - 2, 4); const int tiles_y = DivideRoundUp( src_[0]->Height() + padding_.prepended.h + padding_.appended.h - 2, 4); const int tiles_total = tiles_x * tiles_y; - RETURN_IF_ERROR(args_.SetInt("padding_x", -padding_.prepended.w)); - RETURN_IF_ERROR(args_.SetInt("padding_y", -padding_.prepended.h)); - RETURN_IF_ERROR(args_.SetInt("tiles_total", tiles_total)); - RETURN_IF_ERROR(args_.SetInt("tiles_x", tiles_x)); + RETURN_IF_ERROR(args->SetInt("padding_x", -padding_.prepended.w)); + RETURN_IF_ERROR(args->SetInt("padding_y", -padding_.prepended.h)); + RETURN_IF_ERROR(args->SetInt("tiles_total", tiles_total)); + RETURN_IF_ERROR(args->SetInt("tiles_x", tiles_x)); return absl::OkStatus(); } @@ -463,9 +463,9 @@ int3 Winograd36To4x4::SelectBestWorkGroup(const KernelInfo& kernel_info) const { return GetFirstSuitableWorkGroup(wgs, kernel_info.max_work_group_size); } -absl::Status Winograd36To4x4::BindArguments() { +absl::Status Winograd36To4x4::BindArguments(ArgumentsBinder* args) { const int tiles_x = DivideRoundUp(dst_[0]->Width(), 4); - RETURN_IF_ERROR(args_.SetInt("tiles_x", tiles_x)); + RETURN_IF_ERROR(args->SetInt("tiles_x", tiles_x)); return absl::OkStatus(); } diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/winograd.h b/tensorflow/lite/delegates/gpu/cl/kernels/winograd.h index a5da49e7939..609e38a4c9a 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/winograd.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/winograd.h @@ -36,7 +36,7 @@ class Winograd4x4To36 : public GPUOperation { Winograd4x4To36() = default; Winograd4x4To36(const OperationDef& definition, const Padding2D& padding, const DeviceInfo& device_info); - absl::Status BindArguments() override; + absl::Status BindArguments(ArgumentsBinder* args) override; int3 GetGridSize() const override; void GetPossibleKernelWorkGroups( TuningType tuning_type, const DeviceInfo& device_info, @@ -73,7 +73,7 @@ class Winograd36To4x4 : public GPUOperation { Winograd36To4x4() = default; Winograd36To4x4(const OperationDef& definition, const DeviceInfo& device_info); - absl::Status BindArguments() override; + absl::Status BindArguments(ArgumentsBinder* args) override; int3 GetGridSize() const override; void GetPossibleKernelWorkGroups( TuningType tuning_type, const DeviceInfo& device_info, diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h b/tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h index 0c1be10782e..ea58ff25bc2 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h @@ -16,10 +16,11 @@ limitations under the License. #ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_WORK_GROUP_PICKING_H_ #define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_WORK_GROUP_PICKING_H_ -#include "tensorflow/lite/delegates/gpu/cl/cl_command_queue.h" +#include + #include "tensorflow/lite/delegates/gpu/cl/cl_kernel.h" +#include "tensorflow/lite/delegates/gpu/cl/device_info.h" #include "tensorflow/lite/delegates/gpu/cl/kernels/tuning_parameters.h" -#include "tensorflow/lite/delegates/gpu/common/status.h" #include "tensorflow/lite/delegates/gpu/common/types.h" #include "tensorflow/lite/delegates/gpu/common/workgroup_selection.h" diff --git a/tensorflow/lite/delegates/gpu/cl/linear_storage.cc b/tensorflow/lite/delegates/gpu/cl/linear_storage.cc index 75920f4f8c5..8f7b314b707 100644 --- a/tensorflow/lite/delegates/gpu/cl/linear_storage.cc +++ b/tensorflow/lite/delegates/gpu/cl/linear_storage.cc @@ -204,8 +204,9 @@ absl::Status LinearStorage::CreateFromTensorLinearDescriptor( return CreateCLBuffer(context->context(), depth_ * float4_size, read_only, data_ptr, &memory_); } else { - return CreateFloatRGBAImage2D(context->context(), depth_, 1, - desc.element_type, data_ptr, &memory_); + return CreateRGBAImage2D(context->context(), depth_, 1, + DataTypeToChannelType(desc.element_type), data_ptr, + &memory_); } } diff --git a/tensorflow/lite/delegates/gpu/cl/opencl_wrapper.cc b/tensorflow/lite/delegates/gpu/cl/opencl_wrapper.cc index bf2fd449291..add0e2fd4e9 100644 --- a/tensorflow/lite/delegates/gpu/cl/opencl_wrapper.cc +++ b/tensorflow/lite/delegates/gpu/cl/opencl_wrapper.cc @@ -36,7 +36,7 @@ namespace cl { #ifdef __ANDROID__ #define LoadFunction(function) \ - if (is_pixel) { \ + if (use_wrapper) { \ function = reinterpret_cast(loadOpenCLPointer(#function)); \ } else { \ function = reinterpret_cast(dlsym(libopencl, #function)); \ @@ -53,7 +53,7 @@ namespace cl { #ifdef __WINDOWS__ void LoadOpenCLFunctions(HMODULE libopencl); #else -void LoadOpenCLFunctions(void* libopencl, bool is_pixel); +void LoadOpenCLFunctions(void* libopencl, bool use_wrapper); #endif absl::Status LoadOpenCL() { @@ -77,8 +77,11 @@ absl::Status LoadOpenCL() { // record error std::string error(dlerror()); #ifdef __ANDROID__ - // Pixel phone? + // Pixel phone or auto? libopencl = dlopen("libOpenCL-pixel.so", RTLD_NOW | RTLD_LOCAL); + if (!libopencl) { + libopencl = dlopen("libOpenCL-car.so", RTLD_NOW | RTLD_LOCAL); + } if (libopencl) { typedef void (*enableOpenCL_t)(); enableOpenCL_t enableOpenCL = @@ -96,11 +99,11 @@ absl::Status LoadOpenCL() { #ifdef __WINDOWS__ void LoadOpenCLFunctions(HMODULE libopencl) { #else -void LoadOpenCLFunctions(void* libopencl, bool is_pixel) { +void LoadOpenCLFunctions(void* libopencl, bool use_wrapper) { #ifdef __ANDROID__ typedef void* (*loadOpenCLPointer_t)(const char* name); loadOpenCLPointer_t loadOpenCLPointer; - if (is_pixel) { + if (use_wrapper) { loadOpenCLPointer = reinterpret_cast( dlsym(libopencl, "loadOpenCLPointer")); } diff --git a/tensorflow/lite/delegates/gpu/cl/run_tests.sh b/tensorflow/lite/delegates/gpu/cl/run_tests.sh index 16d2feb8a5a..0eed264a06f 100755 --- a/tensorflow/lite/delegates/gpu/cl/run_tests.sh +++ b/tensorflow/lite/delegates/gpu/cl/run_tests.sh @@ -64,11 +64,17 @@ trap "cleanup_device" EXIT declare -a BUILD_CONFIG abi_version=$(ADB shell getprop ro.product.cpu.abi | tr -d '\r') if [[ "$abi_version" == "armeabi-v7a" ]]; then -#"32 bit" +#"32 bit ARM" BUILD_CONFIG=( --config=android_arm -c opt --copt=-fPIE --linkopt=-pie ) -else -#"64 bit" +elif [[ "$abi_version" == "arm64-v8a" ]]; then +#"64 bit ARM" BUILD_CONFIG=( --config=android_arm64 -c opt ) +elif [[ "$abi_version" == "x86_64" ]]; then +# x86_64 +BUILD_CONFIG=( --config=android_x86_64 -c opt ) +else +echo "Error: Unknown processor ABI" +exit 1 fi targets=($(bazel query 'tests('$test_target')')) diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/BUILD b/tensorflow/lite/delegates/gpu/cl/selectors/BUILD index 3e2b8855af9..8a22741f013 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/BUILD +++ b/tensorflow/lite/delegates/gpu/cl/selectors/BUILD @@ -14,7 +14,6 @@ cc_library( "//tensorflow/lite/delegates/gpu/cl/kernels:conv_common", "//tensorflow/lite/delegates/gpu/cl/kernels:conv_constants", "//tensorflow/lite/delegates/gpu/cl/kernels:conv_powervr", - "//tensorflow/lite/delegates/gpu/cl/kernels:conv_texture", "//tensorflow/lite/delegates/gpu/cl/kernels:conv_weights_converter", "//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation", "//tensorflow/lite/delegates/gpu/cl/kernels:work_group_picking", @@ -82,7 +81,6 @@ cc_library( deps = [ "//tensorflow/lite/delegates/gpu/cl/kernels:conv_buffer_1x1", "//tensorflow/lite/delegates/gpu/cl/kernels:conv_powervr", - "//tensorflow/lite/delegates/gpu/cl/kernels:conv_texture", "//tensorflow/lite/delegates/gpu/cl/kernels:fully_connected", "//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation", "//tensorflow/lite/delegates/gpu/common:operations", @@ -110,6 +108,8 @@ cc_library( "//tensorflow/lite/delegates/gpu/cl/kernels:elementwise", "//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation", "//tensorflow/lite/delegates/gpu/cl/kernels:mean_stddev_normalization", + "//tensorflow/lite/delegates/gpu/cl/kernels:reduce", + "//tensorflow/lite/delegates/gpu/cl/kernels:transpose", "//tensorflow/lite/delegates/gpu/cl/selectors:default_selector", "//tensorflow/lite/delegates/gpu/common:data_type", "//tensorflow/lite/delegates/gpu/common:model", @@ -130,6 +130,7 @@ cc_library( "//tensorflow/lite/delegates/gpu/cl/kernels:add", "//tensorflow/lite/delegates/gpu/cl/kernels:concat_xy", "//tensorflow/lite/delegates/gpu/cl/kernels:concat_z", + "//tensorflow/lite/delegates/gpu/cl/kernels:depthwise_conv", "//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation", "//tensorflow/lite/delegates/gpu/cl/kernels:lstm", "//tensorflow/lite/delegates/gpu/cl/kernels:max_unpooling", @@ -165,6 +166,7 @@ cc_library( "//tensorflow/lite/delegates/gpu/cl:tensor_type", "//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation", "//tensorflow/lite/delegates/gpu/cl/kernels/special:depthwise_conv_plus_1x1_conv", + "//tensorflow/lite/delegates/gpu/cl/kernels/special:fc_fc_add", "//tensorflow/lite/delegates/gpu/common:data_type", "//tensorflow/lite/delegates/gpu/common:model", "//tensorflow/lite/delegates/gpu/common:operations", diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.cc b/tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.cc index eab957e28a6..a3282f05200 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.cc +++ b/tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.cc @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/cl/kernels/conv_buffer_1x1.h" #include "tensorflow/lite/delegates/gpu/cl/kernels/conv_constants.h" #include "tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/conv_texture.h" #include "tensorflow/lite/delegates/gpu/cl/kernels/conv_weights_converter.h" #include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h" #include "tensorflow/lite/delegates/gpu/cl/tensor_type.h" @@ -35,11 +34,11 @@ std::unique_ptr SelectConvolutionAdreno( const DeviceInfo& device_info, const OperationDef& op_def, ModelHints hints) { if (IsConvConstantsSupported(device_info, op_def, attr)) { - ConvConstants conv = CreateConvConstants(device_info, op_def, attr); - return absl::make_unique(std::move(conv)); + GPUOperation conv = CreateConvConstants(device_info, op_def, attr); + return absl::make_unique(std::move(conv)); } else { - ConvTexture conv = CreateConvTexture(device_info, op_def, attr); - return absl::make_unique(std::move(conv)); + ConvPowerVR conv = CreateConvPowerVR(device_info, op_def, attr, &dst_shape); + return absl::make_unique(std::move(conv)); } } @@ -47,8 +46,9 @@ std::unique_ptr SelectConvolutionWinogradAdreno( const Convolution2DAttributes& attr, const BHWC& dst_shape, const DeviceInfo& device_info, const OperationDef& op_def, ModelHints hints) { - ConvTexture conv = CreateConvTextureWino4x4To6x6(device_info, op_def, attr); - return absl::make_unique(std::move(conv)); + ConvPowerVR conv = + CreateConvPowerVRWino4x4To6x6(device_info, op_def, attr, &dst_shape); + return absl::make_unique(std::move(conv)); } std::unique_ptr SelectConvolutionDynamicWeightsAdreno( @@ -66,8 +66,8 @@ std::unique_ptr SelectConvolutionNVidia( const Convolution2DAttributes& attr, const BHWC& dst_shape, const DeviceInfo& device_info, const OperationDef& op_def) { if (IsConvConstantsSupported(device_info, op_def, attr)) { - ConvConstants conv = CreateConvConstants(device_info, op_def, attr); - return absl::make_unique(std::move(conv)); + GPUOperation conv = CreateConvConstants(device_info, op_def, attr); + return absl::make_unique(std::move(conv)); } else { ConvPowerVR conv = CreateConvPowerVR(device_info, op_def, attr, &dst_shape); return absl::make_unique(std::move(conv)); diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/dw_convolution_selector.cc b/tensorflow/lite/delegates/gpu/cl/selectors/dw_convolution_selector.cc index 2d61defe64b..b04335a4d7d 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/dw_convolution_selector.cc +++ b/tensorflow/lite/delegates/gpu/cl/selectors/dw_convolution_selector.cc @@ -33,8 +33,8 @@ std::unique_ptr SelectDWConvolutionAdreno( return absl::make_unique( CreateDepthwiseConv3x3(device_info, op_def, attr)); } else { - return absl::make_unique( - CreateDepthwiseConvolution(device_info, op_def, attr)); + return absl::make_unique( + CreateDepthwiseConvolution2D(device_info, op_def, attr)); } } @@ -45,8 +45,8 @@ std::unique_ptr SelectDWConvolutionPowerVR( return absl::make_unique( CreateDepthwiseConv3x3(device_info, op_def, attr)); } else { - return absl::make_unique( - CreateDepthwiseConvolution(device_info, op_def, attr)); + return absl::make_unique( + CreateDepthwiseConvolution2D(device_info, op_def, attr)); } } @@ -62,8 +62,8 @@ std::unique_ptr SelectDWConvolutionMali( return absl::make_unique( CreateDepthwiseConv3x3(device_info, op_def, attr)); } else { - return absl::make_unique( - CreateDepthwiseConvolution(device_info, op_def, attr)); + return absl::make_unique( + CreateDepthwiseConvolution2D(device_info, op_def, attr)); } } } // namespace diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/fully_connected_selector.cc b/tensorflow/lite/delegates/gpu/cl/selectors/fully_connected_selector.cc index 24c48d52f2a..6c6ee044cdd 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/fully_connected_selector.cc +++ b/tensorflow/lite/delegates/gpu/cl/selectors/fully_connected_selector.cc @@ -18,7 +18,6 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/lite/delegates/gpu/cl/kernels/conv_buffer_1x1.h" #include "tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/conv_texture.h" #include "tensorflow/lite/delegates/gpu/cl/kernels/fully_connected.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/common/status.h" @@ -31,8 +30,9 @@ std::unique_ptr SelectFullyConnectedGeneric( const FullyConnectedAttributes& attr, const DeviceInfo& device_info, const OperationDef& op_def, int batch_size) { if (op_def.IsBatchSupported()) { - ConvTexture conv = CreateConvTexture(device_info, op_def, attr); - return absl::make_unique(std::move(conv)); + BHWC dst_shape = BHWC(batch_size, 1, 1, attr.weights.shape.o); + ConvPowerVR conv = CreateConvPowerVR(device_info, op_def, attr, &dst_shape); + return absl::make_unique(std::move(conv)); } else { FullyConnected fc = CreateFullyConnected(device_info, op_def, attr); return absl::make_unique(std::move(fc)); @@ -43,8 +43,9 @@ std::unique_ptr SelectFullyConnectedAdreno( const FullyConnectedAttributes& attr, const DeviceInfo& device_info, const OperationDef& op_def, int batch_size) { if (op_def.IsBatchSupported()) { - ConvTexture conv = CreateConvTexture(device_info, op_def, attr); - return absl::make_unique(std::move(conv)); + BHWC dst_shape = BHWC(batch_size, 1, 1, attr.weights.shape.o); + ConvPowerVR conv = CreateConvPowerVR(device_info, op_def, attr, &dst_shape); + return absl::make_unique(std::move(conv)); } else { FullyConnected fc = CreateFullyConnected(device_info, op_def, attr); return absl::make_unique(std::move(fc)); @@ -71,8 +72,10 @@ std::unique_ptr SelectFullyConnectedMali( ConvBuffer1x1 conv = CreateConvBuffer1x1(device_info, op_def, attr); return absl::make_unique(std::move(conv)); } else { - ConvTexture conv = CreateConvTexture(device_info, op_def, attr); - return absl::make_unique(std::move(conv)); + BHWC dst_shape = BHWC(batch_size, 1, 1, attr.weights.shape.o); + ConvPowerVR conv = + CreateConvPowerVR(device_info, op_def, attr, &dst_shape); + return absl::make_unique(std::move(conv)); } } else { FullyConnected fc = CreateFullyConnected(device_info, op_def, attr); diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.cc b/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.cc index 5e8e4a9fea7..f7981fc67bb 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.cc +++ b/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.cc @@ -20,6 +20,8 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/cl/cl_device.h" #include "tensorflow/lite/delegates/gpu/cl/kernels/elementwise.h" #include "tensorflow/lite/delegates/gpu/cl/kernels/mean_stddev_normalization.h" +#include "tensorflow/lite/delegates/gpu/cl/kernels/reduce.h" +#include "tensorflow/lite/delegates/gpu/cl/kernels/transpose.h" #include "tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.h" #include "tensorflow/lite/delegates/gpu/cl/selectors/convolution_transposed_selector.h" #include "tensorflow/lite/delegates/gpu/cl/selectors/default_selector.h" @@ -164,6 +166,80 @@ absl::Status GPUOperationFromNode(const DeviceInfo& device_info, return absl::UnimplementedError(absl::StrCat( "No support of ", node.operation.type, " with this parameters")); } + case OperationType::BATCHED_MATMUL: { + // Currently only batch = 1 is supported. + // Matmul replaced with this sequence: + // 1) Transpose second tensor(weights). (1x1xHxW)->(Wx1x1xH) + // 2) Convert second tensor(weights) from 1) to Convolution weights + // 3) Run usual convolution + auto second_shape = inputs[1]->tensor.shape; + auto dst_shape = outputs[0]->tensor.shape; + if (dst_shape.b != 1) { + return absl::UnimplementedError( + "Currently only batch = 1 supported for BATCHED_MATMUL."); + } + BHWC weights_shape(second_shape.c, 1, 1, second_shape.w); + Convolution2DAttributes attr; + attr.strides = HW(1, 1); + attr.dilations = HW(1, 1); + attr.padding.appended = HW(0, 0); + attr.padding.prepended = HW(0, 0); + attr.bias.shape = Linear(weights_shape.b); + attr.bias.data.resize(weights_shape.b, 0.0f); + + TensorDescriptor transposed_desc = {op_def.src_tensors[1].data_type, + op_def.src_tensors[1].storage_type, + Layout::BHWC}; + transposed_desc.storage_type = SelectBestStorageType( + device_info, weights_shape, transposed_desc.storage_type, + transposed_desc.data_type, transposed_desc.layout); + TensorDescriptor weights_desc = {op_def.src_tensors[1].data_type, + TensorStorageType::BUFFER, Layout::BHWC}; + gpu_subgraph->operations.clear(); + gpu_subgraph->operations.resize(3); + auto& transpose_op = gpu_subgraph->operations[0]; + auto& converter_op = gpu_subgraph->operations[1]; + auto& conv_op = gpu_subgraph->operations[2]; + conv_op.input_ids = {static_cast(inputs[0]->id), -1}; + conv_op.output_ids = {static_cast(outputs[0]->id)}; + OperationDef conv_def = op_def; + conv_def.src_tensors[1] = weights_desc; + ConvWeightsDescription conv_weights_desc; + conv_op.operation = SelectConvolutionWithDynamicWeights( + attr, weights_shape, dst_shape, device_info, conv_def, hints, + &conv_weights_desc); + + int aligned_output = + AlignByN(weights_shape.b, conv_weights_desc.output_group_size * 4); + int aligned_input = AlignByN(weights_shape.c, 4); + gpu_subgraph->new_tensors = {{BHWC(1, 1, 1, + aligned_output * aligned_input * + weights_shape.h * weights_shape.w), + weights_desc}, + {weights_shape, transposed_desc}}; + OperationDef converter_def; + converter_def.precision = op_def.precision; + converter_def.src_tensors.push_back(transposed_desc); + converter_def.dst_tensors.push_back(weights_desc); + + converter_op.input_ids = {-2}; + converter_op.output_ids = {-1}; + converter_op.operation = + SelectConverterToConvWeights(conv_weights_desc, converter_def, hints); + + OperationDef transpose_def; + transpose_def.precision = op_def.precision; + transpose_def.src_tensors.push_back(op_def.src_tensors[1]); + transpose_def.dst_tensors.push_back(transposed_desc); + + transpose_op.input_ids = {static_cast(inputs[1]->id)}; + transpose_op.output_ids = {-2}; + TransposeAttributes transpose_attr; + transpose_attr.perm = BHWC(3, 0, 1, 2); + transpose_op.operation = absl::make_unique( + CreateTranspose(transpose_def, transpose_attr)); + return absl::OkStatus(); + } case OperationType::CONCAT: { auto attr = absl::any_cast(node.operation.attributes); std::vector channels(inputs.size()); @@ -190,6 +266,10 @@ absl::Status GPUOperationFromNode(const DeviceInfo& device_info, } } else { auto weights_shape = inputs[1]->tensor.shape; + if (attr.bias.data.empty()) { + attr.bias.shape = Linear(weights_shape.b); + attr.bias.data.resize(weights_shape.b, 0.0f); + } TensorDescriptor weights_desc = {op_def.src_tensors[1].data_type, TensorStorageType::BUFFER, Layout::BHWC}; @@ -235,7 +315,16 @@ absl::Status GPUOperationFromNode(const DeviceInfo& device_info, case OperationType::DEPTHWISE_CONVOLUTION: { auto attr = absl::any_cast( node.operation.attributes); - *gpu_op = SelectDWConvolution(attr, device_info, op_def); + if (inputs.size() == 1) { + *gpu_op = SelectDWConvolution(attr, device_info, op_def); + } else { + if (inputs[1]->tensor.shape.b != 1) { + return absl::UnimplementedError( + "No support of depthwise runtime weights with channel multiplier " + "!= 1"); + } + *gpu_op = SelectDWConvolutionDynamicWeights(attr, device_info, op_def); + } return absl::OkStatus(); } case OperationType::FULLY_CONNECTED: { @@ -260,8 +349,8 @@ absl::Status GPUOperationFromNode(const DeviceInfo& device_info, return SelectMean(attr, op_def, device_info, gpu_op); } case OperationType::MEAN_STDDEV_NORMALIZATION: { - MeanStdDevNormalization operation = - CreateMeanStdDevNormalization(op_def, device_info); + MeanStdDevNormalization operation = CreateMeanStdDevNormalization( + op_def, device_info, (inputs[0]->tensor.shape.c + 3) / 4); *gpu_op = absl::make_unique(std::move(operation)); return absl::OkStatus(); @@ -331,6 +420,7 @@ absl::Status GPUOperationFromNode(const DeviceInfo& device_info, case OperationType::EXP: case OperationType::HARD_SWISH: case OperationType::LOG: + case OperationType::NEG: case OperationType::RSQRT: case OperationType::SIGMOID: case OperationType::SIN: @@ -342,9 +432,15 @@ absl::Status GPUOperationFromNode(const DeviceInfo& device_info, return absl::OkStatus(); } case OperationType::DIV: + case OperationType::EQUAL: + case OperationType::GREATER: + case OperationType::GREATER_EQUAL: + case OperationType::LESS: + case OperationType::LESS_EQUAL: case OperationType::MAXIMUM: case OperationType::MINIMUM: case OperationType::MUL: + case OperationType::NOT_EQUAL: case OperationType::POW: case OperationType::SQUARED_DIFF: case OperationType::SUB: { @@ -364,6 +460,19 @@ absl::Status GPUOperationFromNode(const DeviceInfo& device_info, return absl::UnimplementedError(absl::StrCat( "No support of ", node.operation.type, " with this parameters")); } + case OperationType::REDUCE_MAXIMUM: + case OperationType::REDUCE_MINIMUM: + case OperationType::REDUCE_PRODUCT: + case OperationType::REDUCE_SUM: { + auto attr = absl::any_cast(node.operation.attributes); + if (attr.axis != Axis::CHANNELS) { + return absl::UnimplementedError( + "Currently we can reduce only in channels dimension."); + } + GPUOperation operation = CreateReduce(op_def, attr, op_type); + *gpu_op = absl::make_unique(std::move(operation)); + return absl::OkStatus(); + } default: return SelectDefault(device_info, op_def, hints, inputs, outputs, node, gpu_subgraph); diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.cc b/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.cc index 4dbb1ffd734..713892f9902 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.cc +++ b/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/cl/kernels/add.h" #include "tensorflow/lite/delegates/gpu/cl/kernels/concat_xy.h" #include "tensorflow/lite/delegates/gpu/cl/kernels/concat_z.h" +#include "tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv.h" #include "tensorflow/lite/delegates/gpu/cl/kernels/lstm.h" #include "tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling.h" #include "tensorflow/lite/delegates/gpu/cl/kernels/mean.h" @@ -110,6 +111,13 @@ absl::Status SelectConcat(const ConcatAttributes& attr, } } +std::unique_ptr SelectDWConvolutionDynamicWeights( + const DepthwiseConvolution2DAttributes& attr, const DeviceInfo& device_info, + const OperationDef& op_def) { + return absl::make_unique( + CreateDepthwiseConvolution2DDynamicWeights(device_info, op_def, attr)); +} + void SelectReshape(int src_channels, int dst_channels, const OperationDef& op_def, std::unique_ptr* ptr) { diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.h b/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.h index c6c604da982..084298442e3 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.h +++ b/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.h @@ -57,6 +57,10 @@ absl::Status SelectConcat(const ConcatAttributes& attr, const DeviceInfo& device_info, std::unique_ptr* ptr); +std::unique_ptr SelectDWConvolutionDynamicWeights( + const DepthwiseConvolution2DAttributes& attr, const DeviceInfo& device_info, + const OperationDef& op_def); + void SelectReshape(int src_channels, int dst_channels, const OperationDef& op_def, std::unique_ptr* ptr); diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/special_selector.cc b/tensorflow/lite/delegates/gpu/cl/selectors/special_selector.cc index 31480f231b0..631eabc4569 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/special_selector.cc +++ b/tensorflow/lite/delegates/gpu/cl/selectors/special_selector.cc @@ -18,6 +18,7 @@ limitations under the License. #include "absl/types/any.h" #include "tensorflow/lite/delegates/gpu/cl/cl_device.h" #include "tensorflow/lite/delegates/gpu/cl/kernels/special/depthwise_conv_plus_1x1_conv.h" +#include "tensorflow/lite/delegates/gpu/cl/kernels/special/fc_fc_add.h" #include "tensorflow/lite/delegates/gpu/cl/tensor_type.h" #include "tensorflow/lite/delegates/gpu/common/data_type.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" @@ -39,6 +40,10 @@ absl::Status TryDepthwiseConvPlus1x1Conv( OperationType::DEPTHWISE_CONVOLUTION) { return absl::NotFoundError("DepthwiseConvPlus1x1Conv not suitable."); } + auto dw_inputs = graph.FindInputs(dw_node->id); + if (dw_inputs.size() != 1) { + return absl::NotFoundError("DepthwiseConvPlus1x1Conv not suitable."); + } auto dw_outputs = graph.FindOutputs(dw_node->id); auto consumers = graph.FindConsumers(dw_outputs[0]->id); if (consumers.size() != 1) { @@ -59,7 +64,6 @@ absl::Status TryDepthwiseConvPlus1x1Conv( dw_node->operation.attributes); auto conv_attr = absl::any_cast(conv_node->operation.attributes); - auto dw_inputs = graph.FindInputs(dw_node->id); auto conv_outputs = graph.FindOutputs(conv_node->id); OperationDef op_def; op_def.precision = precision; @@ -82,22 +86,108 @@ absl::Status TryDepthwiseConvPlus1x1Conv( consumed_nodes->insert(conv_node->id); return absl::OkStatus(); } + +// fully connected + fully connected + add +absl::Status TryFCFCAdd( + const DeviceInfo& device_info, CalculationsPrecision precision, + const GraphFloat32& graph, NodeId first_node_id, + const std::map& tensor_descriptors, + std::set* consumed_nodes, GPUOperationsSubgraph* gpu_subgraph) { + auto* fc0_node = graph.GetNode(first_node_id); + if (OperationTypeFromString(fc0_node->operation.type) != + OperationType::FULLY_CONNECTED) { + return absl::NotFoundError("FCFCAdd not suitable."); + } + auto fc0_inputs = graph.FindInputs(fc0_node->id); + if (fc0_inputs.size() != 1) { + return absl::NotFoundError("FCFCAdd not suitable."); + } + auto fc0_output_id = graph.FindOutputs(fc0_node->id)[0]->id; + auto consumers = graph.FindConsumers(fc0_output_id); + if (consumers.size() != 1) { + return absl::NotFoundError("FCFCAdd not suitable."); + } + auto* add_node = consumers[0]; + if (consumed_nodes->find(add_node->id) != consumed_nodes->end()) { + return absl::NotFoundError("FCFCAdd not suitable."); + } + if (OperationTypeFromString(add_node->operation.type) != OperationType::ADD) { + return absl::NotFoundError("FCFCAdd not suitable."); + } + auto add_inputs = graph.FindInputs(add_node->id); + if (add_inputs.size() != 2) { + return absl::NotFoundError("FCFCAdd not suitable."); + } + auto fc1_output_id = add_inputs[0]->id + add_inputs[1]->id - fc0_output_id; + auto* fc1_node = graph.FindProducer(fc1_output_id); + if (OperationTypeFromString(fc1_node->operation.type) != + OperationType::FULLY_CONNECTED) { + return absl::NotFoundError("FCFCAdd not suitable."); + } + if (consumed_nodes->find(fc1_node->id) != consumed_nodes->end()) { + return absl::NotFoundError("FCFCAdd not suitable."); + } + auto fc1_inputs = graph.FindInputs(fc1_node->id); + if (fc1_inputs.size() != 1) { + return absl::NotFoundError("FCFCAdd not suitable."); + } + auto fc0_attr = + absl::any_cast(fc0_node->operation.attributes); + auto fc1_attr = + absl::any_cast(fc1_node->operation.attributes); + if (fc0_attr.weights.shape.o != fc1_attr.weights.shape.o) { + return absl::NotFoundError("FCFCAdd not suitable."); + } + auto add_outputs = graph.FindOutputs(add_node->id); + + OperationDef op_def; + op_def.precision = precision; + auto it = tensor_descriptors.find(fc0_inputs[0]->id); + if (it != tensor_descriptors.end()) { + op_def.src_tensors.push_back(it->second); + } + it = tensor_descriptors.find(fc1_inputs[0]->id); + if (it != tensor_descriptors.end()) { + op_def.src_tensors.push_back(it->second); + } + it = tensor_descriptors.find(add_outputs[0]->id); + if (it != tensor_descriptors.end()) { + op_def.dst_tensors.push_back(it->second); + } + + for (int i = 0; i < fc1_inputs.size(); ++i) { + fc0_inputs.push_back(fc1_inputs[i]); + } + std::unique_ptr* gpu_op = + InitSingleOpSubgraph(fc0_inputs, add_outputs, gpu_subgraph); + FCFCAdd fc = CreateFCFCAdd(device_info, op_def, fc0_attr, fc1_attr); + *gpu_op = absl::make_unique(std::move(fc)); + consumed_nodes->insert(fc0_node->id); + consumed_nodes->insert(fc1_node->id); + consumed_nodes->insert(add_node->id); + return absl::OkStatus(); +} } // namespace absl::Status GPUSubgraphFromGraph( const DeviceInfo& device_info, CalculationsPrecision precision, const GraphFloat32& graph, NodeId first_node_id, const std::map& tensor_descriptors, - std::set* consumed_nodes, GPUOperationsSubgraph* gpu_subgraph) { - if (!device_info.IsNvidia()) { - return absl::NotFoundError( - "Experimental feature, enabled for NVidia only, but device is not " - "nvidia gpu."); - } - if (TryDepthwiseConvPlus1x1Conv(precision, graph, first_node_id, + std::set* consumed_nodes, GPUOperationsSubgraph* gpu_subgraph, + std::string* name) { + if ((device_info.IsAdreno() || device_info.IsNvidia()) && + TryDepthwiseConvPlus1x1Conv(precision, graph, first_node_id, tensor_descriptors, consumed_nodes, gpu_subgraph) .ok()) { + *name = "depthwise_conv_plus_1x1_conv"; + return absl::OkStatus(); + } + if ((device_info.IsIntel() || device_info.IsNvidia()) && + TryFCFCAdd(device_info, precision, graph, first_node_id, + tensor_descriptors, consumed_nodes, gpu_subgraph) + .ok()) { + *name = "fully_connected_x2_and_add"; return absl::OkStatus(); } return absl::NotFoundError("No special combination."); diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/special_selector.h b/tensorflow/lite/delegates/gpu/cl/selectors/special_selector.h index 3ea99b2515a..6091415e14c 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/special_selector.h +++ b/tensorflow/lite/delegates/gpu/cl/selectors/special_selector.h @@ -34,7 +34,8 @@ absl::Status GPUSubgraphFromGraph( const DeviceInfo& device_info, CalculationsPrecision precision, const GraphFloat32& graph, NodeId first_node_id, const std::map& tensor_descriptors, - std::set* consumed_nodes, GPUOperationsSubgraph* gpu_subgraph); + std::set* consumed_nodes, GPUOperationsSubgraph* gpu_subgraph, + std::string* name); } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/serialization.cc b/tensorflow/lite/delegates/gpu/cl/serialization.cc new file mode 100644 index 00000000000..3b52fc40bdf --- /dev/null +++ b/tensorflow/lite/delegates/gpu/cl/serialization.cc @@ -0,0 +1,1049 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/cl/serialization.h" + +#include + +#include "tensorflow/lite/delegates/gpu/cl/arguments.h" +#include "tensorflow/lite/delegates/gpu/cl/buffer.h" +#include "tensorflow/lite/delegates/gpu/cl/gpu_object.h" +#include "tensorflow/lite/delegates/gpu/cl/inference_context.h" +#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" +#include "tensorflow/lite/delegates/gpu/cl/linear_storage.h" +#include "tensorflow/lite/delegates/gpu/cl/precision.h" +#include "tensorflow/lite/delegates/gpu/cl/serialization_generated.h" +#include "tensorflow/lite/delegates/gpu/cl/tensor_type.h" +#include "tensorflow/lite/delegates/gpu/cl/texture2d.h" +#include "tensorflow/lite/delegates/gpu/common/model.h" + +namespace tflite { +namespace gpu { +namespace cl { +namespace { +data::AccessType ToFB(AccessType type) { + switch (type) { + case AccessType::READ: + return data::AccessType::READ; + case AccessType::WRITE: + return data::AccessType::WRITE; + case AccessType::READ_WRITE: + return data::AccessType::READ_WRITE; + default: + return data::AccessType::READ_WRITE; + } +} + +data::DataType ToFB(DataType type) { + switch (type) { + case DataType::FLOAT16: + return data::DataType::FLOAT16; + case DataType::FLOAT32: + return data::DataType::FLOAT32; + default: + return data::DataType::UNKNOWN; + } +} + +data::MemoryType ToFB(MemoryType type) { + switch (type) { + case MemoryType::CONSTANT: + return data::MemoryType::CONSTANT; + case MemoryType::GLOBAL: + return data::MemoryType::GLOBAL; + case MemoryType::LOCAL: + return data::MemoryType::LOCAL; + } +} + +data::LinearStorageType ToFB(LinearStorageType type) { + switch (type) { + case LinearStorageType::BUFFER: + return data::LinearStorageType::BUFFER; + case LinearStorageType::TEXTURE_2D: + return data::LinearStorageType::TEXTURE_2D; + } +} + +data::TensorStorageType ToFB(TensorStorageType type) { + switch (type) { + case TensorStorageType::BUFFER: + return data::TensorStorageType::BUFFER; + case TensorStorageType::IMAGE_BUFFER: + return data::TensorStorageType::IMAGE_BUFFER; + case TensorStorageType::TEXTURE_2D: + return data::TensorStorageType::TEXTURE_2D; + case TensorStorageType::TEXTURE_ARRAY: + return data::TensorStorageType::TEXTURE_ARRAY; + case TensorStorageType::TEXTURE_3D: + return data::TensorStorageType::TEXTURE_3D; + case TensorStorageType::SINGLE_TEXTURE_2D: + return data::TensorStorageType::SINGLE_TEXTURE_2D; + case TensorStorageType::UNKNOWN: + return data::TensorStorageType::UNKNOWN; + } +} + +data::Layout ToFB(Layout type) { + switch (type) { + case Layout::HWC: + return data::Layout::HWC; + case Layout::BHWC: + return data::Layout::BHWC; + case Layout::HWDC: + return data::Layout::HWDC; + case Layout::BHWDC: + return data::Layout::BHWDC; + default: + return data::Layout::UNKNOWN; + } +} + +data::CalculationsPrecision ToFB(CalculationsPrecision type) { + switch (type) { + case CalculationsPrecision::F32: + return data::CalculationsPrecision::F32; + case CalculationsPrecision::F32_F16: + return data::CalculationsPrecision::F32_F16; + case CalculationsPrecision::F16: + return data::CalculationsPrecision::F16; + } +} + +data::TensorToGrid ToFB(TensorToGrid type) { + switch (type) { + case TensorToGrid::kCustom: + return data::TensorToGrid::CUSTOM; + case TensorToGrid::kWBToX_HDToY_SToZ: + return data::TensorToGrid::WB_TO_X_HD_TO_Y_S_TO_Z; + case TensorToGrid::kWBToX_HDToY_ZIs1: + return data::TensorToGrid::WB_TO_X_HD_TO_Y_Z_IS_1; + case TensorToGrid::kWBToX_HToY_DToZ: + return data::TensorToGrid::WB_TO_X_H_TO_Y_D_TO_Z; + case TensorToGrid::kBToX_YIs1_ZIs1: + return data::TensorToGrid::B_TO_X_Y_IS_1_Z_IS_1; + } +} + +data::CompilerOptions ToFB(CompilerOptions type) { + switch (type) { + case CompilerOptions::ADRENO_FULL_SIMD_LINE: + return data::CompilerOptions::ADRENO_FULL_SIMD_LINE; + case CompilerOptions::ADRENO_MORE_WAVES: + return data::CompilerOptions::ADRENO_MORE_WAVES; + case CompilerOptions::POWERVR_FP16: + return data::CompilerOptions::POWERVR_FP16; + case CompilerOptions::CL_OPT_DISABLE: + return data::CompilerOptions::CL_OPT_DISABLE; + case CompilerOptions::CL_2_0: + return data::CompilerOptions::CL_2_0; + case CompilerOptions::CL_3_0: + return data::CompilerOptions::CL_3_0; + } +} + +DataType ToEnum(data::DataType type) { + switch (type) { + case data::DataType::FLOAT16: + return DataType::FLOAT16; + case data::DataType::FLOAT32: + return DataType::FLOAT32; + default: + return DataType::UNKNOWN; + } +} + +AccessType ToEnum(data::AccessType type) { + switch (type) { + case data::AccessType::READ: + return AccessType::READ; + case data::AccessType::WRITE: + return AccessType::WRITE; + case data::AccessType::READ_WRITE: + return AccessType::READ_WRITE; + } +} + +MemoryType ToEnum(data::MemoryType type) { + switch (type) { + case data::MemoryType::CONSTANT: + return MemoryType::CONSTANT; + case data::MemoryType::GLOBAL: + return MemoryType::GLOBAL; + case data::MemoryType::LOCAL: + return MemoryType::LOCAL; + } +} + +LinearStorageType ToEnum(data::LinearStorageType type) { + switch (type) { + case data::LinearStorageType::BUFFER: + return LinearStorageType::BUFFER; + case data::LinearStorageType::TEXTURE_2D: + return LinearStorageType::TEXTURE_2D; + } +} + +TensorStorageType ToEnum(data::TensorStorageType type) { + switch (type) { + case data::TensorStorageType::BUFFER: + return TensorStorageType::BUFFER; + case data::TensorStorageType::IMAGE_BUFFER: + return TensorStorageType::IMAGE_BUFFER; + case data::TensorStorageType::TEXTURE_2D: + return TensorStorageType::TEXTURE_2D; + case data::TensorStorageType::TEXTURE_ARRAY: + return TensorStorageType::TEXTURE_ARRAY; + case data::TensorStorageType::TEXTURE_3D: + return TensorStorageType::TEXTURE_3D; + case data::TensorStorageType::SINGLE_TEXTURE_2D: + return TensorStorageType::SINGLE_TEXTURE_2D; + case data::TensorStorageType::UNKNOWN: + return TensorStorageType::UNKNOWN; + } +} + +Layout ToEnum(data::Layout type) { + switch (type) { + case data::Layout::HWC: + return Layout::HWC; + case data::Layout::BHWC: + return Layout::BHWC; + case data::Layout::HWDC: + return Layout::HWDC; + case data::Layout::BHWDC: + return Layout::BHWDC; + default: + return Layout::UNKNOWN; + } +} + +CalculationsPrecision ToEnum(data::CalculationsPrecision type) { + switch (type) { + case data::CalculationsPrecision::F32: + return CalculationsPrecision::F32; + case data::CalculationsPrecision::F32_F16: + return CalculationsPrecision::F32_F16; + case data::CalculationsPrecision::F16: + return CalculationsPrecision::F16; + } +} + +TensorToGrid ToEnum(data::TensorToGrid type) { + switch (type) { + case data::TensorToGrid::CUSTOM: + return TensorToGrid::kCustom; + case data::TensorToGrid::WB_TO_X_HD_TO_Y_S_TO_Z: + return TensorToGrid::kWBToX_HDToY_SToZ; + case data::TensorToGrid::WB_TO_X_HD_TO_Y_Z_IS_1: + return TensorToGrid::kWBToX_HDToY_ZIs1; + case data::TensorToGrid::WB_TO_X_H_TO_Y_D_TO_Z: + return TensorToGrid::kWBToX_HToY_DToZ; + case data::TensorToGrid::B_TO_X_Y_IS_1_Z_IS_1: + return TensorToGrid::kBToX_YIs1_ZIs1; + } +} + +CompilerOptions ToEnum(data::CompilerOptions type) { + switch (type) { + case data::CompilerOptions::ADRENO_FULL_SIMD_LINE: + return CompilerOptions::ADRENO_FULL_SIMD_LINE; + case data::CompilerOptions::ADRENO_MORE_WAVES: + return CompilerOptions::ADRENO_MORE_WAVES; + case data::CompilerOptions::POWERVR_FP16: + return CompilerOptions::POWERVR_FP16; + case data::CompilerOptions::CL_OPT_DISABLE: + return CompilerOptions::CL_OPT_DISABLE; + case data::CompilerOptions::CL_2_0: + return CompilerOptions::CL_2_0; + case data::CompilerOptions::CL_3_0: + return CompilerOptions::CL_3_0; + } +} + +} // namespace + +flatbuffers::Offset Encode( + const int2& v, flatbuffers::FlatBufferBuilder* builder) { + data::Int2Builder int2_builder(*builder); + int2_builder.add_x(v.x); + int2_builder.add_y(v.y); + return int2_builder.Finish(); +} + +flatbuffers::Offset Encode( + const int3& v, flatbuffers::FlatBufferBuilder* builder) { + data::Int3Builder int3_builder(*builder); + int3_builder.add_x(v.x); + int3_builder.add_y(v.y); + int3_builder.add_z(v.z); + return int3_builder.Finish(); +} + +flatbuffers::Offset Encode( + const GPUObjectDescriptor& desc, flatbuffers::FlatBufferBuilder* builder) { + std::vector> state_vars_fb; + for (auto& v0 : desc.state_vars_) { + auto key_fb = builder->CreateString(v0.first); + auto value_fb = builder->CreateString(v0.second); + data::StateVariableBuilder state_builder(*builder); + state_builder.add_key(key_fb); + state_builder.add_value(value_fb); + state_vars_fb.push_back(state_builder.Finish()); + } + auto state_vars_fb_vec = builder->CreateVector(state_vars_fb); + data::GPUObjectDescriptorBuilder obj_builder(*builder); + obj_builder.add_state_vars(state_vars_fb_vec); + obj_builder.add_access_type(ToFB(desc.access_type_)); + return obj_builder.Finish(); +} + +void Decode(const data::GPUObjectDescriptor* fb_obj, GPUObjectDescriptor* obj) { + obj->access_type_ = ToEnum(fb_obj->access_type()); + for (auto state_fb : *fb_obj->state_vars()) { + std::string key(state_fb->key()->c_str(), state_fb->key()->size()); + std::string value(state_fb->value()->c_str(), state_fb->value()->size()); + obj->state_vars_[key] = value; + } +} + +flatbuffers::Offset Encode( + const BufferDescriptor& desc, flatbuffers::FlatBufferBuilder* builder) { + auto obj_fb = + Encode(*static_cast(&desc), builder); + + std::vector> attributes_fb; + for (auto& attr : desc.attributes) { + attributes_fb.push_back(builder->CreateString(attr)); + } + auto attributes_fb_vec = builder->CreateVector(attributes_fb); + auto data_fb = builder->CreateVector(desc.data); + data::BufferDescriptorBuilder buf_builder(*builder); + buf_builder.add_base_obj(obj_fb); + buf_builder.add_element_type(ToFB(desc.element_type)); + buf_builder.add_element_size(desc.element_size); + buf_builder.add_memory_type(ToFB(desc.memory_type)); + buf_builder.add_attributes(attributes_fb_vec); + buf_builder.add_size(desc.size); + buf_builder.add_data(data_fb); + return buf_builder.Finish(); +} + +void Decode(const data::BufferDescriptor* fb_desc, BufferDescriptor* desc) { + Decode(fb_desc->base_obj(), desc); + desc->element_type = ToEnum(fb_desc->element_type()); + desc->element_size = fb_desc->element_size(); + desc->memory_type = ToEnum(fb_desc->memory_type()); + for (auto attr_fb : *fb_desc->attributes()) { + std::string attr(attr_fb->c_str(), attr_fb->size()); + desc->attributes.push_back(attr); + } + desc->size = fb_desc->size(); + desc->data = + std::vector(fb_desc->data()->data(), + fb_desc->data()->data() + fb_desc->data()->size()); +} + +flatbuffers::Offset Encode( + const Texture2DDescriptor& desc, flatbuffers::FlatBufferBuilder* builder) { + auto obj_fb = + Encode(*static_cast(&desc), builder); + + auto data_fb = builder->CreateVector(desc.data); + auto size_fb = Encode(desc.size, builder); + data::Texture2DDescriptorBuilder tex_builder(*builder); + tex_builder.add_base_obj(obj_fb); + tex_builder.add_element_type(ToFB(desc.element_type)); + tex_builder.add_normalized(desc.normalized); + tex_builder.add_normalized_type(ToFB(desc.normalized_type)); + tex_builder.add_size(size_fb); + tex_builder.add_data(data_fb); + return tex_builder.Finish(); +} + +void Decode(const data::Texture2DDescriptor* fb_desc, + Texture2DDescriptor* desc) { + Decode(fb_desc->base_obj(), desc); + desc->element_type = ToEnum(fb_desc->element_type()); + desc->normalized = fb_desc->normalized(); + desc->normalized_type = ToEnum(fb_desc->normalized_type()); + desc->size.x = fb_desc->size()->x(); + desc->size.y = fb_desc->size()->y(); + desc->data = + std::vector(fb_desc->data()->data(), + fb_desc->data()->data() + fb_desc->data()->size()); +} + +flatbuffers::Offset Encode( + const TensorLinearDescriptor& desc, + flatbuffers::FlatBufferBuilder* builder) { + auto obj_fb = + Encode(*static_cast(&desc), builder); + + auto data_fb = builder->CreateVector(desc.data); + data::TensorLinearDescriptorBuilder tensor_builder(*builder); + tensor_builder.add_base_obj(obj_fb); + tensor_builder.add_element_type(ToFB(desc.element_type)); + tensor_builder.add_storage_type(ToFB(desc.storage_type)); + tensor_builder.add_memory_type(ToFB(desc.memory_type)); + tensor_builder.add_size(desc.size); + tensor_builder.add_data(data_fb); + return tensor_builder.Finish(); +} + +void Decode(const data::TensorLinearDescriptor* fb_desc, + TensorLinearDescriptor* desc) { + Decode(fb_desc->base_obj(), desc); + desc->element_type = ToEnum(fb_desc->element_type()); + desc->storage_type = ToEnum(fb_desc->storage_type()); + desc->memory_type = ToEnum(fb_desc->memory_type()); + desc->size = fb_desc->size(); + desc->data = + std::vector(fb_desc->data()->data(), + fb_desc->data()->data() + fb_desc->data()->size()); +} + +flatbuffers::Offset Encode( + const TensorDescriptor& desc, flatbuffers::FlatBufferBuilder* builder) { + auto obj_fb = + Encode(*static_cast(&desc), builder); + + data::BHWDCBuilder shape_builder(*builder); + shape_builder.add_b(desc.shape.b); + shape_builder.add_h(desc.shape.h); + shape_builder.add_w(desc.shape.w); + shape_builder.add_d(desc.shape.d); + shape_builder.add_c(desc.shape.c); + auto shape_fb = shape_builder.Finish(); + + auto data_fb = builder->CreateVector(desc.data); + data::TensorDescriptorBuilder tensor_builder(*builder); + tensor_builder.add_base_obj(obj_fb); + tensor_builder.add_data_type(ToFB(desc.data_type)); + tensor_builder.add_storage_type(ToFB(desc.storage_type)); + tensor_builder.add_layout(ToFB(desc.layout)); + tensor_builder.add_shape(shape_fb); + tensor_builder.add_data(data_fb); + return tensor_builder.Finish(); +} + +void Decode(const data::TensorDescriptor* fb_desc, TensorDescriptor* desc) { + Decode(fb_desc->base_obj(), desc); + desc->data_type = ToEnum(fb_desc->data_type()); + desc->storage_type = ToEnum(fb_desc->storage_type()); + desc->layout = ToEnum(fb_desc->layout()); + desc->shape.b = fb_desc->shape()->b(); + desc->shape.h = fb_desc->shape()->h(); + desc->shape.w = fb_desc->shape()->w(); + desc->shape.d = fb_desc->shape()->d(); + desc->shape.c = fb_desc->shape()->c(); + desc->data = + std::vector(fb_desc->data()->data(), + fb_desc->data()->data() + fb_desc->data()->size()); +} + +flatbuffers::Offset Encode( + const OperationDef& def, flatbuffers::FlatBufferBuilder* builder) { + std::vector> src_tensors_fb; + for (auto& desc : def.src_tensors) { + auto desc_fb = Encode(desc, builder); + src_tensors_fb.push_back(desc_fb); + } + + std::vector> dst_tensors_fb; + for (auto& desc : def.dst_tensors) { + auto desc_fb = Encode(desc, builder); + dst_tensors_fb.push_back(desc_fb); + } + + auto src_tensors_fb_vec = builder->CreateVector(src_tensors_fb); + auto dst_tensors_fb_vec = builder->CreateVector(dst_tensors_fb); + + data::OperationDefBuilder def_builder(*builder); + def_builder.add_precision(ToFB(def.precision)); + def_builder.add_src_tensors(src_tensors_fb_vec); + def_builder.add_dst_tensors(dst_tensors_fb_vec); + return def_builder.Finish(); +} + +void Decode(const data::OperationDef* fb_def, OperationDef* def) { + for (auto src_fb : *fb_def->src_tensors()) { + TensorDescriptor desc; + Decode(src_fb, &desc); + def->src_tensors.push_back(std::move(desc)); + } + for (auto dst_fb : *fb_def->dst_tensors()) { + TensorDescriptor desc; + Decode(dst_fb, &desc); + def->dst_tensors.push_back(std::move(desc)); + } + def->precision = ToEnum(fb_def->precision()); +} + +flatbuffers::Offset Encode( + const TensorDescriptor& desc, const ValueId& id, + flatbuffers::FlatBufferBuilder* builder) { + auto desc_fb = Encode(desc, builder); + data::TensorDescWithIdBuilder desc_builder(*builder); + desc_builder.add_desc(desc_fb); + desc_builder.add_id(id); + return desc_builder.Finish(); +} + +void Decode(const data::TensorDescWithId* fb_desc, TensorDescriptor* desc, + ValueId* id) { + Decode(fb_desc->desc(), desc); + *id = fb_desc->id(); +} + +absl::Status Decode(CLContext* context, const data::Arguments* fb_args, + Arguments* args) { + args->shared_int4s_data_ = std::vector( + fb_args->shared_int4s()->data(), + fb_args->shared_int4s()->data() + fb_args->shared_int4s()->size()); + + args->shared_float4s_data_ = std::vector( + fb_args->shared_float4s()->data(), + fb_args->shared_float4s()->data() + fb_args->shared_float4s()->size()); + + std::vector tmp = std::vector( + fb_args->shared_half4s()->data(), + fb_args->shared_half4s()->data() + fb_args->shared_half4s()->size()); + + args->shared_half4s_data_.resize(tmp.size()); + for (int i = 0; i < tmp.size(); ++i) { + args->shared_half4s_data_[i] = tmp[i]; + } + + args->int_values_.clear(); + for (auto int_values_fb : *fb_args->int_values()) { + Arguments::IntValue value; + value.value = int_values_fb->value(); + value.offset = int_values_fb->offset(); + value.active = int_values_fb->active(); + std::string name(int_values_fb->name()->c_str(), + int_values_fb->name()->size()); + args->int_values_[name] = value; + } + + args->float_values_.clear(); + for (auto float_values_fb : *fb_args->float_values()) { + Arguments::FloatValue value; + value.value = float_values_fb->value(); + value.offset = float_values_fb->offset(); + value.active = float_values_fb->active(); + std::string name(float_values_fb->name()->c_str(), + float_values_fb->name()->size()); + args->float_values_[name] = value; + } + + args->half_values_.clear(); + for (auto half_values_fb : *fb_args->half_values()) { + Arguments::HalfValue value; + value.value = half_values_fb->value(); + value.offset = half_values_fb->offset(); + value.active = half_values_fb->active(); + value.store_as_f32 = half_values_fb->store_as_f32(); + std::string name(half_values_fb->name()->c_str(), + half_values_fb->name()->size()); + args->half_values_[name] = value; + } + + for (auto buffer_pair_fb : *fb_args->buffer_objects()) { + std::string key(buffer_pair_fb->key()->c_str(), + buffer_pair_fb->key()->size()); + BufferDescriptor desc; + Decode(buffer_pair_fb->value(), &desc); + args->AddObject(key, absl::make_unique(std::move(desc))); + } + + for (auto texture_pair_fb : *fb_args->texture2d_objects()) { + std::string key(texture_pair_fb->key()->c_str(), + texture_pair_fb->key()->size()); + Texture2DDescriptor desc; + Decode(texture_pair_fb->value(), &desc); + args->AddObject(key, + absl::make_unique(std::move(desc))); + } + + for (auto tensor_pair_fb : *fb_args->tensor_linear_objects()) { + std::string key(tensor_pair_fb->key()->c_str(), + tensor_pair_fb->key()->size()); + TensorLinearDescriptor desc; + Decode(tensor_pair_fb->value(), &desc); + args->AddObject(key, + absl::make_unique(std::move(desc))); + } + + for (auto tensor_pair_fb : *fb_args->tensor_objects()) { + std::string key(tensor_pair_fb->key()->c_str(), + tensor_pair_fb->key()->size()); + TensorDescriptor desc; + Decode(tensor_pair_fb->value(), &desc); + args->AddObject(key, absl::make_unique(std::move(desc))); + } + + for (auto buffer_pair_fb : *fb_args->buffer_refs()) { + std::string key(buffer_pair_fb->key()->c_str(), + buffer_pair_fb->key()->size()); + BufferDescriptor desc; + Decode(buffer_pair_fb->value(), &desc); + auto access_type = desc.GetAccess(); + args->AddObjectRef(key, access_type, + absl::make_unique(std::move(desc))); + } + + for (auto texture_pair_fb : *fb_args->texture2d_refs()) { + std::string key(texture_pair_fb->key()->c_str(), + texture_pair_fb->key()->size()); + Texture2DDescriptor desc; + Decode(texture_pair_fb->value(), &desc); + auto access_type = desc.GetAccess(); + args->AddObjectRef(key, access_type, + absl::make_unique(std::move(desc))); + } + + for (auto tensor_pair_fb : *fb_args->tensor_linear_refs()) { + std::string key(tensor_pair_fb->key()->c_str(), + tensor_pair_fb->key()->size()); + TensorLinearDescriptor desc; + Decode(tensor_pair_fb->value(), &desc); + auto access_type = desc.GetAccess(); + args->AddObjectRef( + key, access_type, + absl::make_unique(std::move(desc))); + } + + for (auto tensor_pair_fb : *fb_args->tensor_refs()) { + std::string key(tensor_pair_fb->key()->c_str(), + tensor_pair_fb->key()->size()); + TensorDescriptor desc; + Decode(tensor_pair_fb->value(), &desc); + auto access_type = desc.GetAccess(); + args->AddObjectRef(key, access_type, + absl::make_unique(std::move(desc))); + } + + RETURN_IF_ERROR(args->AllocateObjects(context)); + RETURN_IF_ERROR(args->AddObjectArgs()); + return absl::OkStatus(); +} + +flatbuffers::Offset Encode( + const Arguments& args, flatbuffers::FlatBufferBuilder* builder) { + std::vector> int_values_fb; + for (auto& value : args.int_values_) { + auto name_fb = builder->CreateString(value.first); + data::IntValueBuilder value_builder(*builder); + value_builder.add_name(name_fb); + value_builder.add_value(value.second.value); + value_builder.add_offset(value.second.offset); + value_builder.add_active(value.second.active); + int_values_fb.push_back(value_builder.Finish()); + } + + std::vector> float_values_fb; + for (auto& value : args.float_values_) { + auto name_fb = builder->CreateString(value.first); + data::FloatValueBuilder value_builder(*builder); + value_builder.add_name(name_fb); + value_builder.add_value(value.second.value); + value_builder.add_offset(value.second.offset); + value_builder.add_active(value.second.active); + float_values_fb.push_back(value_builder.Finish()); + } + + std::vector> half_values_fb; + for (auto& value : args.half_values_) { + auto name_fb = builder->CreateString(value.first); + data::HalfValueBuilder value_builder(*builder); + value_builder.add_name(name_fb); + value_builder.add_value(value.second.value); + value_builder.add_offset(value.second.offset); + value_builder.add_active(value.second.active); + value_builder.add_store_as_f32(value.second.store_as_f32); + half_values_fb.push_back(value_builder.Finish()); + } + + std::vector> + buffer_objs_fb; + for (auto& value : args.objects_) { + const auto* buffer_desc = + dynamic_cast(value.second.descriptor.get()); + if (!buffer_desc) continue; + auto desc_fb = Encode(*buffer_desc, builder); + auto key_fb = builder->CreateString(value.first); + data::BufferDescriptorMapValueBuilder buf_map_builder(*builder); + buf_map_builder.add_key(key_fb); + buf_map_builder.add_value(desc_fb); + buffer_objs_fb.push_back(buf_map_builder.Finish()); + } + std::vector> + texture2d_objs_fb; + for (auto& value : args.objects_) { + const auto* texture_desc = + dynamic_cast(value.second.descriptor.get()); + if (!texture_desc) continue; + auto desc_fb = Encode(*texture_desc, builder); + auto key_fb = builder->CreateString(value.first); + data::Texture2DDescriptorMapValueBuilder tex_map_builder(*builder); + tex_map_builder.add_key(key_fb); + tex_map_builder.add_value(desc_fb); + texture2d_objs_fb.push_back(tex_map_builder.Finish()); + } + std::vector> + tensor_linear_objs_fb; + for (auto& value : args.objects_) { + const auto* tensor_desc = dynamic_cast( + value.second.descriptor.get()); + if (!tensor_desc) continue; + auto desc_fb = Encode(*tensor_desc, builder); + auto key_fb = builder->CreateString(value.first); + data::TensorLinearDescriptorMapValueBuilder ten_map_builder(*builder); + ten_map_builder.add_key(key_fb); + ten_map_builder.add_value(desc_fb); + tensor_linear_objs_fb.push_back(ten_map_builder.Finish()); + } + std::vector> + tensor_objs_fb; + for (auto& value : args.objects_) { + const auto* tensor_desc = + dynamic_cast(value.second.descriptor.get()); + if (!tensor_desc) continue; + auto desc_fb = Encode(*tensor_desc, builder); + auto key_fb = builder->CreateString(value.first); + data::TensorDescriptorMapValueBuilder ten_map_builder(*builder); + ten_map_builder.add_key(key_fb); + ten_map_builder.add_value(desc_fb); + tensor_objs_fb.push_back(ten_map_builder.Finish()); + } + + std::vector> + buffer_refs_fb; + for (auto& value : args.object_refs_) { + const auto* buffer_desc = + dynamic_cast(value.second.descriptor.get()); + if (!buffer_desc) continue; + auto desc_fb = Encode(*buffer_desc, builder); + auto key_fb = builder->CreateString(value.first); + data::BufferDescriptorMapValueBuilder buf_map_builder(*builder); + buf_map_builder.add_key(key_fb); + buf_map_builder.add_value(desc_fb); + buffer_refs_fb.push_back(buf_map_builder.Finish()); + } + std::vector> + texture2d_refs_fb; + for (auto& value : args.object_refs_) { + const auto* texture_desc = + dynamic_cast(value.second.descriptor.get()); + if (!texture_desc) continue; + auto desc_fb = Encode(*texture_desc, builder); + auto key_fb = builder->CreateString(value.first); + data::Texture2DDescriptorMapValueBuilder tex_map_builder(*builder); + tex_map_builder.add_key(key_fb); + tex_map_builder.add_value(desc_fb); + texture2d_refs_fb.push_back(tex_map_builder.Finish()); + } + std::vector> + tensor_linear_refs_fb; + for (auto& value : args.object_refs_) { + const auto* tensor_desc = dynamic_cast( + value.second.descriptor.get()); + if (!tensor_desc) continue; + auto desc_fb = Encode(*tensor_desc, builder); + auto key_fb = builder->CreateString(value.first); + data::TensorLinearDescriptorMapValueBuilder ten_map_builder(*builder); + ten_map_builder.add_key(key_fb); + ten_map_builder.add_value(desc_fb); + tensor_linear_refs_fb.push_back(ten_map_builder.Finish()); + } + std::vector> + tensor_refs_fb; + for (auto& value : args.object_refs_) { + const auto* tensor_desc = + dynamic_cast(value.second.descriptor.get()); + if (!tensor_desc) continue; + auto desc_fb = Encode(*tensor_desc, builder); + auto key_fb = builder->CreateString(value.first); + data::TensorDescriptorMapValueBuilder ten_map_builder(*builder); + ten_map_builder.add_key(key_fb); + ten_map_builder.add_value(desc_fb); + tensor_refs_fb.push_back(ten_map_builder.Finish()); + } + + auto shared_int4s_data_fb = builder->CreateVector(args.shared_int4s_data_); + auto shared_float4s_data_fb = + builder->CreateVector(args.shared_float4s_data_); + std::vector tmp(args.shared_half4s_data_.size()); + for (int i = 0; i < tmp.size(); ++i) { + tmp[i] = args.shared_half4s_data_[i]; + } + auto shared_half4s_data_fb = builder->CreateVector(tmp); + auto int_values_fb_vec = builder->CreateVector(int_values_fb); + auto float_values_fb_vec = builder->CreateVector(float_values_fb); + auto half_values_fb_vec = builder->CreateVector(half_values_fb); + auto buffer_objs_fb_vec = builder->CreateVector(buffer_objs_fb); + auto texture2d_objs_fb_vec = builder->CreateVector(texture2d_objs_fb); + auto tensor_linear_objs_fb_vec = builder->CreateVector(tensor_linear_objs_fb); + auto tensor_objs_fb_vec = builder->CreateVector(tensor_objs_fb); + auto buffer_refs_fb_vec = builder->CreateVector(buffer_refs_fb); + auto texture2d_refs_fb_vec = builder->CreateVector(texture2d_refs_fb); + auto tensor_linear_refs_fb_vec = builder->CreateVector(tensor_linear_refs_fb); + auto tensor_refs_fb_vec = builder->CreateVector(tensor_refs_fb); + data::ArgumentsBuilder arguments_builder(*builder); + arguments_builder.add_shared_int4s(shared_int4s_data_fb); + arguments_builder.add_shared_float4s(shared_float4s_data_fb); + arguments_builder.add_shared_half4s(shared_half4s_data_fb); + arguments_builder.add_int_values(int_values_fb_vec); + arguments_builder.add_float_values(float_values_fb_vec); + arguments_builder.add_half_values(half_values_fb_vec); + arguments_builder.add_buffer_objects(buffer_objs_fb_vec); + arguments_builder.add_texture2d_objects(texture2d_objs_fb_vec); + arguments_builder.add_tensor_linear_objects(tensor_linear_objs_fb_vec); + arguments_builder.add_tensor_objects(tensor_objs_fb_vec); + arguments_builder.add_buffer_refs(buffer_refs_fb_vec); + arguments_builder.add_texture2d_refs(texture2d_refs_fb_vec); + arguments_builder.add_tensor_linear_refs(tensor_linear_refs_fb_vec); + arguments_builder.add_tensor_refs(tensor_refs_fb_vec); + return arguments_builder.Finish(); +} + +absl::Status Decode(CLContext* context, const data::GPUOperation* fb_op, + GPUOperation* op) { + RETURN_IF_ERROR(Decode(context, fb_op->arguments(), &op->args_)); + op->code_ = std::string(fb_op->code()->c_str(), fb_op->code()->size()); + op->work_group_size_.x = fb_op->work_group_size()->x(); + op->work_group_size_.y = fb_op->work_group_size()->y(); + op->work_group_size_.z = fb_op->work_group_size()->z(); + for (auto option_fb : *fb_op->compiler_options()) { + op->compiler_options_.push_back(ToEnum(option_fb->option())); + } + op->tensor_to_grid_ = ToEnum(fb_op->tensor_to_grid()); + op->elementwise_ = fb_op->elementwise(); + op->linkable_ = fb_op->linkable(); + op->check_src_channels_size_ = fb_op->check_src_channels_size(); + Decode(fb_op->definition(), &op->definition_); + op->grid_dimension_ = fb_op->grid_dimension(); + op->work_group_launch_order_.x = fb_op->work_group_launch_order()->x(); + op->work_group_launch_order_.y = fb_op->work_group_launch_order()->y(); + op->work_group_launch_order_.z = fb_op->work_group_launch_order()->z(); + op->grid_size_.x = fb_op->grid_size()->x(); + op->grid_size_.y = fb_op->grid_size()->y(); + op->grid_size_.z = fb_op->grid_size()->z(); + for (auto name_fb : *fb_op->src_tensors_names()) { + std::string name(name_fb->c_str(), name_fb->size()); + op->src_tensors_names_.push_back(std::move(name)); + } + for (auto name_fb : *fb_op->dst_tensors_names()) { + std::string name(name_fb->c_str(), name_fb->size()); + op->dst_tensors_names_.push_back(std::move(name)); + } + op->work_groups_count_.x = fb_op->work_groups_count()->x(); + op->work_groups_count_.y = fb_op->work_groups_count()->y(); + op->work_groups_count_.z = fb_op->work_groups_count()->z(); + op->linkable_count_ = fb_op->linkable_count(); + op->elementwise_code_ = std::string(fb_op->elementwise_code()->c_str(), + fb_op->elementwise_code()->size()); + return absl::OkStatus(); +} + +flatbuffers::Offset Encode( + const GPUOperation& op, flatbuffers::FlatBufferBuilder* builder) { + auto args_fb = Encode(op.args_, builder); + auto code_fb = builder->CreateString(op.code_); + auto work_group_size_fb = Encode(op.work_group_size_, builder); + std::vector> compiler_options_fb; + for (int i = 0; i < op.compiler_options_.size(); ++i) { + data::CompilerOptionBuilder option_builder(*builder); + option_builder.add_option(ToFB(op.compiler_options_[i])); + compiler_options_fb.push_back(option_builder.Finish()); + } + auto compiler_options_fb_vec = builder->CreateVector(compiler_options_fb); + + auto def_fb = Encode(op.definition_, builder); + auto work_group_launch_order_fb = + Encode(op.work_group_launch_order_, builder); + auto grid_size_fb = Encode(op.grid_size_, builder); + auto work_groups_count_fb = Encode(op.work_groups_count_, builder); + + std::vector> src_names_fb; + for (auto& name : op.src_tensors_names_) { + src_names_fb.push_back(builder->CreateString(name)); + } + auto src_names_fb_vec = builder->CreateVector(src_names_fb); + + std::vector> dst_names_fb; + for (auto& name : op.dst_tensors_names_) { + dst_names_fb.push_back(builder->CreateString(name)); + } + auto dst_names_fb_vec = builder->CreateVector(dst_names_fb); + + auto elementwise_code_fb = builder->CreateString(op.elementwise_code_); + + data::GPUOperationBuilder op_builder(*builder); + op_builder.add_arguments(args_fb); + op_builder.add_code(code_fb); + op_builder.add_work_group_size(work_group_size_fb); + op_builder.add_compiler_options(compiler_options_fb_vec); + op_builder.add_tensor_to_grid(ToFB(op.tensor_to_grid_)); + op_builder.add_elementwise(op.elementwise_); + op_builder.add_linkable(op.linkable_); + op_builder.add_check_src_channels_size(op.check_src_channels_size_); + op_builder.add_definition(def_fb); + op_builder.add_grid_dimension(op.grid_dimension_); + op_builder.add_work_group_launch_order(work_group_launch_order_fb); + op_builder.add_grid_size(grid_size_fb); + op_builder.add_src_tensors_names(src_names_fb_vec); + op_builder.add_dst_tensors_names(dst_names_fb_vec); + op_builder.add_work_groups_count(work_groups_count_fb); + op_builder.add_linkable_count(op.linkable_count_); + op_builder.add_elementwise_code(elementwise_code_fb); + return op_builder.Finish(); +} + +flatbuffers::Offset Encode( + const CLNode& node, flatbuffers::FlatBufferBuilder* builder) { + auto op_fb = Encode(*node.operation, builder); + std::vector in_ids(node.inputs.size()); + for (int i = 0; i < in_ids.size(); ++i) { + in_ids[i] = node.inputs[i]; + } + std::vector out_ids(node.outputs.size()); + for (int i = 0; i < out_ids.size(); ++i) { + out_ids[i] = node.outputs[i]; + } + auto in_ids_fb = builder->CreateVector(in_ids); + auto out_ids_fb = builder->CreateVector(out_ids); + auto name_fb = builder->CreateString(node.name); + data::CLNodeBuilder node_builder(*builder); + node_builder.add_gpu_op(op_fb); + node_builder.add_input_ids(in_ids_fb); + node_builder.add_output_ids(out_ids_fb); + node_builder.add_name(name_fb); + return node_builder.Finish(); +} + +absl::Status Decode(CLContext* context, const data::CLNode* fb_node, + CLNode* node) { + GPUOperation op; + RETURN_IF_ERROR(Decode(context, fb_node->gpu_op(), &op)); + node->operation = absl::make_unique(std::move(op)); + for (auto in_fb : *fb_node->input_ids()) { + node->inputs.push_back(in_fb); + } + for (auto out_fb : *fb_node->output_ids()) { + node->outputs.push_back(out_fb); + } + node->name = std::string(fb_node->name()->c_str(), fb_node->name()->size()); + + return absl::OkStatus(); +} + +flatbuffers::Offset Encode( + const InferenceContext& inference, + flatbuffers::FlatBufferBuilder* builder) { + std::vector in_ids(inference.input_ids_.size()); + for (int i = 0; i < in_ids.size(); ++i) { + in_ids[i] = inference.input_ids_[i]; + } + std::vector out_ids(inference.output_ids_.size()); + for (int i = 0; i < out_ids.size(); ++i) { + out_ids[i] = inference.output_ids_[i]; + } + auto in_ids_fb = builder->CreateVector(in_ids); + auto out_ids_fb = builder->CreateVector(out_ids); + + std::vector> nodes_fb; + for (int i = 0; i < inference.nodes_.size(); ++i) { + auto node_fb = Encode(inference.nodes_[i], builder); + nodes_fb.push_back(node_fb); + } + auto nodes_fb_vec = builder->CreateVector(nodes_fb); + + std::vector> tensors_fb; + auto tensors = inference.tensor_reserver_.GetTensorDescs(); + for (auto& tensor : tensors) { + auto tensor_fb = Encode(tensor.second, tensor.first, builder); + tensors_fb.push_back(tensor_fb); + } + auto tensors_fb_vec = builder->CreateVector(tensors_fb); + + std::vector> + variable_ids_and_refs_fb; + for (auto& pair : inference.variable_ids_and_refs_) { + data::PairOfValueIdsBuilder pair_builder(*builder); + pair_builder.add_first(pair.first); + pair_builder.add_second(pair.second); + variable_ids_and_refs_fb.push_back(pair_builder.Finish()); + } + auto variable_ids_and_refs_fb_vec = + builder->CreateVector(variable_ids_and_refs_fb); + + data::InferenceContextBuilder inf_builder(*builder); + inf_builder.add_need_flush(inference.need_flush_); + inf_builder.add_flush_periodically(inference.flush_periodically_); + inf_builder.add_flush_period(inference.flush_period_); + inf_builder.add_need_manual_release(inference.need_manual_release_); + inf_builder.add_precision(ToFB(inference.precision_)); + inf_builder.add_storage_type(ToFB(inference.storage_type_)); + inf_builder.add_nodes(nodes_fb_vec); + inf_builder.add_tensors(tensors_fb_vec); + inf_builder.add_input_ids(in_ids_fb); + inf_builder.add_output_ids(out_ids_fb); + inf_builder.add_variable_ids_and_refs(variable_ids_and_refs_fb_vec); + return inf_builder.Finish(); +} + +absl::Status Decode(CLContext* context, + const data::InferenceContext* fb_inference, + InferenceContext* inference) { + inference->need_flush_ = fb_inference->need_flush(); + inference->flush_periodically_ = fb_inference->flush_periodically(); + inference->flush_period_ = fb_inference->flush_period(); + inference->need_manual_release_ = fb_inference->need_manual_release(); + inference->precision_ = ToEnum(fb_inference->precision()); + inference->storage_type_ = ToEnum(fb_inference->storage_type()); + + inference->nodes_.resize(fb_inference->nodes()->size()); + int counter = 0; + for (auto node_fb : *fb_inference->nodes()) { + RETURN_IF_ERROR(Decode(context, node_fb, &inference->nodes_[counter])); + counter++; + } + + std::vector> tensors; + for (auto tensor_fb : *fb_inference->tensors()) { + TensorDescriptor desc; + Decode(tensor_fb->desc(), &desc); + tensors.push_back({tensor_fb->id(), std::move(desc)}); + } + inference->tensor_reserver_.Add(tensors); + for (auto in_fb : *fb_inference->input_ids()) { + inference->input_ids_.push_back(in_fb); + } + for (auto out_fb : *fb_inference->output_ids()) { + inference->output_ids_.push_back(out_fb); + } + + for (auto variable_id : *fb_inference->variable_ids_and_refs()) { + inference->variable_ids_and_refs_[variable_id->first()] = + variable_id->second(); + } + return absl::OkStatus(); +} + +} // namespace cl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/cl/serialization.fbs b/tensorflow/lite/delegates/gpu/cl/serialization.fbs new file mode 100644 index 00000000000..0c0d2241b5a --- /dev/null +++ b/tensorflow/lite/delegates/gpu/cl/serialization.fbs @@ -0,0 +1,278 @@ +// Copyright 2020 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +namespace tflite.gpu.cl.data; + +table Int4 { + x:int32; + y:int32; + z:int32; + w:int32; +} + +table Int3 { + x:int32; + y:int32; + z:int32; +} + +table Int2 { + x:int32; + y:int32; +} + +table IntValue { + name:string; + value:int32; + active:bool; + offset:uint32; +} + +table FloatValue { + name:string; + value:float; + active:bool; + offset:uint32; +} + +table HalfValue { + name:string; + value:float; + active:bool; + store_as_f32:bool; + offset:uint32; +} + +enum AccessType : byte { + READ = 0, + WRITE = 1, + READ_WRITE = 2, +} + +enum DataType : byte { + UNKNOWN = 0, + FLOAT32 = 1, + FLOAT16 = 2, +} + +enum MemoryType : byte { + GLOBAL = 0, + CONSTANT = 1, + LOCAL = 2, +} + +table StateVariable { + key:string; + value:string; +} + +table GPUObjectDescriptor { + state_vars:[StateVariable]; + access_type:AccessType; +} + +table BufferDescriptor { + base_obj:GPUObjectDescriptor; + element_type:DataType; + element_size:int32; + memory_type:MemoryType; + attributes:[string]; + size:int32; + data:[uint8]; +} + +table Texture2DDescriptor { + base_obj:GPUObjectDescriptor; + element_type:DataType; + normalized:bool; + normalized_type:DataType; + size:Int2; + data:[uint8]; +} + +enum LinearStorageType : byte { + BUFFER = 0, + TEXTURE_2D = 1, +} + +table TensorLinearDescriptor { + base_obj:GPUObjectDescriptor; + storage_type:LinearStorageType; + element_type:DataType; + memory_type:MemoryType; + size:int32; + data:[uint8]; +} + +enum TensorStorageType : byte { + UNKNOWN = 0, + BUFFER = 1, + IMAGE_BUFFER = 2, + TEXTURE_2D = 3, + TEXTURE_3D = 4, + TEXTURE_ARRAY = 5, + SINGLE_TEXTURE_2D = 6, +} + +enum Layout : byte { + UNKNOWN = 0, + HWC = 1, + BHWC = 2, + HWDC = 3, + BHWDC = 4, +} + +table BHWDC { + b:int32; + h:int32; + w:int32; + d:int32; + c:int32; +} + +table TensorDescriptor { + base_obj:GPUObjectDescriptor; + data_type:DataType; + storage_type:TensorStorageType; + layout:Layout; + shape:BHWDC; + data:[uint8]; +} + +table BufferDescriptorMapValue { + key:string; + value:BufferDescriptor; +} + +table Texture2DDescriptorMapValue { + key:string; + value:Texture2DDescriptor; +} + +table TensorLinearDescriptorMapValue { + key:string; + value:TensorLinearDescriptor; +} + +table TensorDescriptorMapValue { + key:string; + value:TensorDescriptor; +} + +table Arguments { + int_values:[IntValue]; + shared_int4s:[int32]; + + float_values:[FloatValue]; + shared_float4s:[float]; + + half_values:[HalfValue]; + shared_half4s:[float]; + + buffer_refs:[BufferDescriptorMapValue]; + texture2d_refs:[Texture2DDescriptorMapValue]; + tensor_linear_refs:[TensorLinearDescriptorMapValue]; + tensor_refs:[TensorDescriptorMapValue]; + + buffer_objects:[BufferDescriptorMapValue]; + texture2d_objects:[Texture2DDescriptorMapValue]; + tensor_linear_objects:[TensorLinearDescriptorMapValue]; + tensor_objects:[TensorDescriptorMapValue]; +} + +enum CalculationsPrecision : byte { + F32 = 0, + F32_F16 = 1, + F16 = 2, +} + +enum TensorToGrid : byte { + CUSTOM = 0, + WB_TO_X_HD_TO_Y_S_TO_Z = 1, + WB_TO_X_HD_TO_Y_Z_IS_1 = 2, + WB_TO_X_H_TO_Y_D_TO_Z = 3, + B_TO_X_Y_IS_1_Z_IS_1 = 4, +} + +enum CompilerOptions : byte { + ADRENO_FULL_SIMD_LINE = 0, + ADRENO_MORE_WAVES = 1, + POWERVR_FP16 = 2, + CL_OPT_DISABLE = 3, + CL_2_0 = 4, + CL_3_0 = 5, +} + +table OperationDef { + precision:CalculationsPrecision; + src_tensors:[TensorDescriptor]; + dst_tensors:[TensorDescriptor]; +} + +table CompilerOption { + option:CompilerOptions; +} + +table GPUOperation { + arguments:Arguments; + code:string; + work_group_size:Int3; + compiler_options:[CompilerOption]; + tensor_to_grid:TensorToGrid; + elementwise:bool; + linkable:bool; + check_src_channels_size:bool; + definition:OperationDef; + grid_dimension:int32; + work_group_launch_order:Int3; + grid_size:Int3; + src_tensors_names:[string]; + dst_tensors_names:[string]; + work_groups_count:Int3; + linkable_count:int32; + elementwise_code:string; +} + +table TensorDescWithId { + desc:TensorDescriptor; + id:int32; +} + +table CLNode { + gpu_op:GPUOperation; + input_ids:[int32]; + output_ids:[int32]; + name:string; +} + +table PairOfValueIds { + first:int32; + second:int32; +} + +table InferenceContext { + need_flush:bool; + flush_periodically:bool; + flush_period:int32; + need_manual_release:bool; + precision:CalculationsPrecision; + storage_type:TensorStorageType; + nodes:[CLNode]; + tensors:[TensorDescWithId]; + input_ids:[int32]; + variable_ids_and_refs:[PairOfValueIds]; + output_ids:[int32]; +} + +root_type InferenceContext; diff --git a/tensorflow/lite/delegates/gpu/cl/serialization.h b/tensorflow/lite/delegates/gpu/cl/serialization.h new file mode 100644 index 00000000000..1273e62a100 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/cl/serialization.h @@ -0,0 +1,42 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_SERIALIZATION_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_CL_SERIALIZATION_H_ + +#include "absl/types/span.h" +#include "tensorflow/lite/delegates/gpu/cl/cl_context.h" +#include "tensorflow/lite/delegates/gpu/cl/inference_context.h" +#include "tensorflow/lite/delegates/gpu/cl/serialization_generated.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" + +namespace tflite { +namespace gpu { +namespace cl { + +class InferenceContext; + +flatbuffers::Offset Encode( + const InferenceContext& inference, flatbuffers::FlatBufferBuilder* builder); + +absl::Status Decode(CLContext* context, + const data::InferenceContext* fb_inference, + InferenceContext* inference); + +} // namespace cl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_CL_SERIALIZATION_H_ diff --git a/tensorflow/lite/delegates/gpu/cl/tensor.cc b/tensorflow/lite/delegates/gpu/cl/tensor.cc index 72c53c5b1ac..c35554b875b 100644 --- a/tensorflow/lite/delegates/gpu/cl/tensor.cc +++ b/tensorflow/lite/delegates/gpu/cl/tensor.cc @@ -605,8 +605,11 @@ absl::Status Tensor::CreateFromDescriptor(const TensorDescriptor& desc, descriptor_.layout = desc.layout; memory_owner_ = true; CLMemory memory; - RETURN_IF_ERROR(AllocateTensorMemory(*context, shape_, descriptor_, - desc.data.data(), &memory)); + uint8_t* data_ptr = desc.data.empty() + ? nullptr + : const_cast(desc.data.data()); + RETURN_IF_ERROR( + AllocateTensorMemory(*context, shape_, descriptor_, data_ptr, &memory)); memory_ = memory.Release(); if (desc.storage_type == TensorStorageType::IMAGE_BUFFER) { RETURN_IF_ERROR(CreateImageBufferFromBuffer( diff --git a/tensorflow/lite/delegates/gpu/cl/tensor_type.cc b/tensorflow/lite/delegates/gpu/cl/tensor_type.cc index 7bd5de6e31e..f31df43539e 100644 --- a/tensorflow/lite/delegates/gpu/cl/tensor_type.cc +++ b/tensorflow/lite/delegates/gpu/cl/tensor_type.cc @@ -771,6 +771,46 @@ void TensorDescriptor::UploadData(absl::Span src) { } } +bool TensorDescriptor::SupportsZeroClamp(const Axis& axis) const { + switch (storage_type) { + case TensorStorageType::UNKNOWN: + return false; + case TensorStorageType::BUFFER: + case TensorStorageType::IMAGE_BUFFER: + return false; + case TensorStorageType::TEXTURE_ARRAY: + case TensorStorageType::TEXTURE_2D: + case TensorStorageType::SINGLE_TEXTURE_2D: + return axis == Axis::WIDTH || axis == Axis::HEIGHT; + case TensorStorageType::TEXTURE_3D: + return axis == Axis::WIDTH || axis == Axis::HEIGHT || axis == Axis::DEPTH; + } +} + +bool TensorDescriptor::CanReadOutOfBorder(const Axis& axis) const { + switch (storage_type) { + case TensorStorageType::UNKNOWN: + return false; + case TensorStorageType::BUFFER: + return false; + case TensorStorageType::IMAGE_BUFFER: + case TensorStorageType::TEXTURE_2D: + case TensorStorageType::TEXTURE_3D: + case TensorStorageType::SINGLE_TEXTURE_2D: + case TensorStorageType::TEXTURE_ARRAY: + return true; + } +} + +bool TensorDescriptor::IsLinear() const { + return storage_type == TensorStorageType::BUFFER || + storage_type == TensorStorageType::IMAGE_BUFFER; +} + +bool TensorDescriptor::ReturnsZeroForNegOneRead() const { + return storage_type == TensorStorageType::IMAGE_BUFFER; +} + namespace { int GetLinearIndex(const TensorDescriptor& desc, const BHWDC& shape, int b, int x, int y, int d, int s, int sub_c) { diff --git a/tensorflow/lite/delegates/gpu/cl/tensor_type.h b/tensorflow/lite/delegates/gpu/cl/tensor_type.h index 094e3905966..2157bf05543 100644 --- a/tensorflow/lite/delegates/gpu/cl/tensor_type.h +++ b/tensorflow/lite/delegates/gpu/cl/tensor_type.h @@ -82,6 +82,16 @@ struct TensorDescriptor : public GPUObjectDescriptor { void UploadData(const tflite::gpu::Tensor& src); void UploadData(const tflite::gpu::Tensor& src); + bool SupportsZeroClamp(const Axis& axis) const; + bool CanReadOutOfBorder(const Axis& axis) const; + bool IsLinear() const; + + // applicable only for types that: IsLinear -> true. + // In this case for address we have 1d component - addr (int) + // If for addr == -1 this linear storage type returns FLT4(0.0), this function + // returns true, otherwise false + bool ReturnsZeroForNegOneRead() const; + DataType data_type = DataType::UNKNOWN; TensorStorageType storage_type = TensorStorageType::UNKNOWN; // This field describes logical layout, actual(physical) GPU layout can be diff --git a/tensorflow/lite/delegates/gpu/cl/testing/BUILD b/tensorflow/lite/delegates/gpu/cl/testing/BUILD index c82190ca0e6..a14dfd72cfd 100644 --- a/tensorflow/lite/delegates/gpu/cl/testing/BUILD +++ b/tensorflow/lite/delegates/gpu/cl/testing/BUILD @@ -3,20 +3,6 @@ package( licenses = ["notice"], # Apache 2.0 ) -cc_binary( - name = "performance_profiling", - srcs = ["performance_profiling.cc"], - deps = [ - "//tensorflow/lite/delegates/gpu/cl:environment", - "//tensorflow/lite/delegates/gpu/cl:inference_context", - "//tensorflow/lite/delegates/gpu/common:model", - "//tensorflow/lite/delegates/gpu/common:status", - "//tensorflow/lite/delegates/gpu/common/testing:tflite_model_reader", - "//tensorflow/lite/kernels:builtin_ops", - "@com_google_absl//absl/time", - ], -) - cc_binary( name = "delegate_testing", srcs = ["delegate_testing.cc"], @@ -34,3 +20,38 @@ cc_binary( "@com_google_absl//absl/time", ], ) + +cc_binary( + name = "internal_api_samples", + srcs = ["internal_api_samples.cc"], + tags = [ + "nobuilder", + "notap", + ], + deps = [ + "//tensorflow/lite/delegates/gpu:api", + "//tensorflow/lite/delegates/gpu/cl:api", + "//tensorflow/lite/delegates/gpu/cl:environment", + "//tensorflow/lite/delegates/gpu/cl:inference_context", + "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common/testing:tflite_model_reader", + "//tensorflow/lite/kernels:builtin_ops", + "//tensorflow/lite/kernels:kernel_util", + "@com_google_absl//absl/time", + ], +) + +cc_binary( + name = "performance_profiling", + srcs = ["performance_profiling.cc"], + deps = [ + "//tensorflow/lite/delegates/gpu/cl:environment", + "//tensorflow/lite/delegates/gpu/cl:inference_context", + "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common/testing:tflite_model_reader", + "//tensorflow/lite/kernels:builtin_ops", + "@com_google_absl//absl/time", + ], +) diff --git a/tensorflow/lite/delegates/gpu/cl/testing/delegate_testing.cc b/tensorflow/lite/delegates/gpu/cl/testing/delegate_testing.cc index 10b7ac34404..3a618e55c06 100644 --- a/tensorflow/lite/delegates/gpu/cl/testing/delegate_testing.cc +++ b/tensorflow/lite/delegates/gpu/cl/testing/delegate_testing.cc @@ -132,6 +132,7 @@ int main(int argc, char** argv) { options.inference_priority1 = TFLITE_GPU_INFERENCE_PRIORITY_MIN_LATENCY; options.inference_priority2 = TFLITE_GPU_INFERENCE_PRIORITY_MIN_MEMORY_USAGE; options.inference_priority3 = TFLITE_GPU_INFERENCE_PRIORITY_MAX_PRECISION; + options.experimental_flags = TFLITE_GPU_EXPERIMENTAL_FLAGS_NONE; options.max_delegated_partitions = 1; auto* gpu_delegate = TfLiteGpuDelegateV2Create(&options); status = gpu_inference->ModifyGraphWithDelegate(gpu_delegate); diff --git a/tensorflow/lite/delegates/gpu/cl/testing/internal_api_samples.cc b/tensorflow/lite/delegates/gpu/cl/testing/internal_api_samples.cc new file mode 100644 index 00000000000..be297546709 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/cl/testing/internal_api_samples.cc @@ -0,0 +1,453 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include // NOLINT(build/c++11) +#include +#include + +#include "absl/time/time.h" +#include "absl/types/span.h" +#include "tensorflow/lite/delegates/gpu/api.h" +#include "tensorflow/lite/delegates/gpu/cl/api.h" +#include "tensorflow/lite/delegates/gpu/cl/environment.h" +#include "tensorflow/lite/delegates/gpu/cl/inference_context.h" +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/testing/tflite_model_reader.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/register.h" + +namespace tflite { +namespace gpu { +namespace cl { +namespace { +void FillInputTensors(tflite::Interpreter* interpreter) { + for (int k = 0; k < interpreter->inputs().size(); ++k) { + TfLiteTensor* tensor_ptr = interpreter->tensor(interpreter->inputs()[k]); + const auto tensor_elements_count = tflite::NumElements(tensor_ptr); + if (tensor_ptr->type == kTfLiteFloat32) { + float* p = interpreter->typed_input_tensor(k); + for (int i = 0; i < tensor_elements_count; ++i) { + p[i] = std::sin(i); + } + } else { + std::cout << "No support of non Float32 input/output tensors" + << std::endl; + } + } +} + +void CompareCPUGPUResults(tflite::Interpreter* cpu, + const std::vector& outputs, + const std::vector>& gpu, + float eps) { + for (int i = 0; i < gpu.size(); ++i) { + TfLiteTensor* tensor_ptr = cpu->tensor(outputs[i]); + const float* cpu_out = tensor_ptr->data.f; + const float* gpu_out = gpu[i].data(); + const int kMaxPrint = 10; + int printed = 0; + int total_different = 0; + for (int k = 0; k < tensor_ptr->bytes / 4; ++k) { + const float abs_diff = fabs(cpu_out[k] - gpu_out[k]); + if (abs_diff > eps) { + total_different++; + if (printed < kMaxPrint) { + std::cout << "Output #" << i << ": element #" << k << ": CPU value - " + << cpu_out[k] << ", GPU value - " << gpu_out[k] + << ", abs diff - " << abs_diff << std::endl; + printed++; + } + if (printed == kMaxPrint) { + std::cout << "Printed " << kMaxPrint + << " different elements, threshhold - " << eps + << ", next different elements skipped" << std::endl; + printed++; + } + } + } + std::cout << "Total " << total_different + << " different elements, for output #" << i << ", threshhold - " + << eps << std::endl; + } +} +} // namespace + +absl::Status RunModelSampleWithInternalAPISerializedKernels( + const std::string& model_name, const std::vector& kernel_cache); + +absl::Status RunModelSampleWithInternalAPISerialized( + tflite::Interpreter* cpu, const std::vector& in_refs, + const std::vector& out_refs, + const std::vector& kernel_cache, + const std::vector& serialized_model); + +// Run Jet with OpenCL internal API and compares correctness with TFLite CPU +absl::Status RunModelSampleWithInternalAPI(const std::string& model_name, + std::vector* kernel_cache) { + auto flatbuffer = tflite::FlatBufferModel::BuildFromFile(model_name.c_str()); + + ops::builtin::BuiltinOpResolver op_resolver; + InterpreterBuilder tfl_builder(*flatbuffer, op_resolver); + + // CPU. + std::unique_ptr cpu_inference; + tfl_builder(&cpu_inference); + if (!cpu_inference) { + return absl::InternalError("Failed to build CPU inference."); + } + auto status = cpu_inference->AllocateTensors(); + if (status != kTfLiteOk) { + return absl::InternalError("Failed to AllocateTensors for CPU inference."); + } + for (int k = 0; k < cpu_inference->inputs().size(); ++k) { + TfLiteTensor* tensor_ptr = + cpu_inference->tensor(cpu_inference->inputs()[k]); + if (tensor_ptr->type != kTfLiteFloat32) { + return absl::InvalidArgumentError( + "Internal api supports only F32 input tensors"); + } + } + for (int k = 0; k < cpu_inference->outputs().size(); ++k) { + TfLiteTensor* tensor_ptr = + cpu_inference->tensor(cpu_inference->outputs()[k]); + if (tensor_ptr->type != kTfLiteFloat32) { + return absl::InvalidArgumentError( + "Internal api supports only F32 output tensors"); + } + } + FillInputTensors(cpu_inference.get()); + status = cpu_inference->Invoke(); + if (status != kTfLiteOk) { + return absl::InternalError("Failed to Invoke CPU inference."); + } + + const auto start = std::chrono::high_resolution_clock::now(); + GraphFloat32 graph_cl; + RETURN_IF_ERROR(BuildFromFlatBuffer(*flatbuffer, op_resolver, &graph_cl)); + + auto inputs = graph_cl.inputs(); + auto outputs = graph_cl.outputs(); + std::vector in_refs(inputs.size()); + std::vector out_refs(outputs.size()); + for (int i = 0; i < inputs.size(); ++i) { + in_refs[i] = inputs[i]->tensor.ref; + } + for (int i = 0; i < outputs.size(); ++i) { + out_refs[i] = outputs[i]->tensor.ref; + } + + Environment env; + RETURN_IF_ERROR(CreateEnvironment(&env)); + + std::unique_ptr inf_env; + // Initializes environment. + InferenceEnvironmentOptions env_options; + env_options.device = env.device().id(); + env_options.context = env.context().context(); + env_options.command_queue = env.queue()->queue(); + RETURN_IF_ERROR(NewInferenceEnvironment(env_options, &inf_env, nullptr)); + + std::unique_ptr builder; + // Initializes builder. + InferenceOptions options; + options.priority1 = InferencePriority::MIN_LATENCY; + options.priority2 = InferencePriority::MIN_MEMORY_USAGE; + options.priority3 = InferencePriority::MAX_PRECISION; + options.usage = InferenceUsage::SUSTAINED_SPEED; + + RETURN_IF_ERROR( + inf_env->NewInferenceBuilder(options, std::move(graph_cl), &builder)); + + // Sets input/output object def for builder_. + ObjectDef obj_def; + obj_def.data_type = DataType::FLOAT32; + obj_def.data_layout = DataLayout::BHWC; + obj_def.object_type = ObjectType::CPU_MEMORY; + obj_def.user_provided = true; + for (int i = 0; i < in_refs.size(); ++i) { + RETURN_IF_ERROR(builder->SetInputObjectDef(i, obj_def)); + } + for (int i = 0; i < out_refs.size(); ++i) { + RETURN_IF_ERROR(builder->SetOutputObjectDef(i, obj_def)); + } + + std::unique_ptr<::tflite::gpu::InferenceRunner> runner; + // Builds runner. + RETURN_IF_ERROR(builder->Build(&runner)); + + const auto end = std::chrono::high_resolution_clock::now(); + std::cout << "Initialization total time - " << (end - start).count() * 1e-6f + << "ms" << std::endl; + + if (kernel_cache) { + *kernel_cache = inf_env->GetSerializedBinaryCache(); + std::cout << "Kernel cache size - " << kernel_cache->size() << std::endl; + } + + // Sets the input/output object. + for (int i = 0; i < in_refs.size(); ++i) { + TfLiteTensor* tensor_ptr = cpu_inference->tensor(in_refs[i]); + RETURN_IF_ERROR(runner->SetInputObject( + i, CpuMemory{tensor_ptr->data.data, tensor_ptr->bytes})); + } + + std::vector> output_tensors(out_refs.size()); + for (int i = 0; i < out_refs.size(); ++i) { + TfLiteTensor* tensor_ptr = cpu_inference->tensor(out_refs[i]); + output_tensors[i].resize(tensor_ptr->bytes / 4); + RETURN_IF_ERROR(runner->SetOutputObject( + i, CpuMemory{output_tensors[i].data(), tensor_ptr->bytes})); + } + + RETURN_IF_ERROR(runner->Run()); + + CompareCPUGPUResults(cpu_inference.get(), out_refs, output_tensors, 1e-4f); + + return absl::OkStatus(); +} + +absl::Status RunModelSampleWithInternalAPISerializedKernels( + const std::string& model_name, const std::vector& kernel_cache) { + auto flatbuffer = tflite::FlatBufferModel::BuildFromFile(model_name.c_str()); + + ops::builtin::BuiltinOpResolver op_resolver; + InterpreterBuilder tfl_builder(*flatbuffer, op_resolver); + + // CPU. + std::unique_ptr cpu_inference; + tfl_builder(&cpu_inference); + if (!cpu_inference) { + return absl::InternalError("Failed to build CPU inference."); + } + auto status = cpu_inference->AllocateTensors(); + if (status != kTfLiteOk) { + return absl::InternalError("Failed to AllocateTensors for CPU inference."); + } + for (int k = 0; k < cpu_inference->inputs().size(); ++k) { + TfLiteTensor* tensor_ptr = + cpu_inference->tensor(cpu_inference->inputs()[k]); + if (tensor_ptr->type != kTfLiteFloat32) { + return absl::InvalidArgumentError( + "Internal api supports only F32 input tensors"); + } + } + for (int k = 0; k < cpu_inference->outputs().size(); ++k) { + TfLiteTensor* tensor_ptr = + cpu_inference->tensor(cpu_inference->outputs()[k]); + if (tensor_ptr->type != kTfLiteFloat32) { + return absl::InvalidArgumentError( + "Internal api supports only F32 output tensors"); + } + } + FillInputTensors(cpu_inference.get()); + status = cpu_inference->Invoke(); + if (status != kTfLiteOk) { + return absl::InternalError("Failed to Invoke CPU inference."); + } + + const auto start = std::chrono::high_resolution_clock::now(); + GraphFloat32 graph_cl; + RETURN_IF_ERROR(BuildFromFlatBuffer(*flatbuffer, op_resolver, &graph_cl)); + + auto inputs = graph_cl.inputs(); + auto outputs = graph_cl.outputs(); + std::vector in_refs(inputs.size()); + std::vector out_refs(outputs.size()); + for (int i = 0; i < inputs.size(); ++i) { + in_refs[i] = inputs[i]->tensor.ref; + } + for (int i = 0; i < outputs.size(); ++i) { + out_refs[i] = outputs[i]->tensor.ref; + } + + Environment env; + RETURN_IF_ERROR(CreateEnvironment(&env)); + + std::unique_ptr inf_env; + // Initializes environment. + InferenceEnvironmentOptions env_options; + env_options.device = env.device().id(); + env_options.context = env.context().context(); + env_options.command_queue = env.queue()->queue(); + env_options.serialized_binary_cache = + absl::MakeSpan(kernel_cache.data(), kernel_cache.size()); + RETURN_IF_ERROR(NewInferenceEnvironment(env_options, &inf_env, nullptr)); + + InferenceOptions options; + options.priority1 = InferencePriority::MIN_LATENCY; + options.priority2 = InferencePriority::MIN_MEMORY_USAGE; + options.priority3 = InferencePriority::MAX_PRECISION; + options.usage = InferenceUsage::SUSTAINED_SPEED; + + std::vector serialized_model; + RETURN_IF_ERROR(inf_env->BuildSerializedModel(options, std::move(graph_cl), + &serialized_model)); + std::unique_ptr builder; + RETURN_IF_ERROR(inf_env->NewInferenceBuilder(serialized_model, &builder)); + + // Sets input/output object def for builder_. + ObjectDef obj_def; + obj_def.data_type = DataType::FLOAT32; + obj_def.data_layout = DataLayout::BHWC; + obj_def.object_type = ObjectType::CPU_MEMORY; + obj_def.user_provided = true; + for (int i = 0; i < in_refs.size(); ++i) { + RETURN_IF_ERROR(builder->SetInputObjectDef(i, obj_def)); + } + for (int i = 0; i < out_refs.size(); ++i) { + RETURN_IF_ERROR(builder->SetOutputObjectDef(i, obj_def)); + } + + std::unique_ptr<::tflite::gpu::InferenceRunner> runner; + // Builds runner. + RETURN_IF_ERROR(builder->Build(&runner)); + + const auto end = std::chrono::high_resolution_clock::now(); + std::cout << "Initialization total time(with kernel cache) - " + << (end - start).count() * 1e-6f << "ms" << std::endl; + + // Sets the input/output object. + for (int i = 0; i < in_refs.size(); ++i) { + TfLiteTensor* tensor_ptr = cpu_inference->tensor(in_refs[i]); + RETURN_IF_ERROR(runner->SetInputObject( + i, CpuMemory{tensor_ptr->data.data, tensor_ptr->bytes})); + } + + std::vector> output_tensors(out_refs.size()); + for (int i = 0; i < out_refs.size(); ++i) { + TfLiteTensor* tensor_ptr = cpu_inference->tensor(out_refs[i]); + output_tensors[i].resize(tensor_ptr->bytes / 4); + RETURN_IF_ERROR(runner->SetOutputObject( + i, CpuMemory{output_tensors[i].data(), tensor_ptr->bytes})); + } + + RETURN_IF_ERROR(runner->Run()); + + CompareCPUGPUResults(cpu_inference.get(), out_refs, output_tensors, 1e-4f); + + RETURN_IF_ERROR(RunModelSampleWithInternalAPISerialized( + cpu_inference.get(), in_refs, out_refs, kernel_cache, serialized_model)); + + return absl::OkStatus(); +} + +absl::Status RunModelSampleWithInternalAPISerialized( + tflite::Interpreter* cpu, const std::vector& in_refs, + const std::vector& out_refs, + const std::vector& kernel_cache, + const std::vector& serialized_model) { + FillInputTensors(cpu); + auto status = cpu->Invoke(); + if (status != kTfLiteOk) { + return absl::InternalError("Failed to Invoke CPU inference."); + } + + const auto start = std::chrono::high_resolution_clock::now(); + + Environment env; + RETURN_IF_ERROR(CreateEnvironment(&env)); + + std::unique_ptr inf_env; + // Initializes environment. + InferenceEnvironmentOptions env_options; + env_options.device = env.device().id(); + env_options.context = env.context().context(); + env_options.command_queue = env.queue()->queue(); + env_options.serialized_binary_cache = + absl::MakeSpan(kernel_cache.data(), kernel_cache.size()); + RETURN_IF_ERROR(NewInferenceEnvironment(env_options, &inf_env, nullptr)); + + std::unique_ptr builder; + RETURN_IF_ERROR(inf_env->NewInferenceBuilder(serialized_model, &builder)); + + // Sets input/output object def for builder_. + ObjectDef obj_def; + obj_def.data_type = DataType::FLOAT32; + obj_def.data_layout = DataLayout::BHWC; + obj_def.object_type = ObjectType::CPU_MEMORY; + obj_def.user_provided = true; + for (int i = 0; i < in_refs.size(); ++i) { + RETURN_IF_ERROR(builder->SetInputObjectDef(i, obj_def)); + } + for (int i = 0; i < out_refs.size(); ++i) { + RETURN_IF_ERROR(builder->SetOutputObjectDef(i, obj_def)); + } + + std::unique_ptr<::tflite::gpu::InferenceRunner> runner; + // Builds runner. + RETURN_IF_ERROR(builder->Build(&runner)); + + const auto end = std::chrono::high_resolution_clock::now(); + std::cout << "Serialized initialization total time - " + << (end - start).count() * 1e-6f << "ms" << std::endl; + + // Sets the input/output object. + for (int i = 0; i < in_refs.size(); ++i) { + TfLiteTensor* tensor_ptr = cpu->tensor(in_refs[i]); + RETURN_IF_ERROR(runner->SetInputObject( + i, CpuMemory{tensor_ptr->data.data, tensor_ptr->bytes})); + } + + std::vector> output_tensors(out_refs.size()); + for (int i = 0; i < out_refs.size(); ++i) { + TfLiteTensor* tensor_ptr = cpu->tensor(out_refs[i]); + output_tensors[i].resize(tensor_ptr->bytes / 4); + RETURN_IF_ERROR(runner->SetOutputObject( + i, CpuMemory{output_tensors[i].data(), tensor_ptr->bytes})); + } + + RETURN_IF_ERROR(runner->Run()); + + std::cout << "Comparing results second time:" << std::endl; + + CompareCPUGPUResults(cpu, out_refs, output_tensors, 1e-4f); + + return absl::OkStatus(); +} + +} // namespace cl +} // namespace gpu +} // namespace tflite + +int main(int argc, char** argv) { + if (argc <= 1) { + std::cerr << "Expected model path as second argument."; + return -1; + } + + auto load_status = tflite::gpu::cl::LoadOpenCL(); + if (!load_status.ok()) { + std::cerr << load_status.message(); + return -1; + } + + std::vector kernel_cache; + auto run_status = + tflite::gpu::cl::RunModelSampleWithInternalAPI(argv[1], &kernel_cache); + if (!run_status.ok()) { + std::cerr << run_status.message(); + return -1; + } + run_status = tflite::gpu::cl::RunModelSampleWithInternalAPISerializedKernels( + argv[1], kernel_cache); + if (!run_status.ok()) { + std::cerr << run_status.message(); + return -1; + } + + return EXIT_SUCCESS; +} diff --git a/tensorflow/lite/delegates/gpu/cl/testing/performance_profiling.cc b/tensorflow/lite/delegates/gpu/cl/testing/performance_profiling.cc index ab2e52f14ed..540004ad746 100644 --- a/tensorflow/lite/delegates/gpu/cl/testing/performance_profiling.cc +++ b/tensorflow/lite/delegates/gpu/cl/testing/performance_profiling.cc @@ -43,7 +43,7 @@ absl::Status RunModelSample(const std::string& model_name) { create_info.precision = env.IsSupported(CalculationsPrecision::F16) ? CalculationsPrecision::F16 : CalculationsPrecision::F32; - create_info.storage_type = GetFastestStorageType(env.device()); + create_info.storage_type = GetFastestStorageType(env.device().GetInfo()); create_info.hints.Add(ModelHints::kAllowSpecialKernels); std::cout << "Precision: " << ToString(create_info.precision) << std::endl; std::cout << "Storage type: " << ToString(create_info.storage_type) diff --git a/tensorflow/lite/delegates/gpu/cl/testing/run_delegate_testing.sh b/tensorflow/lite/delegates/gpu/cl/testing/run_delegate_testing.sh index 7b86407dbad..70d2a5cf3dc 100755 --- a/tensorflow/lite/delegates/gpu/cl/testing/run_delegate_testing.sh +++ b/tensorflow/lite/delegates/gpu/cl/testing/run_delegate_testing.sh @@ -78,11 +78,17 @@ ADB push "$model_path" "$OPENCL_DIR" declare -a BUILD_CONFIG abi_version=$(ADB shell getprop ro.product.cpu.abi | tr -d '\r') if [[ "$abi_version" == "armeabi-v7a" ]]; then -#"32 bit" +#"32 bit ARM" BUILD_CONFIG=( --config=android_arm -c opt --copt=-fPIE --linkopt=-pie ) -else -#"64 bit" +elif [[ "$abi_version" == "arm64-v8a" ]]; then +#"64 bit ARM" BUILD_CONFIG=( --config=android_arm64 -c opt ) +elif [[ "$abi_version" == "x86_64" ]]; then +# x86_64 +BUILD_CONFIG=( --config=android_x86_64 -c opt ) +else +echo "Error: Unknown processor ABI" +exit 1 fi bazel build "${BUILD_CONFIG[@]}" //$SHELL_DIR:$BINARY_NAME diff --git a/tensorflow/lite/delegates/gpu/cl/testing/run_internal_api_samples.sh b/tensorflow/lite/delegates/gpu/cl/testing/run_internal_api_samples.sh new file mode 100755 index 00000000000..21900c55875 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/cl/testing/run_internal_api_samples.sh @@ -0,0 +1,101 @@ +#!/bin/bash +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +shopt -s expand_aliases # to work with commands aliases in .sh + +description="Example of intetrnal api usage: +How to use: +[-h or --help, print instructions] +[-m or --model_path, path to the model in .tflite format] +[-d or --device, select device](optional, if you have few connected devices)" + +model_path="" +alias ADB='adb' +host="" + +while [[ "$1" != "" ]]; do + case $1 in + -m | --model_path) + shift + model_path=$1 + ;; + -d | --device) + shift + if [[ "$1" == "HOST" ]] + then + host="HOST" + fi + alias ADB='adb -s '$1'' + ;; + -h | --help) + echo "$description" + exit + ;; + esac + shift +done + +if [ "$model_path" = "" ] +then +echo "No model provided." +echo "$description" +exit +fi + +SHELL_DIR=$(dirname "$0") +BINARY_NAME=internal_api_samples + +if [[ "$host" == "HOST" ]] +then +bazel build -c opt --copt -DCL_DELEGATE_NO_GL //"$SHELL_DIR":"$BINARY_NAME" +chmod +x bazel-bin/"$SHELL_DIR"/"$BINARY_NAME" +./bazel-bin/"$SHELL_DIR"/"$BINARY_NAME" "$model_path" +exit +fi + +model_name=${model_path##*/} # finds last token after '/' + +OPENCL_DIR=/data/local/tmp/internal_api_samples/ + +ADB shell mkdir -p $OPENCL_DIR + +ADB push "$model_path" "$OPENCL_DIR" + +declare -a BUILD_CONFIG +abi_version=$(ADB shell getprop ro.product.cpu.abi | tr -d '\r') +if [[ "$abi_version" == "armeabi-v7a" ]]; then +#"32 bit ARM" +BUILD_CONFIG=( --config=android_arm -c opt --copt=-fPIE --linkopt=-pie ) +elif [[ "$abi_version" == "arm64-v8a" ]]; then +#"64 bit ARM" +BUILD_CONFIG=( --config=android_arm64 -c opt ) +elif [[ "$abi_version" == "x86_64" ]]; then +# x86_64 +BUILD_CONFIG=( --config=android_x86_64 -c opt ) +else +echo "Error: Unknown processor ABI" +exit 1 +fi + +bazel build "${BUILD_CONFIG[@]}" --copt -DCL_DELEGATE_NO_GL //$SHELL_DIR:$BINARY_NAME + +ADB push bazel-bin/$SHELL_DIR/$BINARY_NAME $OPENCL_DIR + +ADB shell chmod +x $OPENCL_DIR/$BINARY_NAME +ADB shell "cd $OPENCL_DIR && ./$BINARY_NAME $model_name" + +# clean up files from device +ADB shell rm -rf $OPENCL_DIR diff --git a/tensorflow/lite/delegates/gpu/cl/testing/run_performance_profiling.sh b/tensorflow/lite/delegates/gpu/cl/testing/run_performance_profiling.sh index 0fd2d33de14..56d1e1010ed 100755 --- a/tensorflow/lite/delegates/gpu/cl/testing/run_performance_profiling.sh +++ b/tensorflow/lite/delegates/gpu/cl/testing/run_performance_profiling.sh @@ -83,11 +83,17 @@ ADB push "$model_path" "$OPENCL_DIR" declare -a BUILD_CONFIG abi_version=$(ADB shell getprop ro.product.cpu.abi | tr -d '\r') if [[ "$abi_version" == "armeabi-v7a" ]]; then -#"32 bit" +#"32 bit ARM" BUILD_CONFIG=( --config=android_arm -c opt --copt=-fPIE --linkopt=-pie ) -else -#"64 bit" +elif [[ "$abi_version" == "arm64-v8a" ]]; then +#"64 bit ARM" BUILD_CONFIG=( --config=android_arm64 -c opt ) +elif [[ "$abi_version" == "x86_64" ]]; then +# x86_64 +BUILD_CONFIG=( --config=android_x86_64 -c opt ) +else +echo "Error: Unknown processor ABI" +exit 1 fi bazel build "${BUILD_CONFIG[@]}" //$SHELL_DIR:$BINARY_NAME diff --git a/tensorflow/lite/delegates/gpu/cl/texture2d.cc b/tensorflow/lite/delegates/gpu/cl/texture2d.cc index 28d26f03260..77cc7c9353c 100644 --- a/tensorflow/lite/delegates/gpu/cl/texture2d.cc +++ b/tensorflow/lite/delegates/gpu/cl/texture2d.cc @@ -24,10 +24,9 @@ namespace { absl::Status CreateTexture2D(int width, int height, DataType type, void* data, CLContext* context, Texture2D* result) { cl_mem texture; - RETURN_IF_ERROR(CreateFloatRGBAImage2D(context->context(), width, height, - type, data, &texture)); - cl_channel_type channel_type = - type == DataType::FLOAT32 ? CL_FLOAT : CL_HALF_FLOAT; + cl_channel_type channel_type = DataTypeToChannelType(type); + RETURN_IF_ERROR(CreateRGBAImage2D(context->context(), width, height, + channel_type, data, &texture)); *result = Texture2D(texture, width, height, channel_type); return absl::OkStatus(); @@ -37,6 +36,8 @@ absl::Status CreateTexture2D(int width, int height, DataType type, void* data, Texture2DDescriptor::Texture2DDescriptor(Texture2DDescriptor&& desc) : GPUObjectDescriptor(std::move(desc)), element_type(desc.element_type), + normalized(desc.normalized), + normalized_type(desc.normalized_type), size(desc.size), data(std::move(desc.data)) {} @@ -44,6 +45,8 @@ Texture2DDescriptor& Texture2DDescriptor::operator=( Texture2DDescriptor&& desc) { if (this != &desc) { std::swap(element_type, desc.element_type); + std::swap(normalized, desc.normalized); + std::swap(normalized_type, desc.normalized_type); std::swap(size, desc.size); data = std::move(desc.data); GPUObjectDescriptor::operator=(std::move(desc)); @@ -80,8 +83,38 @@ absl::Status Texture2DDescriptor::PerformReadSelector( absl::StrCat("Texture2DDescriptor Read require two arguments, but ", args.size(), " was passed")); } - const std::string read = - element_type == DataType::FLOAT16 ? "read_imageh" : "read_imagef"; + std::string read; + switch (element_type) { + case DataType::FLOAT32: + read = "read_imagef"; + break; + case DataType::FLOAT16: + read = "read_imageh"; + break; + case DataType::INT8: + case DataType::INT16: + case DataType::INT32: + if (normalized) { + read = normalized_type == DataType::FLOAT16 ? "read_imageh" + : "read_imagef"; + } else { + read = "read_imagei"; + } + break; + case DataType::UINT8: + case DataType::UINT16: + case DataType::UINT32: + if (normalized) { + read = normalized_type == DataType::FLOAT16 ? "read_imageh" + : "read_imagef"; + } else { + read = "read_imageui"; + } + break; + default: + read = "unknown_type"; + break; + } *result = absl::StrCat(read, "(tex2d, smp_none, (int2)(", args[0], ", " + args[1] + "))"); return absl::OkStatus(); @@ -145,13 +178,12 @@ absl::Status Texture2D::CreateFromTexture2DDescriptor( const Texture2DDescriptor& desc, CLContext* context) { width_ = desc.size.x; height_ = desc.size.y; - channel_type_ = - desc.element_type == DataType::FLOAT32 ? CL_FLOAT : CL_HALF_FLOAT; + channel_type_ = DataTypeToChannelType(desc.element_type, desc.normalized); uint8_t* data_ptr = desc.data.empty() ? nullptr : const_cast(desc.data.data()); - return CreateFloatRGBAImage2D(context->context(), desc.size.x, desc.size.y, - desc.element_type, data_ptr, &texture_); + return CreateRGBAImage2D(context->context(), desc.size.x, desc.size.y, + channel_type_, data_ptr, &texture_); } // Creates new 4-channel 2D texture with f32 elements diff --git a/tensorflow/lite/delegates/gpu/cl/texture2d.h b/tensorflow/lite/delegates/gpu/cl/texture2d.h index 51e0fc7e42c..15864305f21 100644 --- a/tensorflow/lite/delegates/gpu/cl/texture2d.h +++ b/tensorflow/lite/delegates/gpu/cl/texture2d.h @@ -32,7 +32,11 @@ namespace gpu { namespace cl { struct Texture2DDescriptor : public GPUObjectDescriptor { - DataType element_type; // FLOAT32 or FLOAT16 + DataType element_type; + bool normalized = false; // used with INT data types, if normalized, we read + // in kernel float data. + DataType normalized_type; // can be FLOAT32 or FLOAT16, using with normalized + // = true // optional int2 size = int2(0, 0); diff --git a/tensorflow/lite/delegates/gpu/cl/util.cc b/tensorflow/lite/delegates/gpu/cl/util.cc index 199e0129968..d0e65537519 100644 --- a/tensorflow/lite/delegates/gpu/cl/util.cc +++ b/tensorflow/lite/delegates/gpu/cl/util.cc @@ -184,8 +184,32 @@ absl::Status CreateCLBuffer(cl_context context, int size_in_bytes, return absl::OkStatus(); } -absl::Status CreateFloatRGBAImage2D(cl_context context, int width, int height, - DataType type, void* data, cl_mem* result) { +cl_channel_type DataTypeToChannelType(DataType type, bool normalized) { + switch (type) { + case DataType::FLOAT32: + return CL_FLOAT; + case DataType::FLOAT16: + return CL_HALF_FLOAT; + case DataType::INT8: + return normalized ? CL_SNORM_INT8 : CL_SIGNED_INT8; + case DataType::UINT8: + return normalized ? CL_UNORM_INT8 : CL_UNSIGNED_INT8; + case DataType::INT16: + return normalized ? CL_SNORM_INT16 : CL_SIGNED_INT16; + case DataType::UINT16: + return normalized ? CL_UNORM_INT16 : CL_UNSIGNED_INT16; + case DataType::INT32: + return CL_SIGNED_INT32; + case DataType::UINT32: + return CL_UNSIGNED_INT32; + default: + return CL_FLOAT; + } +} + +absl::Status CreateRGBAImage2D(cl_context context, int width, int height, + cl_channel_type channel_type, void* data, + cl_mem* result) { cl_image_desc desc; desc.image_type = CL_MEM_OBJECT_IMAGE2D; desc.image_width = width; @@ -199,8 +223,7 @@ absl::Status CreateFloatRGBAImage2D(cl_context context, int width, int height, cl_image_format format; format.image_channel_order = CL_RGBA; - format.image_channel_data_type = - type == DataType::FLOAT32 ? CL_FLOAT : CL_HALF_FLOAT; + format.image_channel_data_type = channel_type; cl_mem_flags flags = CL_MEM_READ_WRITE; if (data) { diff --git a/tensorflow/lite/delegates/gpu/cl/util.h b/tensorflow/lite/delegates/gpu/cl/util.h index 8e22c017fe7..54a6c74a3ff 100644 --- a/tensorflow/lite/delegates/gpu/cl/util.h +++ b/tensorflow/lite/delegates/gpu/cl/util.h @@ -52,8 +52,10 @@ void CopyLinearFLT4(const tflite::gpu::Tensor& src, absl::Status CreateCLBuffer(cl_context context, int size_in_bytes, bool read_only, void* data, cl_mem* result); -absl::Status CreateFloatRGBAImage2D(cl_context context, int width, int height, - DataType type, void* data, cl_mem* result); +cl_channel_type DataTypeToChannelType(DataType type, bool normalized = false); +absl::Status CreateRGBAImage2D(cl_context context, int width, int height, + cl_channel_type channel_type, void* data, + cl_mem* result); } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/common/BUILD b/tensorflow/lite/delegates/gpu/common/BUILD index 60a0fda422c..99d915f0ed2 100644 --- a/tensorflow/lite/delegates/gpu/common/BUILD +++ b/tensorflow/lite/delegates/gpu/common/BUILD @@ -10,6 +10,7 @@ cc_library( srcs = ["convert.cc"], hdrs = ["convert.h"], deps = [ + ":data_type", ":shape", ":status", ":tensor", @@ -22,7 +23,10 @@ cc_library( ) exports_files( - ["custom_parsers.h"], + [ + "custom_parsers.h", + "custom_transformations.h", + ], visibility = ["//tensorflow/lite/delegates/gpu/common:__subpackages__"], ) @@ -73,6 +77,7 @@ cc_library( ":types", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", ], ) @@ -81,11 +86,11 @@ cc_library( srcs = ["model.cc"], hdrs = ["model.h"], deps = [ - ":data_type", ":shape", ":status", ":tensor", "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:any", "@com_google_absl//absl/types:optional", @@ -97,16 +102,44 @@ cc_test( srcs = ["model_test.cc"], deps = [ ":model", + "@com_google_absl//absl/status", "@com_google_googletest//:gtest_main", ], ) +cc_library( + name = "lstm_parser", + srcs = ["lstm_parser.cc"], + hdrs = ["lstm_parser.h"], + deps = [ + ":data_type", + ":model", + ":model_builder_helper", + ":object_reader", + ":operations", + ":shape", + ":status", + ":tensor", + "//tensorflow/lite:string", + "//tensorflow/lite/c:common", + "//tensorflow/lite/kernels:lstm_shared", + "//tensorflow/lite/kernels/internal:quantization_util", + "//tensorflow/lite/kernels/internal:tensor", + "//tensorflow/lite/kernels/internal:tensor_utils", + "//tensorflow/lite/kernels/internal:types", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:any", + ], +) + cc_library( name = "model_builder", srcs = ["model_builder.cc"], hdrs = ["model_builder.h"], deps = [ ":data_type", + ":lstm_parser", ":model", ":model_builder_helper", ":model_transformer", @@ -115,14 +148,15 @@ cc_library( ":shape", ":status", ":tensor", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "//tensorflow/lite/delegates:utils", - "//tensorflow/lite:context", "//tensorflow/lite:kernel_api", "//tensorflow/lite:util", "//tensorflow/lite/c:common", - "//tensorflow/lite/delegates/gpu/common/transformations:general_transformations", + "//tensorflow/lite/delegates/gpu/common/transformations:model_transformations", "//tensorflow/lite/kernels:kernel_util", "//tensorflow/lite/kernels/internal:reference_base", "//tensorflow/lite/kernels/internal:tensor", @@ -133,10 +167,14 @@ cc_test( name = "model_builder_test", srcs = ["model_builder_test.cc"], deps = [ + ":data_type", ":model_builder", + ":shape", + ":tensor", "//tensorflow/lite:framework", "//tensorflow/lite:kernel_api", "//tensorflow/lite/c:common", + "@com_google_absl//absl/status", "@com_google_googletest//:gtest_main", ], ) @@ -152,10 +190,8 @@ cc_library( ":shape", ":status", ":tensor", - "//tensorflow/lite:context", "//tensorflow/lite:kernel_api", "//tensorflow/lite/c:common", - "//tensorflow/lite/delegates:utils", "//tensorflow/lite/kernels:kernel_util", "//tensorflow/lite/kernels/internal:reference_base", "//tensorflow/lite/kernels/internal:tensor", @@ -186,10 +222,12 @@ cc_library( ":model", ":model_builder_helper", ":status", + ":tensor", "//tensorflow/lite/c:common", "//tensorflow/lite/delegates:utils", "//tensorflow/lite/kernels:kernel_util", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings", ], ) @@ -199,9 +237,9 @@ cc_library( hdrs = ["operations.h"], deps = [ ":data_type", - ":model", ":shape", ":status", + ":tensor", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/types:variant", ], @@ -213,11 +251,12 @@ cc_library( hdrs = ["quantization_util.h"], deps = [ ":status", - "//tensorflow/lite:kernel_api", "//tensorflow/lite/c:common", "//tensorflow/lite/kernels/internal:optimized_base", + "//tensorflow/lite/kernels/internal:tensor", "//tensorflow/lite/kernels/internal:types", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", ], ) @@ -227,6 +266,9 @@ cc_test( deps = [ ":quantization_util", "//tensorflow/lite:util", + "//tensorflow/lite/c:common", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", "@com_google_googletest//:gtest_main", ], ) @@ -237,10 +279,7 @@ cc_library( name = "shape", srcs = ["shape.cc"], hdrs = ["shape.h"], - deps = [ - "@com_google_absl//absl/hash", - "@com_google_absl//absl/strings", - ], + deps = ["@com_google_absl//absl/strings"], ) cc_test( @@ -288,6 +327,9 @@ cc_test( srcs = ["memory_management_test.cc"], deps = [ ":memory_management", + ":shape", + ":types", + "@com_google_absl//absl/status", "@com_google_googletest//:gtest_main", ], ) @@ -334,9 +376,5 @@ cc_library( name = "workgroup_selection", srcs = ["workgroup_selection.cc"], hdrs = ["workgroup_selection.h"], - deps = [ - ":status", - ":types", - ":util", - ], + deps = [":util"], ) diff --git a/tensorflow/lite/delegates/gpu/common/convert.cc b/tensorflow/lite/delegates/gpu/common/convert.cc index fb0caf9f167..3920692bdca 100644 --- a/tensorflow/lite/delegates/gpu/common/convert.cc +++ b/tensorflow/lite/delegates/gpu/common/convert.cc @@ -15,9 +15,19 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/convert.h" +#include +#include + +#include +#include + #include #include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "tensorflow/lite/delegates/gpu/common/data_type.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/tensor.h" #include "tensorflow/lite/delegates/gpu/common/types.h" #include "tensorflow/lite/delegates/gpu/common/util.h" diff --git a/tensorflow/lite/delegates/gpu/common/convert.h b/tensorflow/lite/delegates/gpu/common/convert.h index 3aba9c913c5..c7a6c17380a 100644 --- a/tensorflow/lite/delegates/gpu/common/convert.h +++ b/tensorflow/lite/delegates/gpu/common/convert.h @@ -16,9 +16,12 @@ limitations under the License. #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_CONVERT_H_ #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_CONVERT_H_ +#include + #include #include "absl/types/span.h" +#include "tensorflow/lite/delegates/gpu/common/data_type.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/status.h" #include "tensorflow/lite/delegates/gpu/common/tensor.h" diff --git a/tensorflow/lite/delegates/gpu/common/custom_parsers.h b/tensorflow/lite/delegates/gpu/common/custom_parsers.h index d70e5849315..2644864cb58 100644 --- a/tensorflow/lite/delegates/gpu/common/custom_parsers.h +++ b/tensorflow/lite/delegates/gpu/common/custom_parsers.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_CUSTOM_PARSERS_H_ #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_CUSTOM_PARSERS_H_ -#include +#include #include "absl/strings/string_view.h" #include "absl/types/any.h" diff --git a/tensorflow/lite/delegates/gpu/common/custom_transformations.h b/tensorflow/lite/delegates/gpu/common/custom_transformations.h new file mode 100644 index 00000000000..3ca73a0d245 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/custom_transformations.h @@ -0,0 +1,29 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_CUSTOM_TRANSFORMATIONS_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_CUSTOM_TRANSFORMATIONS_H_ + +#include "tensorflow/lite/delegates/gpu/common/model_transformer.h" + +namespace tflite { +namespace gpu { + +// Applies all implemented custom model transformations. +bool ApplyCustomTransformations(ModelTransformer* transformer); + +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_CUSTOM_TRANSFORMATIONS_H_ diff --git a/tensorflow/lite/delegates/gpu/common/default/BUILD b/tensorflow/lite/delegates/gpu/common/default/BUILD index b085f68fcfb..91ce7e6c028 100644 --- a/tensorflow/lite/delegates/gpu/common/default/BUILD +++ b/tensorflow/lite/delegates/gpu/common/default/BUILD @@ -14,3 +14,12 @@ cc_library( "@com_google_absl//absl/types:any", ], ) + +cc_library( + name = "custom_transformations", + srcs = ["custom_transformations.cc"], + hdrs = ["//tensorflow/lite/delegates/gpu/common:custom_transformations.h"], + deps = [ + "//tensorflow/lite/delegates/gpu/common:model_transformer", + ], +) diff --git a/tensorflow/lite/delegates/gpu/common/default/custom_parsers.cc b/tensorflow/lite/delegates/gpu/common/default/custom_parsers.cc index 5aa1303d55c..a4981a9d459 100644 --- a/tensorflow/lite/delegates/gpu/common/default/custom_parsers.cc +++ b/tensorflow/lite/delegates/gpu/common/default/custom_parsers.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/custom_parsers.h" +#include + #include #include "absl/strings/str_cat.h" diff --git a/tensorflow/lite/delegates/gpu/common/default/custom_transformations.cc b/tensorflow/lite/delegates/gpu/common/default/custom_transformations.cc new file mode 100644 index 00000000000..c57b9276068 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/default/custom_transformations.cc @@ -0,0 +1,26 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/common/custom_transformations.h" + +#include "tensorflow/lite/delegates/gpu/common/model_transformer.h" + +namespace tflite { +namespace gpu { + +bool ApplyCustomTransformations(ModelTransformer* transformer) { return true; } + +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/common/gpu_info.cc b/tensorflow/lite/delegates/gpu/common/gpu_info.cc index 14fb48a2d2d..b56745df971 100644 --- a/tensorflow/lite/delegates/gpu/common/gpu_info.cc +++ b/tensorflow/lite/delegates/gpu/common/gpu_info.cc @@ -15,8 +15,6 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/gpu_info.h" -#include -#include #include #include "absl/strings/ascii.h" diff --git a/tensorflow/lite/delegates/gpu/common/lstm_parser.cc b/tensorflow/lite/delegates/gpu/common/lstm_parser.cc new file mode 100644 index 00000000000..bd84559fd54 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/lstm_parser.cc @@ -0,0 +1,551 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/common/lstm_parser.h" + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/str_cat.h" +#include "absl/types/any.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/delegates/gpu/common/data_type.h" +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/model_builder_helper.h" +#include "tensorflow/lite/delegates/gpu/common/object_reader.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/tensor.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/internal/tensor_utils.h" +#include "tensorflow/lite/kernels/internal/types.h" +#include "tensorflow/lite/kernels/lstm_shared.h" +#include "tensorflow/lite/string_type.h" + +namespace tflite { +namespace gpu { +namespace { + +Value* CreateNewSimilarValue(GraphFloat32* graph, const Value* old_value) { + Value* new_value = graph->NewValue(); + new_value->quant_params = old_value->quant_params; + new_value->tensor.shape = old_value->tensor.shape; + new_value->tensor.type = old_value->tensor.type; + new_value->tensor.ref = -1; + return new_value; +} + +absl::Status SetFullyConnectedWeights(int weights_tensor_id, + ObjectReader* reader, + FullyConnectedAttributes* attr) { + Tensor weights; + RETURN_IF_ERROR(reader->ReadTensor(weights_tensor_id, &weights)); + attr->weights.data = std::move(weights.data); + attr->weights.id = weights.id; + attr->weights.shape.o = weights.shape.h; + attr->weights.shape.h = 1; + attr->weights.shape.w = 1; + attr->weights.shape.i = weights.shape.w; + return absl::OkStatus(); +} + +bool HasTensor(const TfLiteNode* node, const int index) { + return (index < node->inputs->size) && + (node->inputs->data[index] != kTfLiteOptionalTensor); +} + +bool HasCifg(const TfLiteNode* node) { + return !HasTensor( + node, tflite::ops::builtin::lstm::full::kInputToInputWeightsTensor); +} + +bool HasPeephole(const TfLiteNode* node) { + // Use forget weights to detect peephole instead of input weights as input + // weights may be missing for cifg. + return HasTensor( + node, tflite::ops::builtin::lstm::full::kCellToForgetWeightsTensor); +} + +bool HasNormalization(const TfLiteNode* node) { + return HasTensor( + node, + tflite::ops::builtin::lstm::full::kForgetLayerNormCoefficientsTensor); +} + +bool HasProjection(const TfLiteNode* node) { + return HasTensor(node, + tflite::ops::builtin::lstm::full::kProjectionWeightsTensor); +} + +// Builds subgraph for a single LSTM gate. +// Returns a Value representing the gate's output. +// High-level parameters: +// - Has normalization (if true: provide normalization weights). +// - Has peephole connection (if true: provide peephole weights). +// - Which activation function to use. +// Note: no support for aux input. +// +// Implements the following: +// (*: matrix multiply, .*: elementwise multiply, +: elementwise add): +// temp = input_weights * input_tensor + recurrent_weights * output_state; +// if (peephole): +// temp += peephole_weights .* cell_state; +// if (layer normalization): +// gate = activate(normalization_weights .* mean_stddev_norm(temp) + bias); +// else: +// gate = activate(temp + bias); +// +absl::Status BuildLstmGate(GraphFloat32* graph, ObjectReader* reader, + Value* output_state, Value* cell_state, + int input_weight_id, int recurrent_weight_id, + int cell_weight_id, int bias_id, + int normalization_weight_id, + const TfLiteFusedActivation activation, + bool has_peephole, bool has_normalization, + Value** gate_out) { + Value* input_times_weights = CreateNewSimilarValue(graph, cell_state); + { + // #1 matrix multiplication: input_weights * input_tensor + // If has no normalization, also adds bias. + Node* node = graph->NewNode(); + node->operation.type = ToString(OperationType::FULLY_CONNECTED); + FullyConnectedAttributes fc_attr; + RETURN_IF_ERROR( + SetFullyConnectedWeights(input_weight_id, reader, &fc_attr)); + if (!has_normalization) { + RETURN_IF_ERROR(reader->ReadTensor(bias_id, &(fc_attr.bias))); + } + node->operation.attributes = std::move(fc_attr); + RETURN_IF_ERROR( + reader->AddInput(node, tflite::ops::builtin::lstm::full::kInputTensor)); + RETURN_IF_ERROR(graph->SetProducer(node->id, input_times_weights->id)); + } + + Value* output_state_times_weights = CreateNewSimilarValue(graph, cell_state); + { + // #2 matrix multiplication: recurrent_weights * output_state + Node* node = graph->NewNode(); + node->operation.type = ToString(OperationType::FULLY_CONNECTED); + FullyConnectedAttributes fc_attr; + RETURN_IF_ERROR( + SetFullyConnectedWeights(recurrent_weight_id, reader, &fc_attr)); + node->operation.attributes = std::move(fc_attr); + RETURN_IF_ERROR(graph->AddConsumer(node->id, output_state->id)); + RETURN_IF_ERROR( + graph->SetProducer(node->id, output_state_times_weights->id)); + } + + Value* cell_state_times_weights; + if (has_peephole) { + // #3 elementwise multiplication: cell_weight .* cell_state + cell_state_times_weights = CreateNewSimilarValue(graph, cell_state); + Node* node = graph->NewNode(); + node->operation.type = ToString(OperationType::MUL); + ElementwiseAttributes attr; + Tensor weights; + RETURN_IF_ERROR(reader->ReadTensor(cell_weight_id, &weights)); + attr.param = std::move(weights); + node->operation.attributes = std::move(attr); + RETURN_IF_ERROR(graph->AddConsumer(node->id, cell_state->id)); + RETURN_IF_ERROR(graph->SetProducer(node->id, cell_state_times_weights->id)); + } + + Value* gate_before_normalization = CreateNewSimilarValue(graph, cell_state); + Node* add_node = graph->NewNode(); + { + // #4 elementwise addition: #1 + #2 + #3 + add_node->operation.type = ToString(OperationType::ADD); + RETURN_IF_ERROR(graph->AddConsumer(add_node->id, input_times_weights->id)); + RETURN_IF_ERROR( + graph->AddConsumer(add_node->id, output_state_times_weights->id)); + if (has_peephole) { + RETURN_IF_ERROR( + graph->AddConsumer(add_node->id, cell_state_times_weights->id)); + } + RETURN_IF_ERROR( + graph->SetProducer(add_node->id, gate_before_normalization->id)); + } + + if (!has_normalization) { + // #5 Activation function: activate(temp + bias) + // Bias is added in node #1. + RETURN_IF_ERROR(MaybeFuseActivation(activation, graph, add_node)); + *gate_out = gate_before_normalization; + return absl::OkStatus(); + } + + Value* normalized_gate = + CreateNewSimilarValue(graph, gate_before_normalization); + { + // #6 Normalization: normalize(temp) + Node* node = graph->NewNode(); + node->operation.type = ToString(OperationType::MEAN_STDDEV_NORMALIZATION); + RETURN_IF_ERROR( + graph->AddConsumer(node->id, gate_before_normalization->id)); + RETURN_IF_ERROR(graph->SetProducer(node->id, normalized_gate->id)); + } + Value* reweighted_normalized_gate = + CreateNewSimilarValue(graph, normalized_gate); + { + // #7 Elementwise multiplication: norm_weights .* #6 + Node* node = graph->NewNode(); + node->operation.type = ToString(OperationType::MUL); + ElementwiseAttributes attr; + Tensor norm_weights; + RETURN_IF_ERROR(reader->ReadTensor(normalization_weight_id, &norm_weights)); + attr.param = std::move(norm_weights); + node->operation.attributes = std::move(attr); + RETURN_IF_ERROR(graph->AddConsumer(node->id, normalized_gate->id)); + RETURN_IF_ERROR( + graph->SetProducer(node->id, reweighted_normalized_gate->id)); + } + Value* gate = CreateNewSimilarValue(graph, reweighted_normalized_gate); + { + // #8 Elementwise add: #7 + bias + Node* node = graph->NewNode(); + node->operation.type = ToString(OperationType::ADD); + ElementwiseAttributes attr; + Tensor bias; + RETURN_IF_ERROR(reader->ReadTensor(bias_id, &bias)); + attr.param = std::move(bias); + node->operation.attributes = std::move(attr); + RETURN_IF_ERROR( + graph->AddConsumer(node->id, reweighted_normalized_gate->id)); + RETURN_IF_ERROR(graph->SetProducer(node->id, gate->id)); + + // #9: Activation function + RETURN_IF_ERROR(MaybeFuseActivation(activation, graph, node)); + } + *gate_out = gate; + return absl::OkStatus(); +} + +// Builds subgraph for LSTM cell state update. +// Returns a Value representing the updated cell state. +// High-level parameters: +// - clip: if > 0, clamp the resulting cell state to [-clip, +clip]. +// +// Implements the following: +// (*: matrix multiply, .*: elementwise multiply, +: elementwise add): +// +// cell_state_new = clip(forget_gate .* cell_state + input_gate .* cell_gate); +// +absl::Status BuildCellStateUpdate(GraphFloat32* graph, ObjectReader* reader, + Value* forget_gate, Value* input_gate, + Value* cell_gate, float cell_clip, + Value** cell_state_new) { + Value* cell_state; + RETURN_IF_ERROR(reader->ReadValue( + tflite::ops::builtin::lstm::full::kCellStateTensor, &cell_state)); + Value* cell_state_contrib = CreateNewSimilarValue(graph, cell_gate); + { + // #1 elementwise multiplication: forget_gate .* cell_state + Node* node = graph->NewNode(); + node->operation.type = ToString(OperationType::MUL); + RETURN_IF_ERROR(graph->AddConsumer(node->id, forget_gate->id)); + RETURN_IF_ERROR(graph->AddConsumer(node->id, cell_state->id)); + RETURN_IF_ERROR(graph->SetProducer(node->id, cell_state_contrib->id)); + } + Value* cell_gate_contrib = CreateNewSimilarValue(graph, cell_gate); + { + // #2 elementwise multiplication: input_gate .* cell_gate + // Note, with CIFG input_gate is equal to 1-forget_gate. + Node* node = graph->NewNode(); + node->operation.type = ToString(OperationType::MUL); + RETURN_IF_ERROR(graph->AddConsumer(node->id, input_gate->id)); + RETURN_IF_ERROR(graph->AddConsumer(node->id, cell_gate->id)); + RETURN_IF_ERROR(graph->SetProducer(node->id, cell_gate_contrib->id)); + } + Value* new_cell_state = CreateNewSimilarValue(graph, cell_gate); + { + // #3 elementwise add: #1 + #2 + Node* node = graph->NewNode(); + node->operation.type = ToString(OperationType::ADD); + RETURN_IF_ERROR(graph->AddConsumer(node->id, cell_state_contrib->id)); + RETURN_IF_ERROR(graph->AddConsumer(node->id, cell_gate_contrib->id)); + RETURN_IF_ERROR(graph->SetProducer(node->id, new_cell_state->id)); + } + + if (cell_clip <= 0.0f) { + *cell_state_new = new_cell_state; + return absl::OkStatus(); + } + + Value* max_clipped_state = CreateNewSimilarValue(graph, new_cell_state); + { + // #4 elementwise minimum: min(#3, clip) + Node* node = graph->NewNode(); + node->operation.type = ToString(OperationType::MINIMUM); + ElementwiseAttributes attr; + attr.param = cell_clip; + node->operation.attributes = std::move(attr); + RETURN_IF_ERROR(graph->AddConsumer(node->id, new_cell_state->id)); + RETURN_IF_ERROR(graph->SetProducer(node->id, max_clipped_state->id)); + } + Value* clipped_cell_state = CreateNewSimilarValue(graph, max_clipped_state); + { + // #5 elementwise maximum: max(#4, -clip) + Node* node = graph->NewNode(); + node->operation.type = ToString(OperationType::MAXIMUM); + ElementwiseAttributes attr; + attr.param = -cell_clip; + node->operation.attributes = std::move(attr); + RETURN_IF_ERROR(graph->AddConsumer(node->id, max_clipped_state->id)); + RETURN_IF_ERROR(graph->SetProducer(node->id, clipped_cell_state->id)); + } + *cell_state_new = clipped_cell_state; + return absl::OkStatus(); +} + +// Build subgraph for LSTM output state update. +// Returns value representing the updated output state. +// High-level parameters: +// - Has projection (if true, provide projection_weights). +// - Has projection bias (only with projection). +// - clip: clamp the projection output to [-clip, clip]. +// - Which activation function to use. +// Note the updated output state does not depend on the old output state +// directly, only through the output gate. +// +// Implements the following: +// (*: matrix multiply, .*: elementwise multiply, +: elementwise add): +// +// temp = output_gate .* activate(cell_state); +// if (projection): +// output_state_new = clip(projection_weights * temp + projection_bias); +// else: +// output_state_new = temp; +// +absl::Status BuildOutputStateUpdate(GraphFloat32* graph, ObjectReader* reader, + Value* output_state, Value* output_gate, + Value* cell_state, + TfLiteFusedActivation activation, + bool has_projection, float proj_clip, + Value** output_state_new) { + Value* activated_state = CreateNewSimilarValue(graph, cell_state); + { + // #1 activation: activate(cell_state) + Node* node = graph->NewNode(); + switch (activation) { + case kTfLiteActTanh: + node->operation.type = ToString(OperationType::TANH); + break; + case kTfLiteActSigmoid: + node->operation.type = ToString(OperationType::SIGMOID); + break; + default: + return absl::InvalidArgumentError( + absl::StrCat("Unsupported activation: ", activation)); + } + RETURN_IF_ERROR(graph->AddConsumer(node->id, cell_state->id)); + RETURN_IF_ERROR(graph->SetProducer(node->id, activated_state->id)); + } + + Value* new_output_state = CreateNewSimilarValue(graph, cell_state); + { + // #2 elementwise multiplication: output_gate .* #1 + Node* node = graph->NewNode(); + node->operation.type = ToString(OperationType::MUL); + RETURN_IF_ERROR(graph->AddConsumer(node->id, activated_state->id)); + RETURN_IF_ERROR(graph->AddConsumer(node->id, output_gate->id)); + RETURN_IF_ERROR(graph->SetProducer(node->id, new_output_state->id)); + } + + if (!has_projection) { + *output_state_new = new_output_state; + return absl::OkStatus(); + } + + Value* projected_output_state = CreateNewSimilarValue(graph, output_state); + { + // #3 matrix multiplication: projection_weights * #2 + projection_bias + Node* node = graph->NewNode(); + FullyConnectedAttributes fc_attr; + RETURN_IF_ERROR(SetFullyConnectedWeights( + tflite::ops::builtin::lstm::full::kProjectionWeightsTensor, reader, + &fc_attr)); + // Projection bias is optional + reader + ->ReadTensor(tflite::ops::builtin::lstm::full::kProjectionBiasTensor, + &(fc_attr.bias)) + .IgnoreError(); + node->operation.attributes = std::move(fc_attr); + node->operation.type = ToString(OperationType::FULLY_CONNECTED); + RETURN_IF_ERROR(graph->AddConsumer(node->id, new_output_state->id)); + RETURN_IF_ERROR(graph->SetProducer(node->id, projected_output_state->id)); + } + + if (proj_clip <= 0.0f) { + *output_state_new = projected_output_state; + return absl::OkStatus(); + } + + Value* max_clipped_state = + CreateNewSimilarValue(graph, projected_output_state); + { + // #4 elementwise minimum: min(#3, clip) + Node* node = graph->NewNode(); + node->operation.type = ToString(OperationType::MINIMUM); + ElementwiseAttributes attr; + attr.param = proj_clip; + node->operation.attributes = std::move(attr); + RETURN_IF_ERROR(graph->AddConsumer(node->id, projected_output_state->id)); + RETURN_IF_ERROR(graph->SetProducer(node->id, max_clipped_state->id)); + } + Value* clipped_output_state = CreateNewSimilarValue(graph, max_clipped_state); + { + // #5 elementwise maximum: max(#4, -clip) + Node* node = graph->NewNode(); + node->operation.type = ToString(OperationType::MAXIMUM); + ElementwiseAttributes attr; + attr.param = -proj_clip; + node->operation.attributes = std::move(attr); + RETURN_IF_ERROR(graph->AddConsumer(node->id, max_clipped_state->id)); + RETURN_IF_ERROR(graph->SetProducer(node->id, clipped_output_state->id)); + } + *output_state_new = clipped_output_state; + return absl::OkStatus(); +} + +} // namespace + +// Build subgraph for a single LSTM OP. +// Returns a mapping for the used variable tensors' updated Values. +// +// High-level parameters: +// - Has CIFG: +// If false, calculate input_gate regularly. +// If true, calculate input_gate to 1-forget_gate. +// - Has peephole: see BuildLstmGate. Applies to all gates. +// - Has normalization: see BuildLstmGate. Applies to all gates. +// - Has projection, projection_bias, proj_clip: see BuildOutputStateUpdate +// - Which activation to use: +// Applies to only cell gate and output state update. +// Other gates always use Sigmoid. +// +absl::Status ParseLSTMAttributes( + const TfLiteNode* tflite_node, const TfLiteRegistration* registration, + GraphFloat32* graph, ObjectReader* reader, const TfLiteLSTMParams* params, + absl::flat_hash_map* new_variable_input_values) { + const bool has_cifg = HasCifg(tflite_node); + const bool has_peephole = HasPeephole(tflite_node); + const bool has_normalization = HasNormalization(tflite_node); + const bool has_projection = HasProjection(tflite_node); + + Value* old_cell_state; + RETURN_IF_ERROR(reader->ReadValue( + tflite::ops::builtin::lstm::full::kCellStateTensor, &old_cell_state)); + + if (old_cell_state->tensor.shape.b != 1) { + return absl::InvalidArgumentError( + "Batched execution is not supported for LSTM"); + } + + Value* old_output_state; + RETURN_IF_ERROR(reader->ReadValue( + tflite::ops::builtin::lstm::full::kOutputStateTensor, &old_output_state)); + + Value* forget_gate; + RETURN_IF_ERROR(BuildLstmGate( + graph, reader, old_output_state, old_cell_state, + tflite::ops::builtin::lstm::full::kInputToForgetWeightsTensor, + tflite::ops::builtin::lstm::full::kRecurrentToForgetWeightsTensor, + tflite::ops::builtin::lstm::full::kCellToForgetWeightsTensor, + tflite::ops::builtin::lstm::full::kForgetGateBiasTensor, + tflite::ops::builtin::lstm::full::kForgetLayerNormCoefficientsTensor, + kTfLiteActSigmoid, has_peephole, has_normalization, &forget_gate)); + + Value* input_gate; + if (has_cifg) { + // When using cifg, input_gate is computed as (1 - forget_gate). + Node* node = graph->NewNode(); + input_gate = CreateNewSimilarValue(graph, forget_gate); + + node->operation.type = ToString(OperationType::SUB); + ElementwiseAttributes attr; + attr.param = 1.0f; + attr.runtime_tensor_is_second = true; + node->operation.attributes = std::move(attr); + RETURN_IF_ERROR(graph->AddConsumer(node->id, forget_gate->id)); + RETURN_IF_ERROR(graph->SetProducer(node->id, input_gate->id)); + } else { + RETURN_IF_ERROR(BuildLstmGate( + graph, reader, old_output_state, old_cell_state, + tflite::ops::builtin::lstm::full::kInputToInputWeightsTensor, + tflite::ops::builtin::lstm::full::kRecurrentToInputWeightsTensor, + tflite::ops::builtin::lstm::full::kCellToInputWeightsTensor, + tflite::ops::builtin::lstm::full::kInputGateBiasTensor, + tflite::ops::builtin::lstm::full::kInputLayerNormCoefficientsTensor, + kTfLiteActSigmoid, has_peephole, has_normalization, &input_gate)); + } + + // Cell state will not have peephole connections to itself + Value* cell_gate; + RETURN_IF_ERROR(BuildLstmGate( + graph, reader, old_output_state, old_cell_state, + tflite::ops::builtin::lstm::full::kInputToCellWeightsTensor, + tflite::ops::builtin::lstm::full::kRecurrentToCellWeightsTensor, + /*cell_weight_id=*/-1, + tflite::ops::builtin::lstm::full::kCellGateBiasTensor, + tflite::ops::builtin::lstm::full::kCellLayerNormCoefficientsTensor, + params->activation, /*has_peephole=*/false, has_normalization, + &cell_gate)); + + Value* new_cell_state; + RETURN_IF_ERROR(BuildCellStateUpdate(graph, reader, forget_gate, input_gate, + cell_gate, params->cell_clip, + &new_cell_state)); + + Value* output_gate; + RETURN_IF_ERROR(BuildLstmGate( + graph, reader, old_output_state, new_cell_state, + tflite::ops::builtin::lstm::full::kInputToOutputWeightsTensor, + tflite::ops::builtin::lstm::full::kRecurrentToOutputWeightsTensor, + tflite::ops::builtin::lstm::full::kCellToOutputWeightsTensor, + tflite::ops::builtin::lstm::full::kOutputGateBiasTensor, + tflite::ops::builtin::lstm::full::kOutputLayerNormCoefficientsTensor, + kTfLiteActSigmoid, has_peephole, has_normalization, &output_gate)); + + Value* new_output_state; + RETURN_IF_ERROR(BuildOutputStateUpdate(graph, reader, old_output_state, + output_gate, new_cell_state, + params->activation, has_projection, + params->proj_clip, &new_output_state)); + + { + // Copy updated output state to output. + Node* node = graph->NewNode(); + node->operation.type = ToString(OperationType::COPY); + RETURN_IF_ERROR(graph->AddConsumer(node->id, new_output_state->id)); + RETURN_IF_ERROR(reader->AddOutput( + node, tflite::ops::builtin::lstm::full::kOutputTensor)); + } + + new_variable_input_values->clear(); + new_variable_input_values->emplace( + tflite::ops::builtin::lstm::full::kCellStateTensor, new_cell_state->id); + new_variable_input_values->emplace( + tflite::ops::builtin::lstm::full::kOutputStateTensor, + new_output_state->id); + return absl::OkStatus(); +} + +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/common/lstm_parser.h b/tensorflow/lite/delegates/gpu/common/lstm_parser.h new file mode 100644 index 00000000000..b7c32371abc --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/lstm_parser.h @@ -0,0 +1,34 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_LSTM_PARSER_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_LSTM_PARSER_H_ + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/object_reader.h" + +namespace tflite { +namespace gpu { + +absl::Status ParseLSTMAttributes( + const TfLiteNode* tflite_node, const TfLiteRegistration* registration, + GraphFloat32* graph, ObjectReader* reader, const TfLiteLSTMParams* params, + absl::flat_hash_map* new_variable_input_values); +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_LSTM_PARSER_H_ diff --git a/tensorflow/lite/delegates/gpu/common/memory_management.cc b/tensorflow/lite/delegates/gpu/common/memory_management.cc index d7e6a060eb2..2a637d54016 100644 --- a/tensorflow/lite/delegates/gpu/common/memory_management.cc +++ b/tensorflow/lite/delegates/gpu/common/memory_management.cc @@ -15,17 +15,21 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/memory_management.h" -#include -#include +#include #include -#include -#include -#include +#include #include +#include "tensorflow/lite/delegates/gpu/common/memory_management/equality_assignment.h" #include "tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_breadth_assignment.h" #include "tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_size_assignment.h" +#include "tensorflow/lite/delegates/gpu/common/memory_management/greedy_in_order_assignment.h" +#include "tensorflow/lite/delegates/gpu/common/memory_management/min_cost_flow_assignment.h" +#include "tensorflow/lite/delegates/gpu/common/memory_management/naive_assignment.h" +#include "tensorflow/lite/delegates/gpu/common/memory_management/types.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" namespace tflite { namespace gpu { diff --git a/tensorflow/lite/delegates/gpu/common/memory_management.h b/tensorflow/lite/delegates/gpu/common/memory_management.h index 7df4947ee3d..9f1adcebd7f 100644 --- a/tensorflow/lite/delegates/gpu/common/memory_management.h +++ b/tensorflow/lite/delegates/gpu/common/memory_management.h @@ -16,16 +16,12 @@ limitations under the License. #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MEMORY_MANAGEMENT_H_ #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MEMORY_MANAGEMENT_H_ -#include -#include +#include + #include #include "absl/memory/memory.h" #include "tensorflow/lite/delegates/gpu/common/memory_management/equality_assignment.h" -#include "tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_breadth_assignment.h" -#include "tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_size_assignment.h" -#include "tensorflow/lite/delegates/gpu/common/memory_management/greedy_in_order_assignment.h" -#include "tensorflow/lite/delegates/gpu/common/memory_management/min_cost_flow_assignment.h" #include "tensorflow/lite/delegates/gpu/common/memory_management/naive_assignment.h" #include "tensorflow/lite/delegates/gpu/common/memory_management/types.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" diff --git a/tensorflow/lite/delegates/gpu/common/memory_management/equality_assignment.h b/tensorflow/lite/delegates/gpu/common/memory_management/equality_assignment.h index fdccce5159f..018e5a95b51 100644 --- a/tensorflow/lite/delegates/gpu/common/memory_management/equality_assignment.h +++ b/tensorflow/lite/delegates/gpu/common/memory_management/equality_assignment.h @@ -16,6 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MEMORY_MANAGEMENT_EQUALITY_ASSIGNMENT_H_ #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MEMORY_MANAGEMENT_EQUALITY_ASSIGNMENT_H_ +#include + +#include #include #include diff --git a/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_breadth_assignment.cc b/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_breadth_assignment.cc index 2c138b4c14c..b07ab61a1a5 100644 --- a/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_breadth_assignment.cc +++ b/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_breadth_assignment.cc @@ -16,12 +16,13 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_breadth_assignment.h" #include -#include -#include +#include #include #include +#include "absl/status/status.h" #include "tensorflow/lite/delegates/gpu/common/memory_management/internal.h" +#include "tensorflow/lite/delegates/gpu/common/memory_management/types.h" namespace tflite { namespace gpu { diff --git a/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_breadth_assignment.h b/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_breadth_assignment.h index 47035229920..e207ab323b5 100644 --- a/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_breadth_assignment.h +++ b/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_breadth_assignment.h @@ -16,7 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MEMORY_MANAGEMENT_GREEDY_BY_BREADTH_ASSIGNMENT_H_ #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MEMORY_MANAGEMENT_GREEDY_BY_BREADTH_ASSIGNMENT_H_ -#include +#include + #include #include "tensorflow/lite/delegates/gpu/common/memory_management/types.h" diff --git a/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_size_assignment.cc b/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_size_assignment.cc index 76309ce8f1b..130f27152cd 100644 --- a/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_size_assignment.cc +++ b/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_size_assignment.cc @@ -16,8 +16,13 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_size_assignment.h" #include +#include +#include +#include +#include "absl/status/status.h" #include "tensorflow/lite/delegates/gpu/common/memory_management/internal.h" +#include "tensorflow/lite/delegates/gpu/common/memory_management/types.h" namespace tflite { namespace gpu { diff --git a/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_size_assignment.h b/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_size_assignment.h index b0ad9d18911..198a25c7a57 100644 --- a/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_size_assignment.h +++ b/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_size_assignment.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MEMORY_MANAGEMENT_GREEDY_BY_SIZE_ASSIGNMENT_H_ #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MEMORY_MANAGEMENT_GREEDY_BY_SIZE_ASSIGNMENT_H_ +#include + #include #include "tensorflow/lite/delegates/gpu/common/memory_management/types.h" diff --git a/tensorflow/lite/delegates/gpu/common/memory_management/greedy_in_order_assignment.h b/tensorflow/lite/delegates/gpu/common/memory_management/greedy_in_order_assignment.h index 8c3719e4a8b..048ed389700 100644 --- a/tensorflow/lite/delegates/gpu/common/memory_management/greedy_in_order_assignment.h +++ b/tensorflow/lite/delegates/gpu/common/memory_management/greedy_in_order_assignment.h @@ -16,7 +16,11 @@ limitations under the License. #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MEMORY_MANAGEMENT_GREEDY_IN_ORDER_ASSIGNMENT_H_ #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MEMORY_MANAGEMENT_GREEDY_IN_ORDER_ASSIGNMENT_H_ +#include + #include +#include +#include #include #include #include diff --git a/tensorflow/lite/delegates/gpu/common/memory_management/internal.cc b/tensorflow/lite/delegates/gpu/common/memory_management/internal.cc index bbcd373287f..27126aa929f 100644 --- a/tensorflow/lite/delegates/gpu/common/memory_management/internal.cc +++ b/tensorflow/lite/delegates/gpu/common/memory_management/internal.cc @@ -16,6 +16,11 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/memory_management/internal.h" #include +#include +#include + +#include "tensorflow/lite/delegates/gpu/common/memory_management/types.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" namespace tflite { namespace gpu { diff --git a/tensorflow/lite/delegates/gpu/common/memory_management/internal.h b/tensorflow/lite/delegates/gpu/common/memory_management/internal.h index 702fd2992cc..4d48f75da9f 100644 --- a/tensorflow/lite/delegates/gpu/common/memory_management/internal.h +++ b/tensorflow/lite/delegates/gpu/common/memory_management/internal.h @@ -16,9 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MEMORY_MANAGEMENT_INTERNAL_H_ #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MEMORY_MANAGEMENT_INTERNAL_H_ -#include +#include + #include -#include #include #include "absl/memory/memory.h" diff --git a/tensorflow/lite/delegates/gpu/common/memory_management/internal_test.cc b/tensorflow/lite/delegates/gpu/common/memory_management/internal_test.cc index 757cb89b366..ed83e3c5109 100644 --- a/tensorflow/lite/delegates/gpu/common/memory_management/internal_test.cc +++ b/tensorflow/lite/delegates/gpu/common/memory_management/internal_test.cc @@ -15,8 +15,12 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/memory_management/internal.h" +#include +#include + #include #include +#include "tensorflow/lite/delegates/gpu/common/memory_management/types.h" namespace tflite { namespace gpu { diff --git a/tensorflow/lite/delegates/gpu/common/memory_management/min_cost_flow_assignment.cc b/tensorflow/lite/delegates/gpu/common/memory_management/min_cost_flow_assignment.cc index 059c23fab33..c56ac2e391b 100644 --- a/tensorflow/lite/delegates/gpu/common/memory_management/min_cost_flow_assignment.cc +++ b/tensorflow/lite/delegates/gpu/common/memory_management/min_cost_flow_assignment.cc @@ -16,11 +16,15 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/memory_management/min_cost_flow_assignment.h" #include +#include +#include #include -#include +#include #include +#include "absl/status/status.h" #include "tensorflow/lite/delegates/gpu/common/memory_management/internal.h" +#include "tensorflow/lite/delegates/gpu/common/memory_management/types.h" namespace tflite { namespace gpu { diff --git a/tensorflow/lite/delegates/gpu/common/memory_management/min_cost_flow_assignment.h b/tensorflow/lite/delegates/gpu/common/memory_management/min_cost_flow_assignment.h index 1284c12c5c2..df734ad9ea4 100644 --- a/tensorflow/lite/delegates/gpu/common/memory_management/min_cost_flow_assignment.h +++ b/tensorflow/lite/delegates/gpu/common/memory_management/min_cost_flow_assignment.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MEMORY_MANAGEMENT_MIN_COST_FLOW_ASSIGNMENT_H_ #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MEMORY_MANAGEMENT_MIN_COST_FLOW_ASSIGNMENT_H_ +#include + #include #include "tensorflow/lite/delegates/gpu/common/memory_management/types.h" diff --git a/tensorflow/lite/delegates/gpu/common/memory_management/naive_assignment.h b/tensorflow/lite/delegates/gpu/common/memory_management/naive_assignment.h index 8a00c67d853..d700f62006c 100644 --- a/tensorflow/lite/delegates/gpu/common/memory_management/naive_assignment.h +++ b/tensorflow/lite/delegates/gpu/common/memory_management/naive_assignment.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MEMORY_MANAGEMENT_NAIVE_ASSIGNMENT_H_ #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MEMORY_MANAGEMENT_NAIVE_ASSIGNMENT_H_ +#include + #include #include "tensorflow/lite/delegates/gpu/common/memory_management/internal.h" diff --git a/tensorflow/lite/delegates/gpu/common/memory_management/types.cc b/tensorflow/lite/delegates/gpu/common/memory_management/types.cc index 5cec0cab4c4..101ca5316f1 100644 --- a/tensorflow/lite/delegates/gpu/common/memory_management/types.cc +++ b/tensorflow/lite/delegates/gpu/common/memory_management/types.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/memory_management/types.h" #include -#include +#include #include #include diff --git a/tensorflow/lite/delegates/gpu/common/memory_management/types.h b/tensorflow/lite/delegates/gpu/common/memory_management/types.h index a511152ed0b..f3257fcf5f8 100644 --- a/tensorflow/lite/delegates/gpu/common/memory_management/types.h +++ b/tensorflow/lite/delegates/gpu/common/memory_management/types.h @@ -16,8 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MEMORY_MANAGEMENT_TYPES_H_ #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MEMORY_MANAGEMENT_TYPES_H_ -#include -#include +#include + +#include #include namespace tflite { diff --git a/tensorflow/lite/delegates/gpu/common/memory_management/types_test.cc b/tensorflow/lite/delegates/gpu/common/memory_management/types_test.cc index 0312dc27877..22558ec8b94 100644 --- a/tensorflow/lite/delegates/gpu/common/memory_management/types_test.cc +++ b/tensorflow/lite/delegates/gpu/common/memory_management/types_test.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/memory_management/types.h" +#include + #include #include diff --git a/tensorflow/lite/delegates/gpu/common/memory_management_test.cc b/tensorflow/lite/delegates/gpu/common/memory_management_test.cc index 12f5b6ebe6c..ba951354d17 100644 --- a/tensorflow/lite/delegates/gpu/common/memory_management_test.cc +++ b/tensorflow/lite/delegates/gpu/common/memory_management_test.cc @@ -15,8 +15,15 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/memory_management.h" +#include +#include + #include #include +#include "absl/status/status.h" +#include "tensorflow/lite/delegates/gpu/common/memory_management/types.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" namespace tflite { namespace gpu { diff --git a/tensorflow/lite/delegates/gpu/common/model.cc b/tensorflow/lite/delegates/gpu/common/model.cc index a2f9da428ba..696a747a817 100644 --- a/tensorflow/lite/delegates/gpu/common/model.cc +++ b/tensorflow/lite/delegates/gpu/common/model.cc @@ -15,7 +15,21 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/model.h" +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/tensor.h" namespace tflite { namespace gpu { @@ -32,6 +46,11 @@ std::vector GraphFloat32::inputs() const { return FilterValues([](const ValueDef& v) { return v.producer == nullptr; }); } +std::vector GraphFloat32::variable_inputs() const { + return FilterValues( + [](const ValueDef& v) { return v.value->tensor.is_variable_input; }); +} + std::vector GraphFloat32::outputs() const { return FilterValues([](const ValueDef& v) { return v.consumers.empty(); }); } @@ -397,19 +416,19 @@ absl::Status RemoveFollowingNode(GraphFloat32* graph, const Node* to_remove, return graph->DeleteNode(to_remove->id); } -absl::Status RemoveOneInputOneOutputNode(GraphFloat32* graph, - const Node* to_remove) { - auto inputs = graph->FindInputs(to_remove->id); - auto outputs = graph->FindOutputs(to_remove->id); +absl::Status RemoveSimpleNodeKeepInput(GraphFloat32* graph, + const Node* simple_node) { + const auto inputs = graph->FindInputs(simple_node->id); + const auto outputs = graph->FindOutputs(simple_node->id); if (inputs.size() != 1 || outputs.size() != 1) { - return absl::InvalidArgumentError( - "To_remove node must have 1 input and 1 output"); + return absl::FailedPreconditionError( + "simple_node node must have 1 input and 1 output"); } - auto input_id = inputs[0]->id; - auto output_id = outputs[0]->id; - Node* producer = graph->FindProducer(input_id); - auto consumers = graph->FindConsumers(output_id); - RETURN_IF_ERROR(graph->DeleteNode(to_remove->id)); + const auto input_id = inputs[0]->id; + const auto output_id = outputs[0]->id; + const Node* producer = graph->FindProducer(input_id); + const auto consumers = graph->FindConsumers(output_id); + RETURN_IF_ERROR(graph->DeleteNode(simple_node->id)); for (auto& consumer : consumers) { RETURN_IF_ERROR(graph->ReplaceInput(consumer->id, output_id, input_id)); } @@ -420,6 +439,38 @@ absl::Status RemoveOneInputOneOutputNode(GraphFloat32* graph, return absl::OkStatus(); } +absl::Status RemoveSimpleNodeKeepOutput(GraphFloat32* graph, + const Node* simple_node) { + const auto inputs = graph->FindInputs(simple_node->id); + const auto outputs = graph->FindOutputs(simple_node->id); + if (inputs.size() != 1 || outputs.size() != 1) { + return absl::FailedPreconditionError( + "simple_node must have 1 input and 1 output"); + } + const auto input_id = inputs[0]->id; + const auto output_id = outputs[0]->id; + const Node* producer = graph->FindProducer(input_id); + const auto input_consumers = graph->FindConsumers(input_id); + if (input_consumers.size() != 1) { + return absl::FailedPreconditionError( + "simple_node should be the only consumer on the node."); + } + + RETURN_IF_ERROR(graph->DeleteNode(simple_node->id)); + if (producer) { + RETURN_IF_ERROR(graph->RemoveProducer(input_id)); + RETURN_IF_ERROR(graph->SetProducer(producer->id, output_id)); + } + + RETURN_IF_ERROR(graph->DeleteValue(input_id)); + + const auto output_consumers = graph->FindConsumers(output_id); + if (!producer && output_consumers.empty()) { + RETURN_IF_ERROR(graph->DeleteValue(output_id)); + } + return absl::OkStatus(); +} + absl::Status AddOutput(GraphFloat32* graph, const Node* from_node, Value** output) { auto link = graph->NewValue(); @@ -430,14 +481,27 @@ absl::Status AddOutput(GraphFloat32* graph, const Node* from_node, absl::Status ConnectTwoNodes(GraphFloat32* graph, const Node* from_node, const Node* to_node, Value** output) { - Value* link; - RETURN_IF_ERROR(AddOutput(graph, from_node, &link)); - RETURN_IF_ERROR(graph->AddConsumer(to_node->id, link->id)); - *output = link; + const Node* output_producer = + *output ? graph->FindProducer((*output)->id) : nullptr; + // Output is already initialized, but producer is not from_node. + if (*output && output_producer && output_producer->id != from_node->id) { + return absl::InvalidArgumentError("Wrong output is passed."); + } + // Output is already initialized, and producer is from_node. + if (*output) { + RETURN_IF_ERROR(graph->AddConsumer(to_node->id, (*output)->id)); + } else { + // Output is not initialized. + Value* link; + RETURN_IF_ERROR(AddOutput(graph, from_node, &link)); + RETURN_IF_ERROR(graph->AddConsumer(to_node->id, link->id)); + *output = link; + } return absl::OkStatus(); } bool IsBatchMatchesForAllValues(const GraphFloat32& model) { + if (model.values().empty()) return true; const int32_t b = model.values()[0]->tensor.shape.b; for (auto value : model.values()) { if (value->tensor.shape.b != b) { diff --git a/tensorflow/lite/delegates/gpu/common/model.h b/tensorflow/lite/delegates/gpu/common/model.h index f6d160977f9..2e9aac8c53c 100644 --- a/tensorflow/lite/delegates/gpu/common/model.h +++ b/tensorflow/lite/delegates/gpu/common/model.h @@ -24,10 +24,8 @@ limitations under the License. #include #include "absl/memory/memory.h" -#include "absl/strings/str_cat.h" #include "absl/types/any.h" #include "absl/types/optional.h" -#include "tensorflow/lite/delegates/gpu/common/data_type.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/status.h" #include "tensorflow/lite/delegates/gpu/common/tensor.h" @@ -92,6 +90,9 @@ class GraphFloat32 { // @return graph outputs, that are values without consumers. std::vector outputs() const; + // @return values updated in place with a previously defined tensor reference. + std::vector variable_inputs() const; + // @return inputs into the given node. Returns empty vector for deleted node. std::vector FindInputs(NodeId id) const; @@ -235,18 +236,28 @@ absl::Status RemovePrecedingNode(GraphFloat32* graph, const Node* to_remove, absl::Status RemoveFollowingNode(GraphFloat32* graph, const Node* to_remove, const Node* to_keep); -// Removes to_remove node. -// Requires that node has one input and one output; -absl::Status RemoveOneInputOneOutputNode(GraphFloat32* graph, - const Node* to_remove); +// Removes simple_node and its output value from the graph. Node is considered +// simple if it has only one input and one output value. Input value is kept. +absl::Status RemoveSimpleNodeKeepInput(GraphFloat32* graph, + const Node* simple_node); + +// Removes simple_node and its input value from the graph. Node is considered +// simple if it has only one input and one output value. Output value is kept. +// simple_node should be an exclusive consumer of its input value. +absl::Status RemoveSimpleNodeKeepOutput(GraphFloat32* graph, + const Node* simple_node); absl::Status AddOutput(GraphFloat32* graph, const Node* from_node, Value** output); +// Makes a direct connection between from_node and to_node. All input parameters +// except output are expected to be initialized before passing to the function. +// If from_node already has an output value, which is not yet consumed by +// to_node, it may be passed as output parameter. absl::Status ConnectTwoNodes(GraphFloat32* graph, const Node* from_node, const Node* to_node, Value** output); -// @return true if all tensors have same batch value. +// @return true if all tensors have same batch value or if model has no values. bool IsBatchMatchesForAllValues(const GraphFloat32& model); } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/common/model_builder.cc b/tensorflow/lite/delegates/gpu/common/model_builder.cc index efa75a244bf..c200f0926aa 100644 --- a/tensorflow/lite/delegates/gpu/common/model_builder.cc +++ b/tensorflow/lite/delegates/gpu/common/model_builder.cc @@ -27,7 +27,9 @@ limitations under the License. #include #include +#include "absl/base/attributes.h" #include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" @@ -36,6 +38,7 @@ limitations under the License. #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/delegates/gpu/common/custom_parsers.h" #include "tensorflow/lite/delegates/gpu/common/data_type.h" +#include "tensorflow/lite/delegates/gpu/common/lstm_parser.h" #include "tensorflow/lite/delegates/gpu/common/model.h" #include "tensorflow/lite/delegates/gpu/common/model_builder_helper.h" #include "tensorflow/lite/delegates/gpu/common/model_transformer.h" @@ -44,7 +47,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/status.h" #include "tensorflow/lite/delegates/gpu/common/tensor.h" -#include "tensorflow/lite/delegates/gpu/common/transformations/general_transformations.h" +#include "tensorflow/lite/delegates/gpu/common/transformations/model_transformations.h" #include "tensorflow/lite/delegates/utils.h" #include "tensorflow/lite/kernels/internal/reference/dequantize.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" @@ -81,6 +84,13 @@ class TFLiteOperationParser { virtual absl::Status IsSupported(const TfLiteContext* context, const TfLiteNode* tflite_node, const TfLiteRegistration* registration) = 0; + + // Return the value ids in the graph that correspond to the updated values of + // the variable input tensor. + virtual absl::flat_hash_map + GetNewValueIdsForVariableInputNodes() { + return absl::flat_hash_map(); + } }; HW ToHW(int32_t h, int32_t w) { return HW(h > 0 ? h : 1, w > 0 ? w : 1); } @@ -298,6 +308,27 @@ class AddOperationParser : public TFLiteOperationParser { } }; +class BatchedMatMulOperationParser : public TFLiteOperationParser { + public: + absl::Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { + return CheckInputsOutputs(context, tflite_node, + /*runtime_inputs=*/2, /*outputs=*/1); + } + + absl::Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, + GraphFloat32* graph, ObjectReader* reader) final { + Node* node = graph->NewNode(); + node->operation.type = ToString(OperationType::BATCHED_MATMUL); + RETURN_IF_ERROR(reader->AddInput(node, 0)); + RETURN_IF_ERROR(reader->AddInput(node, 1)); + RETURN_IF_ERROR(reader->AddOutputs(node)); + return absl::OkStatus(); + } +}; + class ConcatenationOperationParser : public TFLiteOperationParser { public: absl::Status IsSupported(const TfLiteContext* context, @@ -420,7 +451,7 @@ class Conv2DOperationParser : public TFLiteOperationParser { absl::Status IsSupported(const TfLiteContext* context, const TfLiteNode* tflite_node, const TfLiteRegistration* registration) final { - RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 3)); + RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 5)); const int runtime_inputs = GetNumberOfRuntimeInputsForNode(context, tflite_node); if (runtime_inputs > 2) { @@ -475,55 +506,28 @@ class Conv2DOperationParser : public TFLiteOperationParser { } }; -class Convolution2DTransposeBiasParser : public TFLiteOperationParser { - public: - absl::Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { - RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1)); - const TfLiteTransposeConvParams* tf_options; - RETURN_IF_ERROR(RetrieveCustomInitialData(tflite_node, &tf_options)); - RETURN_IF_ERROR( - CheckStrides(tf_options->stride_height, tf_options->stride_width)); - return absl::OkStatus(); - } - - absl::Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, - GraphFloat32* graph, ObjectReader* reader) final { - auto* node = graph->NewNode(); - node->operation.type = ToString(OperationType::CONVOLUTION_TRANSPOSED); - RETURN_IF_ERROR(reader->AddInput(node, 0)); - RETURN_IF_ERROR(reader->AddOutputs(node)); - - const TfLiteTransposeConvParams* tf_options; - auto status = RetrieveCustomInitialData(tflite_node, &tf_options); - - ConvolutionTransposedAttributes attr; - attr.stride = status.ok() - ? HW(tf_options->stride_height, tf_options->stride_width) - : HW(1, 1); - - RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights)); - reader->ReadTensor(2, &attr.bias).IgnoreError(); // bias is optional - - UpdatePadding(status.ok() ? tf_options->padding : kTfLitePaddingUnknown, - graph->FindInputs(node->id)[0]->tensor.shape, &attr); - - node->operation.attributes = std::move(attr); - return absl::OkStatus(); - } -}; - class DepthwiseConvolutionOperationParser : public TFLiteOperationParser { public: absl::Status IsSupported(const TfLiteContext* context, const TfLiteNode* tflite_node, const TfLiteRegistration* registration) final { - RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 3)); - RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, - /*runtime_inputs=*/1, /*outputs=*/1)); - RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1)); + RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 6)); + const int runtime_inputs = + GetNumberOfRuntimeInputsForNode(context, tflite_node); + if (runtime_inputs > 2) { + return absl::InternalError( + absl::StrCat("Expected 1 or 2 input tensor(s), but node has ", + runtime_inputs, " runtime inputs.")); + } + const int runtime_outputs = NumOutputs(tflite_node); + if (runtime_outputs != 1) { + return absl::InternalError( + absl::StrCat("Expected 1 output tensor(s), but node has ", + runtime_outputs, " runtime outputs.")); + } + if (runtime_inputs == 1) { + RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1)); + } const TfLiteDepthwiseConvParams* tf_options; RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); RETURN_IF_ERROR(CheckStridesAndDilation( @@ -577,7 +581,12 @@ class DepthwiseConvolutionOperationParser : public TFLiteOperationParser { RETURN_IF_ERROR(reader->AddOutputs(node)); DepthwiseConvolution2DAttributes attr; - RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights)); + const int runtime_inputs = reader->GetNumberOfRuntimeInputs(); + if (runtime_inputs == 2) { + RETURN_IF_ERROR(reader->AddInput(node, 1)); + } else { // runtime_inputs == 1; + RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights)); + } reader->ReadTensor(2, &attr.bias).IgnoreError(); // bias is optional const TfLiteDepthwiseConvParams* tf_options; RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); @@ -797,8 +806,10 @@ class ElementwiseOperationParser : public TFLiteOperationParser { case OperationType::ABS: case OperationType::COPY: case OperationType::COS: + case OperationType::ELU: case OperationType::EXP: case OperationType::LOG: + case OperationType::NEG: case OperationType::RSQRT: case OperationType::SIGMOID: case OperationType::SIN: @@ -814,6 +825,8 @@ class ElementwiseOperationParser : public TFLiteOperationParser { bool IsTwoArgumentOperation() const { switch (operation_type_) { case OperationType::DIV: + case OperationType::MAXIMUM: + case OperationType::MINIMUM: case OperationType::POW: case OperationType::SQUARED_DIFF: case OperationType::SUB: @@ -825,8 +838,11 @@ class ElementwiseOperationParser : public TFLiteOperationParser { bool IsTwoArgumentOperationWithConst() const { switch (operation_type_) { - case OperationType::MINIMUM: + case OperationType::DIV: case OperationType::MAXIMUM: + case OperationType::MINIMUM: + case OperationType::POW: + case OperationType::SQUARED_DIFF: case OperationType::SUB: return true; default: @@ -850,6 +866,10 @@ class FullyConnectedOperationParser : public TFLiteOperationParser { return absl::UnimplementedError( "Unsupported FullyConnected weights format."); } + if (GetNumberOfRuntimeInputsForNode(context, tflite_node) > 2) { + return absl::UnimplementedError( + "FullyConnected doesn't support more than 2 runtime inputs."); + } // TODO(eignasheva): check input shape return absl::OkStatus(); } @@ -857,11 +877,31 @@ class FullyConnectedOperationParser : public TFLiteOperationParser { absl::Status Parse(const TfLiteNode* tflite_node, const TfLiteRegistration* registration, GraphFloat32* graph, ObjectReader* reader) final { + const TfLiteFullyConnectedParams* tf_options; + RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); + + if (reader->GetNumberOfRuntimeInputs() == 2) { + // Create Convolution2D, so as it supports runtime weights. + Node* node = graph->NewNode(); + node->operation.type = ToString(OperationType::CONVOLUTION_2D); + RETURN_IF_ERROR(reader->AddInput(node, 0)); + RETURN_IF_ERROR(reader->AddInput(node, 1)); + RETURN_IF_ERROR(reader->AddOutputs(node)); + + Convolution2DAttributes attr; + reader->ReadTensor(2, &attr.bias).IgnoreError(); // bias is optional + + attr.strides = HW(1, 1); + attr.dilations = HW(1, 1); + attr.padding.appended = HW(0, 0); + attr.padding.prepended = HW(0, 0); + RETURN_IF_ERROR(MaybeFuseActivation(tf_options->activation, graph, node)); + node->operation.attributes = std::move(attr); + return absl::OkStatus(); + } Node* node = graph->NewNode(); RETURN_IF_ERROR(reader->AddInput(node, 0)); - const TfLiteFullyConnectedParams* tf_options; - RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); if (tf_options->weights_format != kTfLiteFullyConnectedWeightsFormatDefault) { return absl::UnimplementedError( @@ -870,13 +910,11 @@ class FullyConnectedOperationParser : public TFLiteOperationParser { FullyConnectedAttributes attr; RETURN_IF_ERROR(GetFullyConnectedAttributes(1, 2, reader, &attr)); + const int weights_width = attr.weights.shape.i; - Tensor weights; - RETURN_IF_ERROR(reader->ReadTensor(1, &weights)); auto input = graph->FindInputs(node->id)[0]; int batch_size = input->tensor.shape.b; - if (input->tensor.shape.DimensionsProduct() / batch_size != - weights.shape.w) { + if (input->tensor.shape.DimensionsProduct() / batch_size != weights_width) { return absl::UnimplementedError( "Amount of input data should match weights width"); } @@ -888,7 +926,7 @@ class FullyConnectedOperationParser : public TFLiteOperationParser { Value* reshaped_value = graph->NewValue(); reshaped_value->tensor.type = DataType::FLOAT32; reshaped_value->tensor.shape = - BHWC(input->tensor.shape.b, 1, 1, weights.shape.w); + BHWC(input->tensor.shape.b, 1, 1, weights_width); RETURN_IF_ERROR(graph->SetProducer(reshape->id, reshaped_value->id)); reshape->operation.type = ToString(OperationType::RESHAPE); ReshapeAttributes attr; @@ -943,18 +981,39 @@ class HardSwishOperationParser : public TFLiteOperationParser { // / \ // new_state1 activation0 // +// For full LSTM cells, see this blog post: +// https://colah.github.io/posts/2015-08-Understanding-LSTMs/ +// In addition to Peephole connections and Combined Input Forget Gates (CIFG) +// described in that post, this code also adds the following optional features: +// - Configurable activations (sigmoid or TANH) +// - L2 Normalization of gates: https://arxiv.org/abs/1607.06450 +// - Output projection: +// https://www.isca-speech.org/archive/interspeech_2014/i14_0338.html +// - Configurable clipping of cell state and output state. class LSTMOperationParser : public TFLiteOperationParser { public: absl::Status IsSupported(const TfLiteContext* context, const TfLiteNode* tflite_node, const TfLiteRegistration* registration) final { - RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2)); + RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 3)); const TfLiteLSTMParams* tf_options; RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); switch (tf_options->kernel_type) { - case kTfLiteLSTMFullKernel: - // TODO(b/157166356): Add check for input/output tensor counts. + case kTfLiteLSTMFullKernel: { + const int inputs = NumInputs(tflite_node); + if (inputs != 20 && inputs != 24) { + return absl::InternalError( + absl::StrCat("Expected 20 or 24 input tensors, but node has ", + inputs, " input(s).")); + } + const int runtime_outputs = NumOutputs(tflite_node); + if (runtime_outputs != 1) { + return absl::InternalError( + absl::StrCat("Expected 1 output tensor, but node has ", + runtime_outputs, " output(s).")); + } return CheckFullParameters(tf_options); + } case kTfLiteLSTMBasicKernel: RETURN_IF_ERROR( CheckInputsConstsOutputs(context, tflite_node, /*runtime_inputs=*/3, @@ -976,6 +1035,11 @@ class LSTMOperationParser : public TFLiteOperationParser { } } + absl::flat_hash_map GetNewValueIdsForVariableInputNodes() + final { + return new_variable_input_value_map_; + } + private: absl::Status ParseBasic(const TfLiteNode* tflite_node, const TfLiteRegistration* registration, @@ -1048,14 +1112,24 @@ class LSTMOperationParser : public TFLiteOperationParser { const TfLiteRegistration* registration, GraphFloat32* graph, ObjectReader* reader, const TfLiteLSTMParams* tf_options) { - return absl::UnimplementedError( - "Full LSTM support is not yet implemented."); + // Invoke full LSTM parser + RETURN_IF_ERROR(ParseLSTMAttributes(tflite_node, registration, graph, + reader, tf_options, + &new_variable_input_value_map_)); + return absl::OkStatus(); } absl::Status CheckFullParameters(const TfLiteLSTMParams* tf_options) { - return absl::UnimplementedError( - "Full LSTM support is not yet implemented."); + if (tf_options->activation != kTfLiteActSigmoid && + tf_options->activation != kTfLiteActTanh) { + return absl::UnimplementedError( + "Only sigmoid or tanh activation is supported."); + } + + return absl::OkStatus(); } + + absl::flat_hash_map new_variable_input_value_map_; }; class MulOperationParser : public TFLiteOperationParser { @@ -1067,8 +1141,11 @@ class MulOperationParser : public TFLiteOperationParser { if (tflite_node->inputs->size != 2) { return absl::UnimplementedError("MUL requires two input tensors."); } - auto input0 = tflite::GetInput(context, tflite_node, 0); - auto input1 = tflite::GetInput(context, tflite_node, 1); + const TfLiteTensor* input0 = GetInput(context, tflite_node, 0); + const TfLiteTensor* input1 = GetInput(context, tflite_node, 1); + if (input0 == nullptr || input1 == nullptr) { + return absl::InvalidArgumentError("At least one input tensor is null"); + } if (input0->dims->size == input1->dims->size) { // this code checks that at least one input of Mul not smaller in all // dimensions. Sometimes Mul used for matrix-vector multiplication that we @@ -1099,7 +1176,6 @@ class MulOperationParser : public TFLiteOperationParser { absl::Status Parse(const TfLiteNode* tflite_node, const TfLiteRegistration* registration, GraphFloat32* graph, ObjectReader* reader) final { - // Determine runtime/constant tensors. const TfLiteTensor* input0 = reader->GetInputTensor(0); if (!input0) { return absl::InvalidArgumentError( @@ -1120,10 +1196,22 @@ class MulOperationParser : public TFLiteOperationParser { Node* node = graph->NewNode(); node->operation.type = ToString(OperationType::MUL); + RETURN_IF_ERROR(reader->AddOutputs(node)); - // The "larger" input tensor must be bound to 1st input and the "smaller" - // input tensor ("mask") must be bound to 2nd input. + // Determine runtime/constant tensors. if (runtime_tensor0 && runtime_tensor1) { + if (input0 == input1) { + // replace MUL(A, A) with POW(A, 2.0) + // TODO(b/166831113): Support the same inputs for operations. + node->operation.type = ToString(OperationType::POW); + ElementwiseAttributes attr; + attr.param = 2.0f; + node->operation.attributes = std::move(attr); + return reader->AddInput(node, 0); + } + + // The "larger" input tensor must be bound to 1st input and the "smaller" + // input tensor must be bound to 2nd input. BHWC shape0; RETURN_IF_ERROR(ExtractTensorShape(*input0, &shape0)); BHWC shape1; @@ -1135,57 +1223,78 @@ class MulOperationParser : public TFLiteOperationParser { input_tensor0 = 1; input_tensor1 = 0; } - RETURN_IF_ERROR( - ParseApplyMask(node, input_tensor0, input_tensor1, graph, reader)); + RETURN_IF_ERROR(reader->AddInput(node, input_tensor0)); + RETURN_IF_ERROR(reader->AddInput(node, input_tensor1)); } else { - // The runtime input tensor must be bound to 1st input and the constant - // input tensor must be bound to 2nd input. - int runtime_tensor = 0; - int constant_tensor = 1; - TfLiteIntArray* constant_dims = input1->dims; - if (constant_tensor0 && runtime_tensor1) { - runtime_tensor = 1; - constant_tensor = 0; - constant_dims = input0->dims; - } - RETURN_IF_ERROR(ParseMultiplyScalar(node, runtime_tensor, constant_tensor, - constant_dims, graph, reader)); + ElementwiseAttributes attr; + RETURN_IF_ERROR(ParseInputsWithConstTensor(node, reader, &attr.param)); + node->operation.attributes = std::move(attr); } const TfLiteMulParams* tf_options; RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); return MaybeFuseActivation(tf_options->activation, graph, node); } +}; - private: - absl::Status ParseApplyMask(Node* node, int input_tensor0, int input_tensor1, - GraphFloat32* graph, ObjectReader* reader) { - RETURN_IF_ERROR(reader->AddInput(node, input_tensor0)); - RETURN_IF_ERROR(reader->AddInput(node, input_tensor1)); - return reader->AddOutputs(node); +class PackOperationParser : public TFLiteOperationParser { + public: + absl::Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { + const TfLitePackParams* tf_options; + return RetrieveBuiltinData(tflite_node, &tf_options); } - absl::Status ParseMultiplyScalar(Node* node, int runtime_tensor, - int constant_tensor, - const TfLiteIntArray* constant_dims, - GraphFloat32* graph, ObjectReader* reader) { - RETURN_IF_ERROR(reader->AddInput(node, runtime_tensor)); - ElementwiseAttributes attr; - if (constant_dims->size <= 0 || NumElements(constant_dims) == 1) { - Tensor tensor; - RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor)); - attr.param = tensor.data[0]; - } else if (constant_dims->size == 3) { - Tensor tensor; - RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor)); - attr.param = std::move(tensor); + absl::Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, + GraphFloat32* graph, ObjectReader* reader) final { + if (tflite_node->inputs->size == 1) { + // Pack with single input can be replaced with Reshape + Node* node = graph->NewNode(); + node->operation.type = ToString(OperationType::RESHAPE); + RETURN_IF_ERROR(reader->AddInput(node, 0)); + RETURN_IF_ERROR(reader->AddOutputs(node)); + // New shape comes from output shape. + ReshapeAttributes attr; + attr.new_shape = graph->FindOutputs(node->id)[0]->tensor.shape; + node->operation.attributes = attr; + return absl::OkStatus(); } else { - Tensor tensor; - RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor)); - attr.param = std::move(tensor); + // Pack with few inputs can be replaced with Concat + const TfLitePackParams* tf_options; + RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); + + // Read inputs first to make sure const node is added to a graph before + // concat node to ensure topological order. + std::vector inputs; + for (uint32_t idx = 0; idx < tflite_node->inputs->size; ++idx) { + Value* value; + const auto status = reader->ReadValue(idx, &value); + if (status.ok()) { + inputs.push_back(value); + } else { + TensorFloat32 tensor; + RETURN_IF_ERROR(reader->ReadTensor(idx, &tensor)); + Value* value; + RETURN_IF_ERROR(NewConstNode(std::move(tensor), graph, &value)); + inputs.push_back(value); + } + } + + Node* node = graph->NewNode(); + node->operation.type = ToString(OperationType::CONCAT); + RETURN_IF_ERROR(reader->AddOutputs(node)); + for (const Value* input : inputs) { + RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id)); + } + const TfLiteTensor* output = reader->GetOutputTensor(0); + ConcatAttributes attr; + RETURN_IF_ERROR( + ExtractAxisFromIndex(*output, tf_options->axis, &attr.axis)); + node->operation.attributes = attr; + return absl::OkStatus(); } - node->operation.attributes = std::move(attr); - return reader->AddOutputs(node); } }; @@ -1251,7 +1360,10 @@ class PadOperationParser : public TFLiteOperationParser { RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1, /*outputs=*/1)); RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1)); - auto pad_tensor = tflite::GetInput(context, tflite_node, 1); + const TfLiteTensor* pad_tensor = GetInput(context, tflite_node, 1); + if (pad_tensor == nullptr) { + return absl::InvalidArgumentError("Padding tensor was null"); + } if (pad_tensor->dims->size != 2) { return absl::InvalidArgumentError(absl::StrCat( "Invalid paddings tensor dimension: expected 2 dim, got ", @@ -1381,6 +1493,52 @@ class Pooling2DOperationParser : public TFLiteOperationParser { const PoolingType type_; }; +class ReduceOperationParser : public TFLiteOperationParser { + public: + explicit ReduceOperationParser(OperationType operation_type) + : operation_type_(operation_type) {} + + absl::Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { + RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, + /*runtime_inputs=*/1, /*outputs=*/1)); + auto* axes = &context->tensors[tflite_node->inputs->data[1]]; + if (axes->allocation_type != kTfLiteMmapRo || axes->type != kTfLiteInt32) { + return absl::UnimplementedError( + "Reduce has unsupported tensor for axes."); + } + if (tflite::NumElements(axes) != 1) { + return absl::UnimplementedError( + "Supported reduce in single dimensions only."); + } + return absl::OkStatus(); + } + + absl::Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, + GraphFloat32* graph, ObjectReader* reader) final { + Node* node = graph->NewNode(); + node->operation.type = ToString(operation_type_); + RETURN_IF_ERROR(reader->AddInput(node, 0)); + RETURN_IF_ERROR(reader->AddOutputs(node)); + + const TfLiteReducerParams* tf_options; + RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); + + Tensor axes; + RETURN_IF_ERROR(reader->ReadTensor(1, &axes)); + const TfLiteTensor* input = reader->GetInputTensor(0); + ReduceAttributes attr; + RETURN_IF_ERROR(ExtractAxisFromIndex(*input, axes.data[0], &attr.axis)); + node->operation.attributes = attr; + return absl::OkStatus(); + } + + private: + const OperationType operation_type_; +}; + class QuantizeOperationParser : public TFLiteOperationParser { public: absl::Status IsSupported(const TfLiteContext* context, @@ -1594,6 +1752,15 @@ class SliceOperationParser : public TFLiteOperationParser { const TfLiteNode* tflite_node, const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2)); + if (tflite_node->inputs->size < 3) { + return absl::UnimplementedError("SLICE requires 3 inputs."); + } + const TfLiteTensor* input = GetInput(context, tflite_node, 0); + if (input->dims->size != 3 && input->dims->size != 4) { + return absl::UnimplementedError( + "SLICE supports for 3 or 4 dimensional tensors only."); + } + return absl::OkStatus(); } @@ -1607,6 +1774,9 @@ class SliceOperationParser : public TFLiteOperationParser { RETURN_IF_ERROR(reader->ReadValue(0, &input)); RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id)); + const TfLiteTensor* tfl_input = reader->GetInputTensor(0); + const int input_dims = tfl_input->dims->size; + SliceAttributes attr; attr.strides = BHWC(1, 1, 1, 1); Tensor starts, sizes; @@ -1615,36 +1785,65 @@ class SliceOperationParser : public TFLiteOperationParser { if (starts.data.size() != sizes.data.size()) { return absl::InvalidArgumentError("Starts amount != sizes amount."); } - const auto& in_shape = input->tensor.shape; - if (starts.data.size() == 4) { - sizes.data[0] = - sizes.data[0] != -1 ? sizes.data[0] : in_shape.b - starts.data[0]; - sizes.data[1] = - sizes.data[1] != -1 ? sizes.data[1] : in_shape.h - starts.data[1]; - sizes.data[2] = - sizes.data[2] != -1 ? sizes.data[2] : in_shape.w - starts.data[2]; - sizes.data[3] = - sizes.data[3] != -1 ? sizes.data[3] : in_shape.c - starts.data[3]; - attr.starts = - BHWC(starts.data[0], starts.data[1], starts.data[2], starts.data[3]); - attr.ends = - BHWC(starts.data[0] + sizes.data[0], starts.data[1] + sizes.data[1], - starts.data[2] + sizes.data[2], starts.data[3] + sizes.data[3]); - } else if (starts.data.size() == 3) { - sizes.data[0] = - sizes.data[0] != -1 ? sizes.data[0] : in_shape.h - starts.data[0]; - sizes.data[1] = - sizes.data[1] != -1 ? sizes.data[1] : in_shape.w - starts.data[1]; - sizes.data[2] = - sizes.data[2] != -1 ? sizes.data[2] : in_shape.c - starts.data[2]; - attr.starts = BHWC(0, starts.data[0], starts.data[1], starts.data[2]); - attr.ends = - BHWC(in_shape.b, starts.data[0] + sizes.data[0], - starts.data[1] + sizes.data[1], starts.data[2] + sizes.data[2]); + BHWC bhwc_starts(0, 0, 0, 0); + BHWC bhwc_sizes = input->tensor.shape; + if (input_dims == 4) { + // input in BHWC layout + if (starts.data.size() == 4) { + bhwc_starts.b = starts.data[0]; + bhwc_starts.h = starts.data[1]; + bhwc_starts.w = starts.data[2]; + bhwc_starts.c = starts.data[3]; + bhwc_sizes.b = sizes.data[0]; + bhwc_sizes.h = sizes.data[1]; + bhwc_sizes.w = sizes.data[2]; + bhwc_sizes.c = sizes.data[3]; + } else if (starts.data.size() == 3) { + // if input is 4D(BHWC) and args 3D, we assume that args in HWC layout + bhwc_starts.h = starts.data[0]; + bhwc_starts.w = starts.data[1]; + bhwc_starts.c = starts.data[2]; + bhwc_sizes.h = sizes.data[0]; + bhwc_sizes.w = sizes.data[1]; + bhwc_sizes.c = sizes.data[2]; + } else { + return absl::UnimplementedError( + "Slicing is supported for 3 or 4 dimensional tensors only."); + } + } else if (input_dims == 3) { + // input in BWC layout + if (starts.data.size() == 3) { + bhwc_starts.b = starts.data[0]; + bhwc_starts.w = starts.data[1]; + bhwc_starts.c = starts.data[2]; + bhwc_sizes.b = sizes.data[0]; + bhwc_sizes.w = sizes.data[1]; + bhwc_sizes.c = sizes.data[2]; + } else { + return absl::UnimplementedError( + "Slicing is supported for 3 or 4 dimensional tensors only."); + } } else { return absl::UnimplementedError( "Slicing is supported for 3 or 4 dimensional tensors only."); } + const auto& in_shape = input->tensor.shape; + if (bhwc_sizes.b == -1) { + bhwc_sizes.b = in_shape.b - bhwc_starts.b; + } + if (bhwc_sizes.h == -1) { + bhwc_sizes.h = in_shape.h - bhwc_starts.h; + } + if (bhwc_sizes.w == -1) { + bhwc_sizes.w = in_shape.w - bhwc_starts.w; + } + if (bhwc_sizes.c == -1) { + bhwc_sizes.c = in_shape.c - bhwc_starts.c; + } + attr.starts = bhwc_starts; + attr.ends = + BHWC(bhwc_starts.b + bhwc_sizes.b, bhwc_starts.h + bhwc_sizes.h, + bhwc_starts.w + bhwc_sizes.w, bhwc_starts.c + bhwc_sizes.c); RETURN_IF_ERROR(UpdateIfNegative(in_shape, &attr)); auto out_shape = graph->FindOutputs(node->id)[0]->tensor.shape; @@ -1771,6 +1970,15 @@ class StridedSliceOperationParser : public TFLiteOperationParser { const TfLiteStridedSliceParams* tf_options; RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); RETURN_IF_ERROR(CheckOptionsSupport(tf_options)); + + if (tflite_node->inputs->size < 4) { + return absl::UnimplementedError("STRIDED_SLICE requires 4 inputs."); + } + const TfLiteTensor* input = GetInput(context, tflite_node, 0); + if (input->dims->size != 3 && input->dims->size != 4) { + return absl::UnimplementedError( + "STRIDED_SLICE supports for 3 or 4 dimensional tensors only."); + } return absl::OkStatus(); } @@ -1790,6 +1998,7 @@ class StridedSliceOperationParser : public TFLiteOperationParser { bool read_without_batch = tmp.data.size() == 3; bool read_with_batch = tmp.data.size() == 4; if (!read_without_batch && !read_with_batch) { + // Error: Must be catched in IsSupported() return absl::UnimplementedError( "Slicing is supported for 3 or 4 dimensional tensors only."); } @@ -1941,12 +2150,13 @@ class StridedSliceOperationParser : public TFLiteOperationParser { } }; -class TransposeConvOperationParser : public TFLiteOperationParser { +// Builtin op version of TRANSPOSE_CONV. +class TransposeConvBuiltinOperationParser : public TFLiteOperationParser { public: absl::Status IsSupported(const TfLiteContext* context, const TfLiteNode* tflite_node, const TfLiteRegistration* registration) final { - RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2)); + RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 3)); RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1)); const TfLiteTransposeConvParams* tf_options; RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); @@ -1955,8 +2165,8 @@ class TransposeConvOperationParser : public TFLiteOperationParser { return absl::OkStatus(); } - // TFLite's TRANSPOSE_CONV expects 3 input (output shape, weights, and input) - // and allows configurable padding & stride. + // TFLite's TRANSPOSE_CONV expects 3-4 input tensors (output shape, weights, + // input, and an optional bias) and allows configurable padding & stride. // TODO(impjdi): Translate output_shape to attr.adjacent. absl::Status Parse(const TfLiteNode* tflite_node, const TfLiteRegistration* registration, @@ -1976,8 +2186,7 @@ class TransposeConvOperationParser : public TFLiteOperationParser { ? HW(tf_options->stride_height, tf_options->stride_width) : HW(1, 1); RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights)); - - // TFLite does not support bias. + reader->ReadTensor(3, &attr.bias).IgnoreError(); // bias is optional UpdatePadding(tf_options->padding, graph->FindInputs(node->id)[0]->tensor.shape, &attr); @@ -1986,6 +2195,45 @@ class TransposeConvOperationParser : public TFLiteOperationParser { } }; +// Custom op version of TRANSPOSE_CONV. +class TransposeConvCustomOperationParser : public TFLiteOperationParser { + public: + absl::Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { + RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1)); + const TfLiteTransposeConvParams* tf_options; + RETURN_IF_ERROR(RetrieveCustomInitialData(tflite_node, &tf_options)); + RETURN_IF_ERROR( + CheckStrides(tf_options->stride_height, tf_options->stride_width)); + return absl::OkStatus(); + } + + absl::Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, + GraphFloat32* graph, ObjectReader* reader) final { + auto* node = graph->NewNode(); + node->operation.type = ToString(OperationType::CONVOLUTION_TRANSPOSED); + RETURN_IF_ERROR(reader->AddInput(node, 0)); + RETURN_IF_ERROR(reader->AddOutputs(node)); + + const TfLiteTransposeConvParams* tf_options; + auto status = RetrieveCustomInitialData(tflite_node, &tf_options); + + ConvolutionTransposedAttributes attr; + attr.stride = status.ok() + ? HW(tf_options->stride_height, tf_options->stride_width) + : HW(1, 1); + RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights)); + reader->ReadTensor(2, &attr.bias).IgnoreError(); // bias is optional + + UpdatePadding(status.ok() ? tf_options->padding : kTfLitePaddingUnknown, + graph->FindInputs(node->id)[0]->tensor.shape, &attr); + node->operation.attributes = std::move(attr); + return absl::OkStatus(); + } +}; + class TransposeOperationParser : public TFLiteOperationParser { public: absl::Status IsSupported(const TfLiteContext* context, @@ -2015,29 +2263,18 @@ class TransposeOperationParser : public TFLiteOperationParser { if (perm.data.size() == 4) { attr.perm = BHWC(perm.data[0], perm.data[1], perm.data[2], perm.data[3]); } else if (perm.data.size() == 3) { - std::vector index_to_axis = {Axis::CHANNELS, Axis::WIDTH, - Axis::BATCH}; - std::map remap = { - {Axis::HEIGHT, Axis::HEIGHT}, - {index_to_axis[perm.data[2]], Axis::BATCH}, - {index_to_axis[perm.data[1]], Axis::WIDTH}, - {index_to_axis[perm.data[0]], Axis::CHANNELS}}; - attr.perm.b = axis_to_index[remap[Axis::BATCH]]; - attr.perm.h = axis_to_index[remap[Axis::HEIGHT]]; - attr.perm.w = axis_to_index[remap[Axis::WIDTH]]; - attr.perm.c = axis_to_index[remap[Axis::CHANNELS]]; - + std::vector index_to_axis = {Axis::BATCH, Axis::WIDTH, + Axis::CHANNELS}; + attr.perm.b = axis_to_index[index_to_axis[perm.data[0]]]; + attr.perm.h = 1; + attr.perm.w = axis_to_index[index_to_axis[perm.data[1]]]; + attr.perm.c = axis_to_index[index_to_axis[perm.data[2]]]; } else if (perm.data.size() == 2) { - std::vector index_to_axis = {Axis::CHANNELS, Axis::BATCH}; - std::map remap = { - {Axis::HEIGHT, Axis::HEIGHT}, - {Axis::WIDTH, Axis::WIDTH}, - {index_to_axis[perm.data[1]], Axis::BATCH}, - {index_to_axis[perm.data[0]], Axis::CHANNELS}}; - attr.perm.b = axis_to_index[remap[Axis::BATCH]]; - attr.perm.h = axis_to_index[remap[Axis::HEIGHT]]; - attr.perm.w = axis_to_index[remap[Axis::WIDTH]]; - attr.perm.c = axis_to_index[remap[Axis::CHANNELS]]; + std::vector index_to_axis = {Axis::BATCH, Axis::CHANNELS}; + attr.perm.b = axis_to_index[index_to_axis[perm.data[0]]]; + attr.perm.h = 1; + attr.perm.w = 2; + attr.perm.c = axis_to_index[index_to_axis[perm.data[1]]]; } else { return absl::InvalidArgumentError( "Permutation for transpose is invalid."); @@ -2181,6 +2418,7 @@ class RoIToTransformMatrixOperationParser : public TFLiteOperationParser { absl::Status IsSupported(const TfLiteContext* context, const TfLiteNode* tflite_node, const TfLiteRegistration* registration) final { + RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2)); RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1, /*outputs=*/1)); return absl::OkStatus(); @@ -2205,48 +2443,14 @@ class RoIToTransformMatrixOperationParser : public TFLiteOperationParser { output_value->tensor.shape = output_shape; return absl::OkStatus(); } - - private: }; -class RoIToTransformMatrixV2OperationParser : public TFLiteOperationParser { - public: - absl::Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { - RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, - /*runtime_inputs=*/1, /*outputs=*/1)); - return absl::OkStatus(); - } - - absl::Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, - GraphFloat32* graph, ObjectReader* reader) final { - Node* node = graph->NewNode(); - RETURN_IF_ERROR(reader->AddInput(node, 0)); // bbox - RETURN_IF_ERROR(reader->AddOutputs(node)); - - std::string op_name = "roi_to_transform_matrix_v2"; - node->operation.type = op_name; - BHWC output_shape; - RETURN_IF_ERROR(ParseCustomAttributes( - op_name, registration->version, tflite_node->custom_initial_data, - tflite_node->custom_initial_data_size, &(node->operation.attributes), - &output_shape)); - - auto output_value = graph->FindOutputs(node->id)[0]; - output_value->tensor.shape = output_shape; - return absl::OkStatus(); - } - - private: -}; - -class TransformTensorOperationParser : public TFLiteOperationParser { +class TransformTensorBilinearOperationParser : public TFLiteOperationParser { public: absl::Status IsSupported(const TfLiteContext* context, const TfLiteNode* tflite_node, const TfLiteRegistration* registration) final { + RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2)); RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/2, /*outputs=*/1)); return absl::OkStatus(); @@ -2260,7 +2464,7 @@ class TransformTensorOperationParser : public TFLiteOperationParser { RETURN_IF_ERROR(reader->AddInput(node, 1)); // bbox RETURN_IF_ERROR(reader->AddOutputs(node)); - std::string op_name = "transform_tensor"; + std::string op_name = "transform_tensor_bilinear"; node->operation.type = op_name; BHWC output_shape; RETURN_IF_ERROR(ParseCustomAttributes( @@ -2275,45 +2479,6 @@ class TransformTensorOperationParser : public TFLiteOperationParser { graph->FindInputs(node->id)[0]->tensor.shape.c); return absl::OkStatus(); } - - private: -}; - -class TransformTensorBilinearV2OperationParser : public TFLiteOperationParser { - public: - absl::Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { - RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, - /*runtime_inputs=*/2, /*outputs=*/1)); - return absl::OkStatus(); - } - - absl::Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, - GraphFloat32* graph, ObjectReader* reader) final { - Node* node = graph->NewNode(); - RETURN_IF_ERROR(reader->AddInput(node, 0)); // data - RETURN_IF_ERROR(reader->AddInput(node, 1)); // bbox - RETURN_IF_ERROR(reader->AddOutputs(node)); - - std::string op_name = "transform_tensor_bilinear_v2"; - node->operation.type = op_name; - BHWC output_shape; - RETURN_IF_ERROR(ParseCustomAttributes( - op_name, registration->version, tflite_node->custom_initial_data, - tflite_node->custom_initial_data_size, &(node->operation.attributes), - &output_shape)); - - auto output_value = graph->FindOutputs(node->id)[0]; - - output_value->tensor.shape = - BHWC(1, output_shape.h, output_shape.w, - graph->FindInputs(node->id)[0]->tensor.shape.c); - return absl::OkStatus(); - } - - private: }; class TransformLandmarksOperationParser : public TFLiteOperationParser { @@ -2321,6 +2486,7 @@ class TransformLandmarksOperationParser : public TFLiteOperationParser { absl::Status IsSupported(const TfLiteContext* context, const TfLiteNode* tflite_node, const TfLiteRegistration* registration) final { + RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2)); RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/2, /*outputs=*/1)); return absl::OkStatus(); @@ -2346,42 +2512,6 @@ class TransformLandmarksOperationParser : public TFLiteOperationParser { output_value->tensor.shape = graph->FindInputs(node->id)[0]->tensor.shape; return absl::OkStatus(); } - - private: -}; - -class TransformLandmarksV2OperationParser : public TFLiteOperationParser { - public: - absl::Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { - RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, - /*runtime_inputs=*/2, /*outputs=*/1)); - return absl::OkStatus(); - } - - absl::Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, - GraphFloat32* graph, ObjectReader* reader) final { - Node* node = graph->NewNode(); - RETURN_IF_ERROR(reader->AddInput(node, 0)); // data - RETURN_IF_ERROR(reader->AddInput(node, 1)); // bbox - RETURN_IF_ERROR(reader->AddOutputs(node)); - std::string op_name = "transform_landmarks_v2"; - node->operation.type = op_name; - - auto output_value = graph->FindOutputs(node->id)[0]; - output_value->tensor.shape = graph->FindInputs(node->id)[0]->tensor.shape; - BHWC output_shape = output_value->tensor.shape; - RETURN_IF_ERROR(ParseCustomAttributes( - op_name, registration->version, tflite_node->custom_initial_data, - tflite_node->custom_initial_data_size, &(node->operation.attributes), - &output_shape)); - - return absl::OkStatus(); - } - - private: }; class Landmarks2TransformMatrixOperationParser : public TFLiteOperationParser { @@ -2389,6 +2519,7 @@ class Landmarks2TransformMatrixOperationParser : public TFLiteOperationParser { absl::Status IsSupported(const TfLiteContext* context, const TfLiteNode* tflite_node, const TfLiteRegistration* registration) final { + RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2)); return CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1, /*outputs=*/1); } @@ -2414,37 +2545,6 @@ class Landmarks2TransformMatrixOperationParser : public TFLiteOperationParser { } }; -class Landmarks2TransformMatrixV2OperationParser - : public TFLiteOperationParser { - public: - absl::Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { - return CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1, - /*outputs=*/1); - } - - absl::Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, - GraphFloat32* graph, ObjectReader* reader) final { - Node* node = graph->NewNode(); - RETURN_IF_ERROR(reader->AddInput(node, 0)); // landmarks - RETURN_IF_ERROR(reader->AddOutputs(node)); // transform matrix - - const std::string op_name = "landmarks_to_transform_matrix_v2"; - node->operation.type = op_name; - BHWC output_shape; - RETURN_IF_ERROR(ParseCustomAttributes( - op_name, registration->version, tflite_node->custom_initial_data, - tflite_node->custom_initial_data_size, &(node->operation.attributes), - &output_shape)); - - auto output_value = graph->FindOutputs(node->id)[0]; - output_value->tensor.shape = output_shape; - return absl::OkStatus(); - } -}; - class AlignmentPointsToTransformMatrixOperationParser : public TFLiteOperationParser { public: @@ -2563,6 +2663,8 @@ std::unique_ptr NewOperationParser( return std::make_unique(); case kTfLiteBuiltinAveragePool2d: return std::make_unique(PoolingType::AVERAGE); + case kTfLiteBuiltinBatchMatmul: + return std::make_unique(); case kTfLiteBuiltinConcatenation: return std::make_unique(); case kTfLiteBuiltinConv2d: @@ -2607,10 +2709,23 @@ std::unique_ptr NewOperationParser( return std::make_unique(/*mirror_pad=*/true); case kTfLiteBuiltinMul: return std::make_unique(); + case kTfLiteBuiltinNeg: + return std::make_unique(OperationType::NEG); + case kTfLiteBuiltinPack: + return std::make_unique(); case kTfLiteBuiltinPad: return std::make_unique(/*mirror_pad=*/false); case kTfLiteBuiltinPow: return std::make_unique(OperationType::POW); + case kTfLiteBuiltinReduceMax: + return std::make_unique( + OperationType::REDUCE_MAXIMUM); + case kTfLiteBuiltinReduceMin: + return std::make_unique( + OperationType::REDUCE_MINIMUM); + case kTfLiteBuiltinReduceProd: + return std::make_unique( + OperationType::REDUCE_PRODUCT); case kTfLiteBuiltinQuantize: if (allow_quant_ops) { return std::make_unique(); @@ -2652,17 +2767,19 @@ std::unique_ptr NewOperationParser( return std::make_unique(); case kTfLiteBuiltinSub: return std::make_unique(OperationType::SUB); + case kTfLiteBuiltinSum: + return std::make_unique(OperationType::REDUCE_SUM); case kTfLiteBuiltinTanh: return std::make_unique(OperationType::TANH); case kTfLiteBuiltinTranspose: return std::make_unique(); case kTfLiteBuiltinTransposeConv: - return std::make_unique(); + return std::make_unique(); case kTfLiteBuiltinCustom: const absl::string_view custom_name = registration->custom_name; if (custom_name == "Convolution2DTransposeBias") { - return std::make_unique(); + return std::make_unique(); } if (custom_name == "MaxPoolingWithArgmax2D") { return std::make_unique(PoolingType::MAX); @@ -2673,27 +2790,17 @@ std::unique_ptr NewOperationParser( if (custom_name == "RoIToTransformMatrix") { return std::make_unique(); } - if (custom_name == "RoIToTransformMatrixV2") { - return std::make_unique(); - } - if (custom_name == "TransformTensor") { - return std::make_unique(); - } - if (custom_name == "TransformTensorBilinearV2") { - return std::make_unique(); + if (custom_name == "TransformTensor" /*for version 1*/ || + custom_name == "TransformTensorBilinear" /*for version 2*/) { + return std::make_unique(); } if (custom_name == "TransformLandmarks") { return std::make_unique(); } - if (custom_name == "TransformLandmarksV2") { - return std::make_unique(); - } - if (custom_name == "Landmarks2TransformMatrix") { + if (custom_name == "Landmarks2TransformMatrix" || + custom_name == "Landmarks2TransformMatrixV2") { return std::make_unique(); } - if (custom_name == "Landmarks2TransformMatrixV2") { - return std::make_unique(); - } if (custom_name == "AlignmentPointsToTransformMatrix") { return std::make_unique< AlignmentPointsToTransformMatrixOperationParser>(); @@ -2813,6 +2920,44 @@ absl::Status PrecreateIOTensors( return absl::OkStatus(); } +absl::Status CopyVariableTensorOutputs( + TfLiteNode* tflite_node, TfLiteRegistration* registration, + GraphFloat32* graph, ObjectReader& reader, + const absl::flat_hash_map& new_variable_tensor_values) { + absl::flat_hash_map new_variable_tensor_values_copy( + new_variable_tensor_values); + // Retrieve the final value id for the variable input tensors. + for (int i = 0; i < tflite_node->inputs->size; i++) { + int tensor_idx = tflite_node->inputs->data[i]; + Value* value; + if (!reader.ReadValueByTensorIdx(tensor_idx, &value).ok()) continue; + if (value->tensor.is_variable_input) { + if (new_variable_tensor_values_copy.find(i) == + new_variable_tensor_values_copy.end()) { + return absl::InvalidArgumentError( + absl::StrCat(GetOpNameByRegistration(*registration), + " did not provide a new value for the variable input " + "tensor with index ", + tensor_idx)); + } else { + Node* node = graph->NewNode(); + node->operation.type = ToString(OperationType::COPY); + RETURN_IF_ERROR(graph->AddConsumer( + node->id, new_variable_tensor_values_copy.at(i))); + RETURN_IF_ERROR(reader.AddUpdate(node, i)); + new_variable_tensor_values_copy.erase( + new_variable_tensor_values_copy.find(i)); + } + } + } + if (!new_variable_tensor_values_copy.empty()) { + return absl::InvalidArgumentError( + "More input variable tensors asked to be copied than present on the " + "node"); + } + return absl::OkStatus(); +} + absl::Status BuildModel(TfLiteContext* context, const TfLiteDelegateParams* delegate_params, GraphFloat32* graph, @@ -2843,6 +2988,7 @@ absl::Status BuildModel(TfLiteContext* context, tflite_nodes.push_back(i); } absl::flat_hash_map tensor_to_value; + std::vector variable_inputs_to_value_id; RETURN_IF_ERROR(PrecreateIOTensors(context, graph, delegate_params->input_tensors, quant_conversion_map, &tensor_to_value)); @@ -2863,6 +3009,23 @@ absl::Status BuildModel(TfLiteContext* context, return absl::InternalError(absl::StrCat( GetOpNameByRegistration(*registration), ": ", status.message())); } + + absl::flat_hash_map new_value_for_variable_input_tensors = + operations[i]->GetNewValueIdsForVariableInputNodes(); + + RETURN_IF_ERROR( + CopyVariableTensorOutputs(tflite_node, registration, graph, reader, + new_value_for_variable_input_tensors)); + } + + // Variable input tensors expect to be unchanged throughout model execution. + // They need to be an output of the graph in order to have them unchanged. + for (auto value_id : variable_inputs_to_value_id) { + if (!graph->IsGraphOutput(value_id)) { + return absl::InvalidArgumentError( + absl::StrCat("Variable input tensors must be a graph output. Value ", + value_id, " is not a graph output")); + } } return absl::OkStatus(); } @@ -2876,8 +3039,8 @@ absl::Status BuildFinalModel( // Apply general transformations on the graph. NullTransformationReporter reporter; ModelTransformer transformer(graph, &reporter); - if (!ApplyGeneralTransformations(&transformer)) { - return absl::InternalError("Graph general transformations failed"); + if (!ApplyModelTransformations(&transformer)) { + return absl::InternalError("Graph transformations failed"); } return absl::OkStatus(); } diff --git a/tensorflow/lite/delegates/gpu/common/model_builder.h b/tensorflow/lite/delegates/gpu/common/model_builder.h index 9d80e9636f0..ab18f056d58 100644 --- a/tensorflow/lite/delegates/gpu/common/model_builder.h +++ b/tensorflow/lite/delegates/gpu/common/model_builder.h @@ -16,13 +16,12 @@ limitations under the License. #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MODEL_BUILDER_H_ #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MODEL_BUILDER_H_ -#include -#include - #include "absl/container/flat_hash_map.h" -#include "tensorflow/lite/context.h" +#include "tensorflow/lite/c/common.h" #include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/tensor.h" namespace tflite { namespace gpu { diff --git a/tensorflow/lite/delegates/gpu/common/model_builder_helper.cc b/tensorflow/lite/delegates/gpu/common/model_builder_helper.cc index b030fb7e700..4f67495152c 100644 --- a/tensorflow/lite/delegates/gpu/common/model_builder_helper.cc +++ b/tensorflow/lite/delegates/gpu/common/model_builder_helper.cc @@ -15,19 +15,27 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/model_builder_helper.h" +#include +#include +#include + +#include +#include #include +#include #include #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" -#include "tensorflow/lite/builtin_ops.h" +#include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/context.h" #include "tensorflow/lite/context_util.h" +#include "tensorflow/lite/delegates/gpu/common/data_type.h" #include "tensorflow/lite/delegates/gpu/common/model.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/status.h" -#include "tensorflow/lite/delegates/utils.h" +#include "tensorflow/lite/delegates/gpu/common/tensor.h" #include "tensorflow/lite/kernels/kernel_util.h" namespace tflite { @@ -89,15 +97,19 @@ absl::Status ExtractTensorShape(const TfLiteTensor& tflite_tensor, BHWC* bhwc) { const TfLiteIntArray* dims = tflite_tensor.dims; switch (dims->size) { case 1: + // B layout *bhwc = BHWC(dims->data[0], 1, 1, 1); return absl::OkStatus(); case 2: + // BC layout *bhwc = BHWC(dims->data[0], 1, 1, dims->data[1]); return absl::OkStatus(); case 3: + // BWC layout *bhwc = BHWC(dims->data[0], 1, dims->data[1], dims->data[2]); return absl::OkStatus(); case 4: + // BHWC layout *bhwc = BHWC(dims->data[0], dims->data[1], dims->data[2], dims->data[3]); return absl::OkStatus(); default: @@ -107,6 +119,40 @@ absl::Status ExtractTensorShape(const TfLiteTensor& tflite_tensor, BHWC* bhwc) { } } +absl::Status ExtractAxisFromIndex(const TfLiteTensor& tflite_tensor, int index, + Axis* axis) { + const TfLiteIntArray* dims = tflite_tensor.dims; + if (index == -1) { + index = dims->size - 1; + } + if (index < 0 || index >= dims->size) { + return absl::OutOfRangeError("Index for axis out of range"); + } + std::vector index_to_axis; + switch (dims->size) { + case 1: + // B layout + index_to_axis = {Axis::BATCH}; + break; + case 2: + // BC layout + index_to_axis = {Axis::BATCH, Axis::CHANNELS}; + break; + case 3: + // BWC layout + index_to_axis = {Axis::BATCH, Axis::WIDTH, Axis::CHANNELS}; + break; + case 4: + // BHWC layout + index_to_axis = {Axis::BATCH, Axis::HEIGHT, Axis::WIDTH, Axis::CHANNELS}; + break; + default: + return absl::UnavailableError("Unknown layout."); + } + *axis = index_to_axis[index]; + return absl::OkStatus(); +} + absl::Status ConvertTfLiteTensorToTensorRef(const TfLiteTensor& tflite_tensor, TensorRef* tensor_ref) { tensor_ref->type = ToDataType(tflite_tensor.type); diff --git a/tensorflow/lite/delegates/gpu/common/model_builder_helper.h b/tensorflow/lite/delegates/gpu/common/model_builder_helper.h index 849ef049683..93889314e81 100644 --- a/tensorflow/lite/delegates/gpu/common/model_builder_helper.h +++ b/tensorflow/lite/delegates/gpu/common/model_builder_helper.h @@ -16,6 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MODEL_BUILDER_HELPER_H_ #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MODEL_BUILDER_HELPER_H_ +#include +#include +#include + #include #include "absl/strings/str_cat.h" @@ -29,7 +33,6 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/reference/dequantize.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/internal/types.h" -#include "tensorflow/lite/kernels/kernel_util.h" namespace tflite { namespace gpu { @@ -42,6 +45,9 @@ DataType ToDataType(TfLiteType type); absl::Status ExtractTensorShape(const TfLiteTensor& tflite_tensor, BHWC* bhwc); +absl::Status ExtractAxisFromIndex(const TfLiteTensor& tflite_tensor, int index, + Axis* axis); + absl::Status ConvertTfLiteTensorToTensorRef(const TfLiteTensor& tflite_tensor, TensorRef* tensor_ref); diff --git a/tensorflow/lite/delegates/gpu/common/model_builder_test.cc b/tensorflow/lite/delegates/gpu/common/model_builder_test.cc index c5ee71b3f3f..9bc848b9210 100644 --- a/tensorflow/lite/delegates/gpu/common/model_builder_test.cc +++ b/tensorflow/lite/delegates/gpu/common/model_builder_test.cc @@ -15,15 +15,21 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/model_builder.h" -#include +#include +#include + +#include +#include +#include -#include #include +#include "absl/status/status.h" #include "tensorflow/lite/builtin_ops.h" -#include "tensorflow/lite/c/common.h" #include "tensorflow/lite/core/subgraph.h" +#include "tensorflow/lite/delegates/gpu/common/data_type.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/tensor.h" #include "tensorflow/lite/interpreter.h" -#include "tensorflow/lite/stderr_reporter.h" namespace tflite { namespace gpu { diff --git a/tensorflow/lite/delegates/gpu/common/model_test.cc b/tensorflow/lite/delegates/gpu/common/model_test.cc index 87f65eb730a..816674b6674 100644 --- a/tensorflow/lite/delegates/gpu/common/model_test.cc +++ b/tensorflow/lite/delegates/gpu/common/model_test.cc @@ -15,11 +15,9 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/model.h" -#include -#include - #include #include +#include "absl/status/status.h" namespace tflite { namespace gpu { @@ -139,98 +137,140 @@ TEST(Model, RemoveProducer) { ASSERT_FALSE(graph.RemoveProducer(graph_output->id).ok()); } -TEST(Model, RemoveSimpleNodeDegenerateCase) { - GraphFloat32 graph; - Node* node = graph.NewNode(); - Value* graph_input = graph.NewValue(); - Value* graph_output = graph.NewValue(); +class OneNodeModel : public testing::Test { + protected: + void SetUp() override { + node_ = graph_.NewNode(); + Value* graph_input = graph_.NewValue(); + Value* graph_output = graph_.NewValue(); + ASSERT_TRUE(graph_.AddConsumer(node_->id, graph_input->id).ok()); + ASSERT_TRUE(graph_.SetProducer(node_->id, graph_output->id).ok()); + EXPECT_THAT(graph_.inputs(), UnorderedElementsAre(graph_input)); + EXPECT_THAT(graph_.outputs(), UnorderedElementsAre(graph_output)); + EXPECT_THAT(graph_.nodes(), ElementsAre(node_)); + } + GraphFloat32 graph_; + Node* node_; +}; - ASSERT_TRUE(graph.AddConsumer(node->id, graph_input->id).ok()); - ASSERT_TRUE(graph.SetProducer(node->id, graph_output->id).ok()); - EXPECT_THAT(graph.inputs(), UnorderedElementsAre(graph_input)); - EXPECT_THAT(graph.outputs(), UnorderedElementsAre(graph_output)); - EXPECT_THAT(graph.nodes(), ElementsAre(node)); - - ASSERT_TRUE(RemoveOneInputOneOutputNode(&graph, node).ok()); - EXPECT_THAT(graph.inputs(), UnorderedElementsAre()); - EXPECT_THAT(graph.outputs(), UnorderedElementsAre()); - EXPECT_THAT(graph.nodes(), ElementsAre()); +TEST_F(OneNodeModel, DeleteNodeKeepInput) { + ASSERT_TRUE(RemoveSimpleNodeKeepInput(&graph_, node_).ok()); + EXPECT_TRUE(graph_.inputs().empty()); + EXPECT_TRUE(graph_.outputs().empty()); + EXPECT_TRUE(graph_.nodes().empty()); } -TEST(Model, RemoveSimpleNodeNoPreviousNode) { - GraphFloat32 graph; - Node* simple_node = graph.NewNode(); - Node* consumer_node = graph.NewNode(); - Value* graph_input = graph.NewValue(); - Value* graph_output = graph.NewValue(); - Value* value = graph.NewValue(); - - ASSERT_TRUE(graph.AddConsumer(simple_node->id, graph_input->id).ok()); - ASSERT_TRUE(graph.SetProducer(simple_node->id, value->id).ok()); - ASSERT_TRUE(graph.AddConsumer(consumer_node->id, value->id).ok()); - ASSERT_TRUE(graph.SetProducer(consumer_node->id, graph_output->id).ok()); - EXPECT_THAT(graph.inputs(), UnorderedElementsAre(graph_input)); - EXPECT_THAT(graph.outputs(), UnorderedElementsAre(graph_output)); - EXPECT_THAT(graph.nodes(), ElementsAre(simple_node, consumer_node)); - - ASSERT_TRUE(RemoveOneInputOneOutputNode(&graph, simple_node).ok()); - EXPECT_THAT(graph.inputs(), UnorderedElementsAre(graph_input)); - EXPECT_THAT(graph.outputs(), UnorderedElementsAre(graph_output)); - EXPECT_THAT(graph.nodes(), ElementsAre(consumer_node)); +TEST_F(OneNodeModel, DeleteNodeKeepOutput) { + ASSERT_TRUE(RemoveSimpleNodeKeepOutput(&graph_, node_).ok()); + EXPECT_TRUE(graph_.inputs().empty()); + EXPECT_TRUE(graph_.outputs().empty()); + EXPECT_TRUE(graph_.nodes().empty()); } -TEST(Model, RemoveSimpleNodeNoAfterNodes) { - GraphFloat32 graph; - Node* simple_node = graph.NewNode(); - Node* producer_node = graph.NewNode(); - Value* graph_input = graph.NewValue(); - Value* graph_output = graph.NewValue(); - Value* value = graph.NewValue(); +class TwoNodesModel : public testing::Test { + protected: + void SetUp() override { + graph_input_ = graph_.NewValue(); + first_node_ = graph_.NewNode(); + value_ = graph_.NewValue(); + second_node_ = graph_.NewNode(); + graph_output_ = graph_.NewValue(); - ASSERT_TRUE(graph.AddConsumer(simple_node->id, value->id).ok()); - ASSERT_TRUE(graph.SetProducer(simple_node->id, graph_output->id).ok()); - ASSERT_TRUE(graph.AddConsumer(producer_node->id, graph_input->id).ok()); - ASSERT_TRUE(graph.SetProducer(producer_node->id, value->id).ok()); - EXPECT_THAT(graph.inputs(), UnorderedElementsAre(graph_input)); - EXPECT_THAT(graph.outputs(), UnorderedElementsAre(graph_output)); - EXPECT_THAT(graph.nodes(), ElementsAre(simple_node, producer_node)); + ASSERT_TRUE(graph_.AddConsumer(first_node_->id, graph_input_->id).ok()); + ASSERT_TRUE(graph_.SetProducer(first_node_->id, value_->id).ok()); + ASSERT_TRUE(graph_.AddConsumer(second_node_->id, value_->id).ok()); + ASSERT_TRUE(graph_.SetProducer(second_node_->id, graph_output_->id).ok()); + EXPECT_THAT(graph_.inputs(), UnorderedElementsAre(graph_input_)); + EXPECT_THAT(graph_.outputs(), UnorderedElementsAre(graph_output_)); + EXPECT_THAT(graph_.nodes(), ElementsAre(first_node_, second_node_)); + } + GraphFloat32 graph_; + Node* first_node_; + Node* second_node_; + Value* graph_input_; + Value* value_; + Value* graph_output_; +}; - ASSERT_TRUE(RemoveOneInputOneOutputNode(&graph, simple_node).ok()); - EXPECT_THAT(graph.inputs(), UnorderedElementsAre(graph_input)); - EXPECT_THAT(graph.outputs(), UnorderedElementsAre(value)); - EXPECT_THAT(graph.nodes(), ElementsAre(producer_node)); +TEST_F(TwoNodesModel, DeleteFirstNodeKeepInput) { + ASSERT_TRUE(RemoveSimpleNodeKeepInput(&graph_, first_node_).ok()); + EXPECT_THAT(graph_.inputs(), UnorderedElementsAre(graph_input_)); + EXPECT_THAT(graph_.outputs(), UnorderedElementsAre(graph_output_)); + EXPECT_THAT(graph_.nodes(), ElementsAre(second_node_)); } -TEST(Model, RemoveSimpleNodeGeneralCase) { - GraphFloat32 graph; - Node* simple_node = graph.NewNode(); - Node* producer_node = graph.NewNode(); - Node* consumer_node = graph.NewNode(); - Value* graph_input = graph.NewValue(); - Value* graph_output = graph.NewValue(); - Value* value0 = graph.NewValue(); - Value* value1 = graph.NewValue(); - - ASSERT_TRUE(graph.AddConsumer(producer_node->id, graph_input->id).ok()); - ASSERT_TRUE(graph.SetProducer(producer_node->id, value0->id).ok()); - ASSERT_TRUE(graph.AddConsumer(simple_node->id, value0->id).ok()); - ASSERT_TRUE(graph.SetProducer(simple_node->id, value1->id).ok()); - ASSERT_TRUE(graph.AddConsumer(consumer_node->id, value1->id).ok()); - ASSERT_TRUE(graph.SetProducer(consumer_node->id, graph_output->id).ok()); - EXPECT_THAT(graph.inputs(), UnorderedElementsAre(graph_input)); - EXPECT_THAT(graph.outputs(), UnorderedElementsAre(graph_output)); - EXPECT_THAT(graph.nodes(), - ElementsAre(simple_node, producer_node, consumer_node)); - - ASSERT_TRUE(RemoveOneInputOneOutputNode(&graph, simple_node).ok()); - EXPECT_THAT(graph.inputs(), UnorderedElementsAre(graph_input)); - EXPECT_THAT(graph.outputs(), UnorderedElementsAre(graph_output)); - EXPECT_THAT(graph.nodes(), ElementsAre(producer_node, consumer_node)); - EXPECT_THAT(graph.values(), - UnorderedElementsAre(graph_input, graph_output, value0)); +TEST_F(TwoNodesModel, DeleteFirstNodeKeepOutput) { + ASSERT_TRUE(RemoveSimpleNodeKeepOutput(&graph_, first_node_).ok()); + EXPECT_THAT(graph_.inputs(), UnorderedElementsAre(value_)); + EXPECT_THAT(graph_.outputs(), UnorderedElementsAre(graph_output_)); + EXPECT_THAT(graph_.nodes(), ElementsAre(second_node_)); } -TEST(Model, RemoveSimpleNodeComplexCase) { +TEST_F(TwoNodesModel, DeleteSecondNodeKeepInput) { + ASSERT_TRUE(RemoveSimpleNodeKeepInput(&graph_, second_node_).ok()); + EXPECT_THAT(graph_.inputs(), UnorderedElementsAre(graph_input_)); + EXPECT_THAT(graph_.outputs(), UnorderedElementsAre(value_)); + EXPECT_THAT(graph_.nodes(), ElementsAre(first_node_)); +} + +TEST_F(TwoNodesModel, DeleteSecondNodeKeepOutput) { + ASSERT_TRUE(RemoveSimpleNodeKeepOutput(&graph_, second_node_).ok()); + EXPECT_THAT(graph_.inputs(), UnorderedElementsAre(graph_input_)); + EXPECT_THAT(graph_.outputs(), UnorderedElementsAre(graph_output_)); + EXPECT_THAT(graph_.nodes(), ElementsAre(first_node_)); +} + +class ThreeNodesModel : public testing::Test { + protected: + void SetUp() override { + first_node_ = graph_.NewNode(); + second_node_ = graph_.NewNode(); + third_node_ = graph_.NewNode(); + graph_input_ = graph_.NewValue(); + value0_ = graph_.NewValue(); + value1_ = graph_.NewValue(); + graph_output_ = graph_.NewValue(); + + ASSERT_TRUE(graph_.AddConsumer(first_node_->id, graph_input_->id).ok()); + ASSERT_TRUE(graph_.SetProducer(first_node_->id, value0_->id).ok()); + ASSERT_TRUE(graph_.AddConsumer(second_node_->id, value0_->id).ok()); + ASSERT_TRUE(graph_.SetProducer(second_node_->id, value1_->id).ok()); + ASSERT_TRUE(graph_.AddConsumer(third_node_->id, value1_->id).ok()); + ASSERT_TRUE(graph_.SetProducer(third_node_->id, graph_output_->id).ok()); + EXPECT_THAT(graph_.inputs(), UnorderedElementsAre(graph_input_)); + EXPECT_THAT(graph_.outputs(), UnorderedElementsAre(graph_output_)); + EXPECT_THAT(graph_.nodes(), + ElementsAre(first_node_, second_node_, third_node_)); + } + GraphFloat32 graph_; + Node* first_node_; + Node* second_node_; + Node* third_node_; + Value* graph_input_; + Value* value0_; + Value* value1_; + Value* graph_output_; +}; + +TEST_F(ThreeNodesModel, DeleteMiddleNodeKeepInput) { + ASSERT_TRUE(RemoveSimpleNodeKeepInput(&graph_, second_node_).ok()); + EXPECT_THAT(graph_.inputs(), UnorderedElementsAre(graph_input_)); + EXPECT_THAT(graph_.outputs(), UnorderedElementsAre(graph_output_)); + EXPECT_THAT(graph_.nodes(), ElementsAre(first_node_, third_node_)); + EXPECT_THAT(graph_.values(), + UnorderedElementsAre(graph_input_, value0_, graph_output_)); +} + +TEST_F(ThreeNodesModel, DeleteMiddleNodeKeepOutput) { + ASSERT_TRUE(RemoveSimpleNodeKeepOutput(&graph_, second_node_).ok()); + EXPECT_THAT(graph_.inputs(), UnorderedElementsAre(graph_input_)); + EXPECT_THAT(graph_.outputs(), UnorderedElementsAre(graph_output_)); + EXPECT_THAT(graph_.nodes(), ElementsAre(first_node_, third_node_)); + EXPECT_THAT(graph_.values(), + UnorderedElementsAre(graph_input_, value1_, graph_output_)); +} + +TEST(Model, RemoveSimpleNodeKeepInputComplexCase) { // We have this graph and we are going to delete n1 and preserve order of // v0, v1 for n0 node and v2, v3 for n2 node // v0 v1 @@ -276,7 +316,11 @@ TEST(Model, RemoveSimpleNodeComplexCase) { EXPECT_THAT(graph.outputs(), UnorderedElementsAre(o1, o2)); EXPECT_THAT(graph.nodes(), ElementsAre(n0, n1, n2)); - ASSERT_TRUE(RemoveOneInputOneOutputNode(&graph, n1).ok()); + // Node should be the only consumer of the input value to be able to be + // deleted with this function. + ASSERT_FALSE(RemoveSimpleNodeKeepOutput(&graph, n1).ok()); + + ASSERT_TRUE(RemoveSimpleNodeKeepInput(&graph, n1).ok()); EXPECT_THAT(graph.inputs(), UnorderedElementsAre(v0, v1, v3)); EXPECT_THAT(graph.outputs(), UnorderedElementsAre(o1, o2)); EXPECT_THAT(graph.nodes(), ElementsAre(n0, n2)); @@ -466,6 +510,29 @@ TEST(Model, InsertNodeAfter) { EXPECT_THAT(graph.nodes(), ElementsAre(node1, new_node1, node2, new_node2)); } +TEST(BatchMatchingTest, EmptyGraph) { + GraphFloat32 graph; + ASSERT_TRUE(IsBatchMatchesForAllValues(graph)); +} + +TEST(BatchMatchingTest, AllMatch) { + GraphFloat32 graph; + Value* a = graph.NewValue(); + Value* b = graph.NewValue(); + a->tensor.shape = BHWC(1, 1, 1, 1); + b->tensor.shape = BHWC(1, 1, 1, 1); + ASSERT_TRUE(IsBatchMatchesForAllValues(graph)); +} + +TEST(BatchMatchingTest, NotAllMatch) { + GraphFloat32 graph; + Value* a = graph.NewValue(); + Value* b = graph.NewValue(); + a->tensor.shape = BHWC(1, 1, 1, 1); + b->tensor.shape = BHWC(2, 1, 1, 1); + ASSERT_FALSE(IsBatchMatchesForAllValues(graph)); +} + } // namespace } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/common/model_transformer.cc b/tensorflow/lite/delegates/gpu/common/model_transformer.cc index 81287dd61e5..3be7ec55196 100644 --- a/tensorflow/lite/delegates/gpu/common/model_transformer.cc +++ b/tensorflow/lite/delegates/gpu/common/model_transformer.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "absl/strings/str_join.h" #include "tensorflow/lite/delegates/gpu/common/model.h" diff --git a/tensorflow/lite/delegates/gpu/common/model_transformer.h b/tensorflow/lite/delegates/gpu/common/model_transformer.h index fd2667390f3..b640b14e0b4 100644 --- a/tensorflow/lite/delegates/gpu/common/model_transformer.h +++ b/tensorflow/lite/delegates/gpu/common/model_transformer.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include #include "absl/container/flat_hash_set.h" diff --git a/tensorflow/lite/delegates/gpu/common/object_reader.cc b/tensorflow/lite/delegates/gpu/common/object_reader.cc index c837fa061c0..04e4a14804a 100644 --- a/tensorflow/lite/delegates/gpu/common/object_reader.cc +++ b/tensorflow/lite/delegates/gpu/common/object_reader.cc @@ -16,13 +16,18 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/object_reader.h" #include +#include +#include #include "absl/container/flat_hash_map.h" +#include "absl/strings/str_cat.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/delegates/gpu/common/model.h" #include "tensorflow/lite/delegates/gpu/common/model_builder_helper.h" #include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/tensor.h" #include "tensorflow/lite/delegates/utils.h" +#include "tensorflow/lite/kernels/kernel_util.h" namespace tflite { namespace gpu { @@ -58,6 +63,9 @@ absl::Status ObjectReader::ReadNonConstantTensor( &fp_tensor_index) != kTfLiteOk) { return absl::InternalError("Could not add new tensor to graph"); } + // `tflite_tensor` value could be invalid when the `context->tensors` + // is reallocated. Thus reassigning `tflite_tensor` with a fresh value. + tflite_tensor = &context->tensors[tensor_idx]; // Remember this tensor for later. (*quant_conversion_map)[fp_tensor_index] = tensor_idx; @@ -67,10 +75,8 @@ absl::Status ObjectReader::ReadNonConstantTensor( RETURN_IF_ERROR( ConvertTfLiteTensorToTensorRef(*fp_tflite_tensor, &value->tensor)); value->tensor.ref = fp_tensor_index; + value->tensor.is_variable_input = tflite_tensor->is_variable; value->quant_params.emplace(); - // tflite_tensor from the outer scope is invalidated due to calling - // CreateNewTensorWithDifferentType - tflite_tensor = &context->tensors[tensor_idx]; RETURN_IF_ERROR( PopulateQuantParams(*tflite_tensor, &value->quant_params.value())); (*tensor_to_value)[fp_tensor_index] = value; @@ -84,6 +90,7 @@ absl::Status ObjectReader::ReadNonConstantTensor( RETURN_IF_ERROR( ConvertTfLiteTensorToTensorRef(*tflite_tensor, &value->tensor)); value->tensor.ref = tensor_idx; + value->tensor.is_variable_input = tflite_tensor->is_variable; (*tensor_to_value)[tensor_idx] = value; } } @@ -154,6 +161,53 @@ absl::Status ObjectReader::AddInput(const Node* node, uint32_t idx) { return graph_->AddConsumer(node->id, input->id); } +absl::Status ObjectReader::AddUpdate(const Node* node, uint32_t idx) { + if (node_->inputs->size <= idx) { + return absl::InvalidArgumentError(absl::StrCat( + "Data id ", idx, " must be less than tflite node inputs size ", + node_->inputs->size)); + } + + int update_tensor_idx = node_->inputs->data[idx]; + TfLiteTensor* update_tensor = context_->tensors + update_tensor_idx; + if (!update_tensor->is_variable) { + return absl::InvalidArgumentError( + "The tensor must be a variable tensor to update it in place"); + } + + Value* value; + RETURN_IF_ERROR(ReadValueByTensorIdx(update_tensor_idx, &value)); + if (!value->tensor.is_variable_input) { + return absl::InternalError( + "Variable input tensor is not marked as variable"); + } + + // We cannot create a cycle in the graph. The way around this when a node + // updates a tensor in place would be to add a new value to the graph that + // points to the same tensor. + Value* updated_value = graph_->NewValue(); + updated_value->tensor = value->tensor; + updated_value->quant_params = value->quant_params; + RETURN_IF_ERROR(graph_->SetProducer(node->id, updated_value->id)); + + // We also need to update the tensor_to_value arrays so that the nodes added + // after the current node will access the tensor with the updated value rather + // than the initial value. + if (quant_conversion_map_ != nullptr && + quant_conversion_map_->find(update_tensor_idx) != + quant_conversion_map_->end()) { + // If quantization conversion map exists, then the index provided is not the + // actual tensor idx. We need to find the float version of the tensor from + // the map. + tensor_to_value_->at(quant_conversion_map_->at(update_tensor_idx)) = + updated_value; + } else { + tensor_to_value_->at(update_tensor_idx) = updated_value; + } + + return absl::OkStatus(); +} + TfLiteTensor* ObjectReader::GetInputTensor(int index) const { return index >= 0 && index < node_->inputs->size ? context_->tensors + node_->inputs->data[index] diff --git a/tensorflow/lite/delegates/gpu/common/object_reader.h b/tensorflow/lite/delegates/gpu/common/object_reader.h index 246bc71f9c5..3c7d7f6a859 100644 --- a/tensorflow/lite/delegates/gpu/common/object_reader.h +++ b/tensorflow/lite/delegates/gpu/common/object_reader.h @@ -86,6 +86,8 @@ class ObjectReader { absl::Status AddInput(const Node* node, uint32_t idx); + absl::Status AddUpdate(const Node* node, uint32_t idx); + TfLiteTensor* GetInputTensor(int index) const; TfLiteTensor* GetOutputTensor(int index) const; diff --git a/tensorflow/lite/delegates/gpu/common/operations.cc b/tensorflow/lite/delegates/gpu/common/operations.cc index fbffe9d65ff..19d7bd919c5 100644 --- a/tensorflow/lite/delegates/gpu/common/operations.cc +++ b/tensorflow/lite/delegates/gpu/common/operations.cc @@ -15,11 +15,17 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/operations.h" +#include #include +#include +#include +#include +#include #include "absl/container/flat_hash_map.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/tensor.h" namespace tflite { namespace gpu { @@ -76,6 +82,8 @@ std::string ToString(enum OperationType op) { return "batch_normalization"; case OperationType::BATCH_TO_SPACE: return "batch_to_space"; + case OperationType::BATCHED_MATMUL: + return "batched_matmul"; case OperationType::CONCAT: return "concat"; case OperationType::CONST: @@ -94,12 +102,22 @@ std::string ToString(enum OperationType op) { return "div"; case OperationType::ELU: return "elu"; + case OperationType::EQUAL: + return "equal"; case OperationType::EXP: return "exp"; case OperationType::FULLY_CONNECTED: return "fully_connected"; + case OperationType::GREATER: + return "greater"; + case OperationType::GREATER_EQUAL: + return "greater_equal"; case OperationType::HARD_SWISH: return "hard_swish"; + case OperationType::LESS: + return "less"; + case OperationType::LESS_EQUAL: + return "less_equal"; case OperationType::LOG: return "log"; case OperationType::LSTM: @@ -116,6 +134,10 @@ std::string ToString(enum OperationType op) { return "minimum"; case OperationType::MUL: return "mul"; + case OperationType::NEG: + return "neg"; + case OperationType::NOT_EQUAL: + return "not_equal"; case OperationType::PAD: return "pad"; case OperationType::POOLING_2D: @@ -126,6 +148,14 @@ std::string ToString(enum OperationType op) { return "prelu"; case OperationType::QUANTIZE_AND_DEQUANTIZE: return "quantize_and_dequantize"; + case OperationType::REDUCE_MAXIMUM: + return "reduce_maximum"; + case OperationType::REDUCE_MINIMUM: + return "reduce_minimum"; + case OperationType::REDUCE_PRODUCT: + return "reduce_product"; + case OperationType::REDUCE_SUM: + return "reduce_sum"; case OperationType::RELU: return "relu"; case OperationType::RESHAPE: @@ -169,6 +199,7 @@ OperationType OperationTypeFromString(const std::string& name) { {"abs", OperationType::ABS}, {"add", OperationType::ADD}, {"batch_normalization", OperationType::BATCH_NORMALIZATION}, + {"batched_matmul", OperationType::BATCHED_MATMUL}, {"concat", OperationType::CONCAT}, {"const", OperationType::CONST}, {"convolution_2d", OperationType::CONVOLUTION_2D}, @@ -178,9 +209,14 @@ OperationType OperationTypeFromString(const std::string& name) { {"depthwise_convolution", OperationType::DEPTHWISE_CONVOLUTION}, {"div", OperationType::DIV}, {"elu", OperationType::ELU}, + {"equal", OperationType::EQUAL}, {"exp", OperationType::EXP}, {"fully_connected", OperationType::FULLY_CONNECTED}, + {"greater", OperationType::GREATER}, + {"greater_equal", OperationType::GREATER_EQUAL}, {"hard_swish", OperationType::HARD_SWISH}, + {"less", OperationType::LESS}, + {"less_equal", OperationType::LESS_EQUAL}, {"log", OperationType::LOG}, {"lstm", OperationType::LSTM}, {"maximum", OperationType::MAXIMUM}, @@ -190,11 +226,17 @@ OperationType OperationTypeFromString(const std::string& name) { OperationType::MEAN_STDDEV_NORMALIZATION}, {"minimum", OperationType::MINIMUM}, {"mul", OperationType::MUL}, + {"neg", OperationType::NEG}, + {"not_equal", OperationType::NOT_EQUAL}, {"pad", OperationType::PAD}, {"pooling_2d", OperationType::POOLING_2D}, {"pow", OperationType::POW}, {"prelu", OperationType::PRELU}, {"quantize_and_dequantize", OperationType::QUANTIZE_AND_DEQUANTIZE}, + {"reduce_maximum", OperationType::REDUCE_MAXIMUM}, + {"reduce_minimum", OperationType::REDUCE_MINIMUM}, + {"reduce_product", OperationType::REDUCE_PRODUCT}, + {"reduce_sum", OperationType::REDUCE_SUM}, {"relu", OperationType::RELU}, {"resize", OperationType::RESIZE}, {"reshape", OperationType::RESHAPE}, diff --git a/tensorflow/lite/delegates/gpu/common/operations.h b/tensorflow/lite/delegates/gpu/common/operations.h index 563dbdec96e..a93f63a02b7 100644 --- a/tensorflow/lite/delegates/gpu/common/operations.h +++ b/tensorflow/lite/delegates/gpu/common/operations.h @@ -17,14 +17,15 @@ limitations under the License. #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_OPERATIONS_H_ #include +#include #include #include #include "absl/types/variant.h" #include "tensorflow/lite/delegates/gpu/common/data_type.h" -#include "tensorflow/lite/delegates/gpu/common/model.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/tensor.h" namespace tflite { namespace gpu { @@ -36,6 +37,7 @@ enum class OperationType { ADD, BATCH_TO_SPACE, BATCH_NORMALIZATION, + BATCHED_MATMUL, CONCAT, CONST, CONVOLUTION_2D, @@ -45,9 +47,14 @@ enum class OperationType { DEPTHWISE_CONVOLUTION, DIV, ELU, + EQUAL, EXP, FULLY_CONNECTED, + GREATER, + GREATER_EQUAL, HARD_SWISH, + LESS, + LESS_EQUAL, LOG, LSTM, MAXIMUM, @@ -56,12 +63,18 @@ enum class OperationType { MEAN_STDDEV_NORMALIZATION, MINIMUM, MUL, + NEG, + NOT_EQUAL, PAD, POOLING_2D, POW, PRELU, // Used to accurately run inference on quantized models. QUANTIZE_AND_DEQUANTIZE, + REDUCE_MAXIMUM, + REDUCE_MINIMUM, + REDUCE_PRODUCT, + REDUCE_SUM, RELU, RESHAPE, RESIZE, @@ -358,6 +371,10 @@ struct PReLUAttributes { alpha; }; +struct ReduceAttributes { + Axis axis = Axis::UNKNOWN; +}; + struct SoftmaxAttributes { Axis axis = Axis::UNKNOWN; }; diff --git a/tensorflow/lite/delegates/gpu/common/quantization_util.cc b/tensorflow/lite/delegates/gpu/common/quantization_util.cc index fe92989a3ae..bbd99023a2f 100644 --- a/tensorflow/lite/delegates/gpu/common/quantization_util.cc +++ b/tensorflow/lite/delegates/gpu/common/quantization_util.cc @@ -15,9 +15,15 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/quantization_util.h" +#include + +#include + #include "absl/container/flat_hash_map.h" -#include "tensorflow/lite/builtin_ops.h" +#include "absl/status/status.h" +#include "tensorflow/lite/c/common.h" #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/internal/types.h" namespace tflite { diff --git a/tensorflow/lite/delegates/gpu/common/quantization_util.h b/tensorflow/lite/delegates/gpu/common/quantization_util.h index fc01d612d6f..584f6876a9c 100644 --- a/tensorflow/lite/delegates/gpu/common/quantization_util.h +++ b/tensorflow/lite/delegates/gpu/common/quantization_util.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_QUANTIZATION_UTIL_H_ #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_QUANTIZATION_UTIL_H_ +#include + #include #include "absl/container/flat_hash_map.h" diff --git a/tensorflow/lite/delegates/gpu/common/quantization_util_test.cc b/tensorflow/lite/delegates/gpu/common/quantization_util_test.cc index b5cdaec91e0..ffded543123 100644 --- a/tensorflow/lite/delegates/gpu/common/quantization_util_test.cc +++ b/tensorflow/lite/delegates/gpu/common/quantization_util_test.cc @@ -15,8 +15,18 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/quantization_util.h" +#include + +#include +#include +#include +#include + #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "tensorflow/lite/c/common.h" #include "tensorflow/lite/util.h" using ::testing::Eq; diff --git a/tensorflow/lite/delegates/gpu/common/shape.cc b/tensorflow/lite/delegates/gpu/common/shape.cc index 074637a7774..c66ecea1215 100644 --- a/tensorflow/lite/delegates/gpu/common/shape.cc +++ b/tensorflow/lite/delegates/gpu/common/shape.cc @@ -14,6 +14,11 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/delegates/gpu/common/shape.h" +#include + +#include +#include + #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" diff --git a/tensorflow/lite/delegates/gpu/common/shape.h b/tensorflow/lite/delegates/gpu/common/shape.h index 544d2c1f4d0..a017ff28e63 100644 --- a/tensorflow/lite/delegates/gpu/common/shape.h +++ b/tensorflow/lite/delegates/gpu/common/shape.h @@ -16,9 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SHAPE_H_ #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SHAPE_H_ -#include +#include +#include -#include #include #include #include @@ -26,8 +26,6 @@ limitations under the License. #include #include -#include "absl/hash/hash.h" - namespace tflite { namespace gpu { diff --git a/tensorflow/lite/delegates/gpu/common/shape_test.cc b/tensorflow/lite/delegates/gpu/common/shape_test.cc index 41519115729..3cbf1fddfc2 100644 --- a/tensorflow/lite/delegates/gpu/common/shape_test.cc +++ b/tensorflow/lite/delegates/gpu/common/shape_test.cc @@ -14,10 +14,10 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/delegates/gpu/common/shape.h" -#include +#include + #include -#include #include namespace tflite { diff --git a/tensorflow/lite/delegates/gpu/common/tensor.h b/tensorflow/lite/delegates/gpu/common/tensor.h index fc39d3485ba..ba0fd48810c 100644 --- a/tensorflow/lite/delegates/gpu/common/tensor.h +++ b/tensorflow/lite/delegates/gpu/common/tensor.h @@ -16,7 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TENSOR_H_ #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TENSOR_H_ -#include +#include + #include #include "tensorflow/lite/delegates/gpu/common/data_type.h" @@ -71,6 +72,10 @@ struct TensorRef { // Opaque reference to a tensor. Upstream component is responsible for // resolving this reference into an actual tensor. int64_t ref = -1; + + // Specifies if the tensor should be a variable input tensor that must be an + // output as well as an input to the graph. + bool is_variable_input = false; }; template diff --git a/tensorflow/lite/delegates/gpu/common/testing/BUILD b/tensorflow/lite/delegates/gpu/common/testing/BUILD index a7f97eb67b3..dd8792d6895 100644 --- a/tensorflow/lite/delegates/gpu/common/testing/BUILD +++ b/tensorflow/lite/delegates/gpu/common/testing/BUILD @@ -10,6 +10,8 @@ cc_library( hdrs = ["interpreter_utils.h"], deps = [ "//tensorflow/lite:framework", + "//tensorflow/lite:string", + "//tensorflow/lite/c:common", "//tensorflow/lite/core/api", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common:tensor", @@ -25,13 +27,12 @@ cc_library( hdrs = ["tflite_model_reader.h"], deps = [ "//tensorflow/lite:framework_lib", - "//tensorflow/lite:kernel_api", "//tensorflow/lite/c:common", "//tensorflow/lite/core/api", "//tensorflow/lite/delegates/gpu/common:model", "//tensorflow/lite/delegates/gpu/common:model_builder", + "//tensorflow/lite/delegates/gpu/common:model_transformer", "//tensorflow/lite/delegates/gpu/common:status", - "//tensorflow/lite/delegates/gpu/common/transformations:general_transformations", - "//tensorflow/lite/kernels:builtin_ops", + "//tensorflow/lite/delegates/gpu/common/transformations:model_transformations", ], ) diff --git a/tensorflow/lite/delegates/gpu/common/testing/feature_parity/BUILD b/tensorflow/lite/delegates/gpu/common/testing/feature_parity/BUILD index b5ceff30d1e..50150964e92 100644 --- a/tensorflow/lite/delegates/gpu/common/testing/feature_parity/BUILD +++ b/tensorflow/lite/delegates/gpu/common/testing/feature_parity/BUILD @@ -24,10 +24,12 @@ cc_library( hdrs = ["utils.h"], deps = [ "//tensorflow/lite:framework", + "//tensorflow/lite:string", "//tensorflow/lite/c:common", "//tensorflow/lite/kernels:builtin_ops", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest", ], ) @@ -48,6 +50,7 @@ cc_test( deps = [ ":feature_parity", ":utils", + "//tensorflow/lite:framework_lib", "//tensorflow/lite/delegates/gpu:gl_delegate", "@com_google_googletest//:gtest_main", ], @@ -65,6 +68,7 @@ cc_test( deps = [ ":feature_parity", ":utils", + "//tensorflow/lite:framework_lib", "//tensorflow/lite/delegates/gpu:delegate", "@com_google_googletest//:gtest_main", ], @@ -82,6 +86,7 @@ cc_test( deps = [ ":feature_parity", ":utils", + "//tensorflow/lite:framework_lib", "//tensorflow/lite/delegates/xnnpack:xnnpack_delegate", "@com_google_googletest//:gtest_main", ], diff --git a/tensorflow/lite/delegates/gpu/common/testing/feature_parity/feature_parity.h b/tensorflow/lite/delegates/gpu/common/testing/feature_parity/feature_parity.h index 7661a4ad296..dacb486e303 100644 --- a/tensorflow/lite/delegates/gpu/common/testing/feature_parity/feature_parity.h +++ b/tensorflow/lite/delegates/gpu/common/testing/feature_parity/feature_parity.h @@ -16,9 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TESTING_FEATURE_PARITY_FEATURE_PARITY_H_ #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TESTING_FEATURE_PARITY_FEATURE_PARITY_H_ -#include -#include -#include #include #include "tensorflow/lite/delegates/gpu/common/testing/feature_parity/generators/add.h" diff --git a/tensorflow/lite/delegates/gpu/common/testing/feature_parity/generators/BUILD b/tensorflow/lite/delegates/gpu/common/testing/feature_parity/generators/BUILD index 4fef0a28525..56894c8810a 100644 --- a/tensorflow/lite/delegates/gpu/common/testing/feature_parity/generators/BUILD +++ b/tensorflow/lite/delegates/gpu/common/testing/feature_parity/generators/BUILD @@ -20,9 +20,7 @@ cc_library( srcs = ["add.cc"], hdrs = ["add.h"], deps = [ - "//tensorflow/lite:framework", "//tensorflow/lite:schema_fbs_version", - "//tensorflow/lite/delegates/gpu/common:shape", "//tensorflow/lite/delegates/gpu/common/testing/feature_parity:utils", "@flatbuffers", ], diff --git a/tensorflow/lite/delegates/gpu/common/testing/feature_parity/generators/add.cc b/tensorflow/lite/delegates/gpu/common/testing/feature_parity/generators/add.cc index dbb3851ca56..06649b36e79 100644 --- a/tensorflow/lite/delegates/gpu/common/testing/feature_parity/generators/add.cc +++ b/tensorflow/lite/delegates/gpu/common/testing/feature_parity/generators/add.cc @@ -15,11 +15,14 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/testing/feature_parity/generators/add.h" +#include + +#include +#include #include -#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "flatbuffers/flatbuffers.h" // from @flatbuffers #include "tensorflow/lite/delegates/gpu/common/testing/feature_parity/utils.h" -#include "tensorflow/lite/model.h" #include "tensorflow/lite/version.h" namespace tflite { diff --git a/tensorflow/lite/delegates/gpu/common/testing/feature_parity/opencl_test.cc b/tensorflow/lite/delegates/gpu/common/testing/feature_parity/opencl_test.cc index 24c0e0c424b..3dbb8638196 100644 --- a/tensorflow/lite/delegates/gpu/common/testing/feature_parity/opencl_test.cc +++ b/tensorflow/lite/delegates/gpu/common/testing/feature_parity/opencl_test.cc @@ -13,11 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + +#include +#include + #include #include #include "tensorflow/lite/delegates/gpu/common/testing/feature_parity/feature_parity.h" #include "tensorflow/lite/delegates/gpu/common/testing/feature_parity/utils.h" #include "tensorflow/lite/delegates/gpu/delegate.h" +#include "tensorflow/lite/interpreter.h" namespace tflite { diff --git a/tensorflow/lite/delegates/gpu/common/testing/feature_parity/opengl_test.cc b/tensorflow/lite/delegates/gpu/common/testing/feature_parity/opengl_test.cc index 2f403d2e583..ed0aa104e65 100644 --- a/tensorflow/lite/delegates/gpu/common/testing/feature_parity/opengl_test.cc +++ b/tensorflow/lite/delegates/gpu/common/testing/feature_parity/opengl_test.cc @@ -15,12 +15,14 @@ limitations under the License. #include #include +#include #include #include #include "tensorflow/lite/delegates/gpu/common/testing/feature_parity/feature_parity.h" #include "tensorflow/lite/delegates/gpu/common/testing/feature_parity/utils.h" #include "tensorflow/lite/delegates/gpu/gl_delegate.h" +#include "tensorflow/lite/interpreter.h" namespace tflite { diff --git a/tensorflow/lite/delegates/gpu/common/testing/feature_parity/utils.cc b/tensorflow/lite/delegates/gpu/common/testing/feature_parity/utils.cc index bdcbf7ed62e..6eb94f63b6f 100644 --- a/tensorflow/lite/delegates/gpu/common/testing/feature_parity/utils.cc +++ b/tensorflow/lite/delegates/gpu/common/testing/feature_parity/utils.cc @@ -15,15 +15,18 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/testing/feature_parity/utils.h" +#include +#include #include #include +#include #include "absl/status/status.h" #include "absl/strings/substitute.h" -#include "tensorflow/lite/c/common.h" #include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/model.h" +#include "tensorflow/lite/string_type.h" std::ostream& operator<<(std::ostream& os, const TfLiteTensor& tensor) { std::string shape; diff --git a/tensorflow/lite/delegates/gpu/common/testing/feature_parity/utils.h b/tensorflow/lite/delegates/gpu/common/testing/feature_parity/utils.h index 7c34978fb55..20d43b85468 100644 --- a/tensorflow/lite/delegates/gpu/common/testing/feature_parity/utils.h +++ b/tensorflow/lite/delegates/gpu/common/testing/feature_parity/utils.h @@ -16,14 +16,24 @@ limitations under the License. #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TESTING_FEATURE_PARITY_UTILS_H_ #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TESTING_FEATURE_PARITY_UTILS_H_ +#include + #include +#include +#include #include #include +#include +#include #include #include +#include #include "absl/status/status.h" +#include "absl/types/span.h" +#include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/model.h" +#include "tensorflow/lite/string_type.h" namespace tflite { diff --git a/tensorflow/lite/delegates/gpu/common/testing/feature_parity/xnnpack_test.cc b/tensorflow/lite/delegates/gpu/common/testing/feature_parity/xnnpack_test.cc index 3d05d64437d..bdd12951c8c 100644 --- a/tensorflow/lite/delegates/gpu/common/testing/feature_parity/xnnpack_test.cc +++ b/tensorflow/lite/delegates/gpu/common/testing/feature_parity/xnnpack_test.cc @@ -13,11 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + +#include +#include + #include #include #include "tensorflow/lite/delegates/gpu/common/testing/feature_parity/feature_parity.h" #include "tensorflow/lite/delegates/gpu/common/testing/feature_parity/utils.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" +#include "tensorflow/lite/interpreter.h" namespace tflite { diff --git a/tensorflow/lite/delegates/gpu/common/testing/interpreter_utils.cc b/tensorflow/lite/delegates/gpu/common/testing/interpreter_utils.cc index 08d9448f7e5..ae00e213fa3 100644 --- a/tensorflow/lite/delegates/gpu/common/testing/interpreter_utils.cc +++ b/tensorflow/lite/delegates/gpu/common/testing/interpreter_utils.cc @@ -16,15 +16,18 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/testing/interpreter_utils.h" #include +#include +#include #include #include "absl/memory/memory.h" -#include "tensorflow/lite/context.h" #include "tensorflow/lite/core/api/op_resolver.h" #include "tensorflow/lite/delegates/gpu/common/status.h" #include "tensorflow/lite/delegates/gpu/common/tensor.h" #include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/model.h" +#include "tensorflow/lite/string_type.h" namespace tflite { namespace gpu { diff --git a/tensorflow/lite/delegates/gpu/common/testing/interpreter_utils.h b/tensorflow/lite/delegates/gpu/common/testing/interpreter_utils.h index ca2825b7563..86656abbe0f 100644 --- a/tensorflow/lite/delegates/gpu/common/testing/interpreter_utils.h +++ b/tensorflow/lite/delegates/gpu/common/testing/interpreter_utils.h @@ -18,7 +18,7 @@ limitations under the License. #include -#include "tensorflow/lite/context.h" +#include "tensorflow/lite/c/common.h" #include "tensorflow/lite/core/api/op_resolver.h" #include "tensorflow/lite/delegates/gpu/common/status.h" #include "tensorflow/lite/delegates/gpu/common/tensor.h" diff --git a/tensorflow/lite/delegates/gpu/common/testing/tflite_model_reader.cc b/tensorflow/lite/delegates/gpu/common/testing/tflite_model_reader.cc index a67602cf245..7ba3de641ef 100644 --- a/tensorflow/lite/delegates/gpu/common/testing/tflite_model_reader.cc +++ b/tensorflow/lite/delegates/gpu/common/testing/tflite_model_reader.cc @@ -14,16 +14,18 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/delegates/gpu/common/testing/tflite_model_reader.h" +#include + #include -#include "tensorflow/lite/builtin_ops.h" #include "tensorflow/lite/core/api/op_resolver.h" #include "tensorflow/lite/delegates/gpu/common/model.h" #include "tensorflow/lite/delegates/gpu/common/model_builder.h" +#include "tensorflow/lite/delegates/gpu/common/model_transformer.h" #include "tensorflow/lite/delegates/gpu/common/status.h" -#include "tensorflow/lite/delegates/gpu/common/transformations/general_transformations.h" +#include "tensorflow/lite/delegates/gpu/common/transformations/model_transformations.h" #include "tensorflow/lite/interpreter.h" -#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/model.h" #include "tensorflow/lite/model_builder.h" namespace tflite { @@ -93,8 +95,8 @@ absl::Status BuildFromFlatBuffer(const tflite::FlatBufferModel& flatbuffer, NullTransformationReporter reporter; ModelTransformer transformer(graph, &reporter); - if (!ApplyGeneralTransformations(&transformer)) { - return absl::InternalError("Graph general transformations failed"); + if (!ApplyModelTransformations(&transformer)) { + return absl::InternalError("Graph transformations failed"); } return absl::OkStatus(); diff --git a/tensorflow/lite/delegates/gpu/common/transformations/BUILD b/tensorflow/lite/delegates/gpu/common/transformations/BUILD index bf26b03f534..6cb358bcc93 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/BUILD +++ b/tensorflow/lite/delegates/gpu/common/transformations/BUILD @@ -1,3 +1,5 @@ +load("//tensorflow/core/platform:build_config.bzl", "tf_platform_alias") + package( default_visibility = ["//visibility:public"], licenses = ["notice"], # Apache 2.0 @@ -12,9 +14,9 @@ cc_library( "//tensorflow/lite/delegates/gpu/common:model", "//tensorflow/lite/delegates/gpu/common:model_transformer", "//tensorflow/lite/delegates/gpu/common:operations", - "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common:shape", + "//tensorflow/lite/delegates/gpu/common:tensor", "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", "@com_google_absl//absl/types:any", ], ) @@ -24,11 +26,11 @@ cc_library( srcs = ["add_quant_adjustments.cc"], hdrs = ["add_quant_adjustments.h"], deps = [ - "//tensorflow/lite/delegates/gpu/common:data_type", "//tensorflow/lite/delegates/gpu/common:model", "//tensorflow/lite/delegates/gpu/common:model_transformer", "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common:tensor", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:any", @@ -40,10 +42,13 @@ cc_test( srcs = ["add_quant_adjustments_test.cc"], deps = [ ":add_quant_adjustments", + "//tensorflow/lite/delegates/gpu/common:data_type", "//tensorflow/lite/delegates/gpu/common:model", "//tensorflow/lite/delegates/gpu/common:model_transformer", "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:shape", + "//tensorflow/lite/delegates/gpu/common:tensor", + "@com_google_absl//absl/status", "@com_google_absl//absl/types:any", "@com_google_absl//absl/types:optional", "@com_google_googletest//:gtest_main", @@ -56,9 +61,13 @@ cc_library( hdrs = ["fuse_add_to_conv.h"], deps = [ "//tensorflow/lite/delegates/gpu/common:data_type", + "//tensorflow/lite/delegates/gpu/common:model", "//tensorflow/lite/delegates/gpu/common:model_transformer", "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:shape", "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common:tensor", + "@com_google_absl//absl/strings", ], ) @@ -67,8 +76,13 @@ cc_test( srcs = ["fuse_add_to_conv_test.cc"], deps = [ ":fuse_add_to_conv", + "//tensorflow/lite/delegates/gpu/common:data_type", + "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:model_transformer", "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:shape", + "//tensorflow/lite/delegates/gpu/common:tensor", + "@com_google_absl//absl/status", "@com_google_googletest//:gtest_main", ], ) @@ -82,8 +96,10 @@ cc_library( "//tensorflow/lite/delegates/gpu/common:model", "//tensorflow/lite/delegates/gpu/common:model_transformer", "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:shape", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common:tensor", + "@com_google_absl//absl/strings", ], ) @@ -92,17 +108,21 @@ cc_test( srcs = ["fuse_mul_to_conv_test.cc"], deps = [ ":fuse_mul_to_conv", + "//tensorflow/lite/delegates/gpu/common:data_type", "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:model_transformer", "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:shape", + "//tensorflow/lite/delegates/gpu/common:tensor", + "@com_google_absl//absl/status", "@com_google_googletest//:gtest_main", ], ) cc_library( - name = "general_transformations", - srcs = ["general_transformations.cc"], - hdrs = ["general_transformations.h"], + name = "model_transformations", + srcs = ["model_transformations.cc"], + hdrs = ["model_transformations.h"], deps = [ ":add_quant_adjustments", ":fuse_add_to_conv", @@ -112,7 +132,7 @@ cc_library( ":merge_padding_with", ":remove_noop", "//tensorflow/lite/delegates/gpu/common:model_transformer", - ], + ] + tf_platform_alias("custom_transformations", "//tensorflow/lite/delegates/gpu/common/"), ) cc_library( @@ -123,7 +143,8 @@ cc_library( "//tensorflow/lite/delegates/gpu/common:model", "//tensorflow/lite/delegates/gpu/common:model_transformer", "//tensorflow/lite/delegates/gpu/common:operations", - "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common:shape", + "//tensorflow/lite/delegates/gpu/common:tensor", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:any", ], @@ -138,6 +159,8 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:model_transformer", "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:shape", + "//tensorflow/lite/delegates/gpu/common:tensor", + "@com_google_absl//absl/status", "@com_google_absl//absl/types:any", "@com_google_googletest//:gtest_main", ], @@ -151,8 +174,11 @@ cc_library( "//tensorflow/lite/delegates/gpu/common:model", "//tensorflow/lite/delegates/gpu/common:model_transformer", "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:shape", "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common:tensor", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:any", ], ) @@ -165,6 +191,9 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:model", "//tensorflow/lite/delegates/gpu/common:model_transformer", "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:shape", + "//tensorflow/lite/delegates/gpu/common:tensor", + "@com_google_absl//absl/status", "@com_google_absl//absl/types:any", "@com_google_googletest//:gtest_main", ], @@ -173,7 +202,6 @@ cc_test( cc_library( name = "matching", hdrs = ["matching.h"], - deps = ["//tensorflow/lite/delegates/gpu/common:model"], ) cc_library( @@ -186,7 +214,9 @@ cc_library( "//tensorflow/lite/delegates/gpu/common:model", "//tensorflow/lite/delegates/gpu/common:model_transformer", "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:shape", "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common:tensor", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:any", @@ -198,10 +228,13 @@ cc_test( srcs = ["merge_padding_with_test.cc"], deps = [ ":merge_padding_with", + "//tensorflow/lite/delegates/gpu/common:data_type", "//tensorflow/lite/delegates/gpu/common:model", "//tensorflow/lite/delegates/gpu/common:model_transformer", "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:shape", + "//tensorflow/lite/delegates/gpu/common:tensor", + "@com_google_absl//absl/status", "@com_google_absl//absl/types:any", "@com_google_googletest//:gtest_main", ], @@ -216,8 +249,11 @@ cc_library( "//tensorflow/lite/delegates/gpu/common:model", "//tensorflow/lite/delegates/gpu/common:model_transformer", "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:shape", "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common:tensor", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -230,6 +266,9 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:model", "//tensorflow/lite/delegates/gpu/common:model_transformer", "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:shape", + "//tensorflow/lite/delegates/gpu/common:tensor", + "@com_google_absl//absl/status", "@com_google_googletest//:gtest_main", ], ) diff --git a/tensorflow/lite/delegates/gpu/common/transformations/add_bias.cc b/tensorflow/lite/delegates/gpu/common/transformations/add_bias.cc index 29d70d8f4a9..af274d8381e 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/add_bias.cc +++ b/tensorflow/lite/delegates/gpu/common/transformations/add_bias.cc @@ -15,13 +15,18 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/transformations/add_bias.h" +#include +#include +#include + #include "absl/memory/memory.h" -#include "absl/strings/str_cat.h" #include "absl/types/any.h" #include "tensorflow/lite/delegates/gpu/common/data_type.h" #include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/model_transformer.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" -#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/tensor.h" namespace tflite { namespace gpu { @@ -65,6 +70,12 @@ class AddBias : public NodeTransformation { } if (node->operation.type == ToString(OperationType::DEPTHWISE_CONVOLUTION)) { + if (graph->FindInputs(node->id).size() != 1) { + return {TransformStatus::DECLINED, + "This transformation is only applicable to depth wise conv " + "with one " + "runtime input."}; + } auto& attr = absl::any_cast( node->operation.attributes); return FillBias(attr.weights.shape.o * attr.weights.shape.i, &attr.bias); diff --git a/tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments.cc b/tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments.cc index 6262d1575b7..7f43d70c842 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments.cc +++ b/tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments.cc @@ -15,15 +15,19 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments.h" +#include +#include #include +#include #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/types/any.h" -#include "tensorflow/lite/delegates/gpu/common/data_type.h" #include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/model_transformer.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/tensor.h" namespace tflite { namespace gpu { diff --git a/tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments_test.cc b/tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments_test.cc index 2ff84981f9d..9ef909d4ab7 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments_test.cc +++ b/tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments_test.cc @@ -15,14 +15,20 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments.h" -#include +#include +#include +#include + #include +#include "absl/status/status.h" #include "absl/types/any.h" #include "absl/types/optional.h" +#include "tensorflow/lite/delegates/gpu/common/data_type.h" #include "tensorflow/lite/delegates/gpu/common/model.h" #include "tensorflow/lite/delegates/gpu/common/model_transformer.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/tensor.h" namespace tflite { namespace gpu { @@ -59,7 +65,7 @@ TEST(AddQuantAdjustments, OneNode) { ASSERT_TRUE(graph.AddConsumer(add_node->id, input->id).ok()); - Value* output; + Value* output = nullptr; AddQuantParams(&input->quant_params, /*min=*/0.0, /*max=*/2.0, /*scale=*/0.008); ASSERT_TRUE(AddOutput(&graph, add_node, &output).ok()); @@ -114,18 +120,18 @@ TEST(AddQuantAdjustments, GeneralCase) { // Connections. ASSERT_TRUE(graph.AddConsumer(add1_node->id, input->id).ok()); - Value* link1; + Value* link1 = nullptr; ASSERT_TRUE(ConnectTwoNodes(&graph, add1_node, quant_node, &link1).ok()); AddQuantParams(&link1->quant_params, /*min=*/0.0, /*max=*/2.0, /*scale=*/0.008); link1->tensor.shape = BHWC(1, 4, 4, 8); ASSERT_TRUE(graph.AddConsumer(add2_node->id, link1->id).ok()); - Value* link2; + Value* link2 = nullptr; ASSERT_TRUE(ConnectTwoNodes(&graph, quant_node, add2_node, &link2).ok()); AddQuantParams(&link2->quant_params, /*min=*/-1.0, /*max=*/1.0, /*scale=*/0.008); link2->tensor.shape = BHWC(1, 4, 4, 8); - Value* output; + Value* output = nullptr; ASSERT_TRUE(AddOutput(&graph, add2_node, &output).ok()); AddQuantParams(&output->quant_params, /*min=*/-1.0, /*max=*/1.0, /*scale=*/0.008); diff --git a/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.cc b/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.cc index fdbd6e03755..62c3ec39854 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.cc +++ b/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.cc @@ -15,8 +15,20 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.h" +#include +#include +#include +#include +#include + +#include "absl/strings/string_view.h" #include "tensorflow/lite/delegates/gpu/common/data_type.h" +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/model_transformer.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/tensor.h" namespace tflite { namespace gpu { @@ -42,6 +54,10 @@ class MergeConvolutionWithAdd : public SequenceTransformation { TransformResult ApplyToNodesSequence(const std::vector& sequence, GraphFloat32* graph) final { auto& conv_node = *sequence[0]; + if (graph->FindInputs(conv_node.id).size() != 1) { + return {TransformStatus::DECLINED, + "This fusion is only applicable to ops with one runtime input."}; + } auto& add_node = *sequence[1]; if (add_node.operation.type != ToString(OperationType::ADD)) { return {TransformStatus::SKIPPED, ""}; diff --git a/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.h b/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.h index 53a0cef63c8..26f93dc3765 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.h +++ b/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.h @@ -20,7 +20,6 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/model_transformer.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" -#include "tensorflow/lite/delegates/gpu/common/status.h" namespace tflite { namespace gpu { diff --git a/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv_test.cc b/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv_test.cc index 4a48c7c0b28..76bf7e4a72a 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv_test.cc +++ b/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv_test.cc @@ -15,10 +15,20 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.h" +#include +#include +#include +#include + #include #include +#include "absl/status/status.h" +#include "tensorflow/lite/delegates/gpu/common/data_type.h" +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/model_transformer.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/tensor.h" using ::testing::FloatNear; using ::testing::Pointwise; @@ -57,11 +67,11 @@ TEST(MergeConvolutionWithAddTest, Smoke) { ASSERT_TRUE(graph.AddConsumer(conv_node->id, input->id).ok()); - Value* output; + Value* output = nullptr; ASSERT_TRUE(AddOutput(&graph, add_node, &output).ok()); output->tensor.shape = BHWC(1, 4, 4, 16); - Value* link1; + Value* link1 = nullptr; ASSERT_TRUE(ConnectTwoNodes(&graph, conv_node, add_node, &link1).ok()); link1->tensor.shape = BHWC(1, 4, 4, 16); diff --git a/tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv.cc b/tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv.cc index 25ec6299f11..41bd485a76c 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv.cc +++ b/tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv.cc @@ -15,9 +15,18 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv.h" +#include +#include +#include +#include +#include + +#include "absl/strings/string_view.h" #include "tensorflow/lite/delegates/gpu/common/data_type.h" #include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/model_transformer.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/status.h" #include "tensorflow/lite/delegates/gpu/common/tensor.h" diff --git a/tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv.h b/tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv.h index 8d64ae50488..92fab4553f1 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv.h +++ b/tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv.h @@ -20,7 +20,6 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/model_transformer.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" -#include "tensorflow/lite/delegates/gpu/common/status.h" namespace tflite { namespace gpu { diff --git a/tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv_test.cc b/tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv_test.cc index ea990dd8267..b35cb832335 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv_test.cc +++ b/tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv_test.cc @@ -15,11 +15,20 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv.h" +#include +#include +#include +#include + #include #include +#include "absl/status/status.h" +#include "tensorflow/lite/delegates/gpu/common/data_type.h" #include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/model_transformer.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/tensor.h" using ::testing::FloatNear; using ::testing::Pointwise; @@ -58,11 +67,11 @@ TEST(MergeConvolutionWithMulTest, Smoke) { ASSERT_TRUE(graph.AddConsumer(conv_node->id, input->id).ok()); - Value* output; + Value* output = nullptr; ASSERT_TRUE(AddOutput(&graph, mul_node, &output).ok()); output->tensor.shape = BHWC(1, 4, 4, 16); - Value* link1; + Value* link1 = nullptr; ASSERT_TRUE(ConnectTwoNodes(&graph, conv_node, mul_node, &link1).ok()); link1->tensor.shape = BHWC(1, 4, 4, 16); @@ -109,11 +118,11 @@ TEST(MergeMulWithConvolutionTest, Smoke) { ASSERT_TRUE(graph.AddConsumer(mul_node->id, input->id).ok()); - Value* output; + Value* output = nullptr; ASSERT_TRUE(AddOutput(&graph, conv_node, &output).ok()); output->tensor.shape = BHWC(1, 4, 4, 16); - Value* link1; + Value* link1 = nullptr; ASSERT_TRUE(ConnectTwoNodes(&graph, mul_node, conv_node, &link1).ok()); link1->tensor.shape = BHWC(1, 4, 4, 16); diff --git a/tensorflow/lite/delegates/gpu/common/transformations/make_fully_connected.cc b/tensorflow/lite/delegates/gpu/common/transformations/make_fully_connected.cc index 1236cdec214..226e7d4b2a9 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/make_fully_connected.cc +++ b/tensorflow/lite/delegates/gpu/common/transformations/make_fully_connected.cc @@ -15,11 +15,17 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/transformations/make_fully_connected.h" +#include +#include +#include + #include "absl/memory/memory.h" #include "absl/types/any.h" #include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/model_transformer.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" -#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/tensor.h" namespace tflite { namespace gpu { diff --git a/tensorflow/lite/delegates/gpu/common/transformations/make_fully_connected_test.cc b/tensorflow/lite/delegates/gpu/common/transformations/make_fully_connected_test.cc index d3606d4a097..29f1b4bfbef 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/make_fully_connected_test.cc +++ b/tensorflow/lite/delegates/gpu/common/transformations/make_fully_connected_test.cc @@ -15,13 +15,18 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/transformations/make_fully_connected.h" -#include +#include +#include +#include + #include +#include "absl/status/status.h" #include "absl/types/any.h" #include "tensorflow/lite/delegates/gpu/common/model.h" #include "tensorflow/lite/delegates/gpu/common/model_transformer.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/tensor.h" namespace tflite { namespace gpu { @@ -68,16 +73,16 @@ TEST(MakeFullyConnected, Smoke) { ASSERT_TRUE(graph.AddConsumer(conv1x1_node0->id, input->id).ok()); - Value* output; + Value* output = nullptr; ASSERT_TRUE(AddOutput(&graph, conv1x1_node2, &output).ok()); output->tensor.shape = BHWC(1, 1, 1, 32); - Value* link1; + Value* link1 = nullptr; ASSERT_TRUE( ConnectTwoNodes(&graph, conv1x1_node0, conv4x4_node1, &link1).ok()); link1->tensor.shape = BHWC(1, 4, 4, 16); - Value* link2; + Value* link2 = nullptr; ASSERT_TRUE( ConnectTwoNodes(&graph, conv4x4_node1, conv1x1_node2, &link2).ok()); link2->tensor.shape = BHWC(1, 1, 1, 16); diff --git a/tensorflow/lite/delegates/gpu/common/transformations/make_padding.cc b/tensorflow/lite/delegates/gpu/common/transformations/make_padding.cc index 17aac83baf7..51335a83c38 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/make_padding.cc +++ b/tensorflow/lite/delegates/gpu/common/transformations/make_padding.cc @@ -15,11 +15,19 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/transformations/make_padding.h" +#include +#include +#include + #include "absl/memory/memory.h" +#include "absl/strings/string_view.h" #include "absl/types/any.h" #include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/model_transformer.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/tensor.h" namespace tflite { namespace gpu { diff --git a/tensorflow/lite/delegates/gpu/common/transformations/make_padding_test.cc b/tensorflow/lite/delegates/gpu/common/transformations/make_padding_test.cc index f8be3218239..8aafd75ba5b 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/make_padding_test.cc +++ b/tensorflow/lite/delegates/gpu/common/transformations/make_padding_test.cc @@ -15,12 +15,18 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/transformations/make_padding.h" -#include +#include +#include +#include + #include +#include "absl/status/status.h" #include "absl/types/any.h" #include "tensorflow/lite/delegates/gpu/common/model.h" #include "tensorflow/lite/delegates/gpu/common/model_transformer.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/tensor.h" namespace tflite { namespace gpu { @@ -38,7 +44,7 @@ TEST(MakePadding, Smoke) { attr.axis = Axis::HEIGHT; concat_node->operation.attributes = attr; - Value* output; + Value* output = nullptr; ASSERT_TRUE(AddOutput(&graph, concat_node, &output).ok()); output->tensor.shape = BHWC(1, 7, 3, 5); @@ -50,7 +56,7 @@ TEST(MakePadding, Smoke) { std::vector(const_attr.tensor.shape.DimensionsProduct(), 0); const_node->operation.attributes = const_attr; - Value* const_link; + Value* const_link = nullptr; ASSERT_TRUE( ConnectTwoNodes(&graph, const_node, concat_node, &const_link).ok()); const_link->tensor.shape = const_attr.tensor.shape; diff --git a/tensorflow/lite/delegates/gpu/common/transformations/matching.h b/tensorflow/lite/delegates/gpu/common/transformations/matching.h index 0dfd21e50ba..b28c8b05fed 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/matching.h +++ b/tensorflow/lite/delegates/gpu/common/transformations/matching.h @@ -18,9 +18,10 @@ limitations under the License. // A file provides predicates to match subgraphs. +#include +#include #include - -#include "tensorflow/lite/delegates/gpu/common/model.h" +#include namespace tflite { namespace gpu { diff --git a/tensorflow/lite/delegates/gpu/common/transformations/merge_padding_with.cc b/tensorflow/lite/delegates/gpu/common/transformations/merge_padding_with.cc index 6a4e24b5042..509d715f550 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/merge_padding_with.cc +++ b/tensorflow/lite/delegates/gpu/common/transformations/merge_padding_with.cc @@ -15,16 +15,22 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/transformations/merge_padding_with.h" +#include #include +#include #include #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "absl/types/any.h" #include "tensorflow/lite/delegates/gpu/common/data_type.h" #include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/model_transformer.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/tensor.h" #include "tensorflow/lite/delegates/gpu/common/transformations/matching.h" namespace tflite { diff --git a/tensorflow/lite/delegates/gpu/common/transformations/merge_padding_with_test.cc b/tensorflow/lite/delegates/gpu/common/transformations/merge_padding_with_test.cc index 40029efbc65..826a9b82854 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/merge_padding_with_test.cc +++ b/tensorflow/lite/delegates/gpu/common/transformations/merge_padding_with_test.cc @@ -15,13 +15,19 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/transformations/merge_padding_with.h" -#include +#include +#include +#include + #include +#include "absl/status/status.h" #include "absl/types/any.h" +#include "tensorflow/lite/delegates/gpu/common/data_type.h" #include "tensorflow/lite/delegates/gpu/common/model.h" #include "tensorflow/lite/delegates/gpu/common/model_transformer.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/tensor.h" namespace tflite { namespace gpu { @@ -40,7 +46,7 @@ TEST(MergePaddingWith, Smoke) { pad_node->operation.attributes = attr; auto conv_node = graph.NewNode(); - Value* temp; + Value* temp = nullptr; ASSERT_TRUE(ConnectTwoNodes(&graph, pad_node, conv_node, &temp).ok()); ASSERT_TRUE(AddOutput(&graph, conv_node, &temp).ok()); conv_node->operation.type = ToString(OperationType::CONVOLUTION_2D); @@ -77,16 +83,17 @@ TEST(MergePaddingWith, MergeTwo) { pad_node1->operation.attributes = attr; auto pad_node2 = graph.NewNode(); - Value* temp; - ASSERT_TRUE(ConnectTwoNodes(&graph, pad_node1, pad_node2, &temp).ok()); + Value* temp1 = nullptr; + ASSERT_TRUE(ConnectTwoNodes(&graph, pad_node1, pad_node2, &temp1).ok()); pad_node2->operation.type = ToString(OperationType::PAD); attr.prepended = BHWC(0, 0, 0, 0); attr.appended = BHWC(0, 2, 2, 0); pad_node2->operation.attributes = attr; auto conv_node = graph.NewNode(); - ASSERT_TRUE(ConnectTwoNodes(&graph, pad_node2, conv_node, &temp).ok()); - ASSERT_TRUE(AddOutput(&graph, conv_node, &temp).ok()); + Value* temp2 = nullptr; + ASSERT_TRUE(ConnectTwoNodes(&graph, pad_node2, conv_node, &temp2).ok()); + ASSERT_TRUE(AddOutput(&graph, conv_node, &temp2).ok()); conv_node->operation.type = ToString(OperationType::CONVOLUTION_2D); Convolution2DAttributes conv_attr; conv_attr.padding.appended = HW(0, 0); diff --git a/tensorflow/lite/delegates/gpu/common/transformations/general_transformations.cc b/tensorflow/lite/delegates/gpu/common/transformations/model_transformations.cc similarity index 87% rename from tensorflow/lite/delegates/gpu/common/transformations/general_transformations.cc rename to tensorflow/lite/delegates/gpu/common/transformations/model_transformations.cc index f9ae7f41f8f..d1a6cf127f5 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/general_transformations.cc +++ b/tensorflow/lite/delegates/gpu/common/transformations/model_transformations.cc @@ -13,8 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/delegates/gpu/common/transformations/general_transformations.h" +#include "tensorflow/lite/delegates/gpu/common/transformations/model_transformations.h" +#include + +#include "tensorflow/lite/delegates/gpu/common/custom_transformations.h" +#include "tensorflow/lite/delegates/gpu/common/model_transformer.h" #include "tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments.h" #include "tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.h" #include "tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv.h" @@ -26,6 +30,8 @@ limitations under the License. namespace tflite { namespace gpu { +namespace { + bool ApplyGeneralTransformations(ModelTransformer* transformer) { // whenever any of these transforms return false, that means that a graph // is in the broken state and processing should not continue. @@ -57,5 +63,12 @@ bool ApplyGeneralTransformations(ModelTransformer* transformer) { NewMergeMulWithConvolution().get()); } +} // namespace + +bool ApplyModelTransformations(ModelTransformer* transformer) { + return ApplyCustomTransformations(transformer) && + ApplyGeneralTransformations(transformer); +} + } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/common/transformations/general_transformations.h b/tensorflow/lite/delegates/gpu/common/transformations/model_transformations.h similarity index 89% rename from tensorflow/lite/delegates/gpu/common/transformations/general_transformations.h rename to tensorflow/lite/delegates/gpu/common/transformations/model_transformations.h index ffc5bba4f1a..69592c9777b 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/general_transformations.h +++ b/tensorflow/lite/delegates/gpu/common/transformations/model_transformations.h @@ -21,8 +21,9 @@ limitations under the License. namespace tflite { namespace gpu { +// Applies custom and general transformations to the model in the proper order. // @return false when something went wrong that turned a graph in a broken state -bool ApplyGeneralTransformations(ModelTransformer* transformer); +bool ApplyModelTransformations(ModelTransformer* transformer); } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/common/transformations/remove_noop.cc b/tensorflow/lite/delegates/gpu/common/transformations/remove_noop.cc index 6cc370899e4..a97d9185c71 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/remove_noop.cc +++ b/tensorflow/lite/delegates/gpu/common/transformations/remove_noop.cc @@ -15,14 +15,25 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/transformations/remove_noop.h" +#include +#include +#include +#include +#include #include +#include +#include #include #include "absl/memory/memory.h" +#include "absl/strings/string_view.h" #include "tensorflow/lite/delegates/gpu/common/data_type.h" #include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/model_transformer.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/tensor.h" namespace tflite { namespace gpu { @@ -118,7 +129,7 @@ class RemoveIdentityReshape : public NodeTransformation { return {TransformStatus::SKIPPED, "Can not apply transformation when node output is graph output"}; } - absl::Status status = RemoveOneInputOneOutputNode(graph, node); + absl::Status status = RemoveSimpleNodeKeepInput(graph, node); if (!status.ok()) { return {TransformStatus::INVALID, "Unable to remove a node: " + std::string(status.message())}; diff --git a/tensorflow/lite/delegates/gpu/common/transformations/remove_noop_test.cc b/tensorflow/lite/delegates/gpu/common/transformations/remove_noop_test.cc index a6aafee4f06..b76962d3ecb 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/remove_noop_test.cc +++ b/tensorflow/lite/delegates/gpu/common/transformations/remove_noop_test.cc @@ -15,12 +15,20 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/transformations/remove_noop.h" +#include +#include +#include +#include + #include #include +#include "absl/status/status.h" #include "tensorflow/lite/delegates/gpu/common/data_type.h" #include "tensorflow/lite/delegates/gpu/common/model.h" #include "tensorflow/lite/delegates/gpu/common/model_transformer.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/tensor.h" namespace tflite { namespace gpu { @@ -35,12 +43,12 @@ TEST(RemoveSingleInputAdd, Smoke) { ASSERT_TRUE(graph.AddConsumer(first_node->id, input->id).ok()); auto add_node = graph.NewNode(); - Value* output; + Value* output = nullptr; ASSERT_TRUE(AddOutput(&graph, add_node, &output).ok()); add_node->operation.type = ToString(OperationType::ADD); add_node->operation.attributes = ElementwiseAttributes(); - Value* temp; + Value* temp = nullptr; ASSERT_TRUE(ConnectTwoNodes(&graph, first_node, add_node, &temp).ok()); ASSERT_EQ(2, graph.nodes().size()); ASSERT_EQ(3, graph.values().size()); @@ -63,14 +71,14 @@ TEST(RemoveSingleInputAdd, DoNotTrigger_TensorHWC) { ASSERT_TRUE(graph.AddConsumer(first_node->id, input->id).ok()); auto add_node = graph.NewNode(); - Value* output; + Value* output = nullptr; ASSERT_TRUE(AddOutput(&graph, add_node, &output).ok()); add_node->operation.type = ToString(OperationType::ADD); ElementwiseAttributes attr; attr.param = Tensor(); add_node->operation.attributes = attr; - Value* temp; + Value* temp = nullptr; ASSERT_TRUE(ConnectTwoNodes(&graph, first_node, add_node, &temp).ok()); ASSERT_EQ(2, graph.nodes().size()); ASSERT_EQ(3, graph.values().size()); @@ -90,14 +98,14 @@ TEST(RemoveSingleInputAdd, DoNotTrigger_LinearTensor) { ASSERT_TRUE(graph.AddConsumer(first_node->id, input->id).ok()); auto add_node = graph.NewNode(); - Value* output; + Value* output = nullptr; ASSERT_TRUE(AddOutput(&graph, add_node, &output).ok()); add_node->operation.type = ToString(OperationType::ADD); ElementwiseAttributes attr; attr.param = Tensor(); add_node->operation.attributes = attr; - Value* temp; + Value* temp = nullptr; ASSERT_TRUE(ConnectTwoNodes(&graph, first_node, add_node, &temp).ok()); ASSERT_EQ(2, graph.nodes().size()); ASSERT_EQ(3, graph.values().size()); @@ -117,14 +125,14 @@ TEST(RemoveSingleInputAdd, DoNotTrigger_Scalar) { ASSERT_TRUE(graph.AddConsumer(first_node->id, input->id).ok()); auto add_node = graph.NewNode(); - Value* output; + Value* output = nullptr; ASSERT_TRUE(AddOutput(&graph, add_node, &output).ok()); add_node->operation.type = ToString(OperationType::ADD); ElementwiseAttributes attr; attr.param = 0.5f; add_node->operation.attributes = attr; - Value* temp; + Value* temp = nullptr; ASSERT_TRUE(ConnectTwoNodes(&graph, first_node, add_node, &temp).ok()); ASSERT_EQ(2, graph.nodes().size()); ASSERT_EQ(3, graph.values().size()); @@ -146,13 +154,14 @@ TEST(RemoveSingleInputAdd, DoNotTrigger_Multiple) { ASSERT_TRUE(graph.AddConsumer(node_b->id, input->id).ok()); auto add_node = graph.NewNode(); - Value* output; + Value* output = nullptr; ASSERT_TRUE(AddOutput(&graph, add_node, &output).ok()); add_node->operation.type = ToString(OperationType::ADD); - Value* temp; - ASSERT_TRUE(ConnectTwoNodes(&graph, node_a, add_node, &temp).ok()); - ASSERT_TRUE(ConnectTwoNodes(&graph, node_b, add_node, &temp).ok()); + Value* temp_a = nullptr; + Value* temp_b = nullptr; + ASSERT_TRUE(ConnectTwoNodes(&graph, node_a, add_node, &temp_a).ok()); + ASSERT_TRUE(ConnectTwoNodes(&graph, node_b, add_node, &temp_b).ok()); ASSERT_EQ(3, graph.nodes().size()); ASSERT_EQ(4, graph.values().size()); @@ -171,7 +180,7 @@ TEST(RemoveDegenerateUpsampling, Smoke) { ASSERT_TRUE(graph.AddConsumer(first_node->id, input->id).ok()); auto node_to_remove = graph.NewNode(); - Value* output; + Value* output = nullptr; ASSERT_TRUE(AddOutput(&graph, node_to_remove, &output).ok()); output->tensor.shape = BHWC(1, 5, 5, 1); node_to_remove->operation.type = ToString(OperationType::RESIZE); @@ -180,7 +189,7 @@ TEST(RemoveDegenerateUpsampling, Smoke) { attr.type = SamplingType::BILINEAR; node_to_remove->operation.attributes = attr; - Value* link; + Value* link = nullptr; ASSERT_TRUE(ConnectTwoNodes(&graph, first_node, node_to_remove, &link).ok()); link->tensor.shape = output->tensor.shape; ASSERT_EQ(2, graph.nodes().size()); diff --git a/tensorflow/lite/delegates/gpu/common/types.h b/tensorflow/lite/delegates/gpu/common/types.h index 8725b4234fe..4ddb46f305d 100644 --- a/tensorflow/lite/delegates/gpu/common/types.h +++ b/tensorflow/lite/delegates/gpu/common/types.h @@ -19,7 +19,6 @@ limitations under the License. #include #include #include -#include #include diff --git a/tensorflow/lite/delegates/gpu/common/winograd_util.cc b/tensorflow/lite/delegates/gpu/common/winograd_util.cc index 16be80eef41..4b9581d0f39 100644 --- a/tensorflow/lite/delegates/gpu/common/winograd_util.cc +++ b/tensorflow/lite/delegates/gpu/common/winograd_util.cc @@ -15,6 +15,9 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/winograd_util.h" +#include +#include + #include "tensorflow/lite/delegates/gpu/common/data_type.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/tensor.h" diff --git a/tensorflow/lite/delegates/gpu/common/winograd_util.h b/tensorflow/lite/delegates/gpu/common/winograd_util.h index 2e80a6ce121..e88ceacb490 100644 --- a/tensorflow/lite/delegates/gpu/common/winograd_util.h +++ b/tensorflow/lite/delegates/gpu/common/winograd_util.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_WINOGRAD_UTIL_H_ #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_WINOGRAD_UTIL_H_ +#include + #include "tensorflow/lite/delegates/gpu/common/data_type.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/tensor.h" diff --git a/tensorflow/lite/delegates/gpu/common/workgroup_selection.cc b/tensorflow/lite/delegates/gpu/common/workgroup_selection.cc index 5ae2a53f449..439eb0ade90 100644 --- a/tensorflow/lite/delegates/gpu/common/workgroup_selection.cc +++ b/tensorflow/lite/delegates/gpu/common/workgroup_selection.cc @@ -15,7 +15,10 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/workgroup_selection.h" +#include + #include +#include #include "tensorflow/lite/delegates/gpu/common/util.h" diff --git a/tensorflow/lite/delegates/gpu/common/workgroup_selection.h b/tensorflow/lite/delegates/gpu/common/workgroup_selection.h index a08bfce991a..67c51b45177 100644 --- a/tensorflow/lite/delegates/gpu/common/workgroup_selection.h +++ b/tensorflow/lite/delegates/gpu/common/workgroup_selection.h @@ -18,9 +18,6 @@ limitations under the License. #include -#include "tensorflow/lite/delegates/gpu/common/status.h" -#include "tensorflow/lite/delegates/gpu/common/types.h" - namespace tflite { namespace gpu { diff --git a/tensorflow/lite/delegates/gpu/delegate.cc b/tensorflow/lite/delegates/gpu/delegate.cc index bfc2b7f08c4..98303b51da8 100644 --- a/tensorflow/lite/delegates/gpu/delegate.cc +++ b/tensorflow/lite/delegates/gpu/delegate.cc @@ -448,7 +448,7 @@ TfLiteGpuDelegateOptionsV2 TfLiteGpuDelegateOptionsV2Default() { .inference_priority1 = TFLITE_GPU_INFERENCE_PRIORITY_MAX_PRECISION, .inference_priority2 = TFLITE_GPU_INFERENCE_PRIORITY_AUTO, .inference_priority3 = TFLITE_GPU_INFERENCE_PRIORITY_AUTO, - .experimental_flags = TFLITE_GPU_EXPERIMENTAL_FLAGS_NONE, + .experimental_flags = TFLITE_GPU_EXPERIMENTAL_FLAGS_ENABLE_QUANT, .max_delegated_partitions = 1, }; return options; diff --git a/tensorflow/lite/delegates/gpu/delegate.h b/tensorflow/lite/delegates/gpu/delegate.h index 9af586bfd75..40a06bb4384 100644 --- a/tensorflow/lite/delegates/gpu/delegate.h +++ b/tensorflow/lite/delegates/gpu/delegate.h @@ -51,6 +51,7 @@ enum TfLiteGpuInferencePriority { enum TfLiteGpuExperimentalFlags { TFLITE_GPU_EXPERIMENTAL_FLAGS_NONE = 0, // Enables inference on quantized models with the delegate. + // NOTE: This is enabled in TfLiteGpuDelegateOptionsV2Default. TFLITE_GPU_EXPERIMENTAL_FLAGS_ENABLE_QUANT = 1 << 0, // Enforces execution with the provided backend. TFLITE_GPU_EXPERIMENTAL_FLAGS_CL_ONLY = 1 << 1, @@ -108,6 +109,8 @@ typedef struct { // priority1 = TFLITE_GPU_INFERENCE_PRIORITY_MAX_PRECISION // priority2 = TFLITE_GPU_INFERENCE_PRIORITY_AUTO // priority3 = TFLITE_GPU_INFERENCE_PRIORITY_AUTO +// experimental_flags = TFLITE_GPU_EXPERIMENTAL_FLAGS_ENABLE_QUANT +// max_delegated_partitions = 1 TFL_CAPI_EXPORT TfLiteGpuDelegateOptionsV2 TfLiteGpuDelegateOptionsV2Default(); // Creates a new delegate instance that need to be destroyed with diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/BUILD b/tensorflow/lite/delegates/gpu/gl/compiler/BUILD index f62f48750bd..801e87fd775 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiler/BUILD +++ b/tensorflow/lite/delegates/gpu/gl/compiler/BUILD @@ -38,7 +38,6 @@ cc_library( "//tensorflow/lite/delegates/gpu/common:data_type", "//tensorflow/lite/delegates/gpu/common:types", "//tensorflow/lite/delegates/gpu/gl:object", - "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:variant", @@ -88,6 +87,7 @@ cc_library( "//tensorflow/lite/delegates/gpu/gl:compiler_options", "//tensorflow/lite/delegates/gpu/gl:object", "//tensorflow/lite/delegates/gpu/gl:variable", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/object_accessor.h b/tensorflow/lite/delegates/gpu/gl/compiler/object_accessor.h index 5c4de49c44b..318709fe7ff 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiler/object_accessor.h +++ b/tensorflow/lite/delegates/gpu/gl/compiler/object_accessor.h @@ -16,10 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_OBJECT_ACCESSOR_H_ #define TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_OBJECT_ACCESSOR_H_ +#include #include #include -#include "absl/container/flat_hash_map.h" #include "tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.h" #include "tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.h" #include "tensorflow/lite/delegates/gpu/gl/object.h" @@ -85,7 +85,7 @@ class ObjectAccessor : public InlineRewrite { RewriteStatus RewriteWrite(absl::string_view location, absl::string_view value, std::string* output); - absl::flat_hash_map name_to_object_; + std::map name_to_object_; const bool is_mali_; const bool sampler_textures_; diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/shader_codegen.cc b/tensorflow/lite/delegates/gpu/gl/compiler/shader_codegen.cc index e473f9e77ff..34c24edc5a3 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiler/shader_codegen.cc +++ b/tensorflow/lite/delegates/gpu/gl/compiler/shader_codegen.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "tensorflow/lite/delegates/gpu/common/gpu_info.h" #include "tensorflow/lite/delegates/gpu/common/status.h" @@ -48,6 +49,11 @@ absl::Status ShaderCodegen::Build(CompiledNodeAttributes attr, const auto add_uniform_parameter = [&](Variable&& variable) { const std::string name = variable.name; + const Variable& const_ref = variable; + if (variable_accessor.IsEmptyVariableLength(const_ref)) { + return absl::InvalidArgumentError( + absl::StrCat("Empty uniform vector value \"", name, "\"")); + } if (!variable_accessor.AddUniformParameter(std::move(variable))) { return absl::AlreadyExistsError( absl::StrCat("Uniform parameter \"", name, "\"")); diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.cc b/tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.cc index 96461f26ab8..2bb4a73c0ae 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.cc +++ b/tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.cc @@ -73,6 +73,21 @@ std::string GetVariableType(const Variable::ValueType& value) { return absl::visit(VariableTypeGetter(), value); } +struct LengthGetter { + template + int operator()(const T& param) const { + return 1; + } + template + int operator()(const std::vector& param) const { + return param.size(); + } +}; + +int GetLength(const Variable::ValueType& value) { + return absl::visit(LengthGetter(), value); +} + template void FormatValue(std::string* result, T t) { absl::StrAppend(result, t); @@ -459,6 +474,11 @@ bool VariableAccessor::AddUniformParameter(Variable&& variable) { return true; } +bool VariableAccessor::IsEmptyVariableLength(const Variable& variable) const { + const auto& value = variable.value; + return IsVariableLength(value) && GetLength(value) == 0; +} + std::string VariableAccessor::GetConstDeclarations() const { // Variable length variables are declared as const and accessed via variable // with index. diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.h b/tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.h index db4b031548b..f6d5344d3b3 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.h +++ b/tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.h @@ -57,6 +57,9 @@ class VariableAccessor : public InlineRewrite { // Returns true if variable was successfully added. bool AddUniformParameter(Variable&& variable); + // Returns true if variable value is an empty vector. + bool IsEmptyVariableLength(const Variable& variable) const; + // Returns const variables that need to be inlined in the a shader's code. std::string GetConstDeclarations() const; diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/BUILD b/tensorflow/lite/delegates/gpu/gl/kernels/BUILD index 8e13b58051b..a5d49b2c394 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/BUILD +++ b/tensorflow/lite/delegates/gpu/gl/kernels/BUILD @@ -673,6 +673,8 @@ cc_library( "//tensorflow/lite/delegates/gpu/gl:request_gpu_info", "//tensorflow/lite/delegates/gpu/gl:runtime_options", "//tensorflow/lite/delegates/gpu/gl/workgroups:default_calculator", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_googletest//:gtest", "@com_google_googletest//:gtest_main", ], diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/depthwise_conv.cc b/tensorflow/lite/delegates/gpu/gl/kernels/depthwise_conv.cc index 71217a8e709..ceda5b68ca8 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/depthwise_conv.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/depthwise_conv.cc @@ -38,6 +38,10 @@ class DepthwiseConvolution : public NodeShader { public: absl::Status GenerateCode(const GenerationContext& ctx, GeneratedCode* generated_code) const final { + if (ctx.input_shapes.size() != 1) { + return absl::UnimplementedError( + "DepthWise Convolution does not support more than 1 runtime tensor"); + } const auto& attr = absl::any_cast(ctx.op_attr); auto weights = attr.weights.shape; diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/elementwise.cc b/tensorflow/lite/delegates/gpu/gl/kernels/elementwise.cc index 5d50fcc0118..9c874864bb1 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/elementwise.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/elementwise.cc @@ -69,6 +69,9 @@ class ElementwiseOneArgument : public NodeShader { value_0.w = value_0.w > 0.0 ? log(value_0.w) : nan; )"; break; + case OperationType::NEG: + source = "value_0 = -(value_0);"; + break; case OperationType::RSQRT: source = R"( const float nan = normalize(vec4(0, 0, 0, 0)).x; @@ -222,12 +225,13 @@ std::unique_ptr NewElementwiseNodeShader( OperationType operation_type) { switch (operation_type) { case OperationType::ABS: - case OperationType::COPY: case OperationType::COS: + case OperationType::COPY: case OperationType::ELU: case OperationType::EXP: - case OperationType::LOG: case OperationType::HARD_SWISH: + case OperationType::LOG: + case OperationType::NEG: case OperationType::RSQRT: case OperationType::SIGMOID: case OperationType::SIN: diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/elementwise_test.cc b/tensorflow/lite/delegates/gpu/gl/kernels/elementwise_test.cc index a32a4ea9f76..5ff7bfc9ed7 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/elementwise_test.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/elementwise_test.cc @@ -129,6 +129,18 @@ TEST(ElementwiseOneArgumentTest, Log) { Pointwise(FloatNear(1e-6), {0.0, 1.14473, 0.0, 0.0})); } +TEST(ElementwiseOneArgumentTest, Neg) { + OperationType op_type = OperationType::NEG; + const BHWC shape(1, 2, 2, 1); + SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}}, + /*inputs=*/{GetTensorRef(0, shape)}, + /*outputs=*/{GetTensorRef(1, shape)}); + ASSERT_TRUE(model.PopulateTensor(0, {1.0, -3.1415926, 0.0, 1.0})); + ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type))); + EXPECT_THAT(model.GetOutput(0), + Pointwise(FloatNear(1e-6), {-1.0, 3.1415926, 0.0, -1.0})); +} + TEST(ElementwiseOneArgumentTest, Rsqrt) { OperationType op_type = OperationType::RSQRT; const BHWC shape(1, 2, 2, 1); diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/registry.cc b/tensorflow/lite/delegates/gpu/gl/kernels/registry.cc index 645e5b6c728..efab4dd2274 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/registry.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/registry.cc @@ -103,6 +103,7 @@ class Registry : public NodeShader { insert_elementwise_op(Type::EXP); insert_elementwise_op(Type::HARD_SWISH); insert_elementwise_op(Type::LOG); + insert_elementwise_op(Type::NEG); insert_elementwise_op(Type::MAXIMUM); insert_elementwise_op(Type::MINIMUM); insert_elementwise_op(Type::POW); @@ -125,17 +126,20 @@ class Registry : public NodeShader { absl::Status GenerateCode(const GenerationContext& ctx, GeneratedCode* generated_code) const final { - std::vector errors; auto it = shaders_.find(ctx.op_type); - if (it != shaders_.end()) { - for (auto& shader : it->second) { - const auto status = shader->GenerateCode(ctx, generated_code); - if (status.ok()) return status; - errors.push_back(std::string(status.message())); - } + if (it == shaders_.end()) { + return absl::NotFoundError( + absl::StrCat("No shader implementation for ", ctx.op_type)); } - return absl::NotFoundError(absl::StrCat( - "Suitable node shader is not found: ", absl::StrJoin(errors, ", "))); + std::vector errors; + for (const auto& shader : it->second) { + const auto status = shader->GenerateCode(ctx, generated_code); + // Return the first suitable shader. + if (status.ok()) return absl::OkStatus(); + errors.push_back(std::string(status.message())); + } + return errors.empty() ? absl::OkStatus() + : absl::UnknownError(absl::StrJoin(errors, ", ")); } private: diff --git a/tensorflow/lite/delegates/gpu/gl/variable.h b/tensorflow/lite/delegates/gpu/gl/variable.h index 1c5bb26db62..5237481f96e 100644 --- a/tensorflow/lite/delegates/gpu/gl/variable.h +++ b/tensorflow/lite/delegates/gpu/gl/variable.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include #include "absl/types/variant.h" diff --git a/tensorflow/lite/delegates/gpu/gl_delegate.cc b/tensorflow/lite/delegates/gpu/gl_delegate.cc index 2f25539802a..8b049d483b1 100644 --- a/tensorflow/lite/delegates/gpu/gl_delegate.cc +++ b/tensorflow/lite/delegates/gpu/gl_delegate.cc @@ -35,7 +35,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/status.h" #include "tensorflow/lite/delegates/gpu/common/tensor.h" -#include "tensorflow/lite/delegates/gpu/common/transformations/general_transformations.h" +#include "tensorflow/lite/delegates/gpu/common/transformations/model_transformations.h" #include "tensorflow/lite/delegates/gpu/gl/api.h" #include "tensorflow/lite/delegates/gpu/gl/command_queue.h" #include "tensorflow/lite/delegates/gpu/gl/compiler.h" @@ -138,8 +138,8 @@ class Delegate { // Apply general transformations on the graph. NullTransformationReporter reporter; ModelTransformer transformer(&graph, &reporter); - if (!ApplyGeneralTransformations(&transformer)) { - return absl::InternalError("Graph general transformations failed"); + if (!ApplyModelTransformations(&transformer)) { + return absl::InternalError("Graph transformations failed"); } if (!env_) RETURN_IF_ERROR(EglEnvironment::NewEglEnvironment(&env_)); diff --git a/tensorflow/lite/delegates/gpu/java/src/main/java/org/tensorflow/lite/gpu/GpuDelegate.java b/tensorflow/lite/delegates/gpu/java/src/main/java/org/tensorflow/lite/gpu/GpuDelegate.java index 78cab0d2cbf..5eb6881be88 100644 --- a/tensorflow/lite/delegates/gpu/java/src/main/java/org/tensorflow/lite/gpu/GpuDelegate.java +++ b/tensorflow/lite/delegates/gpu/java/src/main/java/org/tensorflow/lite/gpu/GpuDelegate.java @@ -68,7 +68,7 @@ public class GpuDelegate implements Delegate, Closeable { * *